commit 5d2e7edf78731af5b9762fdfde729297b715d448 Author: wangjing Date: Wed Aug 13 19:46:19 2025 +0800 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..372c13e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ + diff --git a/_C.abi3.so b/_C.abi3.so new file mode 100755 index 0000000..5661dc2 Binary files /dev/null and b/_C.abi3.so differ diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..6232b65 --- /dev/null +++ b/__init__.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""vLLM: a high-throughput and memory-efficient inference engine for LLMs""" +# The version.py should be independent library, and we always import the +# version library first. Such assumption is critical for some customization. +from .version import __version__, __version_tuple__ # isort:skip + +# The environment variables override should be imported before any other +# modules to ensure that the environment variables are set before any +# other modules are imported. +import vllm.env_override # isort:skip # noqa: F401 + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.llm_engine import LLMEngine +from vllm.entrypoints.llm import LLM +from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.model_executor.models import ModelRegistry +from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput, + CompletionOutput, EmbeddingOutput, + EmbeddingRequestOutput, PoolingOutput, + PoolingRequestOutput, RequestOutput, ScoringOutput, + ScoringRequestOutput) +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams + +__all__ = [ + "__version__", + "__version_tuple__", + "LLM", + "ModelRegistry", + "PromptType", + "TextPrompt", + "TokensPrompt", + "SamplingParams", + "RequestOutput", + "CompletionOutput", + "PoolingOutput", + "PoolingRequestOutput", + "EmbeddingOutput", + "EmbeddingRequestOutput", + "ClassificationOutput", + "ClassificationRequestOutput", + "ScoringOutput", + "ScoringRequestOutput", + "LLMEngine", + "EngineArgs", + "AsyncLLMEngine", + "AsyncEngineArgs", + "initialize_ray_cluster", + "PoolingParams", +] diff --git a/_custom_ops.py b/_custom_ops.py new file mode 100644 index 0000000..6ae546d --- /dev/null +++ b/_custom_ops.py @@ -0,0 +1,1912 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import contextlib +import importlib +from typing import TYPE_CHECKING, Optional, Union + +import torch +import torch.library + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType + +logger = init_logger(__name__) + +if not current_platform.is_tpu() and not current_platform.is_hpu(): + try: + import vllm._C + except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) + +supports_moe_ops = False +with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 + supports_moe_ops = True + +if TYPE_CHECKING: + + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + torch.ops._C.paged_attention_v1( + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + torch.ops._C.paged_attention_v2( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) + + +def paged_attention_rocm( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + query_start_loc: Optional[torch.Tensor], + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + fp8_out_scale: Optional[torch.Tensor] = None, +) -> None: + torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, + scale, block_tables, seq_lens, + query_start_loc, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, + v_scale, fp8_out_scale) + + +def mla_decode_kvcache_cpu( + out: torch.Tensor, + query: torch.Tensor, + kv_cache: torch.Tensor, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, +) -> None: + torch.ops._C_cpu.mla_decode_kvcache(out, query, kv_cache, scale, + block_tables, seq_lens) + + +# merge attn states ops +def merge_attn_states(output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None) -> None: + torch.ops._C.merge_attn_states(output, output_lse, prefix_output, + prefix_lse, suffix_output, suffix_lse) + + +def convert_vertical_slash_indexes( + q_seqlens: torch.Tensor, # [BATCH, ] + kv_seqlens: torch.Tensor, # [BATCH, ] + vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + context_size: int, + block_size_M: int, + block_size_N: int, + causal: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = slash_indexes.size(0) + num_heads = slash_indexes.size(1) + nnz_slash = slash_indexes.size(2) + nnz_vertical = vertical_indexes.size(2) + num_rows = (context_size + block_size_M - 1) // block_size_M + + block_count = torch.zeros(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + block_offset = torch.zeros(batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_count = torch.zeros(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_index = torch.zeros(batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + + torch.ops._C.convert_vertical_slash_indexes( + block_count, block_offset, column_count, column_index, q_seqlens, + kv_seqlens, vertical_indexes, slash_indexes, context_size, + block_size_M, block_size_N, causal) + return block_count, block_offset, column_count, column_index + + +def convert_vertical_slash_indexes_mergehead( + q_seqlens: torch.Tensor, # [BATCH, ] + kv_seqlens: torch.Tensor, # [BATCH, ] + vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + # [N_HEADS] : different head use different number of indices + vertical_indices_count: torch.Tensor, + slash_indices_count: torch.Tensor, + context_size: int, + block_size_M: int, + block_size_N: int, + causal: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = slash_indexes.size(0) + num_heads = slash_indexes.size(1) + nnz_slash = slash_indexes.size(2) + nnz_vertical = vertical_indexes.size(2) + num_rows = (context_size + block_size_M - 1) // block_size_M + + block_count = torch.empty(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + block_offset = torch.empty(batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_count = torch.empty(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_index = torch.empty(batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + + torch.ops._C.convert_vertical_slash_indexes_mergehead( + block_count, block_offset, column_count, column_index, q_seqlens, + kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count, + slash_indices_count, context_size, block_size_M, block_size_N, causal) + return block_count, block_offset, column_count, column_index + + +# pos encoding ops +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor], + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + torch.ops._C.rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox) + + +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: Optional[torch.Tensor], head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) + +# layer norm ops +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input + input_contiguous = input.contiguous() + torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon) + + +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) + + +def apply_repetition_penalties_torch( + logits: torch.Tensor, prompt_mask: torch.Tensor, + output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( + 1, logits.size(1)) + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, + 1.0) + # If logits are positive, divide by penalty, otherwise multiply by penalty. + scaling = torch.where(logits > 0, 1.0 / penalties, penalties) + logits *= scaling + + +def apply_repetition_penalties_cuda( + logits: torch.Tensor, prompt_mask: torch.Tensor, + output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: + torch.ops._C.apply_repetition_penalties_(logits, prompt_mask, output_mask, + repetition_penalties) + + +def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor) -> None: + """Apply repetition penalties to logits in-place. + + Args: + logits: The logits tensor of shape [num_seqs, vocab_size]. + prompt_mask: A boolean tensor indicating which tokens appear in the prompt. + output_mask: A boolean tensor indicating which tokens appear in the output. + repetition_penalties: The repetition penalties of shape (num_seqs, ). + """ + if current_platform.is_cuda() and logits.is_contiguous(): + apply_repetition_penalties_cuda(logits, prompt_mask, output_mask, + repetition_penalties) + else: + apply_repetition_penalties_torch(logits, prompt_mask, output_mask, + repetition_penalties) + + +def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor) -> None: + """Advance a step on GPU for existing inputs for a multi-step runner""" + return torch.ops._C.advance_step_flashattn(num_seqs, num_queries, + block_size, input_tokens, + sampled_token_ids, + input_positions, seq_lens, + slot_mapping, block_tables) + + +def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + block_table_bound: torch.Tensor) -> None: + + return torch.ops._C.advance_step_flashinfer( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables, + paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, + block_table_bound) + + +# fused quant layer norm ops +def rms_norm_dynamic_per_token_quant( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, + scale_ub: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + output = torch.empty_like(input, dtype=quant_dtype) + scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + + torch.ops._C.rms_norm_dynamic_per_token_quant(output, input, weight, + scales, epsilon, scale_ub, + residual) + return output, scales + + +# quantization ops +# awq +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_dequantize_triton) + return awq_dequantize_triton(qweight, scales, zeros) + return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, + thx, thy) + + +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, + scales: torch.Tensor, split_k_iters: int, temp_space: torch.Tensor, + dtype_bf16: bool) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_gemm_triton) + return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters) + return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters, + temp_space, dtype_bf16) + +# awq to gptq 4bit conversion +def awq_to_gptq_4bit(qweight: torch.Tensor) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + return qweight + return torch.ops._C.awq_to_gptq_4bit(qweight) + +# gptq +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, + bit: int, group_size: int, perm_space: torch.Tensor, + temp_space: torch.Tensor, dtype_bf16: bool) -> torch.Tensor: + return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, use_exllama, bit, group_size, + perm_space, temp_space, dtype_bf16) + + +if hasattr(torch.ops._C, "gptq_gemm"): + + @register_fake("_C::gptq_gemm") + def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, + use_exllama: bool, bit: int, + group_size: int, perm_space: torch.Tensor, + temp_space: torch.Tensor, dtype_bf16: bool) -> torch.Tensor: + return torch.empty((a.size(0), b_q_weight.size(1)), + dtype=a.dtype, + device=a.device) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: + torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) + + +# marlin +def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, + size_n, size_k) + + +# marlin_24 +def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, b_q_type: ScalarType, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: + return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, + workspace, b_q_type.id, size_m, + size_n, size_k) + + +if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): + + @register_fake("_C::gptq_marlin_24_gemm") + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake("_C::gptq_marlin_gemm") + def _gptq_marlin_gemm_fake(a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake("_C::marlin_qqq_gemm") + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake("_C::marlin_gemm") + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake("_C::awq_dequantize") + def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: torch.SymInt, + thx: int, thy: int) -> torch.Tensor: + in_c = qweight.size(0) + qout_c = qweight.size(1) + out_c = qout_c * 8 + return torch.empty((in_c, out_c), + dtype=scales.dtype, + device=scales.device) + + @register_fake("_C::awq_gemm") + def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, + qzeros: torch.Tensor, scales: torch.Tensor, + split_k_iters: torch.SymInt, temp_space: torch.Tensor, + dtype_bf16: bool) -> torch.Tensor: + num_in_feats = input.size(0) + return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8), + dtype=input.dtype, + device=input.device).sum(0) + + @register_fake("_C::aqlm_gemm") + def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: list[int], + bias: Optional[torch.Tensor]) -> torch.Tensor: + out_features = codes.size(0) * codebooks.size(2) + flat_input = input.reshape((-1, input.size(-1))) + flat_output = torch.empty((flat_input.size(0), out_features), + dtype=input.dtype, + device=input.device) + + output_sizes = list(input.shape) + output_sizes.pop() + output_sizes.append(-1) + return flat_output.reshape(tuple(output_sizes)) + + @register_fake("_C::aqlm_dequant") + def _aqlm_dequant_fake( + codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: list[int]) -> torch.Tensor: + in_features = codes.size(1) * 8 + out_features = codes.size(0) + return torch.empty((out_features, in_features), + dtype=codebooks.dtype, + device=codebooks.device) + + @register_fake("_C::machete_mm") + def machete_mm_fake( + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + schedule: Optional[str] = None, + ) -> torch.Tensor: + m = a.size(0) + n = b_q.size(1) + return torch.empty((m, n), device=a.device, dtype=a.dtype) + + @register_fake("_C::machete_prepack_B") + def machete_prepack_B_fake( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: + return torch.empty_like(b_q_weight, + memory_format=torch.contiguous_format) + + +if hasattr(torch.ops._C, "allspark_w8a16_gemm"): + + @register_fake("_C::allspark_w8a16_gemm") + def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + n: torch.SymInt, group_size: torch.SymInt, + sm_count: torch.SymInt, + sm_version: torch.SymInt, + CUBLAS_M_THRESHOLD: torch.SymInt, + has_zp: bool, + n32k16_reorder: bool) -> torch.Tensor: + m = a.size(0) + return torch.empty((m, n), device=a.device, dtype=a.dtype) + + +if hasattr(torch.ops._C, "ggml_dequantize"): + + @register_fake("_C::ggml_dequantize") + def _ggml_dequantize_fake( + W: torch.Tensor, + quant_type: int, + m: torch.SymInt, + n: torch.SymInt, + dtype: Optional[torch.dtype] = None) -> torch.Tensor: + return torch.empty((m, n), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_vec_a8") + def _ggml_mul_mat_vec_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + return torch.empty((1, row), dtype=X.dtype, device=W.device) + + @register_fake("_C::ggml_mul_mat_a8") + def _ggml_mul_mat_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + batch = X.size(0) + return torch.empty((batch, row), dtype=X.dtype, device=W.device) + + @register_fake("_C::ggml_moe_a8") + def _ggml_moe_a8_fake( + X: torch.Tensor, + W: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + quant_type: int, + row: torch.SymInt, + top_k: torch.SymInt, + tokens: torch.SymInt, + ) -> torch.Tensor: + tokens = X.size(0) + return torch.empty((tokens * top_k, row), + dtype=torch.float16, + device=W.device) + + +if hasattr(torch.ops._C, "ggml_moe_a8_vec"): + + @register_fake("_C::ggml_moe_a8_vec") + def _ggml_moe_a8_vec_fake( + X: torch.Tensor, + W: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + quant_type: int, + row: torch.SymInt, + tokens: torch.SymInt, + ) -> torch.Tensor: + tokens = X.size(0) + return torch.empty((tokens * top_k, row), + dtype=X.dtype, + device=W.device) + + +# cutlass +def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: + # return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) + return False + +def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, alpha: torch.Tensor, + out_dtype: torch.dtype) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + m, n = a.shape[0], b.shape[0] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, + alpha) + return out + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + # return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + return True + + +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + # return torch.ops._C.cutlass_scaled_mm_supports_block_fp8( + # cuda_device_capability) + return True + +# Batch gemm in vllm, support w8a8 int8 quantization +def cutlass_scaled_batch_mm(a: torch.Tensor, b: torch.Tensor, + scale_a: torch.Tensor, scale_b: torch.Tensor, + out_dtype: torch.dtype, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert (a.shape[0] == b.shape[0] and a.shape[2] == b.shape[1]) + out = torch.empty((a.shape[0], a.shape[1], b.shape[2]), device = a.device, dtype = out_dtype) + torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + return out + +def cutlass_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + `cutlass_scaled_mm` implements a fused version of + `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` + where scale_a * a and scale_b * b are implemented using numpy-style + broadcasting. + + In order to support blockwise scaling like found in DeepSeek V3 we also + support extended "group" broadcast rules. We extend the numpy-style + broadcasting rules with the following rule: + "if the extent of a dimension in the source shape is between 1 and + corresponding extent in the target shape we repeat each element along + that dimension src_shape[dim] // target_shape[dim] times consecutively" + example if we have: + a = [[1, 2], and target_shape = (2, 4) + [3, 4]] + then we would expand a to: + a = [[1, 1, 2, 2], + [3, 3, 4, 4]] + currently we only support the case: + scale_a.shape * [1, 128] == a.shape + scale_b.shape * [128, 128] == b.shape + """ + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.shape[0] == b.shape[ + 1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + if current_platform.is_rocm() or not cutlass_compatible_b: + triton_scaled_mm_module = importlib.import_module( + "vllm.model_executor.layers.quantization.compressed_tensors." + "triton_scaled_mm") + triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.numel( + ) == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, + azp, bias) + return out + + +def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_sparse_scaled_mm_supported( + cuda_device_capability) + + +def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability) + +def cutlass_sparse_compress(a: torch.Tensor) \ + -> tuple[torch.Tensor, torch.Tensor]: + """ + Compresses a sparse matrix for use with Cutlass sparse operations. + + This function takes a dense tensor and compresses it into two components: + non-zero elements and metadata. The compressed representation is compatible + with Cutlass sparse kernels. + + Args: + a (torch.Tensor): + The input tensor to be compressed. Must have one of the following data types: + - `torch.int8` + - `torch.float8_e4m3fn` + - `torch.bfloat16` + - `torch.float16` + + Returns: + tuple[torch.Tensor, torch.Tensor]: + A tuple containing: + - `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`. + - `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation. + + Raises: + ValueError: If the compression operation fails. + + Notes: + - The `a_meta` tensor has a data type of `torch.uint8`. + - Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`). + - The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor. + - The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`. + """ + assert (a.dtype in [ + torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16 + ]) + assert (a.is_contiguous()) + + # a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4 + elemsPerMetaElem = 4 + assert (a.shape[1] % (2 * elemsPerMetaElem) == 0) + + return torch.ops._C.cutlass_sparse_compress(a) + + +def cutlass_scaled_sparse_mm( + a: torch.Tensor, + bt_nzs: torch.Tensor, + bt_meta: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Performs a scaled sparse matrix multiplication using Cutlass. + + Steps: + 1. Create a dense matrix `a` of shape (m, k) on the CUDA device: + `a = torch.randn((m, k), device='cuda')`. + + 2. Create a dense matrix `b` of shape (k, n) on the CUDA device: + `b = torch.randn((k, n), device='cuda')`. + + 3. Prune matrix `b` to 2:4 sparsity along the specified dimension: + `b = prune_to_2_4(b, dim=0)`. + + 4. Compress the transposed sparse matrix `b.t()`: + `bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`. + + 5. Perform sparse matrix multiplication using the compressed matrix, + applying scaling factors for `a` and `b`, and the output data type: + `out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`. + + Returns: + - The result of the scaled sparse matrix multiplication. + """ + assert (bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.shape[0] == bt_nzs.shape[0] \ + and bias.dtype == out_dtype + + m = a.shape[0] + n = bt_nzs.shape[0] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a, + scale_b, bias) + + return out + + +def get_cutlass_moe_mm_data(topk_ids: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + input_permutation: torch.Tensor, + output_permutation: torch.Tensor, + num_experts: int, + n: int, + k: int, + blockscale_offsets: Optional[torch.Tensor] = None): + """ + Prepare data necessary to perform CUTLASS grouped matrix multiplications + used in CUTLASS-based fused MoE. + + The function takes in topk_ids (token-expert mapping) and uses it to + compute: + - expert_offsets: Indices that mark at which token index each expert begins + its computation after the input is sorted with + input_permutation. The number of tokens computed with + expert E is expert_offsets[E + 1] - expert_offsets[E] + - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's + multiplication in two grouped MMs used in + the fused MoE operation. + - input_permutation: Permutation that must be used to shuffle the input + before executing the MMs. + - output_permutation: Permutation that must be used to shuffle the output + after executing the MMs. + - blockscale_offsets: Optional argument passed for fp4 moe. Indices that + mark at which block scale index each expert begins + its computation. The number of block scale rows + computed with expert E is blockscale_offsets[E + 1] - + blockscale_offsets[E] + """ + return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, + input_permutation, + output_permutation, + num_experts, n, k, + blockscale_offsets) + + +def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): + """ + Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor. + This is used in MoE to permute the input tensor before performing grouped matrix multiplications. + """ + num_tokens_permuted = dst2src_map.shape[0] + output_tensor = torch.empty((num_tokens_permuted, input_tensor.shape[1]), + device=input_tensor.device, + dtype=input_tensor.dtype) + torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor) + return output_tensor + + +def get_cutlass_pplx_moe_mm_data(expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + expert_num_tokens: torch.Tensor, + num_local_experts: int, padded_m: int, n: int, + k: int): + """ + Prepare data necessary to perform CUTLASS grouped matrix multiplications + used in CUTLASS-based fused MoE. + + The function takes in expert_num_tokens (token count per expert) and + non_zero_expert_idxs (consecutive indices of experts with non-zero token + counts) and uses them to compute: + - expert_offsets: Indices that mark at which token index each expert begins + its computation. + - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's + multiplication in two grouped MMs used in + the fused MoE operation. + """ + return torch.ops._C.get_cutlass_pplx_moe_mm_data( + expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens, + num_local_experts, padded_m, n, k) + + +def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, + b_tensors: torch.Tensor, a_scales: torch.Tensor, + b_scales: torch.Tensor, expert_offsets: torch.Tensor, + problem_sizes: torch.Tensor, a_strides: torch.Tensor, + b_strides: torch.Tensor, c_strides: torch.Tensor, + per_act_token: bool, per_out_ch: bool): + """ + A single grouped matrix multiplication used in CUTLASS-based fused MoE. + The function executes fp8-quantized OUT = AB matrix multiplication. + + - expert_offsets: Indices that mark at which token index each expert begins + its computation. The number of tokens computed with + expert E is expert_offsets[E + 1] - expert_offsets[E] + - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped + MMs used in the fused MoE operation. + - a/b/c_strides: The data strides passed to grouped matrix multiplication. + """ + return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, + a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch) + + +def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, + a_scales: torch.Tensor, b_scales: torch.Tensor, + alphas: torch.Tensor, problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, + out_dtype: torch.dtype, device: torch.device): + """ + An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs + the gemms for each combination based on the specified problem sizes. + + This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward. + - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized + input and expert weights. + - a_/b_scales: The blockscales in FP8-E4M3 precision + - expert_offsets/sf_offsets: Indices that mark at which token index + each expert begins its computation. The number of tokens + computed with expert E is expert_offsets[E + 1] - + expert_offsets[E] And the sf_size per expert is + sf_offset[E+1] - sf_offset[E] + - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped + MMs used in the fused MoE operation. + """ + m_topk = a_tensors.shape[0] + n = b_tensors.shape[1] + c_shape = (m_topk, n) + c = torch.empty(c_shape, device=device, dtype=out_dtype) + torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales, + b_scales, alphas, problem_sizes, + expert_offsets, sf_offsets) + return c.to(out_dtype) + + +# aqlm +def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: list[int], + bias: Optional[torch.Tensor]) -> torch.Tensor: + return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, + codebook_partition_sizes, bias) + + +def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: list[int]) -> torch.Tensor: + return torch.ops._C.aqlm_dequant(codes, codebooks, + codebook_partition_sizes) + + +# gptq_marlin +def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, + num_bits) + + +# gptq_marlin +def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e], + size_k, size_n, num_bits) + return output + + +def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, + size_n, num_bits) + return output + + +def gptq_marlin_gemm(a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, + use_atomic_add, use_fp32_reduce, + is_zp_float) + + +# machete +def machete_supported_schedules( + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], + group_zeros_type: Optional[torch.dtype] = None, + channel_scales_type: Optional[torch.dtype] = None, + token_scales_type: Optional[torch.dtype] = None, + out_type: Optional[torch.dtype] = None) -> list[str]: + return torch.ops._C.machete_supported_schedules( + a_type, b_type.id, group_scales_type, group_zeros_type, + channel_scales_type, token_scales_type, out_type) + + +def machete_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + schedule: Optional[str] = None) -> torch.Tensor: + return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales, + b_group_zeros, b_group_size, + b_channel_scales, a_token_scales, schedule) + + +def machete_prepack_B( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: + return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id, + group_scales_type) + + +if hasattr(torch.ops._C, "permute_cols"): + + @register_fake("_C::permute_cols") + def _permute_cols_fake(a: torch.Tensor, + perm: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) + + +def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: + return torch.ops._C.permute_cols(a, perm) + + +# fp4 +def scaled_fp4_quant( + input: torch.Tensor, + input_global_scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale. + + This function quantizes the last dimension of the given tensor `input`. For + every 16 consecutive elements, a single dynamically computed scaling factor + is shared. This scaling factor is quantized using the `input_global_scale` + and is stored in a swizzled layout (see + https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x). + + Args: + input: The input tensor to be quantized to FP4 + input_global_scale: A scalar scaling factor for the entire tensor. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every + two values are packed into a uint8 and float8_e4m3 scaling factors + in the sizzled layout. + """ + assert not current_platform.is_rocm() + assert input.ndim >= 1, ( + f'input.ndim needs to be >= 1, but got {input.ndim}.') + other_dims = 1 if input.ndim == 1 else -1 + input = input.reshape(other_dims, input.shape[-1]) + m, n = input.shape + block_size = 16 + device = input.device + + assert n % block_size == 0, ( + f'last dim has to be multiple of 16, but got {n}.') + assert input.dtype in (torch.float16, torch.bfloat16), ( + f'input.dtype needs to be fp16 or bf16 but got {input.dtype}.') + + # Two fp4 values will be packed into an uint8. + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + # We use the rounded values to store the swizzled values. Due to the + # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. + # So, we first pad the scales to multiples of 128 and 4. Then, the scales + # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: + # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(m, 128) + scale_n = n // block_size + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty((rounded_m, rounded_n // 4), + device=device, + dtype=torch.int32) + + torch.ops._C.scaled_fp4_quant(output, input, output_scale, + input_global_scale) + output_scale = output_scale.view(torch.float8_e4m3fn) + return output, output_scale + + +def scaled_fp4_experts_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, + topk: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale, for + packed MoE Inputs. + Args: + input_tensor: The input tensor to be quantized to FP4 + input_global_scale: A scalar scaling factor for the entire tensor. + expert_offsets: The expert offsets tensor + blockscale_offsets: The blockscale offsets tensor + Outputs: + output: The quantized tensor in FP4 + output_scales: The blockscale tensor in FP8-E4M3 + """ + assert not current_platform.is_rocm() + assert input_tensor.ndim == 2, ( + f'input.ndim needs to be == 2, but got {input_tensor.ndim}.') + + # Control the maximum number of tokens per expert supported by the + # NVFP4 MoE Expert Quantization. This is used to prevent the kernel + # from running out of memory. This value can also be increased to support + # larger models. + MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE + m_numtopk, k = input_tensor.shape + + assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), ( + f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" + f"{MAX_TOKENS_PER_EXPERT})" + f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use" + f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value.") + scales_k = k // 16 + padded_k = (scales_k + (4 - 1)) // 4 + + # output is uint8 and packed fp4 values + output = torch.empty(m_numtopk, + k // 2, + device=input_tensor.device, + dtype=torch.uint8) + output_scales = torch.empty(MAX_TOKENS_PER_EXPERT * topk, + padded_k, + dtype=torch.int32, + device=input_tensor.device) + torch.ops._C.scaled_fp4_experts_quant(output, output_scales, input_tensor, + input_global_scale, expert_offsets, + blockscale_offsets) + output_scales = output_scales.view(torch.float8_e4m3fn) + return output, output_scales + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert (input.ndim == 2) + shape: Union[tuple[int, int], torch.Size] = input.shape + # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = current_platform.fp8_dtype() + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), + device=input.device, + dtype=torch.float32) + torch.ops._C.dynamic_per_token_scaled_fp8_quant( + output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert (scale.numel() == 1 or num_token_padding is None) + torch.ops._C.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# gptq allspark +def allspark_repack_weight( + qweight: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor] = None, + has_zp: bool = False +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format + for Ampere W8A16 Fused Gemm kernel + + Args: + qweight: uint8 weight tensor, original k x n format. + scale: fp16/bf16 weight scale tensor, 1 x n format. + zero_point: fp16/bf16 weight zero_point tensor, 1 x n format. + Must be provided for asymmetric quantization. + has_zp: if use symmetric quantization, has_zp = False. + if use asymmetric quantization, has_zp = True. + + Returns: + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : + rearranged weight, scale, and optionally zero_point. + """ + K = qweight.shape[0] + N = qweight.shape[1] + N_32align = (N + 32 - 1) // 32 * 32 + + qweight_reorder = torch.empty((N_32align, K), + device=qweight.device, + dtype=qweight.dtype) + scale_reorder = torch.empty((1, N_32align), + device=scale.device, + dtype=scale.dtype) + zero_point_reorder = None + if has_zp: + assert zero_point is not None, ( + "zero_point must be provided for asymmetric quantization.") + zero_point_reorder = torch.empty((1, N_32align), + device=zero_point.device, + dtype=zero_point.dtype) + + torch.ops._C.rearrange_kn_weight_as_n32k16_order( + qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder, + zero_point_reorder, K, N, N_32align) + + return qweight_reorder, scale_reorder, zero_point_reorder + + +def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], n: int, + group_size: int, sm_count: int, sm_version: int, + CUBLAS_M_THRESHOLD: int, has_zp: bool, + n32k16_reorder: bool) -> torch.Tensor: + + return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros, + n, group_size, sm_count, + sm_version, CUBLAS_M_THRESHOLD, + has_zp, n32k16_reorder) + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp + is None), "azp must only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, + input_azp) + return output, input_scales, input_azp + +def scaled_int8_quant_mask( + input: torch.Tensor, + mask: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + mask: mask + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp + is None), "azp must only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_mask_quant(output, input, mask, input_scales, + input_azp) + return output, input_scales, input_azp + +def fused_silu_mul_dq_mask_quant( + input: torch.Tensor, + mask: torch.Tensor +) -> torch.Tensor: + """ + input shape [expert_num, token_num_padded, hidden_dim] + output shape [expert_num, token_num_padded, hidden_dim // 2], dtype bf16 + masked_m shape [expert_num], indicates valid tokens per expert + + implement silu_and_mul + quant + package + """ + out_stride = (input.shape[-1] // 4 + 257) // 256 * 256 + output = torch.empty((input.shape[0], input.shape[1], out_stride), device=input.device, dtype=input.dtype) + torch.ops._C.fused_silu_mul_dq_mask_quant_pack(output, input, mask) + return output + +# qqq ops +def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: + return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group, + workspace, size_m, size_n, size_k) + + +# gguf +def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int, + dtype: Optional[torch.dtype]) -> torch.Tensor: + return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype) + + +def ggml_mul_mat_vec_a8( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +) -> torch.Tensor: + return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row) + + +def ggml_mul_mat_a8( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +) -> torch.Tensor: + return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) + + +def ggml_moe_a8( + X: torch.Tensor, + W: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + quant_type: int, + row: int, + top_k: int, + tokens: int, +) -> torch.Tensor: + return torch.ops._C.ggml_moe_a8(X, W, sorted_token_ids, expert_ids, + num_tokens_post_padded, quant_type, row, + top_k, tokens) + + +def ggml_moe_a8_vec( + X: torch.Tensor, + W: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + quant_type: int, + row: torch.SymInt, + tokens: torch.SymInt, +) -> torch.Tensor: + return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row, + tokens) + + +def ggml_moe_get_block_size(quant_type: int) -> int: + return torch.ops._C.ggml_moe_get_block_size(quant_type) + + +# mamba +def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], + conv_states: Optional[torch.Tensor], + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + silu_activation: bool, pad_slot_id: int): + torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, + query_start_loc, cache_indices, + has_initial_state, silu_activation, + pad_slot_id) + + +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, bias_: Optional[torch.Tensor], + silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor], + pad_slot_id: int): + torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, + silu_activation, cache_seqlens, + conv_state_indices, pad_slot_id) + + +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + ssm_states: torch.Tensor, pad_slot_id: int): + torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, + delta_softplus, query_start_loc, + cache_indices, has_initial_state, + ssm_states, pad_slot_id) + + +# ROCm skinny gemms +def LLMM1(a: torch.Tensor, b: torch.Tensor, + rows_per_block: int) -> torch.Tensor: + return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) + + +def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor: + return torch.ops._rocm_C.wvSplitK(a, b, cu_count) + + +def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, + cu_count: int) -> torch.Tensor: + out = torch.empty((b.shape[0], a.shape[0]), + dtype=out_dtype, + device=b.device) + torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count) + return out + + +# moe +def moe_sum(input: torch.Tensor, output: torch.Tensor): + torch.ops._moe_C.moe_sum(input, output) + + +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + torch.ops._moe_C.sgl_moe_align_block_size(topk_ids, num_experts, + block_size, sorted_token_ids, + experts_ids, num_tokens_post_pad) + + +def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, + b_qweight: torch.Tensor, b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, top_k: int, + BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int, + bit: int) -> torch.Tensor: + if not current_platform.is_cuda(): + raise NotImplementedError( + "The optimized moe_wna16_gemm kernel is only " + "available on CUDA platforms") + torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales, + b_qzeros, topk_weights, sorted_token_ids, + experts_ids, num_tokens_post_pad, top_k, + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, + bit) + + +def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, + token_expert_indicies: torch.Tensor, + gating_output: torch.Tensor) -> None: + torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, + token_expert_indicies, gating_output) + +def fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, tileConfig: int) -> None: + torch.ops._moe_C.fused_moe_kernel(A, B, C, + topk_weights, topk_ids, + sorted_token_ids, expert_ids, + num_tokens_post_padded, mul_routed_weight, top_k, tileConfig) + +def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], + b_qweight: torch.Tensor, b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, moe_block_size: int, + top_k: int, mul_topk_weights: bool, is_ep: bool, + b_q_type: ScalarType, size_m: int, size_n: int, + size_k: int, is_k_full: bool, use_atomic_add: bool, + use_fp32_reduce: bool, + is_zp_float: bool) -> torch.Tensor: + return torch.ops._moe_C.moe_wna16_marlin_gemm( + input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx, + perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded, + topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep, + b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add, + use_fp32_reduce, is_zp_float) + + +if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): + + @register_fake("_moe_C::marlin_gemm_moe") + def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, + sorted_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, b_scales: torch.Tensor, + b_zero_points: torch.Tensor, g_idx: torch.Tensor, + perm: torch.Tensor, workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, size_k: torch.SymInt, + is_k_full: bool, num_experts: int, topk: int, + moe_block_size: int, replicate_input: bool, + apply_weights: bool) -> torch.Tensor: + return torch.empty((size_m, topk, size_n), + dtype=a.dtype, + device=a.device) + + @register_fake("_moe_C::moe_wna16_marlin_gemm") + def moe_wna16_marlin_gemm_fake(input: torch.Tensor, + output: Optional[torch.Tensor], + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, + moe_block_size: int, top_k: int, + mul_topk_weights: bool, is_ep: bool, + b_q_type: ScalarType, size_m: int, + size_n: int, size_k: int, is_k_full: bool, + use_atomic_add: bool, use_fp32_reduce: bool, + is_zp_float: bool) -> torch.Tensor: + return torch.empty((size_m * top_k, size_n), + dtype=input.dtype, + device=input.device) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, v_scale) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, + v_scale) + + +def concat_and_cache_mla( + kv_c: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + scale: torch.Tensor, +) -> None: + torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, + slot_mapping, kv_cache_dtype, + scale) + + +def copy_blocks(key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def copy_blocks_mla(kv_caches: list[torch.Tensor], + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping) + + +def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8(output: torch.Tensor, + input: torch.Tensor, + scale: float = 1.0, + kv_dtype: str = "fp8") -> None: + torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) + + +def gather_cache(src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + seq_starts: Optional[torch.Tensor] = None) -> None: + torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, + cu_seq_lens, batch_size, seq_starts) + + +def get_device_attribute(attribute: int, device: int) -> int: + return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) + + +def get_max_shared_memory_per_block_device_attribute(device: int) -> int: + # ruff: noqa: E501 + return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( + device) + + +# custom ar +def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor, + rank: int, fully_connected: bool) -> int: + return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank, + fully_connected) + + +def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, + reg_buffer_sz_bytes: int) -> None: + torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, + reg_buffer_sz_bytes) + + +def dispose(fa: int) -> None: + torch.ops._C_custom_ar.dispose(fa) + + +def meta_size() -> int: + return torch.ops._C_custom_ar.meta_size() + + +def register_buffer(fa: int, ipc_tensors: list[int]) -> None: + return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) + + +def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]: + return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) + + +def register_graph_buffers(fa: int, handles: list[list[int]], + offsets: list[list[int]]) -> None: + torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + + +def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]: + return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size) + + +def open_mem_handle(mem_handle: torch.Tensor): + return torch.ops._C_custom_ar.open_mem_handle(mem_handle) + + +def free_shared_buffer(ptr: int) -> None: + torch.ops._C_custom_ar.free_shared_buffer(ptr) + + +def get_flash_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return torch.ops._C.get_flash_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + return out, softmax_lse + + +def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, + q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, + seq_lens: torch.Tensor, page_table: torch.Tensor, + scale: float) -> torch.Tensor: + torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table, scale) + return out diff --git a/_ipex_ops.py b/_ipex_ops.py new file mode 100644 index 0000000..ae63e06 --- /dev/null +++ b/_ipex_ops.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +try: + import intel_extension_for_pytorch as ipex +except ImportError as e: + logger.warning("Import error msg: %s", e.msg) + + +class ipex_ops: + + @staticmethod + def _reshape_activation_tensor( + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + num = x.size(0) + d = x.size(1) // 2 + x = x.reshape(num, 2, d) + x1, x2 = torch.chunk(x, chunks=2, dim=1) + x1 = x1.reshape(num, d) + x2 = x2.reshape(num, d) + return x1, x2 + + @staticmethod + def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.silu_and_mul(x, out) + + @staticmethod + def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_and_mul(x, out) + + @staticmethod + def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_and_mul(x, out) + + @staticmethod + def gelu_fast(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) + + @staticmethod + def gelu_new(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) + + @staticmethod + def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_quick(x, out) + + @staticmethod + def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> None: + assert kv_cache_dtype == "auto" + num_heads = out.size(1) + num_queries_per_tokens = num_heads // num_kv_heads + ipex.llm.modules.PagedAttention.single_query_kv_attention( + out, + query.contiguous(), + key_cache.view_as(value_cache), + value_cache, + num_queries_per_tokens, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + + @staticmethod + def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> None: + assert kv_cache_dtype == "auto" + num_heads = out.size(1) + num_queries_per_tokens = num_heads // num_kv_heads + ipex.llm.modules.PagedAttention.single_query_kv_attention( + out, + query.contiguous(), + key_cache.view_as(value_cache), + value_cache, + num_queries_per_tokens, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + + @staticmethod + def rotary_embedding( + positions: torch.Tensor, # [batch_size, seq_len] + query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size] + key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size] + head_size: int, + cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim] + is_neox: bool, + ) -> None: + rot_dim = cos_sin_cache.size(1) + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim) + + @staticmethod + def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim, + cos_sin_cache_offsets) + + @staticmethod + def rms_norm(input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> torch.Tensor: + return ipex.llm.functional.rms_norm(input, weight, epsilon) + + @staticmethod + def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None, + epsilon, True) + input.copy_(tmp) + + @staticmethod + def varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + seqlen_q: torch.Tensor, + seqlen_k: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + pdropout: float, + softmax_scale: float, + zero_tensors: bool, + is_causal: bool, + return_softmax: bool, + gen_: torch.Generator, + window_size_left: float, + window_size_right: float, + logits_soft_cap: float, + ) -> None: + if ipex.__version__.endswith("cpu"): + if logits_soft_cap != 0.0: + raise ValueError("IPEX CPU does not support logits_soft_cap") + assert alibi_slopes is None + assert window_size_left < 0 and window_size_right < 0 + ipex.llm.functional.varlen_attention(query.contiguous(), + key.contiguous(), + value.contiguous(), out, + seqlen_q.int(), + seqlen_k.int(), max_seqlen_q, + max_seqlen_k, pdropout, + softmax_scale, zero_tensors, + is_causal, return_softmax, + gen_) + else: # XPU build + ipex.llm.functional.varlen_attention( + query.contiguous(), key.contiguous(), value.contiguous(), out, + seqlen_q.int(), seqlen_k.int(), alibi_slopes, max_seqlen_q, + max_seqlen_k, pdropout, softmax_scale, zero_tensors, is_causal, + return_softmax, gen_, window_size_left, window_size_right, + logits_soft_cap) + + @staticmethod + def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + ) -> None: + assert kv_cache_dtype == "auto" + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slot_mapping) + + @staticmethod + def copy_blocks(key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], + block_mapping: torch.Tensor) -> None: + torch.xpu.copy_blocks( # type: ignore + key_caches, + value_caches, + block_mapping, + ) + + @staticmethod + def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: torch.Tensor) -> None: + torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore diff --git a/_moe_C.abi3.so b/_moe_C.abi3.so new file mode 100755 index 0000000..71b28da Binary files /dev/null and b/_moe_C.abi3.so differ diff --git a/_release_info.txt b/_release_info.txt new file mode 100644 index 0000000..83dab04 --- /dev/null +++ b/_release_info.txt @@ -0,0 +1 @@ +2af1594b7dae2bd6ae835f20884d0847820aa27f \ No newline at end of file diff --git a/adapter_commons/__init__.py b/adapter_commons/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/adapter_commons/layers.py b/adapter_commons/layers.py new file mode 100644 index 0000000..9753a08 --- /dev/null +++ b/adapter_commons/layers.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + + +@dataclass +class AdapterMapping: + # Per every token in input_ids: + index_mapping: tuple[int, ...] + # Per sampled token: + prompt_mapping: tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) \ No newline at end of file diff --git a/adapter_commons/models.py b/adapter_commons/models.py new file mode 100644 index 0000000..7b68588 --- /dev/null +++ b/adapter_commons/models.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, TypeVar + +from torch import nn + +from vllm.logger import init_logger +from vllm.utils import LRUCache + +logger = init_logger(__name__) + + +class AdapterModel(ABC): + + def __init__(self, model_id=None): + self.id = model_id + + @abstractmethod + def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): + # Common initialization code + # Load weights or embeddings from local checkpoint + raise NotImplementedError("Subclasses must implement this method.") + + +T = TypeVar('T') + + +class AdapterLRUCache(LRUCache[int, T]): + + def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): + super().__init__(capacity) + self.deactivate_fn = deactivate_fn + + def _on_remove(self, key: int, value: Optional[T]): + logger.debug("Removing adapter int id: %d", key) + self.deactivate_fn(key) + return super()._on_remove(key, value) + + +class AdapterModelManager(ABC): + + def __init__( + self, + model: nn.Module, + ): + """Create a AdapterModelManager and adapter for a given model. + Args: + model: the model to be adapted. + """ + self.model: nn.Module = model + self._registered_adapters: dict[int, Any] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._active_adapters: dict[int, None] = {} + self.adapter_type = 'Adapter' + self._last_mapping = None + + def __len__(self) -> int: + return len(self._registered_adapters) + + @property + @abstractmethod + def adapter_slots(self) -> int: + raise NotImplementedError + + @property + @abstractmethod + def capacity(self) -> int: + raise NotImplementedError + + @abstractmethod + def activate_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def deactivate_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def add_adapter(self, adapter: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def set_adapter_mapping(self, mapping: Any) -> None: + raise NotImplementedError + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_all_adapters(self) -> None: + raise NotImplementedError + + @abstractmethod + def get_adapter(self, adapter_id: int) -> Optional[Any]: + raise NotImplementedError + + @abstractmethod + def list_adapters(self) -> dict[int, Any]: + raise NotImplementedError + + @abstractmethod + def pin_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError diff --git a/adapter_commons/request.py b/adapter_commons/request.py new file mode 100644 index 0000000..8135b54 --- /dev/null +++ b/adapter_commons/request.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod + + +class AdapterRequest(ABC): + """ + Base class for adapter requests. + """ + + @property + @abstractmethod + def adapter_id(self) -> int: + raise NotImplementedError + + def __post_init__(self) -> None: + if self.adapter_id < 1: + raise ValueError(f"id must be > 0, got {self.adapter_id}") + + def __eq__(self, value: object) -> bool: + return isinstance( + value, self.__class__) and self.adapter_id == value.adapter_id + + def __hash__(self) -> int: + return hash(self.adapter_id) diff --git a/adapter_commons/utils.py b/adapter_commons/utils.py new file mode 100644 index 0000000..a1a56b6 --- /dev/null +++ b/adapter_commons/utils.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Optional + + +## model functions +def deactivate_adapter(adapter_id: int, active_adapters: dict[int, None], + deactivate_func: Callable) -> bool: + if adapter_id in active_adapters: + deactivate_func(adapter_id) + active_adapters.pop(adapter_id) + return True + return False + + +def add_adapter(adapter: Any, registered_adapters: dict[int, Any], + capacity: int, add_func: Callable) -> bool: + if adapter.id not in registered_adapters: + if len(registered_adapters) >= capacity: + raise RuntimeError('No free adapter slots.') + add_func(adapter) + registered_adapters[adapter.id] = adapter + return True + return False + + +def set_adapter_mapping(mapping: Any, last_mapping: Any, + set_mapping_func: Callable) -> Any: + if last_mapping != mapping: + set_mapping_func(mapping) + return mapping + return last_mapping + + +def remove_adapter(adapter_id: int, registered_adapters: dict[int, Any], + deactivate_func: Callable) -> bool: + deactivate_func(adapter_id) + return bool(registered_adapters.pop(adapter_id, None)) + + +def list_adapters(registered_adapters: dict[int, Any]) -> dict[int, Any]: + return dict(registered_adapters) + + +def get_adapter(adapter_id: int, + registered_adapters: dict[int, Any]) -> Optional[Any]: + return registered_adapters.get(adapter_id) + + +## worker functions +def set_active_adapters_worker(requests: set[Any], mapping: Optional[Any], + apply_adapters_func, + set_adapter_mapping_func) -> None: + apply_adapters_func(requests) + set_adapter_mapping_func(mapping) + + +def add_adapter_worker(adapter_request: Any, list_adapters_func, + load_adapter_func, add_adapter_func, + activate_adapter_func) -> bool: + if adapter_request.adapter_id in list_adapters_func(): + return False + loaded_adapter = load_adapter_func(adapter_request) + loaded = add_adapter_func(loaded_adapter) + activate_adapter_func(loaded_adapter.id) + return loaded + + +def apply_adapters_worker(adapter_requests: set[Any], list_adapters_func, + adapter_slots: int, remove_adapter_func, + add_adapter_func) -> None: + models_that_exist = list_adapters_func() + models_map = { + adapter_request.adapter_id: adapter_request + for adapter_request in adapter_requests if adapter_request + } + if len(models_map) > adapter_slots: + raise RuntimeError( + f"Number of requested models ({len(models_map)}) is greater " + f"than the number of GPU model slots " + f"({adapter_slots}).") + new_models = set(models_map) + models_to_add = new_models - models_that_exist + models_to_remove = models_that_exist - new_models + for adapter_id in models_to_remove: + remove_adapter_func(adapter_id) + for adapter_id in models_to_add: + add_adapter_func(models_map[adapter_id]) + + +def list_adapters_worker(adapter_manager_list_adapters_func) -> set[int]: + return set(adapter_manager_list_adapters_func()) diff --git a/adapter_commons/worker_manager.py b/adapter_commons/worker_manager.py new file mode 100644 index 0000000..07e85d1 --- /dev/null +++ b/adapter_commons/worker_manager.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch + + +class AbstractWorkerManager(ABC): + + def __init__(self, device: torch.device): + self.device = device + + @property + @abstractmethod + def is_enabled(self) -> bool: + raise NotImplementedError + + @abstractmethod + def set_active_adapters(self, requests: set[Any], + mapping: Optional[Any]) -> None: + raise NotImplementedError + + @abstractmethod + def add_adapter(self, adapter_request: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_all_adapters(self) -> None: + raise NotImplementedError + + @abstractmethod + def list_adapters(self) -> set[int]: + raise NotImplementedError diff --git a/assets/__init__.py b/assets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/assets/audio.py b/assets/audio.py new file mode 100644 index 0000000..1c16230 --- /dev/null +++ b/assets/audio.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from pathlib import Path +from typing import Literal +from urllib.parse import urljoin + +import numpy.typing as npt + +from vllm.utils import PlaceholderModule + +from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + +ASSET_DIR = "multimodal_asset" + +AudioAssetName = Literal["winning_call", "mary_had_lamb"] + + +@dataclass(frozen=True) +class AudioAsset: + name: AudioAssetName + + @property + def filename(self) -> str: + return f"{self.name}.ogg" + + @property + def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: + audio_path = get_vllm_public_assets(filename=self.filename, + s3_prefix=ASSET_DIR) + return librosa.load(audio_path, sr=None) + + def get_local_path(self) -> Path: + return get_vllm_public_assets(filename=self.filename, + s3_prefix=ASSET_DIR) + + @property + def url(self) -> str: + return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg") diff --git a/assets/base.py b/assets/base.py new file mode 100644 index 0000000..31cde43 --- /dev/null +++ b/assets/base.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from functools import lru_cache +from pathlib import Path +from typing import Optional + +import vllm.envs as envs +from vllm.connections import global_http_connection + +VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com" + + +def get_cache_dir() -> Path: + """Get the path to the cache for storing downloaded assets.""" + path = Path(envs.VLLM_ASSETS_CACHE) + path.mkdir(parents=True, exist_ok=True) + + return path + + +@lru_cache +def get_vllm_public_assets(filename: str, + s3_prefix: Optional[str] = None) -> Path: + """ + Download an asset file from ``s3://vllm-public-assets`` + and return the path to the downloaded file. + """ + asset_directory = get_cache_dir() / "vllm_public_assets" + asset_directory.mkdir(parents=True, exist_ok=True) + + asset_path = asset_directory / filename + if not asset_path.exists(): + if s3_prefix is not None: + filename = s3_prefix + "/" + filename + global_http_connection.download_file( + f"{VLLM_S3_BUCKET_URL}/{filename}", + asset_path, + timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT) + + return asset_path diff --git a/assets/image.py b/assets/image.py new file mode 100644 index 0000000..c977242 --- /dev/null +++ b/assets/image.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Literal + +import torch +from PIL import Image + +from .base import get_vllm_public_assets + +VLM_IMAGES_DIR = "vision_model_images" + +ImageAssetName = Literal["stop_sign", "cherry_blossom"] + + +@dataclass(frozen=True) +class ImageAsset: + name: ImageAssetName + + @property + def pil_image(self) -> Image.Image: + image_path = get_vllm_public_assets(filename=f"{self.name}.jpg", + s3_prefix=VLM_IMAGES_DIR) + return Image.open(image_path) + + @property + def image_embeds(self) -> torch.Tensor: + """ + Image embeddings, only used for testing purposes with llava 1.5. + """ + image_path = get_vllm_public_assets(filename=f"{self.name}.pt", + s3_prefix=VLM_IMAGES_DIR) + return torch.load(image_path, map_location="cpu", weights_only=True) diff --git a/assets/video.py b/assets/video.py new file mode 100644 index 0000000..01834ae --- /dev/null +++ b/assets/video.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from functools import lru_cache +from typing import ClassVar, Literal, Optional + +import cv2 +import numpy as np +import numpy.typing as npt +from huggingface_hub import hf_hub_download +from PIL import Image + +from vllm.utils import PlaceholderModule + +from .base import get_cache_dir + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + + +@lru_cache +def download_video_asset(filename: str) -> str: + """ + Download and open an image from huggingface + repo: raushan-testing-hf/videos-test + """ + video_directory = get_cache_dir() / "video-example-data" + video_directory.mkdir(parents=True, exist_ok=True) + + video_path = video_directory / filename + video_path_str = str(video_path) + if not video_path.exists(): + video_path_str = hf_hub_download( + repo_id="raushan-testing-hf/videos-test", + filename=filename, + repo_type="dataset", + cache_dir=video_directory, + ) + return video_path_str + + +def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + raise ValueError(f"Could not open video file {path}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frames = [] + + num_frames = num_frames if num_frames > 0 else total_frames + frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) + for idx in range(total_frames): + ok = cap.grab() # next img + if not ok: + break + if idx in frame_indices: # only decompress needed + ret, frame = cap.retrieve() + if ret: + frames.append(frame) + + frames = np.stack(frames) + if len(frames) < num_frames: + raise ValueError(f"Could not read enough frames from video file {path}" + f" (expected {num_frames} frames, got {len(frames)})") + return frames + + +def video_to_pil_images_list(path: str, + num_frames: int = -1) -> list[Image.Image]: + frames = video_to_ndarrays(path, num_frames) + return [ + Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + for frame in frames + ] + + +VideoAssetName = Literal["baby_reading"] + + +@dataclass(frozen=True) +class VideoAsset: + name: VideoAssetName + num_frames: int = -1 + + _NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = { + "baby_reading": "sample_demo_1.mp4", + } + + @property + def filename(self) -> str: + return self._NAME_TO_FILE[self.name] + + @property + def pil_images(self) -> list[Image.Image]: + video_path = download_video_asset(self.filename) + ret = video_to_pil_images_list(video_path, self.num_frames) + return ret + + @property + def np_ndarrays(self) -> npt.NDArray: + video_path = download_video_asset(self.filename) + ret = video_to_ndarrays(video_path, self.num_frames) + return ret + + def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: + """ + Read audio data from the video asset, used in Qwen2.5-Omni examples. + + See also: examples/offline_inference/qwen2_5_omni/only_thinker.py + """ + video_path = download_video_asset(self.filename) + return librosa.load(video_path, sr=sampling_rate)[0] diff --git a/attention/__init__.py b/attention/__init__.py new file mode 100644 index 0000000..3440405 --- /dev/null +++ b/attention/__init__.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend + +__all__ = [ + "Attention", + "AttentionBackend", + "AttentionMetadata", + "AttentionType", + "AttentionMetadataBuilder", + "Attention", + "AttentionState", + "get_attn_backend", +] diff --git a/attention/backends/__init__.py b/attention/backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/attention/backends/abstract.py b/attention/backends/abstract.py new file mode 100644 index 0000000..0ba5a5b --- /dev/null +++ b/attention/backends/abstract.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass, fields +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, + Protocol, Set, Tuple, Type, TypeVar) + +import torch + +from vllm.multimodal import MultiModalPlaceholderMap + +if TYPE_CHECKING: + from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerInputBuilderBase) + + +class AttentionType: + """ + Attention type. + Use string to be compatible with `torch.compile`. + """ + # Decoder attention between previous layer Q/K/V + DECODER = "decoder" + # Encoder attention between previous layer Q/K/V for encoder-decoder + ENCODER = "encoder" + # Encoder attention between previous layer Q/K/V + ENCODER_ONLY = "encoder_only" + # Attention between dec. Q and enc. K/V for encoder-decoder + ENCODER_DECODER = "encoder_decoder" + + +class AttentionBackend(ABC): + """Abstract class for attention backends.""" + # For some attention backends, we allocate an output tensor before + # calling the custom op. When piecewise cudagraph is enabled, this + # makes sure the output tensor is allocated inside the cudagraph. + accept_output_buffer: bool = False + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> Type["AttentionImpl"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_state_cls() -> Type["AttentionState"]: + raise NotImplementedError + + @classmethod + def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": + return cls.get_metadata_cls()(*args, **kwargs) + + @staticmethod + @abstractmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + raise NotImplementedError + + @staticmethod + def get_kv_cache_stride_order() -> Tuple[int, ...]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + raise NotImplementedError + + def advance_step(self, model_input: "ModelRunnerInputBase", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, num_seqs: int, num_queries: int) -> None: + raise NotImplementedError + + +@dataclass +class AttentionMetadata: + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + + # The index maps that relate multi-modal embeddings to the corresponding + # placeholders. + # + # N.B. These aren't really related to attention and don't belong on this + # type -- this is just a temporary solution to make them available to + # `model_executable`. + multi_modal_placeholder_index_maps: Optional[Dict[ + str, MultiModalPlaceholderMap.IndexMap]] + + # Enable/disable KV scales calculation. This is so that we can disable the + # calculation until after prefill and cuda graph capture. + enable_kv_scales_calculation: bool + + @property + @abstractmethod + def prefill_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run prefill + attention.""" + pass + + @property + @abstractmethod + def decode_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run decode + attention.""" + pass + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + """Similar to dataclasses.asdict, but avoids deepcopying.""" + if skip_fields is None: + skip_fields = set() + # Note that if we add dataclasses as fields, they will need + # similar handling. + return { + field.name: getattr(self, field.name) + for field in fields(self) if field.name not in skip_fields + } + + +T = TypeVar("T", bound=AttentionMetadata) + + +class AttentionState(ABC, Generic[T]): + """Holds attention backend-specific objects reused during the + lifetime of the model runner.""" + + @abstractmethod + def __init__(self, runner: "ModelRunnerBase"): + ... + + @abstractmethod + @contextmanager + def graph_capture(self, max_batch_size: int): + """Context manager used when capturing CUDA graphs.""" + yield + + @abstractmethod + def graph_clone(self, batch_size: int) -> "AttentionState[T]": + """Clone attention state to save in CUDA graph metadata.""" + ... + + @abstractmethod + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: + """Get attention metadata for CUDA graph capture of batch_size.""" + ... + + @abstractmethod + def get_graph_input_buffers( + self, + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + """Get attention-specific input buffers for CUDA graph capture.""" + ... + + @abstractmethod + def prepare_graph_input_buffers( + self, + input_buffers: Dict[str, Any], + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> None: + """In-place modify input buffers dict for CUDA graph replay.""" + ... + + @abstractmethod + def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: + """Prepare state for forward pass.""" + ... + + +class AttentionMetadataBuilder(ABC, Generic[T]): + """Abstract class for attention metadata builders.""" + + @abstractmethod + def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: + """Create the builder, remember some configuration and parameters.""" + raise NotImplementedError + + @abstractmethod + def prepare(self) -> None: + """Prepare for one batch.""" + raise NotImplementedError + + @abstractmethod + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> T: + """Build attention metadata with on-device tensors.""" + raise NotImplementedError + + +class AttentionLayer(Protocol): + + _q_scale: torch.Tensor + _k_scale: torch.Tensor + _v_scale: torch.Tensor + _k_scale_float: float + _v_scale_float: float + _prob_scale: torch.Tensor + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + ... + + +class AttentionImpl(ABC, Generic[T]): + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +class MLAAttentionImpl(AttentionImpl[T], Generic[T]): + + @abstractmethod + def forward( + self, + layer: AttentionLayer, + hidden_states_or_cq: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: + return kv_cache_dtype != "auto" diff --git a/attention/backends/blocksparse_attn.py b/attention/backends/blocksparse_attn.py new file mode 100644 index 0000000..c166351 --- /dev/null +++ b/attention/backends/blocksparse_attn.py @@ -0,0 +1,461 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) +from vllm.attention.ops.blocksparse_attention.interface import ( + LocalStridedBlockSparseAttn, get_head_sliding_step) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + + +@dataclass +class BlocksparseParams: + max_seqlen: int + + # Num q heads per tensor-parallel rank/partition + num_heads: int # per TP partition + # Num kv heads per tensor-parallel rank/partition + num_kv_heads: int + + # block size used for blocksparse attention. + # This is the block_size used in `local_blocks`, `vert_stride`. + block_size: int + + # Number of blocks for local attention, i.e., number of + # local attended tokens / `sparse_block_size` + local_blocks: int + + # Attend to one block per every `vert_stride` blocks. + # Controlling the sparsity + vert_stride: int + """ + If to use the same vertical stride offset for all heads, + i.e., attend to the same block of tokens on all heads. + By default, it is False, i.e., attention on the non-local + blocks depends on the `head_idx`, that is on + blocks satisfying + `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0` + where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`, + `block_idx = position_id // sparse_block_size`. + See `..ops.blocksparse_attention.utils:get_sparse_attn_mask` + for more detail. + """ + homo_head: bool = False + + # If within a group, the kv offsets that each q attends is the same or no. + homo_head_group: bool = False + + # Decided by homo_head and homo_head group + head_sliding_step: int = field(init=False) + + # range of q heads to for a TP rank + active_head_range: Tuple = field(init=False) + + def __post_init__(self): + assert self.block_size > 0 + assert self.local_blocks >= 0 + assert self.vert_stride >= 1 + assert self.num_heads % self.num_kv_heads == 0 + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + total_heads = tp_size * self.num_heads + total_kv_heads = tp_size * self.num_kv_heads + + if self.homo_head: + self.head_sliding_step = 0 + elif self.homo_head_group: + head_sliding_step = get_head_sliding_step(total_kv_heads, + self.vert_stride) + # negative indicates sliding along kv heads, i.e., homo q group + self.head_sliding_step = -head_sliding_step + else: + self.head_sliding_step = get_head_sliding_step( + total_heads, self.vert_stride) + + self.active_head_range = ( + tp_rank * self.num_heads, + (tp_rank + 1) * self.num_heads, + ) + + +class BlocksparseFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "BLOCK_SPARSE_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: + return BlocksparseFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return BlocksparseFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: + return BlocksparseFlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class BlocksparseFlashAttentionMetadata(AttentionMetadata): + """A copy of Metadata for FlashAttentionBackend, + to avoid having to install flash_attn. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # Max number of query tokens for among request in the batch. + max_decode_query_len: Optional[int] = None + + _cached_prefill_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + + @property + def prefill_metadata( + self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class BlocksparseFlashAttentionMetadataBuilder( + CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]): + + _metadata_cls = BlocksparseFlashAttentionMetadata + + +class BlocksparseFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + assert blocksparse_params is not None + assert alibi_slopes is None, ValueError( + "Alibi not support for blocksparse flash attention.") + assert sliding_window is None, ValueError( + "sliding_window is invalid for blocksparse attention.") + assert logits_soft_cap is None, ValueError( + "logits_soft_cap is invalid for blocksparse attention.") + + if "num_heads" not in blocksparse_params: + blocksparse_params["num_heads"] = num_heads + if "num_kv_heads" not in blocksparse_params: + blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads + self.blocksparse_params = BlocksparseParams(**blocksparse_params) + self.kv_cache_dtype = kv_cache_dtype + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.alibi_slopes = alibi_slopes + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.local_blocks = self.blocksparse_params.local_blocks + self.vert_stride = self.blocksparse_params.vert_stride + self.sparse_block_size = self.blocksparse_params.block_size + self.head_sliding_step = self.blocksparse_params.head_sliding_step + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + total_num_heads = num_heads * self.tp_size + self.bs_attn = LocalStridedBlockSparseAttn( + total_num_heads, + self.blocksparse_params.max_seqlen, + self.blocksparse_params.local_blocks, + self.blocksparse_params.vert_stride, + self.blocksparse_params.block_size, + homo_head=self.blocksparse_params.homo_head, + active_head_range=self.blocksparse_params.active_head_range, + ) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "BlocksparseFlashAttentionImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: BlocksparseFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache.numel() > 0: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if prefill_meta := attn_metadata.prefill_metadata: + + # Prompt run. + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + + assert kv_cache.numel() == 0 \ + or prefill_meta.block_tables is None \ + or prefill_meta.block_tables.numel() == 0, \ + "Does not support prefix-enabled attention." + + output = self.bs_attn( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + sm_scale=self.scale, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + self.blocksparse_params.max_seqlen, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + tp_rank=self.tp_rank, + blocksparse_local_blocks=self.local_blocks, + blocksparse_vert_stride=self.vert_stride, + blocksparse_block_size=self.sparse_block_size, + blocksparse_head_sliding_step=self.head_sliding_step, + ) + + assert output is not None + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) diff --git a/attention/backends/configs/tp8_merge.json b/attention/backends/configs/tp8_merge.json new file mode 100644 index 0000000..773051b --- /dev/null +++ b/attention/backends/configs/tp8_merge.json @@ -0,0 +1,986 @@ +[ + { + "BS": 1, + "L": 2, + "num_kv_splits": 8, + "num_stages": 2 + }, + { + "BS": 1, + "L": 4, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 1, + "L": 8, + "num_kv_splits": 8, + "num_stages": 2 + }, + { + "BS": 1, + "L": 16, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 1, + "L": 32, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 1, + "L": 64, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 1, + "L": 128, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 1, + "L": 256, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 1, + "L": 512, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 1, + "L": 1024, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 1, + "L": 2048, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 1, + "L": 4096, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 1, + "L": 8192, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 1, + "L": 16384, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 1, + "L": 32768, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 1, + "L": 65536, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 2, + "L": 2, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 2, + "L": 4, + "num_kv_splits": 8, + "num_stages": 2 + }, + { + "BS": 2, + "L": 8, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 2, + "L": 16, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 2, + "L": 32, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 2, + "L": 64, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 2, + "L": 128, + "num_kv_splits": 8, + "num_stages": 2 + }, + { + "BS": 2, + "L": 256, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 2, + "L": 512, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 2, + "L": 1024, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 2, + "L": 2048, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 2, + "L": 4096, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 2, + "L": 8192, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 2, + "L": 16384, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 2, + "L": 32768, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 2, + "L": 65536, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 4, + "L": 2, + "num_kv_splits": 8, + "num_stages": 2 + }, + { + "BS": 4, + "L": 4, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 4, + "L": 8, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 4, + "L": 16, + "num_kv_splits": 8, + "num_stages": 2 + }, + { + "BS": 4, + "L": 32, + "num_kv_splits": 13, + "num_stages": 2 + }, + { + "BS": 4, + "L": 64, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 4, + "L": 128, + "num_kv_splits": 8, + "num_stages": 2 + }, + { + "BS": 4, + "L": 256, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 4, + "L": 512, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 4, + "L": 1024, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 4, + "L": 2048, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 4, + "L": 4096, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 4, + "L": 8192, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 4, + "L": 16384, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 4, + "L": 32768, + "num_kv_splits": 16, + "num_stages": 2 + }, + { + "BS": 8, + "L": 2, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 8, + "L": 4, + "num_kv_splits": 2, + "num_stages": 2 + }, + { + "BS": 8, + "L": 8, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 8, + "L": 16, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 8, + "L": 32, + "num_kv_splits": 8, + "num_stages": 2 + }, + { + "BS": 8, + "L": 64, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 8, + "L": 128, + "num_kv_splits": 8, + "num_stages": 2 + }, + { + "BS": 8, + "L": 256, + "num_kv_splits": 8, + "num_stages": 2 + }, + { + "BS": 8, + "L": 512, + "num_kv_splits": 13, + "num_stages": 2 + }, + { + "BS": 8, + "L": 1024, + "num_kv_splits": 13, + "num_stages": 2 + }, + { + "BS": 8, + "L": 2048, + "num_kv_splits": 13, + "num_stages": 2 + }, + { + "BS": 8, + "L": 4096, + "num_kv_splits": 13, + "num_stages": 2 + }, + { + "BS": 8, + "L": 8192, + "num_kv_splits": 13, + "num_stages": 2 + }, + { + "BS": 8, + "L": 16384, + "num_kv_splits": 13, + "num_stages": 2 + }, + { + "BS": 16, + "L": 2, + "num_kv_splits": 2, + "num_stages": 2 + }, + { + "BS": 16, + "L": 4, + "num_kv_splits": 2, + "num_stages": 2 + }, + { + "BS": 16, + "L": 8, + "num_kv_splits": 2, + "num_stages": 2 + }, + { + "BS": 16, + "L": 16, + "num_kv_splits": 2, + "num_stages": 2 + }, + { + "BS": 16, + "L": 32, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 16, + "L": 64, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 16, + "L": 128, + "num_kv_splits": 4, + "num_stages": 2 + }, + { + "BS": 16, + "L": 256, + "num_kv_splits": 6, + "num_stages": 2 + }, + { + "BS": 16, + "L": 512, + "num_kv_splits": 13, + "num_stages": 1 + }, + { + "BS": 16, + "L": 1024, + "num_kv_splits": 13, + "num_stages": 1 + }, + { + "BS": 16, + "L": 2048, + "num_kv_splits": 13, + "num_stages": 1 + }, + { + "BS": 16, + "L": 4096, + "num_kv_splits": 13, + "num_stages": 1 + }, + { + "BS": 16, + "L": 8192, + "num_kv_splits": 13, + "num_stages": 1 + }, + { + "BS": 32, + "L": 2, + "num_kv_splits": 2, + "num_stages": 2 + }, + { + "BS": 32, + "L": 4, + "num_kv_splits": 2, + "num_stages": 2 + }, + { + "BS": 32, + "L": 8, + "num_kv_splits": 2, + "num_stages": 2 + }, + { + "BS": 32, + "L": 16, + "num_kv_splits": 2, + "num_stages": 2 + }, + { + "BS": 32, + "L": 32, + "num_kv_splits": 2, + "num_stages": 2 + }, + { + "BS": 32, + "L": 64, + "num_kv_splits": 2, + "num_stages": 2 + }, + { + "BS": 32, + "L": 128, + "num_kv_splits": 3, + "num_stages": 2 + }, + { + "BS": 32, + "L": 256, + "num_kv_splits": 6, + "num_stages": 1 + }, + { + "BS": 32, + "L": 512, + "num_kv_splits": 6, + "num_stages": 1 + }, + { + "BS": 32, + "L": 1024, + "num_kv_splits": 6, + "num_stages": 1 + }, + { + "BS": 32, + "L": 2048, + "num_kv_splits": 6, + "num_stages": 1 + }, + { + "BS": 32, + "L": 4096, + "num_kv_splits": 13, + "num_stages": 1 + }, + { + "BS": 64, + "L": 2, + "num_kv_splits": 4, + "num_stages": 1 + }, + { + "BS": 64, + "L": 4, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 64, + "L": 8, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 64, + "L": 16, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 64, + "L": 32, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 64, + "L": 64, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 64, + "L": 128, + "num_kv_splits": 3, + "num_stages": 1 + }, + { + "BS": 64, + "L": 256, + "num_kv_splits": 3, + "num_stages": 1 + }, + { + "BS": 64, + "L": 512, + "num_kv_splits": 3, + "num_stages": 1 + }, + { + "BS": 64, + "L": 1024, + "num_kv_splits": 3, + "num_stages": 1 + }, + { + "BS": 64, + "L": 2048, + "num_kv_splits": 3, + "num_stages": 1 + }, + { + "BS": 64, + "L": 2048, + "num_kv_splits": 8, + "num_stages": 1 + }, + { + "BS": 96, + "L": 2, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 96, + "L": 4, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 96, + "L": 8, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 96, + "L": 16, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 96, + "L": 32, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 96, + "L": 64, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 96, + "L": 128, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 96, + "L": 256, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 96, + "L": 512, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 96, + "L": 1024, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 128, + "L": 2, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 128, + "L": 4, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 128, + "L": 8, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 128, + "L": 16, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 128, + "L": 32, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 128, + "L": 64, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 128, + "L": 128, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 128, + "L": 256, + "num_kv_splits": 3, + "num_stages": 1 + }, + { + "BS": 128, + "L": 512, + "num_kv_splits": 3, + "num_stages": 1 + }, + { + "BS": 128, + "L": 1024, + "num_kv_splits": 3, + "num_stages": 1 + }, + { + "BS": 256, + "L": 2, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 256, + "L": 4, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 256, + "L": 8, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 256, + "L": 16, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 256, + "L": 32, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 256, + "L": 64, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 256, + "L": 128, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 256, + "L": 256, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 256, + "L": 512, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 512, + "L": 2, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 512, + "L": 4, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 512, + "L": 8, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 512, + "L": 16, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 512, + "L": 32, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 512, + "L": 64, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 512, + "L": 128, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 512, + "L": 256, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1024, + "L": 2, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1024, + "L": 4, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1024, + "L": 8, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1024, + "L": 16, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1024, + "L": 32, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1024, + "L": 64, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1024, + "L": 128, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1536, + "L": 2, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1536, + "L": 4, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1536, + "L": 8, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1536, + "L": 16, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1536, + "L": 32, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 1536, + "L": 64, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 2048, + "L": 2, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 2048, + "L": 4, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 2048, + "L": 8, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 2048, + "L": 16, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 2048, + "L": 32, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 2048, + "L": 64, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 3072, + "L": 2, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 3072, + "L": 4, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 3072, + "L": 8, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 3072, + "L": 16, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 3072, + "L": 32, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 4096, + "L": 2, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 4096, + "L": 4, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 4096, + "L": 8, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 4096, + "L": 16, + "num_kv_splits": 2, + "num_stages": 1 + }, + { + "BS": 4096, + "L": 32, + "num_kv_splits": 2, + "num_stages": 1 + } +] \ No newline at end of file diff --git a/attention/backends/cpu_mla.py b/attention/backends/cpu_mla.py new file mode 100644 index 0000000..793cb87 --- /dev/null +++ b/attention/backends/cpu_mla.py @@ -0,0 +1,307 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +import vllm._custom_ops as ops +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState +from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder + + +class CPUMLABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "CPU_MLA" + + @staticmethod + def get_metadata_cls() -> Type["CPUMLAMetadata"]: + return CPUMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]: + return CPUMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["MLACommonState"]: + return MLACommonState + + @staticmethod + def get_impl_cls() -> Type["CPUMLAImpl"]: + return CPUMLAImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + ops.copy_blocks_mla(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +@dataclass +class CPUMLAMetadata(TorchSDPAMetadata): + # New for MLA + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor = None + + # required by MLACommonImpl + is_profile_run: bool = False + + +class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_builder = input_builder + assert not self.chunked_prefill, \ + "chunked prefill is currently not supported" + + def prepare(self): + self.input_data = self.input_builder.input_data + + def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size): + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # metadata for prefill + if input_data.num_prefills > 0: + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + + # for chunked-prefill + if self.chunked_prefill: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + prefill_block_tables = None + + else: + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + prefill_block_tables = None + + # metadata for decode + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor( + input_data.seq_lens[:input_data.num_prefills], + dtype=torch.int32, + device="cpu", + ) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + return CPUMLAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + prefill_query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + input_positions=torch.tensor([self.input_data.input_positions])) + + +class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "CPUMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "CPUMLAImpl") + + # states is implemented. + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "CPUMLAImpl with FP8 KV cache not yet supported") + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: CPUMLAMetadata, # type: ignore[override] + ) -> torch.Tensor: + + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + output = torch.empty_like(q) + ipex_ops.varlen_attention( + query=q, + key=k, + value=v_padded, + out=output, + seqlen_q=prefill_metadata.prefill_query_start_loc, + seqlen_k=prefill_metadata.prefill_query_start_loc, + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.max_query_len, + pdropout=0.0, + softmax_scale=self.scale, + zero_tensors=False, + is_causal=True, + return_softmax=False, + gen_=None, + logits_soft_cap=0.0, + window_size_left=-1, + window_size_right=-1, + alibi_slopes=None, + ) + + # remove padding + output = output.view(-1, self.num_heads, + q.shape[-1])[..., :v.shape[-1]] + return output.reshape(-1, self.num_heads * v.shape[-1]) + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: CPUMLAMetadata, # type: ignore[override] + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1) + o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank) + + # Run MQA + ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale, + decode_meta.block_tables, + decode_meta.seq_lens_tensor) + return self._v_up_proj(o) diff --git a/attention/backends/dual_chunk_flash_attn.py b/attention/backends/dual_chunk_flash_attn.py new file mode 100644 index 0000000..963bccd --- /dev/null +++ b/attention/backends/dual_chunk_flash_attn.py @@ -0,0 +1,1498 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with Dual chunk flash attention and sparse attention. +""" +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch +import torch.distributed +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.attention.backends.flash_attn import (FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, + FlashAttentionMetadataBuilder) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.utils import async_tensor_h2d +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache, sparse_attn_func) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + +logger = init_logger(__name__) + + +class DualChunkFlashAttentionBackend(FlashAttentionBackend): + + accept_output_buffer: bool = False + + @staticmethod + def get_name() -> str: + return "DUAL_CHUNK_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["DualChunkFlashAttentionImpl"]: + return DualChunkFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["DualChunkFlashAttentionMetadata"]: + return DualChunkFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["DualChunkFlashAttentionMetadataBuilder"]: + return DualChunkFlashAttentionMetadataBuilder + + +@dataclass +class DualChunkFlashAttentionMetadata(FlashAttentionMetadata): + # Block size of the paged kv cache. + block_size: int = 16 + + # Original max position embeddings. + original_max_position_embeddings: int = 0 + + # Chunk size + chunk_size: int = 8192 + + # Local size + local_size: int = 1024 + + # (batch_size,). The orig sequence length per sequence. + orig_seq_lens: Optional[List[int]] = None + + # orig_seq_lens stored as a tensor. + orig_seq_lens_tensor: Optional[torch.Tensor] = None + + # Length scaling factor + scaling_factor: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for intra attention. + seq_lens_intra: Optional[torch.Tensor] = None + + # Max sequence length for intra attention. + max_seq_len_intra: Optional[int] = None + + # (batch_size, num_blocks). Block table for intra attention. + block_tables_intra: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for succ attention. + seq_lens_succ: Optional[torch.Tensor] = None + + # Max sequence length for succ attention. + max_seq_len_succ: Optional[int] = None + + # (batch_size, num_blocks). Block table for succ attention. + block_tables_succ: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for inter attention. + seq_lens_inter: Optional[torch.Tensor] = None + + # Max sequence length for inter attention. + max_seq_len_inter: Optional[int] = None + + _cached_prefill_metadata: Optional[ + "DualChunkFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["DualChunkFlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + prefill_metadata = super().prefill_metadata + if prefill_metadata is None: + return None + + prefill_metadata = DualChunkFlashAttentionMetadata( + **prefill_metadata.asdict_zerocopy()) + + prefill_metadata.orig_seq_lens = ( + None if self.orig_seq_lens is None else + self.orig_seq_lens[:self.num_prefills]) + prefill_metadata.orig_seq_lens_tensor = ( + None if self.orig_seq_lens_tensor is None else + self.orig_seq_lens_tensor[:self.num_prefills]) + + if self.original_max_position_embeddings > 0: + assert prefill_metadata.orig_seq_lens_tensor is not None + prefill_metadata.scaling_factor = ( + 0.1 * torch.log(prefill_metadata.orig_seq_lens_tensor / + self.original_max_position_embeddings) + + 1.0).clip(min=1) + + self._cached_prefill_metadata = prefill_metadata + return prefill_metadata + + @property + def decode_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + + decode_metadata = super().decode_metadata + if decode_metadata is None: + return None + + decode_metadata = DualChunkFlashAttentionMetadata( + **decode_metadata.asdict_zerocopy()) + + decode_metadata.orig_seq_lens_tensor = ( + None if self.orig_seq_lens_tensor is None else + self.orig_seq_lens_tensor[self.num_prefills:]) + + assert decode_metadata.orig_seq_lens_tensor is not None + assert decode_metadata.block_tables is not None + + cache_seq_lens = decode_metadata.orig_seq_lens_tensor + chunk_len = self.chunk_size - self.local_size + chunk_num_curr = (cache_seq_lens - 1) // chunk_len + batch_size = decode_metadata.num_decode_tokens + + if self.original_max_position_embeddings > 0: + decode_metadata.scaling_factor = (0.1 * torch.log( + cache_seq_lens / self.original_max_position_embeddings) + + 1.0).clip(min=1) + + seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len + max_seq_len_intra = seq_lens_intra.max().item() + decode_metadata.seq_lens_intra = seq_lens_intra + decode_metadata.max_seq_len_intra = max_seq_len_intra + + block_tables_intra = torch.zeros( + batch_size, + (max_seq_len_intra - 1) // self.block_size + 1, + dtype=decode_metadata.block_tables.dtype, + device=decode_metadata.block_tables.device, + ) + for i in range(batch_size): + st = chunk_num_curr[i] * chunk_len // self.block_size + ed = min( + st + (max_seq_len_intra - 1) // self.block_size + 1, + (cache_seq_lens[i] - 1) // self.block_size + 1, + ) + block_tables_intra[i, :ed - + st] = decode_metadata.block_tables[i, st:ed] + decode_metadata.block_tables_intra = block_tables_intra + + seq_lens_succ = (chunk_num_curr - + (chunk_num_curr - 1).clip(min=0)) * chunk_len + max_seq_len_succ = seq_lens_succ.max().item() + decode_metadata.seq_lens_succ = seq_lens_succ + decode_metadata.max_seq_len_succ = max_seq_len_succ + if max_seq_len_succ: + block_tables_succ = torch.zeros( + batch_size, + (max_seq_len_succ - 1) // self.block_size + 1, + dtype=decode_metadata.block_tables.dtype, + device=decode_metadata.block_tables.device, + ) + for i in range(batch_size): + start = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len // + self.block_size) + end = min( + start + (max_seq_len_succ - 1) // self.block_size + 1, + (cache_seq_lens[i] - 1) // self.block_size + 1, + ) + block_tables_succ[ + i, :end - start] = decode_metadata.block_tables[i, + start:end] + decode_metadata.block_tables_succ = block_tables_succ + + seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len + max_seq_len_inter = seq_lens_inter.max().item() + decode_metadata.seq_lens_inter = seq_lens_inter + decode_metadata.max_seq_len_inter = max_seq_len_inter + + self._cached_decode_metadata = decode_metadata + return decode_metadata + + +class DualChunkFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder): + + def prepare(self): + super().prepare() + self.orig_seq_lens: List[int] = [] + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + super()._add_seq_group(inter_data, chunked_prefill_enabled, + prefix_cache_hit) + for prompt_len, seq_len in zip(inter_data.prompt_lens, + inter_data.seq_lens): + self.orig_seq_lens.append(max(prompt_len, seq_len)) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + attn_metadata = super().build(seq_lens, query_lens, + cuda_graph_pad_size, batch_size) + attn_metadata = DualChunkFlashAttentionMetadata( + **attn_metadata.asdict_zerocopy()) + + device = self.runner.device + attn_metadata.orig_seq_lens = self.orig_seq_lens + attn_metadata.orig_seq_lens_tensor = async_tensor_h2d( + self.orig_seq_lens, torch.int, device, self.runner.pin_memory) + + attn_metadata.block_size = self.runner.block_size + dual_chunk_attn_config = getattr(self.runner.model_config.hf_config, + "dual_chunk_attention_config", {}) + attn_metadata.original_max_position_embeddings = \ + dual_chunk_attn_config.get("original_max_position_embeddings", 0) + attn_metadata.chunk_size = dual_chunk_attn_config.get( + "chunk_size", 8192) + attn_metadata.local_size = dual_chunk_attn_config.get( + "local_size", 1024) + + return attn_metadata + + +class DualChunkFlashAttentionImpl(FlashAttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + The prompts might have different lengths, while the generation tokens + always have length 1. + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + layer_idx: int = -1, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = ( + DualChunkFlashAttentionBackend.get_supported_head_sizes()) + + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + assert dual_chunk_attention_config is not None + self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192) + self.local_size = dual_chunk_attention_config.get("local_size", 1024) + self.original_max_position_embeddings = dual_chunk_attention_config.get( + "original_max_position_embeddings", 0) + self.sparse_attention_config = dual_chunk_attention_config.get( + "sparse_attention_config", None) + if not self.sparse_attention_config: + logger.warning_once("Sparse attention will not be enabled as " + "sparse attention config is not provided.") + self.sparse_attention_enabled = dual_chunk_attention_config.get( + "sparse_attention_enabled", self.sparse_attention_config + is not None) + self.sparse_attention_threshold = dual_chunk_attention_config.get( + "sparse_attention_threshold", 32768) + self.sparse_attention_last_q = dual_chunk_attention_config.get( + "sparse_attention_last_q", 64) + self.layer_idx = layer_idx + self.dual_chunk_attention_config = dual_chunk_attention_config + + if self.sparse_attention_config: + self.sparse_attention_config = { + int(i): j + for i, j in self.sparse_attention_config[ + self.layer_idx].items() + } + start_head = self.num_heads * get_tensor_model_parallel_rank() + end_head = start_head + self.num_heads + self.sparse_attention_config = [ + self.sparse_attention_config[i] + for i in range(start_head, end_head) + ] + + if self.sparse_attention_enabled: + self.arange = torch.arange(self.sparse_attention_last_q, + device="cuda") + self.last_q_mask = (self.arange[None, None, :, None] + >= self.arange[None, None, None, :]) + + def forward( # type: ignore + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: DualChunkFlashAttentionMetadata, + ) -> torch.Tensor: + """Forward pass with DualChunkFlashAttention. + Args: + query: shape = [num_tokens, num_heads * head_size] + query_succ: shape = [num_tokens, num_heads * head_size] + query_inter: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + ( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ) = torch.split(query, query.shape[-1] // 5, dim=-1) + + assert ( + query_succ is not None and query_inter is not None + ), "query_succ and query_inter are required in Dual Chunk Attention." + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + query_succ = query_succ.view(-1, self.num_heads, self.head_size) + query_inter = query_inter.view(-1, self.num_heads, self.head_size) + query_succ_critical = query_succ_critical.view(-1, self.num_heads, + self.head_size) + query_inter_critical = query_inter_critical.view( + -1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.original_max_position_embeddings > 0: + if prefill_meta := attn_metadata.prefill_metadata: + assert prefill_meta.scaling_factor is not None + assert prefill_meta.query_start_loc is not None + assert prefill_meta.orig_seq_lens is not None + current_start = 0 + query_start_loc_cpu = prefill_meta.query_start_loc.cpu() + for i in range(len(prefill_meta.orig_seq_lens)): + current_end = (current_start + + (query_start_loc_cpu[i + 1] - + query_start_loc_cpu[i]).item()) + key[current_start:current_end].mul_( + prefill_meta.scaling_factor[i]) + current_start = current_end + assert current_end <= attn_metadata.num_prefill_tokens + if decode_meta := attn_metadata.decode_metadata: + assert decode_meta.scaling_factor is not None + scaling_factor = decode_meta.scaling_factor + key[attn_metadata.num_prefill_tokens:].mul_( + scaling_factor.unsqueeze(-1).unsqueeze(-1)) + + if kv_cache is not None and kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + output = torch.empty_like(query) + + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + decode_query_succ = query_succ[num_prefill_tokens:] + decode_query_inter = query_inter[num_prefill_tokens:] + + # QKV for prefill. + query = query[:num_prefill_tokens] + query_succ = query_succ[:num_prefill_tokens] + query_inter = query_inter[:num_prefill_tokens] + query_succ_critical = query_succ_critical[:num_prefill_tokens] + query_inter_critical = query_inter_critical[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache is None or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention, called during the profiling run. + out = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + else: + # prefix-enabled attention + assert prefill_meta.seq_lens is not None + assert prefill_meta.orig_seq_lens is not None + output[:num_prefill_tokens] = ( + self._dual_chunk_flash_attn_prefill( + q=query, + q_succ=query_succ, + q_inter=query_inter, + q_succ_critical=query_succ_critical, + q_inter_critical=query_inter_critical, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + orig_seq_lens=prefill_meta.orig_seq_lens, + scaling_factor=prefill_meta.scaling_factor, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1), + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, + chunk_size=self.chunk_size, + local_size=self.local_size, + )) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output[num_prefill_tokens:] = ( + self._dual_chunk_flash_attn_decoding( + decode_query.unsqueeze(1), + decode_query_succ.unsqueeze(1), + decode_query_inter.unsqueeze(1), + key_cache, + value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + chunk_size=self.chunk_size, + local_size=self.local_size, + original_max_position_embeddings=self. + original_max_position_embeddings, + decode_meta=decode_meta, + ).squeeze(1)) + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + + def _dual_chunk_flash_attn_prefill( + self, + q, + q_succ, + q_inter, + q_succ_critical, + q_inter_critical, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + orig_seq_lens: List[int], + scaling_factor: torch.Tensor, + softmax_scale: float, + causal: Optional[bool] = True, + window_size: Tuple[int, int] = (-1, -1), + alibi_slopes: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + chunk_size: int = 8192, + local_size: int = 1024, + ): + if alibi_slopes is not None: + raise ValueError( + "Dual Chunk Attention does not support alibi_slopes") + if not causal: + raise ValueError( + "Dual Chunk Attention does not support causal=False") + if window_size != (-1, -1): + raise ValueError( + "Dual Chunk Attention does not support window_size") + + cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist() + cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist() + all_outputs = [] + + for i in range(0, len(cu_seqlens_q_cpu) - 1): + qs = cu_seqlens_q_cpu[i] + qe = cu_seqlens_q_cpu[i:i + 2][-1] + ks = cu_seqlens_k_cpu[i] + ke = cu_seqlens_k_cpu[i:i + 2][-1] + + current_q = q[qs:qe] + current_q_succ = q_succ[qs:qe] + current_q_inter = q_inter[qs:qe] + current_q_succ_critical = q_succ_critical[qs:qe] + current_q_inter_critical = q_inter_critical[qs:qe] + + if block_table is None: + current_k = k[ks:ke] + current_v = v[ks:ke] + current_block_table = None + current_orig_seq_len = orig_seq_lens[i] + else: + current_block_table = block_table[i] + current_orig_seq_len = orig_seq_lens[i] + current_k = k + current_v = v + sparse_attn_enabled = (self.sparse_attention_enabled + and current_orig_seq_len + > self.sparse_attention_threshold) + + if current_q.shape[0] == 0: + continue + + if current_k.shape[0] == 0: + all_outputs.append( + torch.zeros( + (current_q.shape[0], current_q.shape[1], v.shape[2]), + device=q.device, + dtype=q.dtype, + )) + continue + + current_output = torch.empty_like(current_q) + group_size = int(current_q.size(-2) / current_k.size(-2)) + + if sparse_attn_enabled: + num_device_q_heads = current_q.size(-2) + heads_vertical_size = torch.empty(size=(num_device_q_heads, ), + dtype=torch.int32) + heads_slash_size = torch.empty(size=(num_device_q_heads, ), + dtype=torch.int32) + for head_id in range(current_q.size(-2)): + ( + ty, + vertical_size, + slash_size, + _, + ) = self.sparse_attention_config[head_id] + assert ty == "vertical_and_slash", "only support slash mode" + + if vertical_size == 30: + vertical_size += 100 + heads_vertical_size[head_id] = vertical_size + heads_slash_size[head_id] = slash_size + + current_output = self._dual_chunk_flash_attn_prefill_func( + current_q, # allheads + current_q_succ, + current_q_inter, + current_q_succ_critical, + current_q_inter_critical, + current_k, + current_v, + current_block_table, + softmax_scale, + chunk_size, + local_size, + scaling_factor[i].item(), + ke - ks, + sparse_attn_enabled=sparse_attn_enabled, + heads_vertical_size=heads_vertical_size, + heads_slash_size=heads_slash_size, + group_size=group_size) + else: + for head_id in range(current_q.size(-2)): + # (seq_len, num_heads, head_size) + current_q_head = current_q[:, head_id, :].unsqueeze(1) + current_q_succ_head = \ + current_q_succ[:, head_id, :].unsqueeze(1) + current_q_inter_head = \ + current_q_inter[:, head_id, :].unsqueeze(1) + current_q_succ_head_critical = \ + current_q_succ_critical[:, head_id, :].unsqueeze(1) + current_q_inter_head_critical = \ + current_q_inter_critical[:, head_id, :].unsqueeze(1) + if block_table is not None: + current_k_head = current_k[..., head_id // + group_size, :].unsqueeze(2) + current_v_head = current_v[..., head_id // + group_size, :].unsqueeze(2) + + else: + current_k_head = current_k[:, head_id, :].unsqueeze(1) + current_v_head = current_v[:, head_id, :].unsqueeze(1) + + current_out = self._dual_chunk_flash_attn_prefill_func( + current_q_head, + current_q_succ_head, + current_q_inter_head, + current_q_succ_head_critical, + current_q_inter_head_critical, + current_k_head, + current_v_head, + current_block_table, + softmax_scale, + chunk_size, + local_size, + scaling_factor[i].item(), + ke - ks, + sparse_attn_enabled=sparse_attn_enabled, + ) + current_output[:, head_id:head_id + 1, :] = current_out + all_outputs.append(current_output) + return torch.cat(all_outputs, dim=0) + + def _dual_chunk_flash_attn_prefill_func( + self, + q, + q_succ, + q_inter, + q_succ_critical, + q_inter_critical, + k, + v, + block_table, + softmax_scale: float, + chunk_size: int, + local_size: int, + scaling_factor: float, + k_length: int, + sparse_attn_enabled: Optional[bool] = True, + heads_vertical_size=None, + heads_slash_size=None, + group_size=None, + ): + flash_results = [] + chunk_len = chunk_size - local_size + + if block_table is not None: + block_size = v.shape[1] + if chunk_len % block_size != 0: + raise ValueError("chunk_len must be divisible by block_size.") + else: + block_size = 1 + + if self.original_max_position_embeddings > 0: + softmax_scale = softmax_scale * scaling_factor + + begin = k_length - q.shape[0] + while begin < k_length: + flash_per_chunk = [] + + prev_chunk_end_pos = (begin // chunk_len) * chunk_len + next_chunk_end_pos = prev_chunk_end_pos + chunk_len + end = min(next_chunk_end_pos, k_length) + qbegin = begin - (k_length - q.shape[0]) + qend = end - (k_length - q.shape[0]) + + qk_chunks = [] + q_states_intra = q[qbegin:qend] + # choose critical token + if block_table is not None: + block_tables_intra = _get_block(block_table, block_size, + prev_chunk_end_pos, end) + k_states_intra = k[block_tables_intra].view( + -1, *k.shape[-2:])[:(end - prev_chunk_end_pos)] + v_states_intra = v[block_tables_intra].view( + -1, *v.shape[-2:])[:(end - prev_chunk_end_pos)] + else: + block_tables_intra = None + k_states_intra = k[prev_chunk_end_pos:end] + v_states_intra = v[prev_chunk_end_pos:end] + + if sparse_attn_enabled: + last_q_size = min(qend - qbegin, self.sparse_attention_last_q) + _, num_device_k_heads, head_dim = k_states_intra.shape + k_states_intra = (k_states_intra.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, head_dim)) + v_states_intra = (v_states_intra.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, head_dim)) + qk_chunks.append( + (q_states_intra.transpose(0, 1)[:, -last_q_size:] * + softmax_scale) @ k_states_intra.permute(1, 2, 0)) + + if prev_chunk_end_pos - chunk_len >= 0: + q_states_succ = q_succ[qbegin:qend] + q_states_succ_critical = q_succ_critical[qbegin:qend] + if block_table is not None: + block_tables_succ = _get_block( + block_table, block_size, + prev_chunk_end_pos - chunk_len, prev_chunk_end_pos) + k_states_succ = k[block_tables_succ].view( + -1, *k.shape[-2:])[:chunk_len] + v_states_succ = v[block_tables_succ].view( + -1, *v.shape[-2:])[:chunk_len] + else: + k_states_succ = k[prev_chunk_end_pos - + chunk_len:prev_chunk_end_pos] + v_states_succ = v[prev_chunk_end_pos - + chunk_len:prev_chunk_end_pos] + + if sparse_attn_enabled: + k_states_succ = (k_states_succ.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + v_states_succ = (v_states_succ.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + qk_chunks.append((q_states_succ_critical.transpose( + 0, 1)[:, -last_q_size:] * softmax_scale) + @ k_states_succ.permute(1, 2, 0)) + + if prev_chunk_end_pos - chunk_len * 2 >= 0: + q_states_inter = q_inter[qbegin:qend] + q_states_inter_critical = q_inter_critical[qbegin:qend] + if block_table is not None: + block_tables_inter = _get_block( + block_table, block_size, 0, + prev_chunk_end_pos - chunk_len) + k_states_inter = k[block_tables_inter].view( + -1, *k.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] + v_states_inter = v[block_tables_inter].view( + -1, *v.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] + else: + k_states_inter = k[:prev_chunk_end_pos - chunk_len] + v_states_inter = v[:prev_chunk_end_pos - chunk_len] + + if sparse_attn_enabled: + k_states_inter = (k_states_inter.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + v_states_inter = (v_states_inter.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + qk_chunks.append((q_states_inter_critical.transpose( + 0, 1)[:, -last_q_size:] * softmax_scale) + @ k_states_inter.permute(1, 2, 0)) + + if sparse_attn_enabled: + reversed_qk = qk_chunks[::-1] + qk = torch.cat(reversed_qk, dim=-1) + + qk[:, :, -last_q_size:] = torch.where( + self.last_q_mask[..., -last_q_size:, + -last_q_size:].to(qk.device), + qk[:, :, -last_q_size:], -torch.inf) + qk = F.softmax(qk, dim=-1, dtype=torch.float32) + + vertical = qk.sum(-2, keepdim=True) + vertical[..., :30] = torch.inf + + # Avoid sorting by using the min/max ints to fill the indexer + # buffers. + int32_max = torch.iinfo(torch.int32).max + int32_min = torch.iinfo(torch.int32).min + n_heads = qk.size()[0] + max_slash_topk = torch.max(heads_slash_size).item() + max_vertical_topk = torch.max(heads_vertical_size).item() + # store each head's slash topk, vertical topk + vertical = vertical.reshape((n_heads, -1)) + # prevent out of range when prompt size < max_vertical_topk + max_vertical_topk = min(vertical.shape[-1], max_vertical_topk) + vertical_topk_buffer = torch.topk(vertical, max_vertical_topk, + -1).indices + slash_topk_buffer = torch.empty(size=(n_heads, max_slash_topk), + dtype=torch.int64, + device=qk.device) + for head_i in range(n_heads): + # (nqheads=1, lastq, k_len) + head_score = qk[head_i:head_i + 1, :, :] + slash_scores = _sum_all_diagonal_matrix(head_score) + if head_score.size(1) != 1: + # drop right up corner + slash_scores = slash_scores[..., :-last_q_size + 1] + slash_scores[..., -100:] = torch.inf + + head_slash_size = heads_slash_size[head_i] + head_slash_size = min(head_slash_size, vertical.size(-1)) + slash_topk = torch.topk(slash_scores, head_slash_size, + -1).indices + #(nheads, max_topk) + slash_topk_buffer[head_i, :head_slash_size] = slash_topk + + # reset heads topk + heads_slash_size[head_i] = head_slash_size + heads_vertical_size[head_i] = min( + heads_vertical_size[head_i], max_vertical_topk) + + # store + vertical_buffer = torch.full((n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + succ_vertical_buffer = torch.full((n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + succ_slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + inter_vertical_buffer = torch.full( + (n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + inter_slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + + vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + succ_vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + succ_slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + inter_vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + inter_slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + + for head_i in range(n_heads): + vertical_topk = vertical_topk_buffer[ + head_i, :heads_vertical_size[head_i]] + # intra + intra_vertical_indices = vertical_topk[ + vertical_topk >= + prev_chunk_end_pos] - prev_chunk_end_pos + if intra_vertical_indices.nelement() == 0: + intra_vertical_indices = torch.cat([ + intra_vertical_indices, + torch.arange(0, + k_states_intra.size(0), + max(1, + k_states_intra.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + slash_topk = slash_topk_buffer[ + head_i, :heads_slash_size[head_i]] + intra_slash_indices = ( + (qk.size(-1) - 1) - + slash_topk[slash_topk >= prev_chunk_end_pos]) + # fill buffer + v_count = intra_vertical_indices.nelement() + s_count = intra_slash_indices.nelement() + vertical_size_buffer[head_i] = v_count + slash_sizes_buffer[head_i] = s_count + vertical_buffer[head_i, :v_count].copy_( + intra_vertical_indices) + slash_buffer[head_i, :s_count].copy_(intra_slash_indices) + # succ + if prev_chunk_end_pos - chunk_len >= 0: + succ_vertical_indices = vertical_topk[ + (vertical_topk < prev_chunk_end_pos) + & (vertical_topk >= prev_chunk_end_pos - + chunk_len)] - (prev_chunk_end_pos - chunk_len) + # TODO: support no vertical + if succ_vertical_indices.nelement() == 0: + succ_vertical_indices = torch.cat([ + succ_vertical_indices, + torch.arange( + 0, + k_states_succ.size(0), + max(1, + k_states_succ.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + succ_slash_indices = ( + (prev_chunk_end_pos + (qend - qbegin) - 1) - + slash_topk[((slash_topk >= + (prev_chunk_end_pos - chunk_len)) & + (slash_topk < (prev_chunk_end_pos + + (qend - qbegin))))]) + if succ_slash_indices.nelement() == 0: + succ_slash_indices = torch.cat([ + succ_slash_indices, + torch.arange( + 0, + k_states_succ.size(0), + max(1, + k_states_succ.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + # fill buffer + v_count = succ_vertical_indices.nelement() + s_count = succ_slash_indices.nelement() + succ_vertical_size_buffer[head_i] = v_count + succ_slash_sizes_buffer[head_i] = s_count + succ_vertical_buffer[head_i, :v_count].copy_( + succ_vertical_indices) + succ_slash_buffer[head_i, :s_count].copy_( + succ_slash_indices) + + if prev_chunk_end_pos - 2 * chunk_len >= 0: + inter_vertical_indices = vertical_topk[ + vertical_topk < prev_chunk_end_pos - chunk_len] + + if inter_vertical_indices.nelement() == 0: + inter_vertical_indices = torch.cat([ + inter_vertical_indices, + torch.arange( + 0, + k_states_inter.size(0), + max(1, + k_states_inter.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + inter_slash_indices = ( + (prev_chunk_end_pos - chunk_len + + (qend - qbegin) - 1) - + slash_topk[slash_topk < (prev_chunk_end_pos - + chunk_len + + (qend - qbegin))]) + if inter_slash_indices.nelement() == 0: + inter_slash_indices = torch.cat([ + inter_slash_indices, + torch.arange( + 0, + k_states_inter.size(0), + max(1, + k_states_inter.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + # fill buffer + v_count = inter_vertical_indices.nelement() + s_count = inter_slash_indices.nelement() + inter_vertical_size_buffer[head_i] = v_count + inter_slash_sizes_buffer[head_i] = s_count + inter_vertical_buffer[head_i, :v_count].copy_( + inter_vertical_indices) + inter_slash_buffer[head_i, :s_count].copy_( + inter_slash_indices) + else: + intra_vertical_indices, intra_slash_indices = None, None + succ_vertical_indices, succ_slash_indices = None, None + inter_vertical_indices, inter_slash_indices = None, None + + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_intra, + k_states_intra, + v_states_intra, + softmax_scale=softmax_scale, + causal=True, + block_table=block_table, + stage="intra", + vertical_indices=vertical_buffer, + slash_indices=slash_buffer, + vertical_indices_count=vertical_size_buffer, + slash_indices_count=slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_intra, + k_states_intra, + v_states_intra, + softmax_scale=softmax_scale, + causal=True, + block_table=block_table, + stage="intra", + vertical_indices=intra_vertical_indices, + slash_indices=intra_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + if prev_chunk_end_pos - chunk_len >= 0: + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_succ, + k_states_succ, + v_states_succ, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="succ", + vertical_indices=succ_vertical_buffer, + slash_indices=succ_slash_buffer, + vertical_indices_count=succ_vertical_size_buffer, + slash_indices_count=succ_slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_succ, + k_states_succ, + v_states_succ, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="succ", + vertical_indices=succ_vertical_indices, + slash_indices=succ_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + if prev_chunk_end_pos - chunk_len * 2 >= 0: + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_inter, + k_states_inter, + v_states_inter, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="inter", + vertical_indices=inter_vertical_buffer, + slash_indices=inter_slash_buffer, + vertical_indices_count=inter_vertical_size_buffer, + slash_indices_count=inter_slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_inter, + k_states_inter, + v_states_inter, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="inter", + vertical_indices=inter_vertical_indices, + slash_indices=inter_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + flash_results.append(flash_per_chunk) + begin = end + + attn_output = self._merge_attn_outputs(flash_results) + del flash_results + return attn_output + + def _do_flash_attn( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + softmax_scale: float, + causal: bool = True, + block_table: torch.Tensor = None, + max_seqlen_k: Optional[int] = None, + stage: str = "intra", + vertical_indices: Optional[torch.Tensor] = None, + slash_indices: Optional[torch.Tensor] = None, + vertical_indices_count: Optional[torch.Tensor] = None, + slash_indices_count: Optional[torch.Tensor] = None, + mergehead_softmax_scale: Optional[float] = None, + sparse_attn_enabled: Optional[bool] = False, + ): + if max_seqlen_k is None: + max_seqlen_k = key_states.shape[0] + + q_len = query_states.shape[0] + q_heads = query_states.shape[1] + h_dim = query_states.shape[-1] + + if sparse_attn_enabled: + assert slash_indices is not None + if stage == "intra": + assert causal + else: + assert not causal + + query_states = query_states.unsqueeze(0).transpose(1, 2) + key_states = key_states.unsqueeze(0).transpose(1, 2) + value_states = value_states.unsqueeze(0).transpose(1, 2) + + q = query_states + k = key_states + v = value_states + + if (vertical_indices_count is not None and \ + slash_indices_count is not None): + assert mergehead_softmax_scale is not None + + res, s_lse = _vertical_slash_sparse_attention( + q, + k, + v, + vertical_indices, + slash_indices, + mergehead_softmax_scale, + causal=causal, + stage=stage, + vertical_indices_count=vertical_indices_count, + slash_indices_count=slash_indices_count) + res = res.view(q_heads, q_len, + h_dim).transpose(0, 1) # (qlen,nhead,h_dim) + s_lse = s_lse.view( + q_heads, q_len, + 1).squeeze(-1).unsqueeze(0).float() # (1, nhead,qlen) + else: + res, s_lse = _vertical_slash_sparse_attention(q, + k, + v, + vertical_indices, + slash_indices, + softmax_scale, + causal=causal, + stage=stage) + res = res.view(q_len, q_heads, h_dim) + s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float() + return res, s_lse + + output, softmax_lse = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + softmax_scale=softmax_scale, + cu_seqlens_q=torch.tensor([0, query_states.shape[0]], + dtype=torch.int32, + device=query_states.device), + max_seqlen_q=query_states.shape[0], + cu_seqlens_k=torch.tensor([0, max_seqlen_k], + dtype=torch.int32, + device=query_states.device), + max_seqlen_k=max_seqlen_k, + causal=causal, + block_table=block_table.unsqueeze(0), + return_softmax_lse=True, + ) + softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, + 2).float() + return output, softmax_lse + + def _merge_attn_outputs( + self, + flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]], + return_lse: Optional[bool] = False, + ) -> torch.Tensor: + attn_outputs_all = [] + logits_all = [] + + for flash_per_chunk in flash_results: + if len(flash_per_chunk) == 1: + attn_outputs_all.append(flash_per_chunk[0][0]) + if return_lse: + logits_all.append(flash_per_chunk[0][1]) + continue + + attn_outputs = torch.stack([ + flash_attn_output[0] for flash_attn_output in flash_per_chunk + ]) + logits = torch.stack([ + flash_attn_output[1] for flash_attn_output in flash_per_chunk + ]) + logits = logits.to(torch.float32) + + if return_lse: + max_val = torch.max(logits, dim=0).values + diff = torch.abs(logits[0] - logits[1]) + log_sum_exp = max_val + torch.log1p(torch.exp(-diff)) + logits_all.append(log_sum_exp) + + max_logits = torch.max(logits, dim=0).values + stable_logits = logits - max_logits.unsqueeze(0) + lse_s = torch.exp(stable_logits).detach() + lse_sum = torch.sum(lse_s, dim=0) + lse_s /= lse_sum + attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1) + attn_outputs_all.append(attn_outputs.sum(dim=0)) + + if return_lse: + return (torch.cat(attn_outputs_all, + dim=0), torch.cat(logits_all, dim=-1)) + else: + return torch.cat(attn_outputs_all, dim=0) + + def _dual_chunk_flash_attn_decoding( + self, + query: torch.Tensor, + query_succ: torch.Tensor, + query_inter: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + causal: bool, + alibi_slopes: Optional[torch.Tensor], + chunk_size: int, + local_size: int, + original_max_position_embeddings: int, + decode_meta: DualChunkFlashAttentionMetadata, + ): + if not causal: + raise ValueError( + "Dual Chunk Attention does not support causal=False") + + block_size = value_cache.shape[1] + chunk_len = chunk_size - local_size + if chunk_len % block_size != 0: + raise ValueError("chunk_len must be divisible by block_size.") + if original_max_position_embeddings > 0: + assert decode_meta.scaling_factor is not None + scaling_factor = decode_meta.scaling_factor + query = (query * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype + ) # possible for numerical issue, need to fused in the kernel + query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype) + query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype) + outputs_list = [] + softmax_lses_list = [] + + # intra-attention + intra_output, intra_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query, + key_cache, + value_cache, + decode_meta.block_tables_intra, + decode_meta.seq_lens_intra, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(intra_output) + softmax_lses_list.append(intra_softmax_lse) + + # succ-attention + if decode_meta.max_seq_len_succ: + succ_output, succ_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query_succ, + key_cache, + value_cache, + decode_meta.block_tables_succ, + decode_meta.seq_lens_succ, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(succ_output) + softmax_lses_list.append(succ_softmax_lse) + + # inter-attention + if decode_meta.max_seq_len_inter: + inter_output, inter_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query_inter, + key_cache, + value_cache, + block_table[:, :decode_meta.max_seq_len_inter], + decode_meta.seq_lens_inter, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(inter_output) + softmax_lses_list.append(inter_softmax_lse) + outputs = torch.stack(outputs_list, dim=0) + del outputs_list + softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32) + del softmax_lses_list + max_logits = torch.max(softmax_lses, dim=0).values + stable_logits = softmax_lses - max_logits.unsqueeze(0) + lse_s = torch.exp(stable_logits).detach() + lse_sum = torch.sum(lse_s, dim=0) + lse_s /= lse_sum + outputs *= lse_s.unsqueeze(-1).transpose(2, 3) + return outputs.sum(0) + + def _dual_chunk_flash_attn_decoding_with_exp_sums( + self, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + ): + out, softmax_lse = flash_attn_with_kvcache( + q=query, + k_cache=key_cache, + v_cache=value_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + softmax_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + return_softmax_lse=True, + ) + mask = (cache_seqlens == 0) + out[mask] = 0 + softmax_lse[mask] = -float("inf") + return out, softmax_lse + + +def _vertical_slash_sparse_attention( + query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] + key: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] + value: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] + v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + softmax_scale: float, + causal: bool = True, + stage: str = "intra", + block_size_M: int = 64, + block_size_N: int = 64, + vertical_indices_count: torch.Tensor = None, # [N_HEADS,] + slash_indices_count: torch.Tensor = None, +): + if stage == "intra": + assert causal + else: + assert not causal + + batch_size, num_heads, context_size, head_dim = query.shape + _, _, kv_seq_len, _ = key.shape + + if head_dim not in [16, 32, 64, 128, 256, 512]: + target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim + query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) + key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) + value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) + + v_idx = v_idx.to(torch.int32).reshape( + (batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape( + (batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] + q_seqlens = torch.tensor([context_size], + dtype=torch.int32, + device=query.device) + kv_seqlens = torch.tensor([kv_seq_len], + dtype=torch.int32, + device=query.device) + + if vertical_indices_count is not None and slash_indices_count is not None: + ( + block_count, + block_offset, + column_count, + column_index, + ) = ops.convert_vertical_slash_indexes_mergehead( + q_seqlens, kv_seqlens, v_idx, s_idx, vertical_indices_count, + slash_indices_count, context_size, block_size_M, block_size_N, + causal) + else: + ( + block_count, + block_offset, + column_count, + column_index, + ) = ops.convert_vertical_slash_indexes(q_seqlens, kv_seqlens, v_idx, + s_idx, context_size, + block_size_M, block_size_N, + causal) + + q = query.transpose(1, 2).contiguous() + k = key.transpose(1, 2).contiguous() + v = value.transpose(1, 2).contiguous() + out, lse = sparse_attn_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + causal=causal, + softmax_scale=softmax_scale, + return_softmax_lse=True, + ) + out = out.transpose(1, 2).contiguous() + softmax_lse = lse.reshape(*lse.shape, 1) + return (out[..., :context_size, :head_dim], + softmax_lse[..., :context_size, :]) + + +def _sum_all_diagonal_matrix(mat: torch.tensor): + h, n, m = mat.shape + # Zero matrix used for padding + zero_mat = torch.zeros((h, n, n), device=mat.device) + # pads the matrix on left and right + mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) + # Change the strides + mat_strided = mat_padded.as_strided((1, n, n + m), + (n * (2 * n + m), 2 * n + m + 1, 1)) + # Sums the resulting matrix's columns + sum_diags = torch.sum(mat_strided, 1) + return sum_diags[:, 1:] # drop left bottom corner + + +def _get_block(block_table: torch.Tensor, block_size: int, begin: int, + end: int): + begin_block = begin // block_size + end_block = (end - 1) // block_size + 1 + return block_table[begin_block:end_block] diff --git a/attention/backends/flash_attn.py b/attention/backends/flash_attn.py new file mode 100644 index 0000000..a46c0a8 --- /dev/null +++ b/attention/backends/flash_attn.py @@ -0,0 +1,1005 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlashAttention.""" +from collections import defaultdict +from dataclasses import dataclass +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm import _custom_ops as ops +# yapf conflicts with isort for this block +# yapf: disable +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +# yapf: enable +from vllm.attention.backends.utils import ( + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, + get_seq_len_block_table_args, is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, is_block_tables_empty) +# from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, +# get_flash_attn_version) +from vllm.logger import init_logger +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) +def flash_attn_supports_fp8() -> bool: + return False + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +logger = init_logger(__name__) + + +class FlashAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class FlashAttentionMetadata(AttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + + use_cuda_graph: bool + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + + @property + def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + + self._cached_prefill_metadata = FlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + + self._cached_decode_metadata = FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class FlashAttentionMetadataBuilder( + AttentionMetadataBuilder[FlashAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return FlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class FlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + if use_irope: + logger.warning( + "Using irope in V0 is not supported yet, it will fall back " + "to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + # self.vllm_flash_attn_version = get_flash_attn_version( + # requires_alibi=self.alibi_slopes is not None) + if is_quantized_kv_cache(self.kv_cache_dtype) and ( + not self.kv_cache_dtype.startswith("fp8") + or not flash_attn_supports_fp8()): + raise NotImplementedError( + f"FlashAttention does not support {self.kv_cache_dtype} " + "kv-cache on this device " + f"(FA supports fp8 = {flash_attn_supports_fp8()}).") + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + output: shape = [num_tokens, num_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + NOTE: It in-place updates the output tensor. + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + + # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. + if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: + assert ( + layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( + "key/v_scale is only supported in FlashAttention 3 with " + "base dtype bfloat16") + + attn_type = self.attn_type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes + logits_soft_cap: Optional[float] = self.logits_soft_cap + fp8_attention = kv_cache_dtype.startswith("fp8") + + if fp8_attention and not flash_attn_supports_fp8(): + raise NotImplementedError( + "FlashAttention does not support FP8 kv-cache on this device.") + + if kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + # We skip updating the KV cache under two conditions: + # a. When the Attention Type is ENCODER. In this phase, we compute + # only the encoder attention without updating the cache. + # b. When both Key and Value are None. This occurs during + # cross-attention computation in the decoding phase, where the + # KV cache is already populated with the cross-attention + # tensor. Thus, we skip cache updates during this time. + if (attn_type != AttentionType.ENCODER) and (key is not None) and ( + value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), # type: ignore[union-attr] + kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if fp8_attention: + kv_cache = kv_cache.view(torch.float8_e4m3fn) + key_cache = key_cache.view(torch.float8_e4m3fn) + value_cache = value_cache.view(torch.float8_e4m3fn) + + if fp8_attention: + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + decode_query = query[num_prefill_query_tokens:] + decode_output = output[num_prefill_query_tokens:] + # QKV for prefill. + query = query[:num_prefill_query_tokens] + prefill_output = output[:num_prefill_query_tokens] + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ + _get_query_key_seq_metadata(prefill_meta, True, attn_type) + + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + + if fp8_attention: + num_kv_tokens, num_kv_heads, head_size = key.shape + + key, _ = ops.scaled_fp8_quant( + key.reshape((num_kv_tokens, + num_kv_heads * head_size)).contiguous(), + layer._k_scale) + key = key.reshape((num_kv_tokens, num_kv_heads, head_size)) + + value, _ = ops.scaled_fp8_quant( + value.reshape((num_kv_tokens, + num_kv_heads * head_size)).contiguous(), + layer._v_scale) + value = value.reshape( + (num_kv_tokens, num_kv_heads, head_size)) + + descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1]) + output[:num_prefill_query_tokens] = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, + softmax_scale=softmax_scale, + causal=_get_causal_option(attn_type), + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + # out=prefill_output, + # fa_version=self.vllm_flash_attn_version, + # q_descale=layer._q_scale.expand(descale_shape), + # k_descale=layer._k_scale.expand(descale_shape), + # v_descale=layer._v_scale.expand(descale_shape), + ) + else: + # prefix-enabled attention + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support prefix caching") + assert prefill_meta.seq_lens is not None + assert prefill_meta.query_start_loc is not None + max_seq_len = max(prefill_meta.seq_lens) + descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, + key.shape[1]) + output[:num_prefill_query_tokens] = flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + # out=prefill_output, + # fa_version=self.vllm_flash_attn_version, + # q_descale=layer._q_scale.expand(descale_shape), + # k_descale=layer._k_scale.expand(descale_shape), + # v_descale=layer._v_scale.expand(descale_shape), + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + # Use flash_attn_varlen_func kernel for speculative decoding + # because different queries might have different lengths. + + assert decode_meta.max_decode_query_len is not None + # use only for actual varlen decoding + if decode_meta.max_decode_query_len > 1: + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support max_decode_query_len > 1" + ) + assert decode_meta.query_start_loc is not None + descale_shape = (decode_meta.query_start_loc.shape[0] - 1, + key.shape[1]) + output[num_prefill_query_tokens:] = flash_attn_varlen_func( + q=decode_query, + k=key_cache, + v=value_cache, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.max_decode_query_len, + cu_seqlens_k=decode_meta.seq_start_loc, + max_seqlen_k=decode_meta.max_decode_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + block_table=decode_meta.block_tables, + # out=decode_output, + # fa_version=self.vllm_flash_attn_version, + # q_descale=layer._q_scale.expand(descale_shape), + # k_descale=layer._k_scale.expand(descale_shape), + # v_descale=layer._v_scale.expand(descale_shape), + ) + else: + # Use flash_attn_with_kvcache for normal decoding. + ( + seq_lens_arg, + _, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2]) + output[num_prefill_query_tokens:] = flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + block_table=block_tables_arg, + cache_seqlens=seq_lens_arg, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + # out=decode_output.unsqueeze(1), + # fa_version=self.vllm_flash_attn_version, + # q_descale=layer._q_scale.expand(descale_shape), + # k_descale=layer._k_scale.expand(descale_shape), + # v_descale=layer._v_scale.expand(descale_shape), + ).squeeze(1) + return output + + +def _get_query_key_seq_metadata( + attn_metadata, + is_prompt: bool, + attn_type: str, +) -> tuple: + """ + Returns sequence metadata for key and query based on the specified + attention type and whether input is a prompt. + + This function computes the starting locations and maximum sequence lengths + for key and query sequences for different attention types. + + Args: + attn_metadata: The attention metadata object + is_prompt (bool): A flag indicating if the input is a prompt + attn_type (AttentionType): The type of attention being used. + + Returns: + tuple: A tuple containing four integers: + - Starting location for the query sequence. + - Maximum sequence length for the query sequence. + - Starting location for the key sequence. + - Maximum sequence length for the key sequence. + + Raises: + AttributeError: If an invalid attention type is provided. + """ + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.seq_start_loc, max_seq_len) + + elif attn_type == AttentionType.ENCODER_DECODER: + # This is cross attention between the where the key + # is the precomputed encoder attention and query + # is the input sequence. + # Choose query max length based on whether it is prompt + # or not. + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER: + # For encoder attention both the query and the key are same i.e the + # encoder sequence. + return (attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER_ONLY: + assert is_prompt, "Should not have decode for encoder only model." + return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, + attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _get_causal_option(attn_type: str) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) diff --git a/attention/backends/flashinfer.py b/attention/backends/flashinfer.py new file mode 100644 index 0000000..5e8036b --- /dev/null +++ b/attention/backends/flashinfer.py @@ -0,0 +1,1105 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +import os +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type + +from vllm.multimodal import MultiModalPlaceholderMap + +try: + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + + # from vllm.vllm_flash_attn import flash_attn_varlen_func + from flash_attn import flash_attn_varlen_func + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + # Avoid turning these types into variables during type checking + if not TYPE_CHECKING: + BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.layer import Attention +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.logger import init_logger +from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, + make_tensor_with_pad) + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT", + "NHD").upper() + + +class FlashInferBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "FLASHINFER" + + @staticmethod + def get_impl_cls() -> Type["FlashInferImpl"]: + return FlashInferImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashInferMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: + return FlashInferMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashInferState"]: + return FlashInferState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order() -> Tuple[int, ...]: + cache_layout = FLASHINFER_KV_CACHE_LAYOUT + assert (cache_layout in ("NHD", "HND")) + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, + 2, 4) + return stride_order + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 128, 256] + + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + + +@dataclass +class PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters. + """ + + window_left: int + logits_soft_cap: Optional[float] + sm_scale: float + + +def get_per_layer_parameters( + vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]: + """ + Scan all attention layers and determine some hyperparameters + to use during `plan`. + """ + + layers = get_layers_from_vllm_config(vllm_config, Attention) + per_layer_params: Dict[str, PerLayerParameters] = {} + + for key, layer in layers.items(): + impl = layer.impl + assert isinstance(impl, FlashInferImpl) + + # Infer hyperparameters from the attention layer + window_size = impl.sliding_window + window_left = window_size[0] if window_size is not None else -1 + logits_soft_cap = impl.logits_soft_cap + sm_scale = impl.scale + + per_layer_params[key] = PerLayerParameters(window_left, + logits_soft_cap, sm_scale) + + return per_layer_params + + +def infer_global_hyperparameters( + per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters: + - `window_left` + - `logits_soft_cap` + - `sm_scale` + + So this function asserts that all layers share the same values for these + hyperparameters and returns the global values. + """ + + assert len(per_layer_params) > 0, "No attention layers found in the model." + + param_sets = list(per_layer_params.values()) + global_params = param_sets[0] + for params in param_sets: + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all " + "layers share the same values for the following hyperparameters: " + "`window_left`, `logits_soft_cap`, `sm_scale`.") + + return global_params + + +class FlashInferState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + self._workspace_buffer = None + self._decode_wrapper = None + self._prefill_wrapper = None + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = self.runner.vllm_config + self._kv_cache_layout = None + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def get_kv_cache_layout(self): + if self._kv_cache_layout is None: + self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT + return self._kv_cache_layout + + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), self.get_kv_cache_layout()) + return self._prefill_wrapper + + def _get_decode_wrapper(self): + if self._decode_wrapper is None: + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + self.get_kv_cache_layout(), + use_tensor_cores=use_tensor_cores) + return self._decode_wrapper + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_decode_wrapper = None + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + self._graph_decode_workspace_buffer = self._get_workspace_buffer() + self._graph_indices_buffer = torch.empty( + max_batch_size * self.runner.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.runner.device) + self._graph_indptr_buffer = torch.empty(max_batch_size + 1, + dtype=torch.int32, + device=self.runner.device) + self._graph_last_page_len_buffer = torch.empty( + max_batch_size, dtype=torch.int32, device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._graph_decode_workspace_buffer + del self._graph_indices_buffer + del self._graph_indptr_buffer + del self._graph_last_page_len_buffer + del self._graph_decode_wrapper + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + state = self.__class__(self.runner) + state._workspace_buffer = self._graph_decode_workspace_buffer + state._decode_wrapper = self._graph_decode_wrapper + state._prefill_wrapper = self._get_prefill_wrapper() + return state + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] + _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] + + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) + self._graph_decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self._graph_decode_workspace_buffer, _indptr_buffer, + self._graph_indices_buffer, _last_page_len_buffer, + self.get_kv_cache_layout(), + use_tensor_cores) + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + paged_kv_indptr_tensor_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + paged_kv_indices_tensor_host = torch.arange(0, + batch_size, + dtype=torch.int32) + paged_kv_last_page_len_tensor_host = torch.full((batch_size, ), + self.runner.block_size, + dtype=torch.int32) + query_start_loc_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + + global_params = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + max_prefill_seq_len=0, + block_tables=self._graph_block_tables, + paged_kv_indptr=paged_kv_indptr_tensor_host, + paged_kv_indices=paged_kv_indices_tensor_host, + paged_kv_last_page_len=paged_kv_last_page_len_tensor_host, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=self.runner.model_config.get_head_size(), + page_size=self.runner.block_size, + seq_start_loc=None, + query_start_loc=query_start_loc_host, + device=self.runner.device, + data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, + use_cuda_graph=True, + decode_wrapper=self._graph_decode_wrapper, + prefill_wrapper=None, + **dataclasses.asdict(global_params), + ) + attn_metadata.begin_forward() + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + return { + "slot_mapping": attn_metadata.slot_mapping, + } + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + return + + def begin_forward(self, model_input): + assert not self._is_graph_capturing + state = self + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + is_decode = model_input.attn_metadata.num_prefills == 0 + # In case of multistep chunked-prefill, there might be prefill requests + # scheduled while CUDA graph mode is enabled. We don't run graph in that + # case. + if use_cuda_graph and is_decode: + if model_input.inputs_embeds is None: + batch_size = model_input.input_tokens.shape[0] + state = ( + self.runner.graph_runners[model_input.virtual_engine][( + batch_size, False)].attn_state) + else: + batch_size = model_input.inputs_embeds.shape[0] + state = ( + self.runner.graph_runners[model_input.virtual_engine][( + batch_size, True)].attn_state) + + model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( + ) + model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() + model_input.attn_metadata.begin_forward() + + +@dataclass +class FlashInferMetadata(AttentionMetadata): + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = 1 + + use_cuda_graph: bool = True + + prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + + # Metadata for the prefill stage + seq_start_loc: Optional[torch.Tensor] = None + query_start_loc: Optional[torch.Tensor] = None + block_tables: Optional[torch.Tensor] = None + + # used for GPU in-place advance_step + seq_lens_tensor: Optional[torch.Tensor] = None + block_table_bound: Optional[torch.Tensor] = None + + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads + num_qo_heads: Optional[int] = None + # The number of key/value heads + num_kv_heads: Optional[int] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + # Block size of vllm + page_size: Optional[int] = None + # The data type of the paged kv cache + data_type: torch.dtype = None + # The data type of the query + q_data_type: torch.dtype = None + # FlashInfer 0.2 encourages passing host tensors + device: torch.device = torch.device("cpu") + is_profile_run: bool = False + + # The FlashInfer backend currently supports only models in which all layers + # share the same following hyperparameters: + + # The left (inclusive) window size for the attention window, when + # set to `-1`, the window size will be set to the full length of + # the sequence. Defaults to `-1`. + window_left: int = -1 + # The attention logits soft capping value (used in Gemini, Grok and + # Gemma-2, etc.), if not provided, will be set to `0`. If greater + # than 0, the logits will be capped according to formula: + # $$\texttt{logits\_soft\_cap} \times + # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$, + # where $x$ is the input logits. + logits_soft_cap: Optional[float] = None + # The scale used in softmax, if not provided, will be set to + # `1.0 / sqrt(head_dim)`. + sm_scale: Optional[float] = None + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f" received {self.head_dim}.") + + def begin_forward(self): + if self.num_prefill_tokens > 0: + if self.paged_kv_indices is None: + return + + assert self.prefill_wrapper is not None + assert self.query_start_loc is not None + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + assert self.block_table_bound is not None + assert self.seq_lens_tensor is not None + self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] + batch_size = self.query_start_loc.shape[0] - 1 + assert batch_size >= 0 + # We will use flash attention for profiling to + # determine the number of blocks. Therefore, + # we don't need to prepare the input for flashinfer for profile run. + if not self.is_profile_run: + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + self.block_table_bound = self.block_table_bound.to(self.device) + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.prefill_wrapper.plan( + self.query_start_loc, + self.paged_kv_indptr[:self.num_prefills + 1], + self.paged_kv_indices, + self.paged_kv_last_page_len[:self.num_prefills], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.data_type) + if self.num_decode_tokens > 0: + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + # handle model warmup path + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to(self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + + assert self.decode_wrapper is not None + self.decode_wrapper.plan( + self.paged_kv_indptr[self.num_prefills:], + self.paged_kv_indices, + self.paged_kv_last_page_len[self.num_prefills:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + sm_scale=self.sm_scale, + # kv-cache data type. + kv_data_type=self.data_type, + # query data type. + q_data_type=self.q_data_type) + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + if skip_fields is None: + skip_fields = set() + # We need to skip the prefill/decode_wrapper field since it cannot be + # broadcasted with nccl when TP is enabled. + skip_fields.add('prefill_wrapper') + skip_fields.add('decode_wrapper') + return super().asdict_zerocopy(skip_fields) + + @property + def prefill_metadata(self) -> Optional["FlashInferMetadata"]: + if self.num_prefills == 0: + return None + return self + + @property + def decode_metadata(self) -> Optional["FlashInferMetadata"]: + if self.num_decode_tokens == 0: + return None + return self + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + # Flashinfer doesn't support speculative decoding + chunked-prefill + # + multi-step scheduling yet. + assert self.decode_query_len == 1 + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens_tensor is not None + + assert num_seqs > 0 + assert num_queries > 0 + assert model_input.attn_metadata is not None + assert sampled_token_ids is not None + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + model_input.input_tokens[:num_queries] = sampled_token_ids.flatten() + + # Update GPU tensors + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=model_input.input_tokens, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_len=self.paged_kv_last_page_len, + block_table_bound=self.block_table_bound) + + +class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = self.runner.vllm_config + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + self.paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + self.paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + self.paged_kv_last_page_len: List[int] = [] + self.total_blocks = 0 + self.is_profile_run: bool = False + + if self.global_hyperparameters is None: + # Infer global hyperparameters, since currently we only support + # models in which all layers share the same values for the + # following hyperparameters: + # - `window_left` + # - `logits_soft_cap` + # - `sm_scale` + inferred_params = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + self.global_hyperparameters = inferred_params + self.window_left = inferred_params.window_left + self.logits_soft_cap = inferred_params.logits_soft_cap + self.sm_scale = inferred_params.sm_scale + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + computed_block_nums = inter_data.computed_block_nums + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if inter_data.prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + is_profile_run = is_block_tables_empty(block_tables) + + # Compute slot mapping. + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + # It is not necessary to add paged_kv_indices, paged_kv_indptr, + # and paged_kv_last_page_len for profile run because we will + # create dummy inputs. + if is_profile_run: + self.is_profile_run = is_profile_run + return + + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + decode_query_len = max(query_lens[self.num_prefills:], default=1) + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] + for i, block_table in enumerate(self.block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) + + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + + assert device is not None + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device="cpu", + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device="cpu", + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device="cpu", + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + block_table_bound_tensor = None + + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + return FlashInferMetadata( + decode_query_len=decode_query_len, + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + block_table_bound=block_table_bound_tensor, + seq_lens_tensor=seq_lens_tensor, + num_qo_heads=self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config), + num_kv_heads=self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config), + head_dim=self.runner.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc, + device=device, + data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, + use_cuda_graph=use_captured_graph, + is_profile_run=self.is_profile_run, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + sm_scale=self.sm_scale, + ) + + +class FlashInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if use_irope: + logger.warning_once( + "Using irope in FlashInfer is not supported yet, it will fall" + " back to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # TODO: directly write to output tensor + num_heads: int = self.num_heads + head_size: int = self.head_size + num_kv_heads: int = self.num_kv_heads + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes = self.alibi_slopes + logits_soft_cap = self.logits_soft_cap + + num_tokens, hidden_size = query.shape + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + if kv_cache.numel() > 0: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + query = query.contiguous( + ) # Flashinfer requires query to be contiguous + # Query for decode. KV is not needed because it is already cached. + # QKV for prefill. + decode_query = query[num_prefill_tokens:] + query = query[:num_prefill_tokens] + + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + window_left = window_size[0] if window_size is not None else -1 + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + stride_order = FlashInferBackend.get_kv_cache_stride_order() + if prefill_meta := attn_metadata.prefill_metadata: + # We will use flash attention for prefill + # when kv_cache is not provided. + # This happens when vllm runs the profiling to + # determine the number of blocks. + if kv_cache.numel() == 0: + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + ) + else: + assert prefill_meta is not None + assert prefill_meta.prefill_wrapper is not None + + assert prefill_meta.prefill_wrapper._causal + assert prefill_meta.prefill_wrapper._window_left == window_left + assert prefill_meta.prefill_wrapper._logits_soft_cap == ( + logits_soft_cap or 0.0) + assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale + + prefill_output = prefill_meta.prefill_wrapper.run( + query, + kv_cache.permute(*stride_order), + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + ) + if decode_meta := attn_metadata.decode_metadata: + assert decode_meta is not None + assert decode_meta.decode_wrapper is not None + + assert decode_meta.decode_wrapper._window_left == window_left + assert decode_meta.decode_wrapper._logits_soft_cap == ( + logits_soft_cap or 0.0) + assert decode_meta.decode_wrapper._sm_scale == softmax_scale + + decode_output = decode_meta.decode_wrapper.run( + decode_query, + kv_cache.permute(*stride_order), + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + ) + + if prefill_output is None and decode_output is not None: + # Decode only batch. + output, num_tokens = decode_output, num_decode_tokens + elif decode_output is None and prefill_output is not None: + # Prefill only batch. + output, num_tokens = prefill_output, num_prefill_tokens + else: + # Chunked prefill batch does not work with speculative decoding in + # FlashInfer backend, so the query length for decode should be 1. + assert prefill_output is not None + assert decode_output is not None + assert decode_meta is not None + assert decode_meta.decode_query_len == 1 + decode_output = decode_output.squeeze(1) + output = torch.cat([prefill_output, decode_output], dim=0) + return output.view(num_tokens, hidden_size) diff --git a/attention/backends/flashmla.py b/attention/backends/flashmla.py new file mode 100644 index 0000000..e185d02 --- /dev/null +++ b/attention/backends/flashmla.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA" + + @staticmethod + def get_impl_cls() -> Type["FlashMLAImpl"]: + return FlashMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["FlashMLAMetadata"]: + return FlashMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashMLAState"]: + return FlashMLAState + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata): + decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None + decode_num_splits: Optional[torch.Tensor] = None + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + # TODO: cache assignment? + if decode_metadata is not None: + decode_metadata.decode_tile_scheduler_metadata=\ + self.decode_tile_scheduler_metadata + decode_metadata.decode_num_splits=\ + self.decode_num_splits + return decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + raise NotImplementedError( + "advance_step is not implemented for FlashMLA") + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + m = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + + if m.num_decode_tokens > 0: + m.decode_tile_scheduler_metadata, m.decode_num_splits = \ + get_mla_metadata( + m.seq_lens_tensor[m.num_prefills:], + self.num_q_heads, + 1, # MQA for the decode path + ) + + return m + + +class FlashMLAState(MLACommonState[FlashMLAMetadata]): + + def __init__(self, *args, **kwds): + super().__init__(*args, **kwds) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + @contextmanager + def graph_capture(self, max_batch_size: int): + # Run a dummy `get_mla_metadata` so we can get the right shapes + self._graph_decoder_tile_scheduler_metadata, \ + self._graph_decode_num_splits = get_mla_metadata( + torch.ones( + max_batch_size, dtype=torch.int32, device=self.runner.device), + self.num_q_heads, + 1, # MQA for the decode path + ) + + with super().graph_capture(max_batch_size): + yield + + del self._graph_decoder_tile_scheduler_metadata + del self._graph_decode_num_splits + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + assert metadata.num_decode_tokens > 0 + + decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( + self._graph_seq_lens[:batch_size], + self.num_q_heads, + 1, # MQA for the decode path + ) + + self._graph_decoder_tile_scheduler_metadata.copy_( + decoder_tile_scheduler_metadata) + self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits) + + metadata.decode_tile_scheduler_metadata=\ + self._graph_decoder_tile_scheduler_metadata + metadata.decode_num_splits=\ + self._graph_decode_num_splits[:batch_size + 1] + + return metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers["decode_tile_scheduler_metadata"] = \ + attn_metadata.decode_metadata.decode_tile_scheduler_metadata + input_buffers["decode_num_splits"] = \ + attn_metadata.decode_metadata.decode_num_splits + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + input_buffers["decode_tile_scheduler_metadata"].copy_( + attn_metadata.decode_metadata.decode_tile_scheduler_metadata) + input_buffers["decode_num_splits"].copy_( + attn_metadata.decode_metadata.decode_num_splits) + + +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str] = None, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + assert is_flashmla_supported(), \ + "FlashMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashMLA with FP8 KV cache not yet supported") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1)\ + .unsqueeze(1) # Add seqlen dim of 1 (decode) + + o, _ = flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata, + num_splits=decode_meta.decode_num_splits, + softmax_scale=self.scale, + causal=True, + ) + + return self._v_up_proj(o) diff --git a/attention/backends/hpu_attn.py b/attention/backends/hpu_attn.py new file mode 100644 index 0000000..9bd513f --- /dev/null +++ b/attention/backends/hpu_attn.py @@ -0,0 +1,313 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import vllm_hpu_extension.kernels as kernels +import vllm_hpu_extension.ops as ops +from vllm_hpu_extension.flags import enabled_flags +from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, + HPUPagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class HPUAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "HPU_ATTN" + + @staticmethod + def get_impl_cls() -> Type["HPUAttentionImpl"]: + return HPUAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return HPUAttentionMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dsts: torch.Tensor, + ) -> None: + HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dsts: torch.Tensor, + ) -> None: + HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts) + + +@dataclass +class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): + """Metadata for HPUAttentionbackend.""" + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + attn_bias: Optional[torch.Tensor] + seq_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] + + +class HPUAttentionImpl(AttentionImpl, torch.nn.Module): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + max_seq_len: int = 4096, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + super(AttentionImpl, self).__init__() + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if use_irope: + logger.warning_once( + "Using irope in HPU is not supported yet, it will fall back " + "to global attention for long context.") + self.kv_cache_dtype = kv_cache_dtype + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.matmul_qk = Matmul() + self.softmax = Softmax() + self.matmul_av = Matmul() + self.batch2block_matmul = Matmul() + self.block2batch_matmul = Matmul() + self.k_cache = VLLMKVCache() + self.v_cache = VLLMKVCache() + self.fused_scaled_dot_product_attention = kernels.fsdpa() + + self.prefill_impl = 'naive' + if "flex_attention" in enabled_flags(): + self.prefill_impl = 'flex' + if "fsdpa" in enabled_flags(): + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + self.prefill_impl = 'fsdpa' + + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window + self.alibi_slopes = alibi_slopes + if alibi_slopes is not None: + alibi_slopes_tensor = torch.tensor(alibi_slopes, + dtype=torch.bfloat16) + self.alibi_slopes = alibi_slopes_tensor + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if self.prefill_impl == 'fsdpa': + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + + supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + self.attn_type = attn_type + if self.attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "HPUAttentionImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "HPUAttention with FP8 KV cache not yet supported") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: HPUAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with xFormers and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + batch_size, seq_len, hidden_size = query.shape + _, seq_len_kv, _ = key.shape + + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + block_indices = attn_metadata.block_indices + block_offsets = attn_metadata.block_offsets + key_cache = None + value_cache = None + if attn_metadata.is_prompt and self.attn_type \ + is not AttentionType.ENCODER_ONLY: + key = key.unflatten(0, (block_indices.size(0), -1)) + value = value.unflatten(0, (block_indices.size(0), -1)) + if kv_cache is not None and isinstance(kv_cache, tuple): + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + key_cache = self.k_cache(key, key_cache, block_indices, + block_offsets) + value_cache = self.v_cache(value, value_cache, block_indices, + block_offsets) + + if attn_metadata.is_prompt: + # Prompt run. + query_shape = (batch_size, seq_len, self.num_heads, self.head_size) + kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, + self.head_size) + + attn_bias = attn_metadata.attn_bias + if attn_bias is not None and self.alibi_slopes is not None: + position_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, + attn_bias.dtype, + attn_bias.shape[-1]) + attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) + attn_bias.add_(position_bias) + + block_list = attn_metadata.block_list if attn_metadata \ + and attn_metadata.block_list is not None else None + + out = ops.prompt_attention( + impl=self.prefill_impl, + query=query.view(query_shape), + key=key.view(kv_shape), + value=value.view(kv_shape), + is_causal=True, + attn_bias=attn_bias, + valid_seq_lengths=attn_metadata.seq_lens_tensor, + **self.common_attention_args(block_list, key_cache, + value_cache)) + output = out.reshape(batch_size, seq_len, hidden_size) + else: + # Decoding run. + output = HPUPagedAttention.forward_decode( + query=query, + block_mapping=attn_metadata.block_mapping, + block_bias=attn_metadata.attn_bias, + block_groups=attn_metadata.block_groups, + **self.common_attention_args(attn_metadata.block_list, + key_cache, value_cache)) + # Reshape the output tensor. + return output.view(batch_size, seq_len, hidden_size) + + def common_attention_args(self, + block_list=None, + key_cache=None, + value_cache=None): + fsdpa_op = self.fused_scaled_dot_product_attention.apply \ + if self.fused_scaled_dot_product_attention is not None else None + return { + 'scale': self.scale, + 'matmul_qk_op': self.matmul_qk, + 'matmul_av_op': self.matmul_av, + 'batch2block_matmul_op': self.batch2block_matmul, + 'block2batch_matmul_op': self.block2batch_matmul, + 'fsdpa_op': fsdpa_op, + 'keys_fetch_func': self.k_cache.fetch_from_cache, + 'values_fetch_func': self.v_cache.fetch_from_cache, + 'softmax_op': self.softmax, + 'block_list': block_list, + 'key_cache': key_cache, + 'value_cache': value_cache, + } + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_len: int, +) -> torch.Tensor: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + return bias diff --git a/attention/backends/ipex_attn.py b/attention/backends/ipex_attn.py new file mode 100644 index 0000000..5051c6a --- /dev/null +++ b/attention/backends/ipex_attn.py @@ -0,0 +1,398 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + +_PARTITION_SIZE = 512 + + +class IpexAttnBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "IPEX" + + @staticmethod + def get_impl_cls() -> Type["IpexAttnBackendImpl"]: + return IpexAttnBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["IpexAttnMetadata"]: + return IpexAttnMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + from vllm._ipex_ops import ipex_ops as ops + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + from vllm._ipex_ops import ipex_ops as ops + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for IpexAttnBackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + slot_mapping: torch.Tensor + seq_lens: Optional[List[int]] + seqlen_q: Optional[torch.Tensor] + max_seqlen: Optional[int] + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + + @property + def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if use_irope: + logger.warning_once( + "Using irope in Ipex is not supported yet, it will fall" + " back to global attention for long context.") + if blocksparse_params is not None: + raise ValueError( + "IPEX backend does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.sliding_window is not None) + if logits_soft_cap is None: + logits_soft_cap = -1 + self.logits_soft_cap = logits_soft_cap + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + if is_quantized_kv_cache(kv_cache_dtype): + raise NotImplementedError( + "IPEX backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "IpexAttnBackendImpl") + + def split_kv_cache( + self, + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 1 + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: IpexAttnMetadata, # type: ignore + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with IPEX varlen_attention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache.numel() > 0: + key_cache, value_cache = self.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + ipex_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + + if attn_metadata.is_prompt: + assert attn_metadata.seq_lens is not None + if (kv_cache.numel() == 0 + or attn_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=1) + + if attn_metadata.attn_bias is None: + if self.sliding_window is not None: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, None, dtype=query.dtype) + attn_metadata.attn_bias = att_masks + + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + ipex_ops.varlen_attention( + query, + key, + value, + output, + attn_metadata.seqlen_q, + attn_metadata.seqlen_q, + self.alibi_slopes, + attn_metadata.max_seqlen, + attn_metadata.max_seqlen, + pdropout=0.0, + softmax_scale=self.scale, + zero_tensors=False, + is_causal=True, + return_softmax=False, + gen_=None, + window_size_left=-1, + window_size_right=-1, + logits_soft_cap=self.logits_soft_cap, + ) + else: + # prefix-enabled attention + raise RuntimeError( + "IPEX backend doesn't support prefix decoding.") + + else: + # Decoding run. + max_seq_len = attn_metadata.max_decode_seq_len + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory + # shortage. + use_v1 = (max_seq_len <= 8192 and + (max_num_partitions == 1 or num_seqs * num_heads > 512)) + if use_v1: + # Run PagedAttention V1. + ipex_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ipex_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype, + device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/attention/backends/mla/__init__.py b/attention/backends/mla/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/attention/backends/mla/common.py b/attention/backends/mla/common.py new file mode 100644 index 0000000..103fdce --- /dev/null +++ b/attention/backends/mla/common.py @@ -0,0 +1,1387 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +# MLA Common Components + +This file implements common components for MLA implementations. + +First we define: + +Sq as Q sequence length +Skv as KV sequence length + +MLA has two possible ways of computing, a data-movement friendly approach and a +compute friendly approach, we generally want to use the compute friendly +approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) +and the data-movement friendly approach for "decode" (i.e. the ratio +Sq / Skv is "large"). + +NOTE what we deem small and large is currently determined by if its labelled +prefill or decode by the scheduler, but this is something we should probably +tune. + +Main reference: DeepseekV2 paper, and FlashInfer Implementation +(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + +Deepseek's MLA attention works the following way: +* Use a single latent vector to represent the per-token entry of the KV cache. +* For decode (i.e. the memory friendly approach) the attention "simulates" a +multi-head attention, while the compute is similar to multi-query attention. + +Below is example of both paths assuming batchsize = 1 + +## More Extent Definitions: + +C Context length, `Skv - Sq` +H hidden size +N number of attention heads +Lq latent dimension for Q 1536 in DSV3 +Lkv latent dimension for K/V 512 in DSV3 +P nope dimension, no rope. 128 in DSV3 +R rope dimension, goes through rope. 64 in DSV3 +V V head dim. 128 in DSV3 + +## Vector/Matrix Definitions + +h_t hidden states (input to attention) shape [Sq, H] +q_c latent/compressed Q shape [Sq, Lq] +q_nope uncompressed Q (no-rope) shape [Sq, N, P] +q_pe uncompressed Q (rope) shape [Sq, N, R] +kv_c latent/compressed KV shape [Skv, Lkv] +k_pe decoupled k position embeddings shape [Skv, R] +new_kv_c new kv_c from current iter shape [Sq, Lkv] +new_k_pe new k_pe from current iter shape [Sq, R] +cache_kv_c cached k_c from previous iters shape [C, Lkv] +cache_k_pe cached k_pe from previous iters shape [C, R] +W_DQ project h_t to q_c shape [H, Lq] +W_UQ project q_c to q_nope shape [Lq, N * P] +W_QR project q_c to q_pe shape [Lq, N * R] +W_DKV project h_t to kv_c shape [H, Lkv] +W_UK project kv_c to k_nope shape [Lkv, N, P] +W_KR project h_t to k_pe shape [H, R] +W_UV project kv_c to v shape [Lkv, N, V] +W_O project v to h_t shape [N * V, H] + + +## Compute Friendly Approach (i.e. "_forward_prefill"): + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) +k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) +v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V) + +// MHA with QK headdim = P + R +// V headdim = V +// spda_o shape [Sq, N, V] +spda_o = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + v +) +return spda_o @ W_O + +NOTE: in the actual code, + `kv_b_proj` is [W_UK; W_UV] concatenated per head + `q_b_proj` is [W_UQ; W_QR] concatenated per head + `out_proj` is W_O + + +## Data-Movement Friendly Approach (i.e. "_forward_decode"): + +Runtime +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(-1, N, P) +ql_nope = einsum("snh,lnh->snl", q, W_UK) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) + +// MQA with QK headdim = Lkv + R +// V headdim = Lkv +// spda_o shape [Sq, N, Lkv] +// NOTE: this is less compute-friendly since Lkv > P +// but is more data-movement friendly since its MQA vs MHA +spda_o = scaled_dot_product_attention( + torch.cat([ql_nope, q_pe], dim=-1), + torch.cat([kv_c, k_pe], dim=-1), + kv_c +) + +o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) +return o.view(-1, N * V) @ self.num_heads @ W_O + + +## Chunked Prefill + +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +the data-movement friendly approach if the chunk (i.e. `Sq`) is small. + +However, the compute-friendly approach can potentially run out of memory if Skv +is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` + +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +fixed workspace size. + +The chunked prefill approach is as follows: + +MCC Max chunk of context to process per iter, computed dynamically, + used to bound the memory usage + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) +new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V) + +// MHA between queries and new KV +// with QK headdim = P + R +// V headdim = V +// curr_o shape [Sq, N, V] +// curr_lse shape [N, Sq], this is just order FA returns +curr_o, curr_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + new_v, + casual=True, + return_softmax_lse=True +) + +// Compute attention with the already existing context +for chunk_idx in range(cdiv(C, MCC)): + chunk_start = chunk_idx * MCC + chunk_end = min(chunk_start + MCC, C) + Sc = chunk_end - chunk_start + cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] + cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] + cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) + cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) + + chunk_o, chunk_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([cache_k_nope_chunk, + cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], + dim=-1), + cache_v_chunk, + casual=False, + return_softmax_lse=True + ) + + curr_o, curr_lse = merge_attn_states( + suffix_output=curr_o, + suffix_lse=curr_lse, + prefix_output=chunk_o, + prefix_lse=chunk_lse, + ) + +return curr_o @ W_O +""" + +import functools +from abc import abstractmethod +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import accumulate +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, + Type, TypeVar) + +import torch + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, MLAAttentionImpl) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.merge_attn_states import merge_attn_states +# from vllm.attention.utils.fa_utils import get_flash_attn_version +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, + UnquantizedLinearMethod) +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON +from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down + +if HAS_TRITON: + from vllm.attention.ops.triton_flash_attention import triton_attention +else: + triton_attention = None + +try: + from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True +except ImportError: + is_vllm_fa = False + try: + # For rocm use upstream flash attention + from flash_attn import flash_attn_varlen_func + except ImportError: + flash_attn_varlen_func = None + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +is_hip = current_platform.is_rocm() + +def get_flash_attn_version(): + return None + +class MLACommonBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return MLACommonMetadata + + @staticmethod + def get_builder_cls() -> Type["MLACommonMetadataBuilder"]: + return MLACommonMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["MLACommonState"]: + return MLACommonState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + ops.copy_blocks_mla(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +T = TypeVar("T", bound="MLACommonMetadata") + + +class MLACommonState(AttentionState, Generic[T]): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + + scheduler_config = runner.scheduler_config + self.model_config = runner.model_config + cache_config = runner.cache_config + + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + self.enable_prefix_caching = cache_config.enable_prefix_caching + + if self.chunked_prefill_enabled or self.enable_prefix_caching: + self.context_chunk_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max( + 8 * self.model_config.max_model_len, 4 * + scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.context_chunk_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + self._positions = torch.zeros((max_batch_size, ), + dtype=torch.long, + device=self.runner.device) + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._positions + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: + assert self._is_graph_capturing + + attn_metadata = self.runner.attn_backend.make_metadata( + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + use_cuda_graph=True, + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + head_dim=self.runner.model_config.get_head_size()) + + if is_encoder_decoder_model: + raise NotImplementedError( + "MLACommonState does not support encoder/decoder yet") + + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + if is_encoder_decoder_model: + raise NotImplementedError( + "MLACommonState does not support encoder/decoder yet") + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + def begin_forward(self, model_input): + if self.chunked_prefill_enabled or self.enable_prefix_caching: + if not hasattr(self, "context_chunk_workspace"): + # not self.runner.device does not return the correct device + # for this process, (init_device sets the correct device but + # only on the Worker). The only way Ive figured out to get the + # correct device is to allocate the workspace on the first call + # to begin_forward and use the device of the input tokens + assert model_input.input_tokens is not None + self.context_chunk_workspace = torch.empty( + (self.context_chunk_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=model_input.input_tokens.device, + ) + + model_input.attn_metadata.context_chunk_workspace = \ + self.context_chunk_workspace + + +@dataclass +class MLACommonMetadata(AttentionMetadata): + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional[Any] = None + _cached_decode_metadata: Optional[Any] = None + + num_prefill_tokens: int + + # The dimension of the attention heads + head_dim: Optional[int] = None + + # Used when chunked prefill is enabled to simulate worst case workspace + # allocations, hopefully to avoid going OOM + is_profile_run: bool = False + + # New for MLA (compared to FlashAttention) + # For chunked prefill + context_chunk_cu_seq_lens: Optional[torch.Tensor] = None + context_chunk_starts: Optional[torch.Tensor] = None + context_chunk_seq_tot: Optional[List[int]] = None + context_chunk_max_seq_lens: Optional[List[int]] = None + # Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted + context_chunk_workspace: Optional[torch.Tensor] = None + + def __post_init__(self): + supported_head_sizes = MLACommonBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f" received {self.head_dim}.") + + @property + def prefill_metadata(self): + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + + self._cached_prefill_metadata = self.__class__( + # Required by ModelRunner + use_cuda_graph=False, # Not Attention Related + # Required by Attention Metadata + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + # MLACommonMetadata + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + head_dim=self.head_dim, + is_profile_run=self.is_profile_run, + # MLACommonMetadata Chunk prefill specific + context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens, + context_chunk_starts=self.context_chunk_starts, + context_chunk_seq_tot=self.context_chunk_seq_tot, + context_chunk_max_seq_lens=self.context_chunk_max_seq_lens, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self): + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + + self._cached_decode_metadata = self.__class__( + # Required by ModelRunner + use_cuda_graph=self.use_cuda_graph, # Not Attention Related + # Required by Attention Metadata + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + # MLACommonMetadata + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + head_dim=self.head_dim, + is_profile_run=self.is_profile_run) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + self._ops_advance_step(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions) + + def _ops_advance_step(self, num_seqs: int, num_queries: int, + block_size: int, input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor) -> None: + # here we use advance_step_flashinfo to update the paged_kv_* tensors + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + BLOCK_TABLE_EXTENDER: list[list[int]] = [] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.chunked_prefill_enabled = \ + self.runner.scheduler_config.chunked_prefill_enabled + self.enable_prefix_caching = \ + self.runner.cache_config.enable_prefix_caching + + if self.chunked_prefill_enabled or self.enable_prefix_caching: + attn_state = self.input_builder.runner.attn_state + self.context_chunk_workspace_size = \ + attn_state.context_chunk_workspace_size + self.page_size = self.runner.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER * + cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + + context_chunk_cu_seq_lens = None + context_chunk_starts = None + context_chunk_seq_tot = None + context_chunk_max_seq_lens = None + + if (self.chunked_prefill_enabled or self.enable_prefix_caching) \ + and self.num_prefills > 0 \ + and context_lens_tensor is not None \ + and context_lens_tensor[:self.num_prefills].max() > 0: + + # NOTE: it is recommend you read the `Chunked Prefill` section in + # the comment at the top of the file before trying to understand + # the following code + + num_prefills_with_context = \ + (context_lens_tensor[:self.num_prefills] > 0).sum().item() + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = \ + self.context_chunk_workspace_size // num_prefills_with_context + + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, self.page_size) + assert max_context_chunk > 0 + num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + context_chunk_starts = \ + torch.arange(num_chunks, device=device, dtype=torch.int32)\ + .unsqueeze(1).expand(-1, self.num_prefills)\ + * max_context_chunk + chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\ + .unsqueeze(0), context_chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) + _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( + torch.int32) + zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ + .unsqueeze(-1) + context_chunk_cu_seq_lens = \ + torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) + context_chunk_max_seq_lens = \ + chunk_seq_lens.max(dim=1).values.tolist() + context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() + assert max(context_chunk_seq_tot) <= \ + self.context_chunk_workspace_size + + return self.runner.attn_backend.make_metadata( + # Required by ModelRunner + use_cuda_graph=use_captured_graph, # Not Attention Related + # Required by Attention Metadata + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, # Not Attention Related + enable_kv_scales_calculation=False, + # MLACommonMetadata + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + head_dim=self.runner.model_config.get_head_size(), + is_profile_run=self.runner.in_profile_run, + # MLACommonMetadata Chunk prefill specific + context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, + context_chunk_starts=context_chunk_starts, + context_chunk_seq_tot=context_chunk_seq_tot, + context_chunk_max_seq_lens=context_chunk_max_seq_lens, + ) + + +class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + kv_b_proj: ColumnParallelLinear, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing not supported in V0.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + self.kv_b_proj = kv_b_proj + + self.triton_fa_func = triton_attention + # Handle the differences between the flash_attn_varlen from flash_attn + # and the one from vllm_flash_attn. The former is used on RoCM and the + # latter has an additional parameter to control FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = \ + functools.partial(flash_attn_varlen_func, + fa_version=self.vllm_flash_attn_version) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9) + + def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, + return_softmax_lse, **kwargs): + maybe_padded_v = v + if self._pad_v: + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0) + + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \ + and not return_softmax_lse: + attn_out = self.triton_fa_func( + q, + k, + maybe_padded_v, + None, # output + kwargs["cu_seqlens_q"], + kwargs["cu_seqlens_k"], + kwargs["max_seqlen_q"], + kwargs["max_seqlen_k"], + kwargs["causal"], + softmax_scale, + None, # bias + ) + elif is_vllm_fa: + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + else: + # Use return_attn_probs instead of return_softmax_lse for RoCM + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_attn_probs=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + + # Unpack the output if there is multiple results, + # triton always returns (output, softmax_lse), + # vllm_flash_attn returns (output, softmax_lse) when + # `return_softmax_lse = True` + # flash_attn (RoCM) returns (output, softmax_lse, ...) when + # `return_attn_probs = True` + rest = None + if isinstance(attn_out, tuple): + attn_out, *rest = attn_out + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + assert rest is not None + return attn_out, rest[0] + return attn_out + + def _v_up_proj(self, x): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight if not envs.MACA_VLLM_USE_TN_2_NN else layer.weight.T + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ): + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + assert prefill_metadata.context_chunk_seq_tot is not None + assert prefill_metadata.context_chunk_cu_seq_lens is not None + assert prefill_metadata.context_chunk_starts is not None + assert prefill_metadata.context_chunk_max_seq_lens is not None + assert prefill_metadata.context_lens_tensor is not None + + output = None + iters = len(prefill_metadata.context_chunk_seq_tot) + + # Fetch from attn_metadata directly, since it late bound by + # MLAAttentionState, grabbing it directly `attn_metadata` can avoid + # any weirdness around prefill_metadata caching + assert attn_metadata.context_chunk_workspace is not None + workspace = attn_metadata.context_chunk_workspace + + for i in range(iters): + toks = prefill_metadata.context_chunk_seq_tot[i] + + ops.gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_tables, + cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], + batch_size=prefill_metadata.num_prefills, + seq_starts=prefill_metadata.context_chunk_starts[i], + ) + + kv_c_normed = workspace[:toks]\ + [..., :self.kv_lora_rank] + k_pe = workspace[:toks]\ + [..., self.kv_lora_rank:].unsqueeze(1) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + attn_output, attn_softmax_lse = \ + self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + + has_context = prefill_metadata.context_lens_tensor is not None \ + and prefill_metadata.context_lens_tensor.max() > 0 + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + output = self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + ) + + if has_context: + # ROCm flash_attn_varlen_func will return 3 objects instead of 2 + suffix_output, suffix_lse = output + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, attn_metadata) + + output = torch.empty_like(suffix_output) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + + # unpad if necessary + if self._pad_v: + output = output[..., :v.shape[-1]] + + return output.flatten(start_dim=-2) + + @abstractmethod + def _forward_decode( + self, + ql_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, # query in unified attn + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if output is not None: + raise NotImplementedError( + "output is not yet supported for MLAImplBase") + + if attn_metadata.is_profile_run and \ + attn_metadata.context_chunk_workspace is not None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + (attn_metadata.context_chunk_workspace.shape[0], + self.num_heads, self.qk_nope_head_dim + self.v_head_dim), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + has_decode = attn_metadata.decode_metadata is not None + has_prefill = attn_metadata.prefill_metadata is not None + + num_prefill_tokens: int = attn_metadata.num_prefill_tokens + q = q.view(-1, self.num_heads, self.qk_head_dim) + + decode_q = q[num_prefill_tokens:] + + prefill_q = q[:num_prefill_tokens] + prefill_k_pe = k_pe[:num_prefill_tokens] + prefill_k_c_normed = k_c_normed[:num_prefill_tokens] + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + output = torch.empty(attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens, + self.v_head_dim * self.num_heads, + device=q.device, + dtype=q.dtype) + if has_prefill: + output[:num_prefill_tokens] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata) + + if has_decode: + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + + output[num_prefill_tokens:] = self._forward_decode( + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) + + return output diff --git a/attention/backends/pallas.py b/attention/backends/pallas.py new file mode 100644 index 0000000..7ad6761 --- /dev/null +++ b/attention/backends/pallas.py @@ -0,0 +1,351 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla.experimental.custom_kernel # Required to register custom ops. + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "PALLAS" + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: + return PallasAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PallasMetadata"]: + return PallasMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_kv_heads, num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise RuntimeError("swap_blocks is not used for the TPU backend.") + + @torch.compile(backend="openxla") + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dists: Tuple[torch.Tensor, torch.Tensor], + ) -> None: + src_indices, dst_indices = src_to_dists + for k_cache, v_cache in kv_caches: + torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) + k_cache[:, dst_indices] = k_cache[:, src_indices] + torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) + v_cache[:, dst_indices] = v_cache[:, src_indices] + + +@dataclass +class PallasMetadata(AttentionMetadata): + + # Currently, input sequences can only contain all prefills + # or all decoding. + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None + effective_query_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["PallasMetadata"]: + if self.num_prefills == 0: + return None + + assert self.num_decode_tokens == 0 + return self + + @property + def decode_metadata(self) -> Optional["PallasMetadata"]: + if self.num_decode_tokens == 0: + return None + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.block_tables is not None + assert self.context_lens is not None + return self + + +class PallasAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if use_irope: + logger.warning_once( + "Using irope in Pallas is not supported yet, it will fall back " + "to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.logits_soft_cap = logits_soft_cap + if head_size % 128 != 0: + raise NotImplementedError( + f"Head size must be a multiple of 128, found {head_size}.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if is_quantized_kv_cache(kv_cache_dtype): + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + + self.megacore_mode = None + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + attn_metadata: PallasMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Pallas attention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + if kv_cache[0].numel() > 0: + slot_mapping = attn_metadata.slot_mapping + key_cache, value_cache = kv_cache + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.num_prefills > 0: + if attn_metadata.block_tables is None: + # Prefill without paged KV cache. + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, + dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention kernel requires the input shape to be + # [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Prefill with paged KV cache. + # TODO(woosuk): Tune the below knobs. + num_kv_pages_per_compute_block = 16 + num_queries_per_compute_block = 16 + assert seq_len % num_queries_per_compute_block == 0 + output = torch.ops.xla.multi_queries_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.effective_query_lens, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=True, + attn_logits_soft_cap=self.logits_soft_cap, + ) + else: + # Decoding run. + assert kv_cache[0].numel() > 0 + query = query.squeeze(dim=1) + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + self.megacore_mode, + attn_logits_soft_cap=self.logits_soft_cap, + ) + else: + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + attn_logits_soft_cap=self.logits_soft_cap, + ) + output[chunk_start:chunk_end] = chunk_output + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], + *, + attn_logits_soft_cap: Optional[float], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + return torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + attn_logits_soft_cap=attn_logits_soft_cap, + ) diff --git a/attention/backends/placeholder_attn.py b/attention/backends/placeholder_attn.py new file mode 100644 index 0000000..820ddca --- /dev/null +++ b/attention/backends/placeholder_attn.py @@ -0,0 +1,400 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import defaultdict +from dataclasses import dataclass +from itertools import accumulate +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.multimodal import MultiModalPlaceholderMap + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) +from vllm.utils import async_tensor_h2d + +# Placeholder attention backend for models like Mamba and pooling models that +# lack attention. + + +class PlaceholderAttentionBackend(AttentionBackend): + """Placeholder backend for when no attention is needed.""" + + @staticmethod + def get_name() -> str: + return "NO_ATTENTION" + + @staticmethod + def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: + return PlaceholderAttentionImpl + + @staticmethod + def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: + return PlaceholderAttentionMetadataBuilder + + @staticmethod + def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: + return PlaceholderAttentionMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (1, 1, 1, 1, 1) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + return + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + return + + +@dataclass +class PlaceholderAttentionMetadata(AttentionMetadata): + """Attention metadata for prefill and decode batched together.""" + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # Placeholder. + block_tables: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None + _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + + self._cached_prefill_metadata = PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=0, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + + self._cached_decode_metadata = PlaceholderAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert not turn_prefills_into_decodes, \ + ("Multi-Step + Chunked-Prefill is not supported for attention-free" + "models. turn_prefills_into_decodes is a " + "Multi-Step + Chunked-Prefill specific parameter.") + + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + # Update sequences, masking off entries greater than num_queries + device = self.seq_lens_tensor.device + mask = torch.arange(self.seq_lens_tensor.size(0), + device=device) < num_queries + self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype) + if sampled_token_ids is not None: + model_input.input_tokens.masked_scatter_( + mask, sampled_token_ids[:num_queries]) + + +class PlaceholderAttentionMetadataBuilder( + AttentionMetadataBuilder[PlaceholderAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + + self.input_builder = input_builder + self.runner = input_builder.runner + + def prepare(self): + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + """ + is_prompt = inter_data.is_prompt + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + + # Some input builders such as ModelInputForCPUBuilder do not have the + # "inter_data_list" attribute. + # Let's check inter_data_list exists before we reference it. + if hasattr(self.input_builder, "inter_data_list"): + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + if use_captured_graph: + num_decode_tokens = batch_size - self.num_prefill_tokens + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + # Placeholders + slot_mapping_tensor = torch.empty(0) + block_tables = torch.empty(0) + + return PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class PlaceholderAttentionImpl(AttentionImpl): + + def __init__(self, *args, **kwargs) -> None: + return + + def forward(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError diff --git a/attention/backends/rocm_aiter_mla.py b/attention/backends/rocm_aiter_mla.py new file mode 100644 index 0000000..1edf343 --- /dev/null +++ b/attention/backends/rocm_aiter_mla.py @@ -0,0 +1,435 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Type, Union + +import torch + +import vllm._custom_ops as ops +import vllm.envs as envs +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.backends.utils import (compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, + get_aiter_mla_metadata) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + + +def is_aiter_mla_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_MLA + + +class AiterMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_MLA" + + @staticmethod + def get_impl_cls() -> Type["AiterMLAImpl"]: + return AiterMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["AiterMLAMetadata"]: + return AiterMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]: + return AiterMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["AiterMLAState"]: + return AiterMLAState + + +@dataclass +class AiterMLAMetadata(MLACommonMetadata): + # The following 5 tensors are for current version of AITER MLA + block_table_bound: Optional[torch.Tensor] = None + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_lens: Optional[torch.Tensor] = None + + # This is just to make new AITER MLA API work + # -- MTP support is not added yet. + qo_indptr: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self): + prefill_metadata = super().prefill_metadata + self._cached_prefill_metadata = prefill_metadata + + if prefill_metadata is not None: + prefill_metadata.paged_kv_indptr = self.paged_kv_indptr + prefill_metadata.paged_kv_indices = self.paged_kv_indices + prefill_metadata\ + .paged_kv_last_page_lens = self.paged_kv_last_page_lens + prefill_metadata.block_table_bound = self.block_table_bound + prefill_metadata.qo_indptr = self.qo_indptr + + # update the cache + self._cached_prefill_metadata = self.__class__( + **prefill_metadata.__dict__) + + return self._cached_prefill_metadata + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + + self._cached_decode_metadata = decode_metadata + + if decode_metadata is not None: + decode_metadata.paged_kv_indptr = self.paged_kv_indptr + decode_metadata.paged_kv_indices = self.paged_kv_indices + decode_metadata\ + .paged_kv_last_page_lens = self.paged_kv_last_page_lens + decode_metadata.block_table_bound = self.block_table_bound + decode_metadata.qo_indptr = self.qo_indptr + + # update the cache + self._cached_decode_metadata = self.__class__( + **decode_metadata.__dict__) + + return self._cached_decode_metadata + + def _ops_advance_step(self, num_seqs: int, num_queries: int, + block_size: int, input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor) -> None: + + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_lens=self.paged_kv_last_page_lens, + block_table_bound=self.block_table_bound) + + +class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + super().__init__(input_builder) + assert self.block_size == 1, "AITER MLA requires only block size 1." + + def prepare(self): + super().prepare() + self.paged_kv_indices: list[int] = [] + self.paged_kv_indptr: list[int] = [0] + self.paged_kv_last_page_lens: list[int] = [] + self.total_blocks = 0 + self.qo_indptr: list[int] = [0] + + def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, + prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + if is_profile_run: + return + + # Update paged_kv_* tensors only for non-profile run + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + self.qo_indptr.append(self.qo_indptr[-1] + 1) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_lens.append(last_page_len) + + def build(self, seq_lens: list[int], query_lens: list[int], + cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata: + metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + if use_captured_graph: + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) + last_qo_indptr = self.qo_indptr[-1] + self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size) + + # For current version of AITER MLA + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device=device, + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device=device, + dtype=torch.int) + paged_kv_last_page_lens_tensor = torch.tensor( + self.paged_kv_last_page_lens, device=device, dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device=device, + dtype=torch.int) + + qo_indptr = torch.tensor(self.qo_indptr, + device=device, + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_lens_tensor = None + block_table_bound_tensor = None + qo_indptr = None + + metadata.paged_kv_indptr = paged_kv_indptr_tensor + metadata.paged_kv_indices = paged_kv_indices_tensor + metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor + metadata.block_table_bound = block_table_bound_tensor + metadata.qo_indptr = qo_indptr + + return metadata + + +class AiterMLAState(MLACommonState[AiterMLAMetadata]): + + @contextmanager + def graph_capture(self, max_batch_size: int): + kv_indices, kv_indptr, last_page_lens, qo_indptr = \ + get_aiter_mla_metadata( + max_batch_size=max_batch_size, + block_size=self.runner.block_size, + max_block_per_batch=\ + self.runner.get_max_block_per_batch(), + device=self.runner.device) + self._paged_kv_indices_tensor = kv_indices + self._paged_kv_indptr_tensor = kv_indptr + self._paged_kv_last_page_lens_tensor = last_page_lens + self._qo_indptr_tensor = qo_indptr + + with super().graph_capture(max_batch_size): + yield + + del self._paged_kv_indices_tensor + del self._paged_kv_indptr_tensor + del self._paged_kv_last_page_lens_tensor + del self._qo_indptr_tensor + + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> AiterMLAMetadata: + + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + + paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1] + paged_kv_indices = self._paged_kv_indices_tensor + paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: + batch_size] + qo_indptr = self._qo_indptr_tensor[:batch_size + 1] + + metadata.paged_kv_indptr = paged_kv_indptr + metadata.paged_kv_indices = paged_kv_indices + metadata.paged_kv_last_page_lens = paged_kv_last_page_lens + metadata.qo_indptr = qo_indptr + + return metadata + + def get_graph_input_buffers(self, + attn_metadata: AiterMLAMetadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers[ + 'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr + input_buffers[ + "paged_kv_indices"] = attn_metadata.\ + decode_metadata.paged_kv_indices + input_buffers[ + "paged_kv_last_page_lens"] = attn_metadata.\ + decode_metadata.paged_kv_last_page_lens + input_buffers['qo_indptr'] = attn_metadata.qo_indptr + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata: AiterMLAMetadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[ + 0] + input_buffers["paged_kv_indptr"].copy_( + attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True) + input_buffers["paged_kv_indices"][:num_total_blocks].copy_( + attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True) + input_buffers["paged_kv_last_page_lens"].copy_( + attn_metadata.decode_metadata.paged_kv_last_page_lens, + non_blocking=True) + input_buffers["qo_indptr"].copy_( + attn_metadata.decode_metadata.qo_indptr, non_blocking=True) + + +class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "Aiter MLA does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func + + def _flash_attn_varlen_diff_headdims( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + softmax_scale: float, return_softmax_lse: bool, + **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: + output = self.flash_attn_varlen_func( + q, + k, + v, + **kwargs, + ) + + return output + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AiterMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.empty(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.qo_indptr, + attn_metadata.max_query_len, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_lens) + + return self._v_up_proj(o) diff --git a/attention/backends/rocm_flash_attn.py b/attention/backends/rocm_flash_attn.py new file mode 100644 index 0000000..4b460dc --- /dev/null +++ b/attention/backends/rocm_flash_attn.py @@ -0,0 +1,975 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer ROCm GPUs.""" +import itertools +from dataclasses import dataclass +from functools import cache +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.platforms.rocm import use_rocm_custom_paged_attention + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) +_PARTITION_SIZE_ROCM = 256 + + +@cache +def is_rocm_aiter_paged_attn_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \ + and envs.VLLM_ROCM_USE_AITER \ + + +@cache +def _get_paged_attn_module() -> PagedAttention: + """ + Initializes the appropriate PagedAttention module from `attention/ops`, + which is used as helper function + by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. + + The choice of attention module depends on whether + AITER paged attention is enabled: + - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. + - Otherwise, it defaults to using the original `PagedAttention`. + """ + if is_rocm_aiter_paged_attn_enabled(): + # Import AITERPagedAttention only when the flag is enabled + from vllm.attention.ops.rocm_aiter_paged_attn import ( + AITERPagedAttention) + return AITERPagedAttention() + return PagedAttention() + + +class ROCmFlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ROCM_FLASH" + + @staticmethod + def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: + return ROCmFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return ROCmFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: + return ROCmFlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + paged_attn = _get_paged_attn_module() + return paged_attn.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + paged_attn = _get_paged_attn_module() + paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + paged_attn = _get_paged_attn_module() + paged_attn.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.block_tables is not None + + self._cached_prefill_metadata = ROCmFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = ROCmFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + + assert not turn_prefills_into_decodes, \ + ("Chunked prefill is not supported with rocm_flash_attn yet." + "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " + "specific parameter.") + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class ROCmFlashAttentionMetadataBuilder( + CommonMetadataBuilder[ROCmFlashAttentionMetadata]): + + _metadata_cls = ROCmFlashAttentionMetadata + + +def _make_alibi_bias(alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: Optional[List[int]], + make_attn_mask: bool = True) -> List[torch.Tensor]: + attn_biases = [] + if seq_lens: + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat( + (num_heads, 1, 1)).to(alibi_slopes.device) + bias.mul_(alibi_slopes[:, None, None]) + if make_attn_mask: + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( + alibi_slopes.device) + attn_biases.append((bias + inf_mask).to(dtype)) + else: + attn_biases.append(bias.to(dtype)) + + return attn_biases + + +def _get_seq_len_block_table_args( + attn_metadata: ROCmFlashAttentionMetadata, + attn_type: str, +) -> tuple: + ''' + The particular choice of sequence-length + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths + Encoder attn -> select encoder sequence lengths fields + Encoder-only attn -> select prefill sequence lengths with + bidirectional attention + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention, encoder-only + + Returns: + + * Appropriate sequence-lengths tensors for query and key + * Appropriate max sequence-length scalar + * Causal masking flag + ''' + + if attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + causal_mask = False + + # No block tables associated with encoder attention + return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, + query_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_lens, causal_mask) + + elif attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only models, we use the prefill sequence lengths + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.seq_lens)), + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + # Encoder-only models typically use bidirectional attention + causal_mask = False + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + + elif attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.seq_lens)), + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + causal_mask = True + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.seq_lens)), + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + key_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + causal_mask = False + + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (query_start_loc, attn_metadata.max_prefill_seq_len, + key_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.seq_lens, causal_mask) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class ROCmFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if use_irope: + logger.warning_once( + "Using irope in ROCm Flash Attention is not supported yet, it " + "will fail back to global attention for long context.") + if blocksparse_params is not None: + raise ValueError( + "ROCmFlashAttention does not support blocksparse attention.") + if use_irope: + logger.warning( + "Using irope in V0 is not supported yet, it will fall back " + "to global attention for long context.") + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + self.logits_soft_cap = 0.0 + else: + self.logits_soft_cap = logits_soft_cap + self.attn_type = attn_type + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.paged_attn_module = _get_paged_attn_module() + supported_head_sizes = self.paged_attn_module.get_supported_head_sizes( + ) + + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + self.use_naive_attn = False + # NOTE: Allow for switching between Triton and CK. Defaulting to triton. + self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN + if self.use_triton_flash_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Triton FlashAttention does not support attention" + " logits soft capping." + " please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 + triton_attention) + self.triton_attn_func = triton_attention + logger.debug("Using Triton FA in ROCmBackend") + if self.sliding_window != (-1, -1): + logger.warning("ROCm Triton FA does not currently support " + "sliding window attention. If using half " + "precision, please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + else: + # if not using triton, navi3x/navi21/navi10 do not use flash-attn + # either + if not current_platform.has_device_capability(90): + self.use_naive_attn = True + else: + try: + from flash_attn import flash_attn_varlen_func # noqa: F401 + self.fa_attn_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") + except ModuleNotFoundError: + self.use_naive_attn = True + + if self.use_naive_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Naive FlashAttention does not support " + "attention logits soft capping.") + + self.sdpa_attn_func = _sdpa_attention + logger.debug("Using naive (SDPA) attention in ROCmBackend") + + self.aiter_kv_scales_initialized = False + + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return (x[:, :, + None, :].expand(tokens, n_kv_heads, n_rep, + head_dim).reshape(tokens, n_kv_heads * n_rep, + head_dim)) + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * ROCmFlashAttentionImpl.forward() may be invoked for both self- and + cross-attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + * ENCODER_ONLY: bidirectional attention with no KV caching; + use prefill sequence attributes + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + paged_attn = self.paged_attn_module + + # Reshaping kv tensors is required for AITER paged attention kernel + # because it works on a different tensor shape, + # when the size of one element is one byte (int8/fp8 dtypes). + # This reshaping is only required on the first forward call + # and the kv cache must not be empty. + if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1 + and not self.aiter_kv_scales_initialized + and kv_cache.shape != torch.Size([0])): + num_blocks = kv_cache.shape[1] + block_size = kv_cache.shape[2] // (self.num_kv_heads * + self.head_size) + k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + self.aiter_kv_scales_initialized = True + k_scale.fill_(layer._k_scale.item()) + v_scale.fill_(layer._v_scale.item()) + layer._k_scale = k_scale + layer._v_scale = v_scale + + # Only update KV cache for decoder self-attention + # and encoder-decoder cross-attention + if self.attn_type not in [ + AttentionType.ENCODER, AttentionType.ENCODER_ONLY + ] and kv_cache.numel() > 0: + key_cache, value_cache = paged_attn.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if key is not None and value is not None: + # Reshape the input keys and values and store them in the + # cache. If kv_cache is not provided, the new key and value + # tensors are not cached. This happens during the initial + # memory profiling run. + paged_attn.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping + if self.attn_type != AttentionType.ENCODER_DECODER else + attn_metadata.cross_slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.attn_type != AttentionType.ENCODER: + num_prefill_tokens = attn_metadata.num_prefill_tokens + elif self.attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only models, all tokens are processed in one go + num_prefill_tokens = query.shape[0] + else: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + + # For encoder-only and encoder models, + # we process all tokens at once + # For decoder and encoder-decoder, + # we may need to limit key/value to prefill tokens + if key is not None and value is not None \ + and self.attn_type not in [AttentionType.ENCODER_DECODER, + AttentionType.ENCODER_ONLY]: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + # normal attention and DECODER + if self.attn_type == AttentionType.DECODER and ( + kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = (prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + attn_metadata.seq_lens, True) + # prefix-enabled attention and ENCODER/ENCODER_DECODER + else: + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = _get_seq_len_block_table_args( + prefill_meta, self.attn_type) + # Prompt run. + if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: + # triton attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + attn_masks = None + if self.use_triton_flash_attn: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, + query.dtype, + seq_lens, + make_attn_mask=causal_mask) # type: ignore + use_fp8_scales = (layer._q_scale and layer._k_scale + and layer._v_scale and layer._prob_scale + and self.kv_cache_dtype == "fp8") + full_scales = ( + layer._q_scale.item(), layer._k_scale.item(), + layer._v_scale.item(), + layer._prob_scale.item()) if use_fp8_scales else None + self.triton_attn_func( + query, + key, + value, + output[:num_prefill_tokens], + query_seq_start_loc, + key_seq_start_loc, + query_max_seq_len, + key_max_seq_len, + causal_mask, + self.scale, + attn_masks[0][None] + if attn_masks is not None else None, + full_scales, + ) + elif self.use_naive_attn: + if self.num_kv_heads != self.num_heads: + # Interleave for MQA workaround. + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, + make_attn_mask=causal_mask) # type: ignore + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + # sdpa math backend attention + self.sdpa_attn_func( + query, + key, + value, + output[:num_prefill_tokens], + query_seq_start_loc, + num_prefill_tokens, + self.num_heads, + self.head_size, + self.scale, + attn_masks, + ) + else: + # upstream FA does not support an output arg, copy + output[:num_prefill_tokens] = self.fa_attn_func( + q=query, + k=key, + v=value, + cu_seqlens_q=query_seq_start_loc, + cu_seqlens_k=key_seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=key_max_seq_len, + softmax_scale=self.scale, + causal=causal_mask, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + + else: + # prefix-enabled attention - + # not applicable for encoder-only models + if self.attn_type != AttentionType.ENCODER_ONLY: + output[:num_prefill_tokens] = paged_attn.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], + layer._k_scale, + layer._v_scale, + ) + # Skip decode phase for encoder-only models + if (decode_meta := attn_metadata.decode_metadata) and ( + self.attn_type != AttentionType.ENCODER_ONLY): + # Decoding run. + # Whether to use rocm custom paged attention or not + num_seqs, num_heads, head_size = decode_query.shape + block_size = value_cache.shape[3] + gqa_ratio = num_heads // self.num_kv_heads + use_custom = use_rocm_custom_paged_attention( + decode_query.dtype, head_size, block_size, gqa_ratio, + decode_meta.max_decode_seq_len, self.sliding_window, + self.kv_cache_dtype, self.alibi_slopes) + if use_custom: + max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type + != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len) + assert max_seq_len is not None + max_num_partitions = ( + (max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + + query_start_loc = None + ops.paged_attention_rocm( + output[num_prefill_tokens:], + exp_sums, + max_logits, + tmp_output, + decode_query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + decode_meta.block_tables + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + query_start_loc, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + output[num_prefill_tokens:] = paged_attn.forward_decode( + decode_query, + key_cache, + value_cache, + decode_meta.block_tables + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + decode_meta.max_decode_seq_len + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _sdpa_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + seq_lens: torch.Tensor, + num_tokens: int, + num_heads: int, + head_size: int, + scale: float, + attn_masks: Optional[List[torch.Tensor]] = None, +) -> torch.Tensor: + start = 0 + assert output.shape == (num_tokens, num_heads, head_size) + assert output.dtype == query.dtype + assert output.device == query.device + + for i, seq_len in enumerate(seq_lens): + end = start + seq_len + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH): + sub_out = torch.nn.functional.scaled_dot_product_attention( + query[:, start:end, :], + key[:, start:end, :], + value[:, start:end, :], + dropout_p=0.0, + is_causal=attn_masks is None, + attn_mask=attn_masks[i] if attn_masks else None, + scale=scale).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + return output diff --git a/attention/backends/torch_sdpa.py b/attention/backends/torch_sdpa.py new file mode 100644 index 0000000..23231c3 --- /dev/null +++ b/attention/backends/torch_sdpa.py @@ -0,0 +1,703 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from torch.nn.functional import scaled_dot_product_attention + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +# yapf: enable +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex +from vllm.attention.ops.paged_attn import PagedAttentionMetadata +from vllm.logger import init_logger +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder + +logger = init_logger(__name__) + + +class TorchSDPABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TORCH_SDPA" + + @staticmethod + def get_impl_cls() -> Type["TorchSDPABackendImpl"]: + return TorchSDPABackendImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return TorchSDPAMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]: + return TorchSDPAMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for TorchSDPABackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + chunked_prefill: bool + seq_lens: Optional[List[int]] = None # For non-chunked prefill + + # For chunked prefill only + max_query_len: Optional[int] = None + max_kv_len: Optional[int] = None + prefill_query_start_loc: Optional[torch.Tensor] = None + kv_start_loc: Optional[torch.Tensor] = None + prefill_block_tables: Optional[torch.Tensor] = None + + # For V1 logits index only + query_start_loc: Optional[torch.Tensor] = None + + # Begin encoder attn & enc/dec cross-attn fields... + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + self.encoder_attn_bias: Optional[List[torch.Tensor]] = None + self.cross_attn_bias: Optional[List[torch.Tensor]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) + + @property + def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: + if self.num_prefill_tokens == 0: + return None + return self + + @property + def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: + if self.num_decode_tokens == 0: + return None + return self + + def get_seq_lens( + self, + attn_type: str, + ): + ''' + Extract appropriate sequence lengths from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate sequence lengths tensor for query + * Appropriate sequence lengths tensor for key & value + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + seq_lens_q = self.seq_lens + seq_lens_kv = self.seq_lens + elif attn_type == AttentionType.ENCODER: + seq_lens_q = self.encoder_seq_lens + seq_lens_kv = self.encoder_seq_lens + elif attn_type == AttentionType.ENCODER_DECODER: + seq_lens_q = self.seq_lens + seq_lens_kv = self.encoder_seq_lens + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + return seq_lens_q, seq_lens_kv + + def get_attn_bias( + self, + attn_type: str, + ) -> Optional[List[torch.Tensor]]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + return self.attn_bias + elif attn_type == AttentionType.ENCODER: + return self.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return self.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def set_attn_bias( + self, + attn_bias: List[torch.Tensor], + attn_type: str, + ) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + self.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + self.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + self.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def get_seq_len_block_table_args( + self, + attn_type: str, + ) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + return (self.seq_lens_tensor, self.max_decode_seq_len, + self.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + self.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_builder = input_builder + + def prepare(self): + self.input_data = self.input_builder.input_data + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # For chunked-prefill + if self.chunked_prefill and input_data.num_prefill_tokens != 0: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + else: + prefill_block_tables = None + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + + # For paged attention + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor( + input_data.seq_lens[:input_data.num_prefills], + dtype=torch.int32, + device="cpu", + ) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + attn_metadata = TorchSDPAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + prefill_query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + ) + + return attn_metadata + + +class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if blocksparse_params is not None: + raise ValueError( + "Torch SPDA does not support block-sparse attention.") + if logits_soft_cap is not None: + logger.warning_once("Torch SPDA does not support logits soft cap. " + "Outputs may be slightly off.") + if use_irope: + logger.warning_once( + "Using irope in Torch SPDA is not supported yet, it will fall" + " back to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex: + raise NotImplementedError( + "Torch SDPA backend FP8 KV cache requires " + "intel_extension_for_pytorch support.") + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TorchSDPAMetadata, # type: ignore + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with torch SDPA and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + + # For warming-up + if attn_metadata is None: + return query + + attn_type = self.attn_type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if (key is not None) and (value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + PagedAttention.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) + + if attn_type != AttentionType.ENCODER: + # Decoder self-attention supports chunked prefill. + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + + if attn_type == AttentionType.DECODER: + # Only enforce this shape-constraint for decoder + # self-attention + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + if prefill_meta := attn_metadata.prefill_metadata: + if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore + assert attn_metadata.seq_lens is not None + self._run_sdpa_forward(output, + query, + key, + value, + prefill_meta, + attn_type=attn_type) + else: + # prefix-enabled attention + assert not self.need_mask + import intel_extension_for_pytorch.llm.modules as ipex_modules + output = torch.empty_like(query) + ipex_modules.PagedAttention.flash_attn_varlen_func( + output[:prefill_meta.num_prefill_tokens, :, :], + query[:prefill_meta.num_prefill_tokens, :, :], + key_cache, + value_cache, + prefill_meta.prefill_query_start_loc, + prefill_meta.kv_start_loc, + prefill_meta.max_query_len, + prefill_meta.max_kv_len, + self.scale, + True, + prefill_meta.prefill_block_tables, + self.alibi_slopes, + ) + + if decode_meta := attn_metadata.decode_metadata: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have decode metadata.") + # Decoding run. + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = decode_meta.get_seq_len_block_table_args(attn_type) + + PagedAttention.forward_decode( + output[attn_metadata.num_prefill_tokens:, :, :], + query[attn_metadata.num_prefill_tokens:, :, :], + key_cache, + value_cache, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + def _run_sdpa_forward( + self, + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: TorchSDPAMetadata, + attn_type: str = AttentionType.DECODER, + ) -> None: + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + attn_masks = attn_metadata.get_attn_bias(attn_type) + if attn_masks is None: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + assert attn_metadata.seq_lens is not None + attn_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + seq_lens, _ = attn_metadata.get_seq_lens(attn_type) + attn_masks = [None] * len(seq_lens) + attn_metadata.set_attn_bias(attn_masks, attn_type) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + causal_attn = (attn_type == AttentionType.DECODER) + + seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) + start_q, start_kv = 0, 0 + for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, + attn_masks): + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + sub_out = scaled_dot_product_attention( + query[None, :, start_q:end_q, :], + key[None, :, start_kv:end_kv, :], + value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=causal_attn and mask is None, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start_q:end_q, :, :] = sub_out + start_q, start_kv = end_q, end_kv + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases: List[torch.Tensor] = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases: List[torch.Tensor] = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/attention/backends/triton_mla.py b/attention/backends/triton_mla.py new file mode 100644 index 0000000..5cf7618 --- /dev/null +++ b/attention/backends/triton_mla.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) +from vllm.attention.ops.triton_decode_attention import decode_attention_fwd + + +import json +import os + +# TODO: Configure environment variables temporarily. New versions do not need to be configured +os.environ['TRITON_ENABLE_MACA_OPT_MOVE_DOT_OPERANDS_OUT_LOOP'] = '1' +os.environ['TRITON_ENABLE_MACA_CHAIN_DOT_OPT'] = '1' + +def load_config(): + # Load JSON data from the file + json_path = config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", "tp8_merge.json") + with open(json_path, 'r') as file: + data = json.load(file) + return data + +JSON_DATA = load_config() + +def find_best_mla_para(json_data, batch_size, input_len, tp_size): + best_match = None + best_batch_size_diff = float('inf') + best_input_len_diff = float('inf') + + for entry in json_data: + if entry["BS"] == batch_size and entry["L"] == input_len: + return entry["num_kv_splits"], entry['num_stages'] + batch_size_diff = abs(entry["BS"] - batch_size) + input_len_diff = abs(entry["L"] - input_len) + + # Check if this is a better match than the current best match + if batch_size_diff < best_batch_size_diff or (batch_size_diff == best_batch_size_diff and input_len_diff < best_input_len_diff): + best_match = entry + best_batch_size_diff = batch_size_diff + best_input_len_diff = input_len_diff + + # If a match was found, return the best_kv_splits, otherwise return None + return best_match["num_kv_splits"],best_match["num_stages"] + + +class TritonMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_impl_cls() -> Type["TritonMLAImpl"]: + return TritonMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["TritonMLAMetadata"]: + return TritonMLAMetadata + +@dataclass +class TritonMLAMetadata(MLACommonMetadata): + num_kv_splits: int = 4 # TODO: heuristic + num_stages: int = 1 + + @property + def decode_metadata(self): + if self.num_decode_tokens == 0: + return None + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + + decode_metadata = super().decode_metadata + + if decode_metadata is not None: + if decode_metadata.seq_lens_tensor is not None: + batch = decode_metadata.seq_lens_tensor.shape[0] + max_seq_len = int(decode_metadata.seq_lens_tensor.max()) + num_kv_splits, num_stages = find_best_mla_para(JSON_DATA, batch, max_seq_len, 8) + else: + num_kv_splits = self.num_kv_splits + num_stages = self.num_stages + decode_metadata.num_kv_splits = num_kv_splits + decode_metadata.num_stages = num_stages + return decode_metadata + +class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "TritonMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "TritonMLA with FP8 KV cache not yet supported") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: TritonMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + # TODO(lucas) Allocate ahead of time + attn_logits = torch.empty( + ( + B, + self.num_heads, + decode_meta.num_kv_splits, + # NOTE(lucas) idk why the +1 is here but sglang has it so we + # just mirror that + self.kv_lora_rank + 1, + ), + dtype=torch.float32, + device=q.device, + ) + + # Add a head dim of 1 + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) + kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + PAGE_SIZE = kv_c_and_k_pe_cache.size(1) + + # Run MQA + decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, attn_logits, + decode_meta.num_kv_splits, decode_meta.num_stages, self.scale, PAGE_SIZE) + + return self._v_up_proj(o) diff --git a/attention/backends/utils.py b/attention/backends/utils.py new file mode 100644 index 0000000..e3f02a1 --- /dev/null +++ b/attention/backends/utils.py @@ -0,0 +1,610 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention backend utils""" +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import accumulate +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, + TypeVar, Union) + +import numpy as np +import torch + +from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, + AttentionState) +from vllm.attention.backends.abstract import AttentionType +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, make_tensor_with_pad + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.worker.model_runner_base import ModelRunnerBase + +# Error string(s) for encoder/decoder +# unsupported attention scenarios +STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " + "with encoder/decoder models.") + +PAD_SLOT_ID = -1 + +# Switch to numpy implementation of compute_slot_mapping +# if we have at least this many elements. Could be tuned further. +_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256 + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + + +def is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + return (isinstance(block_tables, dict) + and all(value is None for value in block_tables.values())) + + +def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, + context_len: int, sliding_window: int): + """ + Compute the start index of slot mapping. + """ + start_idx = 0 + if is_prompt and sliding_window is not None: + start_idx = max(0, query_len - sliding_window) + return start_idx + + +def _compute_slot_mapping_python(slot_mapping: List[int], + block_table: List[int], range_start: int, + range_end: int, block_size: int): + for i in range(range_start, range_end): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + + +def _compute_slot_mapping_numpy(slot_mapping: List[int], + block_table: List[int], range_start: int, + range_end: int, block_size: int): + block_table_array = np.array(block_table) + idx = np.arange(range_start, range_end) + block_offset = idx % block_size + idx //= block_size + seq_slot_mapping_array = block_table_array[idx] + seq_slot_mapping_array *= block_size + seq_slot_mapping_array += block_offset + slot_mapping.extend(seq_slot_mapping_array) + + +def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], + seq_id: int, seq_len: int, context_len: int, + start_idx: int, block_size: int, + block_tables: Dict[int, List[int]]): + """ + Compute slot mapping. + """ + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([PAD_SLOT_ID] * seq_len) + return + + # Mask the [0, start_idx) tokens of the prompt with + # PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + padding_mask_len = max(0, start_idx - context_len) + slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len) + + range_start = max(start_idx, context_len) + range_end = seq_len + numel = range_end - range_start + block_table = block_tables[seq_id] + + # numpy implementation will be faster than python if we have + # many elements, otherwise it will be slower. + if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL: + _compute_slot_mapping_python(slot_mapping, block_table, range_start, + range_end, block_size) + else: + _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, + range_end, block_size) + + +TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') + + +class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): + + _metadata_cls: Type[TAttentionMetadata] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if inter_data.prefix_cache_hit: + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return self._metadata_cls( # type: ignore + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class CommonAttentionState(AttentionState): + + def __init__(self, runner: "ModelRunnerBase"): + self.runner = runner + self._is_graph_capturing = False + + @contextmanager + def graph_capture(self, max_batch_size: int): + + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + + def graph_clone(self, batch_size: int) -> "CommonAttentionState": + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, + ) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in \ + ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ + f"Expected attn_backend name to be either 'XFORMERS'," \ + f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ + f"got '{self.runner.attn_backend.get_name()}'" + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) + + return attn_metadata + + def get_graph_input_buffers( + self, + attn_metadata, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in \ + ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ + f"Expected attn_backend name to be either 'XFORMERS'," \ + f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ + f"got '{self.runner.attn_backend.get_name()}'" + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) + return input_buffers + + def prepare_graph_input_buffers( + self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False) -> None: + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) + + def begin_forward(self, model_input) -> None: + return + + def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, + attn_metadata): + """ + Updates the attention metadata parameters for CUDA graph capture in an + encoder-decoder model. + + This method modifies attention-related tensors and metadata required + for CUDA graph capture in encoder-decoder models. Specifically, it + updates the cross-attention and encoder sequence tensors in the + AttentionMetadata object. + """ + # During decode phase the cross_slot_mapping will be empty. Hence set + # an empty tensor for CUDA Graph capture. + attn_metadata.cross_slot_mapping = torch.tensor( + [], dtype=torch.int).cuda() + attn_metadata.cross_block_tables = torch.full( + (batch_size, self.runner.get_max_block_per_batch()), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens = torch.full((batch_size, ), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens_tensor = torch.full( + (batch_size, ), 1, dtype=torch.int).cuda() + attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + attn_metadata.num_encoder_tokens = 0 + + def _add_additonal_input_buffers_for_enc_dec_model( + self, attn_metadata, input_buffers: Dict[str, Any]): + """ + Saves additional input buffers specific to the encoder-decoder model + from the attention metadata. + + This method extracts and stores encoder-decoder related input buffers + from the `attn_metadata` into the `input_buffers` dictionary. The + buffers include encoder sequence lengths, cross-slot mappings, and + cross-block tables, which are essential for the encoder-decoder model + during CUDA graph replay. + """ + input_buffers["encoder_seq_lens_tensor"] = ( + attn_metadata.decode_metadata.encoder_seq_lens_tensor) + input_buffers["cross_slot_mapping"] = ( + attn_metadata.decode_metadata.cross_slot_mapping) + input_buffers["cross_block_tables"] = ( + attn_metadata.decode_metadata.cross_block_tables) + + def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, + input_buffers: Dict[str, + Any]): + """ + Populates input buffers with data from the encoder-decoder model's + attention metadata. + + This method fills the input buffers with encoder-decoder specific + tensors. It copies data from the `attn_metadata` and keyword arguments + (`kwargs`) into corresponding buffers in the `input_buffers` dictionary. + The copied data includes attention-related metadata as well as input + IDs and positional information for the encoder. + """ + input_buffers["encoder_seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.encoder_seq_lens_tensor, + non_blocking=True) + input_buffers["cross_slot_mapping"].copy_( + attn_metadata.decode_metadata.cross_slot_mapping, + non_blocking=True) + input_buffers["cross_block_tables"].copy_( + attn_metadata.decode_metadata.cross_block_tables, + non_blocking=True) + + +def is_all_encoder_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((attn_metadata.encoder_seq_lens is not None) + and (attn_metadata.encoder_seq_lens_tensor is not None) + and (attn_metadata.max_encoder_seq_len is not None)) + + +def is_all_cross_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (attn_metadata.is_all_encoder_attn_metadata_set + and (attn_metadata.cross_slot_mapping is not None) + and (attn_metadata.cross_block_tables is not None)) + + +def get_seq_len_block_table_args( + attn_metadata, + is_prompt: bool, + attn_type: str, +) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def get_num_prefill_decode_query_kv_tokens( + attn_metadata, + attn_type: str, +) -> Tuple[int, int, int]: + """ + Calculate the number of prefill and decode tokens for query, key/value + based on the attention metadata and the specified attention type. + + Args: + attn_metadata (AttentionMetadata): Attention Metadata object. + attn_type (AttentionType): The type of attention being used. + Returns: + Tuple[int, int, int]: A tuple containing three integers: + - The number of prefill query tokens. + - The number of prefill key/value tokens. + - The number of decode query tokens. + + Raises: + AssertionError: If the number of encoder tokens in `attn_metadata` + is `None` when required for the calculations. + """ + num_prefill_query_tokens = 0 + num_decode_query_tokens = 0 + num_prefill_kv_tokens = 0 + if attn_type == AttentionType.ENCODER: + # Encoder attention is only invoked during prefill phase. + # The same input servers a both query and key. + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_encoder_tokens + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = 0 + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + # The key is the encoder/cross-attention. + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + else: # attn_type == AttentionType.DECODER or + # attn_type == AttentionType.ENCODER_ONLY + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + num_prefill_kv_tokens = attn_metadata.num_prefill_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + + return (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) + + +@dataclass +class MLADims: + q_lora_rank: Optional[int] + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + + +def get_mla_dims(model_config: ModelConfig) -> MLADims: + hf_text_config = model_config.hf_text_config + + return MLADims( + q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), + kv_lora_rank=hf_text_config.kv_lora_rank, + qk_nope_head_dim=hf_text_config.qk_nope_head_dim, + qk_rope_head_dim=hf_text_config.qk_rope_head_dim, + v_head_dim=hf_text_config.v_head_dim, + ) diff --git a/attention/backends/xformers.py b/attention/backends/xformers.py new file mode 100644 index 0000000..04ef928 --- /dev/null +++ b/attention/backends/xformers.py @@ -0,0 +1,802 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with xFormers and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import (AttentionBias, + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMaskWithTensorBias) + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import ( + CommonAttentionState, CommonMetadataBuilder, + get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class XFormersBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "XFORMERS" + + @staticmethod + def get_impl_cls() -> Type["XFormersImpl"]: + return XFormersImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return XFormersMetadata + + @staticmethod + def get_builder_cls() -> Type["XFormersMetadataBuilder"]: + return XFormersMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for XFormersbackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # FIXME: It is for flash attn. + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] = None + + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] = None + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + + # Self-attention prefill/decode metadata cache + _cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cached_decode_metadata: Optional["XFormersMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[AttentionBias]] = None + self.encoder_attn_bias: Optional[List[AttentionBias]] = None + self.cross_attn_bias: Optional[List[AttentionBias]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + + @property + def prefill_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure + return self._cached_prefill_metadata + + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + + # Construct & cache prefill-phase attention metadata structure + self._cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure + return self._cached_decode_metadata + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + + # Construct & cache decode-phase attention metadata structure + self._cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens_tensor=seq_lens_tensor, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] + return self._cached_decode_metadata + + +def _get_attn_bias( + attn_metadata: XFormersMetadata, + attn_type: str, +) -> Optional[AttentionBias]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + return attn_metadata.attn_bias + elif attn_type == AttentionType.ENCODER: + return attn_metadata.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return attn_metadata.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _set_attn_bias( + attn_metadata: XFormersMetadata, + attn_bias: List[Optional[AttentionBias]], + attn_type: str, +) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + attn_metadata.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + attn_metadata.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + attn_metadata.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): + + _metadata_cls = XFormersMetadata + + +class XFormersImpl(AttentionImpl[XFormersMetadata]): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if blocksparse_params is not None: + raise ValueError( + "XFormers does not support block-sparse attention.") + if logits_soft_cap is not None: + logger.warning_once("XFormers does not support logits soft cap. " + "Outputs may be slightly off.") + if use_irope: + logger.warning_once( + "Using irope in XFormers is not supported yet, it will fall" + " back to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: "XFormersMetadata", + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with xFormers and PagedAttention. + + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * XFormersImpl.forward() may be invoked for both self- and cross- + attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len). + Used for encoder branch of encoder-decoder models. + * ENCODER_ONLY: no kv_caching, uses the normal attention + attributes (seq_lens/seq_lens_tensor/max_seq_len). + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + Returns: + shape = [num_tokens, num_heads * head_size] + """ + attn_type = self.attn_type + # Check that appropriate attention metadata attributes are + # selected for the desired attention type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + # Self-attention vs. cross-attention will impact + # which KV cache memory-mapping & which + # seqlen datastructures we utilize + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if (key is not None) and (value is not None): + + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + PagedAttention.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_query_tokens:] + # QKV for prefill. + query = query[:num_prefill_query_tokens] + if key is not None and value is not None: + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: + # normal attention. + # block tables are empty if the prompt does not have a cached + # prefix. + out = self._run_memory_efficient_xformers_forward( + query, key, value, prefill_meta, attn_type=attn_type) + assert out.shape == output[:num_prefill_query_tokens].shape + output[:num_prefill_query_tokens] = out + else: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have prefix attention.") + + assert prefill_meta.query_start_loc is not None + assert prefill_meta.max_query_len is not None + + # prefix-enabled attention + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. + out = PagedAttention.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window, + layer._k_scale, + layer._v_scale, + ) + assert output[:num_prefill_query_tokens].shape == out.shape + output[:num_prefill_query_tokens] = out + + if decode_meta := attn_metadata.decode_metadata: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have decode metadata.") + + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + + output[num_prefill_query_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + def _run_memory_efficient_xformers_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: XFormersMetadata, + attn_type: str = AttentionType.DECODER, + ) -> torch.Tensor: + """Attention for 1D query of multiple prompts. Multiple prompt + tokens are flattened in to `query` input. + + See https://facebookresearch.github.io/xformers/components/ops.html + for API spec. + + Args: + output: shape = [num_prefill_tokens, num_heads, head_size] + query: shape = [num_prefill_tokens, num_heads, head_size] + key: shape = [num_prefill_tokens, num_kv_heads, head_size] + value: shape = [num_prefill_tokens, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + """ + + original_query = query + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + attn_bias = _get_attn_bias(attn_metadata, attn_type) + if attn_bias is None: + if self.alibi_slopes is None: + + # Cross attention block of decoder branch of encoder-decoder + # model uses seq_lens for dec / encoder_seq_lens for enc + if (attn_type == AttentionType.ENCODER_DECODER): + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens is not None + + # Cross-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens, + attn_metadata.encoder_seq_lens, + device=query.device) + + # Encoder branch of encoder-decoder model uses + # attn_metadata.encoder_seq_lens + elif attn_type == AttentionType.ENCODER: + + assert attn_metadata.encoder_seq_lens is not None + + # Encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.encoder_seq_lens, device=query.device) + + # Self-attention block of encoder-only model just + # uses the seq_lens directly. + elif attn_type == AttentionType.ENCODER_ONLY: + assert attn_metadata.seq_lens is not None + + # Encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens, device=query.device) + + # Self-attention block of decoder branch just + # uses the seq_lens directly + elif attn_type == AttentionType.DECODER: + assert attn_metadata.seq_lens is not None + + # Decoder self-attention mask is causal + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens, device=query.device) + else: + raise ValueError("Unknown AttentionType: %s", attn_type) + + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + attn_bias = [attn_bias] + else: + assert attn_type == AttentionType.DECODER + assert attn_metadata.seq_lens is not None + attn_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, query.dtype, + attn_metadata.seq_lens) + + _set_attn_bias(attn_metadata, attn_bias, attn_type) + + # No alibi slopes. + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale) + return out.view_as(original_query) + + # Attention with alibi slopes. + # FIXME(woosuk): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + assert attn_metadata.seq_lens is not None + output = torch.empty_like(original_query) + start = 0 + for i, seq_len in enumerate(attn_metadata.seq_lens): + end = start + seq_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=attn_bias[i], + p=0.0, + scale=self.scale) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out.view_as(original_query[start:end])) + start += seq_len + return output + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[AttentionBias]: + attn_biases: List[AttentionBias] = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) + + return attn_biases diff --git a/attention/layer.py b/attention/layer.py new file mode 100644 index 0000000..a5fbd1a --- /dev/null +++ b/attention/layer.py @@ -0,0 +1,468 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer.""" +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import vllm.envs as envs +from vllm.attention import AttentionType +from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.platforms import _Backend, current_platform +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.utils import validate_kv_sharing_target + + +class Attention(nn.Module): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + per_layer_sliding_window: Optional[int] = None, + use_mla: bool = False, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + **extra_impl_args, + ) -> None: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + """ + super().__init__() + if per_layer_sliding_window is not None: + # per-layer sliding window + sliding_window = per_layer_sliding_window + elif cache_config is not None: + # model-level sliding window + sliding_window = cache_config.sliding_window + else: + sliding_window = None + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + is_attention_free = cache_config.is_attention_free + calculate_kv_scales = cache_config.calculate_kv_scales + else: + kv_cache_dtype = "auto" + block_size = 16 + is_attention_free = False + calculate_kv_scales = False + if num_kv_heads is None: + num_kv_heads = num_heads + + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self.calculate_kv_scales = calculate_kv_scales + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) + # FlashAttn doesn't support quantizing the kv-cache only + # but requires q to be quantized as well. + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep the float32 versions of k/v_scale for attention + # backends that don't support tensors (Flashinfer) + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + self.use_mla = use_mla + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.sliding_window = sliding_window + + quant_method = quant_config.get_quant_method( + self, prefix=prefix) if quant_config else None + if quant_method is not None and not isinstance( + quant_method, UnquantizedLinearMethod): + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + blocksparse_params is not None, + use_mla=use_mla) + impl_cls = attn_backend.get_impl_cls() + self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **extra_impl_args) + self.backend = backend_name_to_enum(attn_backend.get_name()) + self.dtype = dtype + + # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how + # torch.compile works by registering the attention as one giant + # opaque custom op. For other platforms, we directly call them + # and let torch.compile handle them. + self.use_direct_call = not current_platform.is_cuda_alike( + ) and not current_platform.is_cpu() + + self.use_output = attn_backend.accept_output_buffer + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + self.attn_type = attn_type + + if kv_sharing_target_layer_name is not None: + if not envs.VLLM_USE_V1: + raise NotImplementedError( + "Cross-layer KV sharing is not supported in V0.") + + validate_kv_sharing_target( + prefix, + kv_sharing_target_layer_name, + compilation_config.static_forward_context, + ) + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + # use a placeholder kv cache tensor during init, which will be replaced + # by bind_kv_cache + # this variable will not be accessed if use_direct_call is True + self.kv_cache = [ + torch.tensor([]) for _ in range(get_current_vllm_config( + ).parallel_config.pipeline_parallel_size) + ] + + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + # For some alternate attention backends like MLA the attention output + # shape does not match the query shape, so we optionally let the model + # definition specify the output tensor shape. + output_shape: Optional[torch.Size] = None, + ) -> torch.Tensor: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + + Attention metadata (`attn_metadata`) is set using a context manager in + the model runner's `execute_model` method. It is accessed via forward + context using + `vllm.forward_context.get_forward_context().attn_metadata`. + """ + if self.calculate_kv_scales: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata.enable_kv_scales_calculation: + self.calc_kv_scales(query, key, value) + if self.use_output: + output_shape = (output_shape + if output_shape is not None else query.shape) + output = torch.empty(output_shape, + dtype=query.dtype, + device=query.device) + hidden_size = output_shape[-1] + # We skip reshaping query, key and value tensors for the MLA + # backend since these tensors have different semantics and are + # processed differently. + if not self.use_mla: + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) + if self.use_direct_call: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward(self, + query, + key, + value, + self_kv_cache, + attn_metadata, + output=output) + else: + torch.ops.vllm.unified_attention_with_output( + query, key, value, output, self.layer_name) + return output.view(-1, hidden_size) + else: + if self.use_direct_call: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + return self.impl.forward(self, query, key, value, + self_kv_cache, attn_metadata) + else: + return torch.ops.vllm.unified_attention( + query, key, value, self.layer_name) + + def calc_kv_scales(self, query, key, value): + self._q_scale.copy_(torch.abs(query).max() / self.q_range) + self._k_scale.copy_(torch.abs(key).max() / self.k_range) + self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._k_scale_float = self._k_scale.item() + self._v_scale_float = self._v_scale.item() + # We only calculate the scales once + self.calculate_kv_scales = False + + def extra_repr(self) -> str: + s = f"head_size={self.impl.head_size}" # type: ignore + s += f", num_heads={self.impl.num_heads}" # type: ignore + s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore + s += f", scale={self.impl.scale}" # type: ignore + s += f", backend={self.impl.__class__.__name__}" + return s + + def process_weights_after_loading(self, act_dtype: torch.dtype): + if hasattr(self.impl, "process_weights_after_loading"): + self.impl.process_weights_after_loading(act_dtype) + + +class MultiHeadAttention(nn.Module): + """Multi-headed attention without any cache, used for ViT.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + ): + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = scale + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype=None, + block_size=16, + is_attention_free=False) + backend = backend_name_to_enum(attn_backend.get_name()) + if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: + backend = _Backend.XFORMERS + + self.attn_backend = backend if backend in { + _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 + } else _Backend.TORCH_SDPA + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + """Input shape: batch_size x seq_len x hidden_size""" + # TODO(Isotr0py): Use existing backend implementations and support FA3 + bsz, q_len, _ = query.size() + kv_len = key.size(1) + + query = query.view(bsz, q_len, self.num_heads, self.head_size) + key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) + value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) + + if (num_repeat := self.num_queries_per_kv) > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_repeat, dim=2) + value = torch.repeat_interleave(value, num_repeat, dim=2) + + if self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + + out = xops.memory_efficient_attention_forward(query, + key, + value, + scale=self.scale) + elif self.attn_backend == _Backend.TORCH_SDPA: + query, key, value = (x.transpose(1, 2) + for x in (query, key, value)) + out = F.scaled_dot_product_attention(query, + key, + value, + scale=self.scale) + out = out.transpose(1, 2) + elif self.attn_backend == _Backend.PALLAS_VLLM_V1: + query, key, value = (x.transpose(1, 2) + for x in (query, key, value)) + from torch_xla.experimental.custom_kernel import flash_attention + out = flash_attention(query, key, value, sm_scale=self.scale) + out = out.transpose(1, 2) + + return out.reshape(bsz, q_len, -1) + + +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + assert isinstance(attn_metadata, dict) + connector.wait_for_layer_load(layer_name) + + +def maybe_save_kv_layer_to_connector( + layer_name: str, + kv_cache_layer: List[torch.Tensor], +): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + assert isinstance(attn_metadata, dict) + connector.save_kv_layer(layer_name, kv_cache_layer, + attn_metadata[layer_name]) + + +def unified_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + output = self.impl.forward(self, query, key, value, kv_cache, + attn_metadata) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output + + +def unified_attention_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + return torch.empty_like(query).contiguous() + + +direct_register_custom_op( + op_name="unified_attention", + op_func=unified_attention, + mutates_args=[], + fake_impl=unified_attention_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def unified_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward(self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + + +def unified_attention_with_output_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_attention_with_output", + op_func=unified_attention_with_output, + mutates_args=["output"], + fake_impl=unified_attention_with_output_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/attention/ops/__init__.py b/attention/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/attention/ops/blocksparse_attention/__init__.py b/attention/ops/blocksparse_attention/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py new file mode 100644 index 0000000..05fa9d1 --- /dev/null +++ b/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.triton_utils import tl, triton + + +def blocksparse_flash_attn_varlen_fwd( + q, + k, + v, # (#tokens, n_heads, head_size) + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + sparse_layout, + *, + block_size=64, + q_block_size=None, + max_seqlen=None): + # split q to blocks + + assert isinstance(sparse_layout, (list, tuple)) + + _, n_heads, head_size = q.shape + batch_size = cu_seqlens_k.size(0) - 1 + q_block_size = q_block_size or block_size + + assert q.dim() == k.dim() == v.dim() == 3 + assert q.size(1) % k.size(1) == 0 + assert q.size(2) == k.size(2) + # TODO(linxihui): allow k, v to have different head_size + assert k.shape == v.shape + assert cu_seqlens_k.dim() == 1 + + q_k_ratio = q.size(1) // k.size(1) + + if cu_seqlens_q is None: + if q.size(0) == batch_size: # decoding only + cu_seqlens_q = torch.arange( + 0, + batch_size + 1, + dtype=cu_seqlens_k.dtype, + device=cu_seqlens_k.device, + ) + elif q.size(0) == k.size(0): + cu_seqlens_q = cu_seqlens_k + else: + raise ValueError("cu_seqlens_q must be specified\ + if it mix of prefilling and decoding.") + else: + assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) + + # switch to use cpu to avoid too many kernel launches when iterated over + q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() + k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() + + assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), ( + "length of q should either be 1 (decoding) or same as k (prefilling).") + + if max_seqlen: + assert k_lens.max() <= max_seqlen + + n_blocks = (q_lens + q_block_size - 1) // q_block_size + + q_batch_ids = torch.tensor( + [i for i, n in enumerate(n_blocks) for _ in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + q_start_sids = torch.tensor( + [i * q_block_size for n in n_blocks for i in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + + out = q.new_empty(q.shape) + cu_seqlens_q = cu_seqlens_q.contiguous() + cu_seqlens_k = cu_seqlens_k.contiguous() + + layout_crow_indices, layout_col_indices = sparse_layout + block_d = triton.next_power_of_2(head_size) + + decoding_only = (q_lens == 1).all().item() + grid = (len(q_start_sids), n_heads, 1) + + _fwd_kernel_batch_inference[grid]( + q, + k, + v, + out, + sm_scale, + cu_seqlens_q[:-1], + cu_seqlens_q[1:], + cu_seqlens_k[:-1], + cu_seqlens_k[1:], + q_batch_ids, + q_start_sids, + 0, + *q.stride(), + 0, + *k.stride(), + 0, + *v.stride(), + 0, + *out.stride(), + layout_crow_indices, + layout_col_indices, + *layout_crow_indices.stride(), + *layout_col_indices.stride(), + q_k_ratio, + HAS_BATCH_DIM=False, + D_HEAD=head_size, + BLOCK_M=q_block_size, + BLOCK_N=block_size, + BLOCK_D=block_d, + BLOCK_M_LOADING=(16 if decoding_only else + q_block_size), # smaller for decoding + EVEN_D=block_d == head_size, + num_warps=1 if decoding_only else 4, + num_stages=3) + + return out + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + BLOCK_N: tl.constexpr, + D_HEAD: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + + k_block_col_idx * layout_col_stride_m).to(tl.int32) + start_n = k_block_id * BLOCK_N + if LAST_K_BLOCK: + if EVEN_D: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=offs_n[None, :] + start_n < k_seqlen, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=(offs_n[None, :] + start_n < k_seqlen) & + (offs_d[:, None] < D_HEAD), + other=0.0, + ) + else: + if EVEN_D: + k = tl.load(k_ptrs + start_n * stride_kt) + else: + k = tl.load(k_ptrs + start_n * stride_kt, + mask=offs_d[:, None] < D_HEAD, + other=0.0) + + qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK | M_LT_N: + qk += tl.where( + offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), + 0, + float("-inf"), + ) + + # flash-attn2 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + # update m_i + m_i = m_ij + l_i = l_i * alpha + l_ij + + p = p.to(Q.dtype.element_ty) + # update acc + if LAST_K_BLOCK: + if EVEN_D: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=offs_n[:, None] + start_n < k_seqlen, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=(offs_n[:, None] + start_n < k_seqlen) & + (offs_d[None, :] < D_HEAD), + other=0.0, + ) + else: + if EVEN_D: + v = tl.load(v_ptrs + start_n * stride_vt) + else: + v = tl.load(v_ptrs + start_n * stride_vt, + mask=offs_d[None, :] < D_HEAD, + other=0.0) + + acc += tl.dot(p, v) + + return acc, l_i, m_i + + +@triton.heuristics({ + "M_LT_N": + lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"], +}) +@triton.jit +def _fwd_kernel_batch_inference( + Q, + K, + V, + Out, + sm_scale, + q_batch_starts, + q_batch_ends, + k_batch_starts, + k_batch_ends, + q_batch_ids, + q_start_sids, + stride_qb, + stride_qt, + stride_qh, + stride_qd, + stride_kb, + stride_kt, + stride_kh, + stride_kd, + stride_vb, + stride_vt, + stride_vh, + stride_vd, + stride_ob, + stride_ot, + stride_oh, + stride_od, + layout_crow_ptr, + layout_col_ptr, + layout_crow_stride_h, + layout_crow_stride_m, + layout_col_stride_h, + layout_col_stride_m, + q_k_ratio, + HAS_BATCH_DIM: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + """ + NOTATION: + pid: position id + sid: storage id + sbid: storage block id + pbid: position block id + offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) + + TODO(linxihui): + Optimize grouped-attn + """ + off_zm = tl.program_id(0) + off_h = tl.program_id(1) + + off_h_for_kv = off_h // q_k_ratio + + if HAS_BATCH_DIM: + off_z = tl.program_id(2) + Q += off_z * stride_qb + K += off_z * stride_kb + V += off_z * stride_vb + Out += off_z * stride_ob + start_m = off_zm + q_start_sid = start_m * BLOCK_M # always 0 for decoding + else: + off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1] + q_start_sid = tl.load(q_start_sids + off_zm) + start_m = q_start_sid // BLOCK_M # q_sbid + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) + q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start + k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) + k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start + past_len = k_seqlen - q_seqlen + + Q += q_cu_start * stride_qt + off_h * stride_qh + K += k_cu_start * stride_kt + off_h_for_kv * stride_kh + V += k_cu_start * stride_vt + off_h_for_kv * stride_vh + Out += q_cu_start * stride_ot + off_h * stride_oh + + q_pbid = (past_len + q_start_sid) // BLOCK_M + + if EVEN_D: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=offs_m[:, None] < q_seqlen, + other=0.0, + ) + else: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + other=0.0, + ) + + sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + + q_pbid * layout_crow_stride_m) + + # TODO(linxihui): load at once, with any Triton version + # that supports `tl.split`, e.g., Triton 3.0 + k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) + k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) + + m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) + acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) + + k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd + + sm_scale *= ( + 1.44269504 # 1/log2 as we use base2 for exponential and logarithm + ) + + for k_block_col_idx in range(k_block_start, k_block_end - 1): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + False, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_end - 1, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + True, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + # flash-attn 2 + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + + # write output + if EVEN_D: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=offs_m[:, None] < q_seqlen, + ) + else: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + ) diff --git a/attention/ops/blocksparse_attention/interface.py b/attention/ops/blocksparse_attention/interface.py new file mode 100644 index 0000000..c6f6cc2 --- /dev/null +++ b/attention/ops/blocksparse_attention/interface.py @@ -0,0 +1,239 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math + +import torch + +from vllm.platforms import current_platform + +from .utils import (dense_to_crow_col, get_head_sliding_step, + get_sparse_attn_mask) + +IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) + +if IS_COMPUTE_8_OR_ABOVE: + from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd + + +class LocalStridedBlockSparseAttn(torch.nn.Module): + + def __init__( + self, + n_heads, + max_seqlen, + local_blocks, + vert_stride, + block_size, + device=None, + dtype=None, + homo_head=False, + active_head_range=None, + q_block_size=None, + use_spda=None, + ): + super().__init__() + if use_spda is None: + use_spda = current_platform.is_rocm() or \ + current_platform.is_cpu() or not \ + IS_COMPUTE_8_OR_ABOVE + device = device or (torch.cuda.current_device() + if current_platform.is_cuda_alike() else "cpu") + device = torch.device(device) + # NOTE: vllm CPU backend support BF16 instead of FP16. + dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE + or device.type == "cpu" else torch.half) + + self.n_heads = n_heads + self.max_seqlen = max_seqlen + self.local_blocks = local_blocks + self.vert_stride = vert_stride + self.use_spda = use_spda + self.dtype = dtype + self.device = device + self.block_size = block_size + self.q_block_size = q_block_size + self.homo_head = homo_head + self.active_head_range = active_head_range + self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride, + homo_head) + + sparse_layout, sparse_pattern, self.dense_attn_mask = ( + self.get_attn_pattern(dtype, device)) + + if q_block_size is not None and q_block_size != block_size: + if q_block_size > block_size: + assert q_block_size % block_size == 0 + blocks_to_merge = q_block_size // block_size + shape = sparse_pattern.shape + sparse_pattern = sparse_pattern.view(shape[0], -1, + blocks_to_merge, + shape[-1]) + sparse_pattern = sparse_pattern.sum(2) + sparse_layout = dense_to_crow_col(sparse_pattern) + else: + raise ValueError( + "Does not support smaller q_block_size. It will be slower." + ) + + self.sparse_layout = sparse_layout + + def get_attn_pattern(self, dtype, device): + sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask( + self.n_heads, + self.max_seqlen, + self.max_seqlen, + dtype, + device, + block_size=self.block_size, + local_blocks=self.local_blocks, + vert_stride=self.vert_stride, + homo_head=self.homo_head, + return_dense=self.use_spda, + dense_mask_type="bias", + ) + if (not self.homo_head) and (self.active_head_range is not None): + assert isinstance(self.active_head_range, tuple) + assert (len(self.active_head_range) == 2) + h_start, h_end = self.active_head_range + sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout) + if self.use_spda: + dense_attn_mask = dense_attn_mask[h_start:h_end] + return sparse_layout, sparse_pattern, dense_attn_mask + + def varlen_attn(self, + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=None, + sm_scale=None): + """ + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), + indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify is when q is a mix of + prefilling and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert ( + IS_COMPUTE_8_OR_ABOVE + ), "Requires compute capability of 8 or above (Ampere or newer) to use \ + Triton kernel." + + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + + return blocksparse_flash_attn_varlen_fwd( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + self.sparse_layout, + block_size=self.block_size, + q_block_size=self.q_block_size, + max_seqlen=self.max_seqlen, + ) + + @staticmethod + def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1): + """ + :param x: (total_tokens, n_heads, head_size) + :return: (batch, n_heads, length, head_size) + """ + x_padded = x.new_empty( + len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2)) + cu_seqlens = cu_seqlens.cpu() + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0, + 1).unsqueeze(1)) + return x_padded.flatten(1, 2) + + @staticmethod + def transpose_and_unpad(x_padded, cu_seqlens): + """ + :param x_padded: (batch, n_heads, length, head_size) + :return: (total_tokens, n_heads, head_size) + """ + cu_seqlens = cu_seqlens.cpu() + total_n_tokens = cu_seqlens[-1] + x = x_padded.new_empty(total_n_tokens, x_padded.size(1), + x_padded.size(3)) + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1)) + return x + + def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """For CPU, V100 or other older GPUs. + NOTE: torch SPDA supports nested tensor, + but seems extremely slow. Choose to pad instead. + """ + assert (cu_seqlens_q is None or + (cu_seqlens_q + == cu_seqlens_k).all()), "Can only handle prompt with SPDA." + assert q.size(0) == k.size(0), "can only handle prompt with SPDA." + + assert q.size(1) % k.size(1) == 0 + q_k_ratio = q.size(1) // k.size(1) + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + cu_seqlens = cu_seqlens_k.cpu() + maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + if (self.dense_attn_mask.dtype != q.dtype + or self.dense_attn_mask.device != q.device): + _, _, self.dense_attn_mask = self.get_attn_pattern( + q.dtype, q.device) + attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen] + + q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1) + k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio) + for x in [k, v]) + spda_output = torch.nn.functional.scaled_dot_product_attention( + q2, k2, v2, attn_mask=attn_mask, scale=sm_scale) + return self.transpose_and_unpad(spda_output, cu_seqlens) + + def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """Dispatch to `varlen_attn` (Ampere or newer) or + `self.spda`(cpu, Volta, Turing or older)based on + the type of device used and cuda compute capability. + + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify + is when q is a mix of prefilling + and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert k.dim() == 3 + if self.use_spda: + return self.spda( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale, + ) + return self.varlen_attn(q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale) diff --git a/attention/ops/blocksparse_attention/utils.py b/attention/ops/blocksparse_attention/utils.py new file mode 100644 index 0000000..445720c --- /dev/null +++ b/attention/ops/blocksparse_attention/utils.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Helper functions for 3D sparse pattern +# These function are not optimized and very inefficient. +# Avoid calling them too frequent or use a cache mechanism. + +from functools import lru_cache + +import numpy as np +import torch + +from vllm.triton_utils import triton + + +class csr_matrix: + """Simple implementation of CSR matrix conversion without scipy. + This replaced scipy.sparse.csr_matrix() previously used.""" + + def __init__(self, input_array): + if not isinstance(input_array, np.ndarray): + raise ValueError("Input must be a NumPy array") + + self.shape = input_array.shape + rows, cols = self.shape + data = [] + indices = [] + indptr = [0] + + for i in range(rows): + for j in range(cols): + if input_array[i, j]: + data.append(input_array[i, j]) + indices.append(j) + indptr.append(len(indices)) + + self.data = np.array(data) + self.indices = np.array(indices) + self.indptr = np.array(indptr) + + +def dense_to_crow_col(x: torch.Tensor): + """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. + NOTE: col_indices padded -1 + """ + device = x.device + pad = -1 + dim = x.dim() + assert x.dim() in (2, 3) + if x.dim() == 2: + x = x[None] + x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x] + crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x]) + cols = [torch.from_numpy(xi.indices) for xi in x] + max_cols = max(len(xi) for xi in cols) + cols = [ + torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) + for xi in cols + ] + cols = torch.vstack(cols) + if dim == 2: + crows = crows[0] + cols = cols[0] + return crows.to(device), cols.to(device) + + +def crow_col_to_dense(crows: torch.Tensor, + cols: torch.Tensor, + dtype: torch.dtype = torch.float16): + dim = crows.dim() + if dim == 1: + crows = crows[None] + cols = cols[None] + device = crows.device + crows, cols = crows.cpu(), cols.cpu() # faster in cpu + shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) + x = torch.zeros(shape, dtype=dtype) + for i in range(shape[0]): + for j in range(shape[1]): + x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1 + if dim == 1: + x = x[0] + return x.to(device) + + +def dense_to_ccol_row(x: torch.Tensor): + """Similar, but to CSC format""" + x = x.transpose(-2, -1) + return dense_to_crow_col(x) + + +def ccol_row_to_dense(ccol: torch.Tensor, + rows: torch.Tensor, + dtype: torch.dtype = torch.float16): + return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() + + +def _get_sparse_attn_mask_homo_head( + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 128, + local_blocks: int = 4, + vert_stride: int = 4, + return_dense: bool = False, +): + """ + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be + OOM if it is too big) if `return_dense==True`, + otherwise, None + """ + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[:, None] + k_pos = torch.arange(num_blocks)[None] + mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0 + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = (dense_to_crow_col( + block_mask_dense[-num_blocks_q:].contiguous())) + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask + return ( + block_mask_dense_output, + block_mask_dense, + mask_dense, + ) + else: + return ( + block_mask_dense_output, + block_mask_dense, + None, + ) + + +def binary_mask_to_bias(mask_dense: torch.Tensor): + mask_dense = 1 - mask_dense + mask_dense.masked_fill_(mask_dense.bool(), -torch.inf) + return mask_dense + + +def get_head_sliding_step(n_heads: int, + vert_stride: int, + homo_head: bool = False): + if homo_head: + return 0 + return max(1, int(vert_stride / n_heads)) + + +@lru_cache +def get_sparse_attn_mask( + n_heads: int, + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 64, + local_blocks: int = 4, + vert_stride: int = 4, + homo_head: bool = True, + return_dense: bool = False, + dense_mask_type: str = "binary", +): + """ + :param dense_mask_type: "binary" (0 for skip token, 1 for others) + or "bias" (-inf for skip token, 0 or others) + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be OOM if it + is too big) if `return_dense==True`, otherwise, None + """ + assert dense_mask_type in ("binary", "bias") + if homo_head: + with torch.no_grad(): + (crow, col), block_mask_dense, mask_dense = ( + _get_sparse_attn_mask_homo_head( + q_len, + max_seqlen, + dtype, + device, + block_size, + local_blocks, + vert_stride, + return_dense, + )) + crow = crow[None].expand(n_heads, crow.shape[0]) + col = col[None].expand(n_heads, col.shape[0]) + if return_dense: + mask_dense = mask_dense[None].expand(n_heads, + *mask_dense.shape) + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + return (crow, col), block_mask_dense, mask_dense + + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[None, :, None] + k_pos = torch.arange(num_blocks)[None, None] + head_sliding_step = get_head_sliding_step(n_heads, vert_stride) + mask_vert_strided = [ + (torch.arange(num_blocks) + h * head_sliding_step + 1) % + vert_stride == 0 for h in range(n_heads) + ] + mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = block_mask_dense[:, -num_blocks_q:] + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None] + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + mask_dense, + ) + else: + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + None, + ) diff --git a/attention/ops/chunked_prefill_paged_decode.py b/attention/ops/chunked_prefill_paged_decode.py new file mode 100644 index 0000000..4f83934 --- /dev/null +++ b/attention/ops/chunked_prefill_paged_decode.py @@ -0,0 +1,368 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.platforms.rocm import use_rocm_custom_paged_attention +from vllm.triton_utils import tl, triton + +from .prefix_prefill import context_attention_fwd + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def kernel_paged_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + num_queries_per_kv_padded: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + x: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.int64, # int + stride_k_cache_4: tl.int64, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.int64, # int + filter_by_query_len: tl.constexpr, # bool + query_start_len_ptr, # [num_seqs+1] +): + seq_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + if filter_by_query_len: + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + + 1) + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + if cur_batch_query_len > 1: + return + else: + cur_batch_in_all_start_index = seq_idx + + query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange( + 0, num_queries_per_kv_padded) + + query_offset = (cur_batch_in_all_start_index * query_stride_0 + + query_head_idx[:, None] * query_stride_1) + + head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv + head_mask = head_mask & (query_head_idx < num_query_heads) + + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, + 0).to(tl.int1) + + # Q : (num_queries_per_kv, HEAD_SIZE,) + Q = tl.load( + query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :], + mask=dim_mask[None, :] & head_mask[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) + acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], + dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx, + mask=head_mask, + other=0.0) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles + for j in range(0, num_blocks): + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_d[None, :] * stride_v_cache_2 + + offs_n[:, None] * stride_v_cache_3) + + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_1 + + (offs_d[:, None] // x) * stride_k_cache_2 + + offs_n[None, :] * stride_k_cache_3 + + (offs_d[:, None] % x) * stride_k_cache_4) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) + + if K_load.dtype.is_fp8(): + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0) + + if V_load.dtype.is_fp8(): + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32) + seq_mask = seq_offset[None, :] < boundary + + # S : (num_queries_per_kv, BLOCK_SIZE,) + S = tl.where(head_mask[:, None] & seq_mask, 0.0, + float("-inf")).to(tl.float32) + S += scale * tl.dot(Q, K) + + context_len = seq_len - 1 + + if SLIDING_WINDOW > 0: + S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, + -10000) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (num_queries_per_kv,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + + # P : (num_queries_per_kv, BLOCK_SIZE,) + P = tl.exp(S - m_j[:, None]) + + # l_j : (num_queries_per_kv,) + l_j = tl.sum(P, axis=1) + + # alpha : (num_queries_per_kv, ) + alpha = tl.exp(M - m_j) + + # acc : (num_queries_per_kv, BLOCK_SIZE,) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (num_queries_per_kv, BLOCK_SIZE,) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + + output_offset = (cur_batch_in_all_start_index * output_stride_0 + + query_head_idx * output_stride_1) + + tl.store( + output_ptr + output_offset[:, None] + + tl.arange(0, HEAD_SIZE_PADDED)[None, :], + acc, + mask=dim_mask[None, :] & head_mask[:, None], + ) + + +def chunked_prefill_paged_decode( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_table, + query_start_loc, + seq_lens, + max_seq_len, + max_query_len, + k_scale, + v_scale, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, +): + + if sm_scale is None: + sm_scale = 1.0 / (query.shape[1]**0.5) + + use_alibi_slopes = alibi_slopes is not None + + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if max_query_len > 1: + context_attention_fwd( + q=query, + k=key, + v=value, + o=output, + kv_cache_dtype=kv_cache_dtype, + k_cache=key_cache, + v_cache=value_cache, + b_loc=block_table, + b_start_loc=query_start_loc, + b_seq_len=seq_lens, + max_seq_len=max_seq_len, + max_input_len=max_query_len, + k_scale=k_scale, + v_scale=v_scale, + alibi_slopes=alibi_slopes, + sliding_window=sliding_window, + sm_scale=sm_scale, + skip_decode=True, + ) + + block_size = value_cache.shape[3] + num_seqs = len(seq_lens) + num_query_heads = query.shape[1] + num_kv_heads = key.shape[1] + num_queries_per_kv = query.shape[1] // key.shape[1] + head_size = query.shape[2] + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = current_platform.fp8_dtype() + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + key_cache = key_cache.view(target_dtype) + value_cache = value_cache.view(target_dtype) + + num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), + 16) + + use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, + block_size, + num_queries_per_kv, + max_seq_len, sliding_window, + kv_cache_dtype, alibi_slopes) + if use_custom: + _PARTITION_SIZE_ROCM = 256 + max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 + total_num_seq = block_table.shape[0] + tmp_output = torch.empty( + size=(total_num_seq, num_query_heads, max_num_partitions, + head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(total_num_seq, num_query_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale=sm_scale, + block_tables=block_table, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + block_size=block_size, + max_seq_len=max_seq_len, + alibi_slopes=alibi_slopes, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, + ) + else: + kernel_paged_attention_2d[( + num_seqs, + num_kv_heads, + )]( + output_ptr=output, + query_ptr=query, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + block_tables_ptr=block_table, + seq_lens_ptr=seq_lens, + alibi_slopes_ptr=alibi_slopes, + scale=sm_scale, + k_scale=k_scale, + v_scale=v_scale, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + num_queries_per_kv_padded=num_queries_per_kv_padded, + block_table_stride=block_table.stride(0), + query_stride_0=query.stride(0), + query_stride_1=query.stride(1), + output_stride_0=output.stride(0), + output_stride_1=output.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + SLIDING_WINDOW=sliding_window, + x=key_cache.shape[4], + stride_k_cache_0=key_cache.stride(0), + stride_k_cache_1=key_cache.stride(1), + stride_k_cache_2=key_cache.stride(2), + stride_k_cache_3=key_cache.stride(3), + stride_k_cache_4=key_cache.stride(4), + stride_v_cache_0=value_cache.stride(0), + stride_v_cache_1=value_cache.stride(1), + stride_v_cache_2=value_cache.stride(2), + stride_v_cache_3=value_cache.stride(3), + filter_by_query_len=True, + query_start_len_ptr=query_start_loc, + ) diff --git a/attention/ops/flashmla.py b/attention/ops/flashmla.py new file mode 100644 index 0000000..0eaf50b --- /dev/null +++ b/attention/ops/flashmla.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py +from typing import Optional, Tuple + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +# if current_platform.is_cuda(): +# try: +# import vllm._flashmla_C # noqa: F401 +# _flashmla_C_AVAILABLE = True +# except ImportError: +# _flashmla_C_AVAILABLE = False +# else: +# _flashmla_C_AVAILABLE = False +try : + import flash_mla + _flashmla_AVAILABLE = True +except ImportError as e: + logger.warning("Failed to import from flash_mla with %r on MACA Platform", e) + _flashmla_AVAILABLE = False + + +def is_flashmla_supported() -> Tuple[bool, Optional[str]]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + # if not current_platform.is_cuda(): + # return False, "FlashMLA is only supported on CUDA devices." + # if current_platform.get_device_capability()[0] != 9: + # return False, "FlashMLA is only supported on Hopper devices." + # if not _flashmla_C_AVAILABLE: + # return False, "vllm._flashmla_C is not available, likely was not "\ + # "compiled due to insufficient nvcc version or a supported arch "\ + # "(only sm90a currently) was not in the list of target arches to "\ + # "compile for." + if not _flashmla_AVAILABLE: + return False, "flash_mla is not available" + return True, None + + +def get_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + # return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, + # num_heads_per_head_k, + # num_heads_k) + return flash_mla.flash_mla_interface.get_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + # if softmax_scale is None: + # softmax_scale = q.shape[-1]**(-0.5) + # out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( + # q, + # k_cache, + # None, + # head_dim_v, + # cache_seqlens, + # block_table, + # softmax_scale, + # causal, + # tile_scheduler_metadata, + # num_splits, + # ) + out, softmax_lse = flash_mla.flash_mla_interface.flash_mla_with_kvcache( + q, + k_cache, + block_table, + cache_seqlens, + head_dim_v, + tile_scheduler_metadata, + num_splits, + softmax_scale, + causal, + ) + return out, softmax_lse + + +# +# TODO: Add fake functions +# +# @register_fake("_flashmla_C::get_mla_metadata") +# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# +# @register_fake("_flashmla_C::fwd_kvcache_mla") +# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# diff --git a/attention/ops/hpu_paged_attn.py b/attention/ops/hpu_paged_attn.py new file mode 100644 index 0000000..412dd20 --- /dev/null +++ b/attention/ops/hpu_paged_attn.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +from vllm_hpu_extension import cache_ops, ops + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +@dataclass +class HPUPagedAttentionMetadata: + """Metadata for PagedAttention.""" + block_list: Optional[torch.Tensor] + block_mapping: Optional[torch.Tensor] + block_usage: Optional[torch.Tensor] + block_indices: Optional[torch.Tensor] + block_offsets: Optional[torch.Tensor] + block_groups: Optional[torch.Tensor] + + +class HPUPagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 128, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, kv_cache_dtype: str, + is_prompt: bool) -> None: + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, is_prompt) + + @staticmethod + def forward_decode(**kwargs) -> torch.Tensor: + return ops.flat_pa(**kwargs) + + @staticmethod + def swap_blocks( + src_kv_cache: Tuple[torch.Tensor, torch.Tensor], + dst_kv_cache: Tuple[torch.Tensor, torch.Tensor], + src_to_dsts: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts) + + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dsts: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts) diff --git a/attention/ops/ipex_attn.py b/attention/ops/ipex_attn.py new file mode 100644 index 0000000..b7e4ba4 --- /dev/null +++ b/attention/ops/ipex_attn.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Dict, List, Optional, Tuple + +try: + import intel_extension_for_pytorch.llm.modules as ipex_modules + _use_ipex = True +# AttributeError is to handle a bug in ipex https://github.com/intel/intel-extension-for-pytorch/pull/813 +except (ImportError, AttributeError): + _use_ipex = False + +import torch + +from vllm import _custom_ops as ops + + +class _PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 80, 96, 112, 128, 192, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + tp_rank: int = 0 + blocksparse_local_blocks: int = 0 + blocksparse_vert_stride: int = 0 + blocksparse_block_size: int = 64 + blocksparse_head_sliding_step: int = 0 + block_size = value_cache.shape[3] + + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + *args, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +class _IPEXPagedAttention(_PagedAttention): + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + ipex_modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, + slot_mapping.flatten().int()) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + block_size = value_cache.shape[2] + head_mapping = torch.arange( + 0, + num_kv_heads, + device="cpu", + dtype=torch.int32, + ).view(num_kv_heads, + 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + ipex_modules.PagedAttention.single_query_cached_kv_attention( + output, query.contiguous(), key_cache, value_cache, head_mapping, + scale, block_tables, context_lens, block_size, max_context_len, + alibi_slopes) + + +PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention diff --git a/attention/ops/merge_attn_states.py b/attention/ops/merge_attn_states.py new file mode 100644 index 0000000..5cb1a47 --- /dev/null +++ b/attention/ops/merge_attn_states.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.platforms import current_platform + + +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + + # NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel + # is not support for FP8 dtype, fallback to use Triton kernel. + def supported_dtypes(o: torch.Tensor) -> bool: + return o.dtype in [torch.float32, torch.half, torch.bfloat16] + + # NOTE(DefTruth): Currently, custom merge_attn_states CUDA + # kernel load/store 128b(16 bytes) per memory issue within + # thread. Namely, the headsize(headdim) must be multiple of + # pack_size (float32 -> 4, half/bfloat16 -> 8). + def supported_headdim(o: torch.Tensor) -> bool: + headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + if o.dtype == torch.float32: + return headdim % 4 == 0 + return headdim % 8 == 0 + + if (current_platform.is_cuda() and supported_dtypes(output) + and supported_headdim(output)): + from vllm._custom_ops import merge_attn_states + return merge_attn_states(output, prefix_output, prefix_lse, + suffix_output, suffix_lse, output_lse) + else: + from vllm.attention.ops.triton_merge_attn_states import ( + merge_attn_states) + return merge_attn_states(output, prefix_output, prefix_lse, + suffix_output, suffix_lse, output_lse) diff --git a/attention/ops/nki_flash_attn.py b/attention/ops/nki_flash_attn.py new file mode 100644 index 0000000..e28ff7e --- /dev/null +++ b/attention/ops/nki_flash_attn.py @@ -0,0 +1,906 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import neuronxcc.nki.isa as nisa +import neuronxcc.nki.language as nl +import numpy as np +import torch +from neuronxcc import nki +from neuronxcc.nki.language import par_dim + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def is_power_of_2(x): + return x > 0 and (x & (x - 1)) == 0 + + +@nki.jit +def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile): + """ + Load block tables from HBM into SRAM + + `block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`. + In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension. + """ + B_P_SIZE = 128 + + # reshape as `(num_tiles, num_blocks_per_tile)` + assert len(block_tables_hbm.shape) == 1 + (num_total_blocks, ) = block_tables_hbm.shape + assert num_blocks_per_tile * num_tiles == num_total_blocks + block_tables_hbm = block_tables_hbm.reshape( + (num_tiles, num_blocks_per_tile)) + + block_tables_sbuf = nl.zeros( + (ceil_div(num_tiles, + B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile), + dtype=nl.int32, + ) + for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(num_blocks_per_tile)[None, :] + block_tables_sbuf[i, i_p, i_f] = nl.load( + block_tables_hbm[i_p + i * B_P_SIZE, i_f], + dtype=nl.int32, + mask=(i_p + i * B_P_SIZE < num_tiles), + ) + return block_tables_sbuf + + +@nki.jit +def transform_block_tables_for_indirect_load( + block_tables, + block_size_tiling_factor, + num_head, + head_id, +): + """ + This function does two things: + 1. calculate new `block_tables` for a `head_id` after flattening + `num_block`, `num_head`, and `block_size_tiling_factor` dimensions + 2. transpose the result so that `block_table` for each tile is mapped to + SBUF Partition dimension for vectorized DMA + + Tiling trick to further improve DMA performance: + Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M + blocks of a given `head_id` from HBM, the load `cache[block_tables, + head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not + fully utilize hardware parallelization. The solution is to tile `block_size` + into `(block_size_tiling_factor, tiled_block_size)` s.t. `M * + block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape + `(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`. + + Note: + We don't further tile D dimension as small DMA size also hurts performance. + """ + B_P_SIZE = 128 + num_partitions, num_tiles_per_partition, num_blocks_per_tile = ( + block_tables.shape) + assert num_tiles_per_partition == B_P_SIZE + assert is_power_of_2( + num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2" + + num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE) + block_tables_transposed = nl.ndarray( + ( + num_loads, + par_dim(B_P_SIZE), + num_partitions * num_tiles_per_partition, + ), + dtype=nl.int32, + ) + + # prepare iota ahead of time to avoid repeatedly using Gpsimd + if num_head > 1: + head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1)) + head_id = nl.transpose( + head_id.broadcast_to((1, num_tiles_per_partition))) + if num_blocks_per_tile > 1: + head_id = head_id.broadcast_to( + (num_tiles_per_partition, num_blocks_per_tile)) + + if block_size_tiling_factor > 1: + broadcast_shape = ( + num_tiles_per_partition, + num_blocks_per_tile, + block_size_tiling_factor, + ) + offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :], + dtype=nl.int32).broadcast_to(broadcast_shape) + + for partition_id in nl.affine_range(num_partitions): + block_tables_partition = block_tables[partition_id] + if num_head > 1: + # fuse num_block and num_head dimension + block_tables_partition = block_tables_partition * num_head + head_id + + if block_size_tiling_factor > 1: + # need to apply block size tiling trick + assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE + block_tables_partition = ((block_tables_partition * + block_size_tiling_factor).reshape( + (num_tiles_per_partition, + num_blocks_per_tile, + 1)).broadcast_to(broadcast_shape)) + new_block_tables = block_tables_partition + offset + new_block_tables = new_block_tables.reshape( + (num_tiles_per_partition, B_P_SIZE)) + else: + new_block_tables = block_tables_partition + + # transpose the block table so that it can be used by vector DGE + for i in nl.affine_range(num_loads): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = (partition_id * num_tiles_per_partition + + nl.arange(num_tiles_per_partition)[None, :]) + block_tables_transposed[i, i_p, i_f] = nl.transpose( + new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)]) + return block_tables_transposed + + +@nki.jit +def load_kv_tile_from_cache( + cur_k_tile, + cur_v_tile, + kv_cache, + block_tables, + large_k_tile_idx, + num_blocks_per_large_tile, + tiled_block_size, + B_P_SIZE, + B_D_SIZE, +): + """ + Load KV cache and transform Key and Value into layout required by Matmul + + Vectorized DMA Load layout: + Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) + + Layout used by attention matmuls: + Key: (par_dim(B_D_SIZE), seqlen_kv) + Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE) + equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) + """ + # load key cache + num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) + for load_idx in nl.affine_range(num_loads): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] + loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p, + large_k_tile_idx], i_f]) + if cur_k_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_k_tile.dtype) + # Transpose SBUF tensor using PE + for tb_i in nl.affine_range(tiled_block_size): + cur_k_tile[ + :, + nl.ds( + load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE, + B_P_SIZE, + ), + ] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)]) + + # load value cache + for load_idx in nl.affine_range(num_loads): + loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p, + large_k_tile_idx], i_f]) + if cur_v_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] + cur_v_tile[ + :, + nl.ds( + load_idx * tiled_block_size * B_D_SIZE, + tiled_block_size * B_D_SIZE, + ), + ] = loaded + + +@nki.jit +def transpose_p_local(p_local_transposed, + p_local, + LARGE_TILE_SZ, + B_F_SIZE=512): + for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): + if nisa.get_nc_version() == nisa.nc_version.gen3: + p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), + buffer=nl.sbuf, + dtype=p_local.dtype) + else: + p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), + buffer=nl.psum, + dtype=np.float32) + + for j in nl.affine_range(B_F_SIZE // 128): + j_128_slice = nl.ds(j * 128, 128) + i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128) + + if nisa.get_nc_version() == nisa.nc_version.gen3: + p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose( + p_local[:, i_j_128_slice]) + else: + p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( + p_local[:, i_j_128_slice]) + + p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy( + p_local_t_tmp, dtype=p_local_transposed.dtype) + + +@nki.jit +def _flash_attention_core( + q_local_tile, + k, + v, + o_buffer, + l_buffer, + m_buffer, + kernel_dtype, + acc_type, + tile_mask, + use_causal_mask, + q_tile_idx=None, + initialize=False, + LARGE_TILE_SZ=2048, + B_P_SIZE=128, + B_F_SIZE=512, + B_D_SIZE=128, + qk_res_buffer=None, +): + """ + The flash attention core function to calculate self attention between a tile + of q and a block of K and V. + The q_local_tile has (B_P_SIZE, B_D_SIZE) + The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will + be split into size B_F_SIZE tiles + + The results are stored in the following three buffers + o_buffer: (B_P_SIZE, d) + l_buffer: (B_P_SIZE, 1) + m_buffer: (B_P_SIZE, 1) + + All IO buffers are in SBUF. + """ + num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE + + qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + buffer=nl.sbuf, + dtype=acc_type) + max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), + dtype=acc_type) + for k_i in nl.affine_range(num_k_tile_per_large_tile): + k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) + + if use_causal_mask: + # mask are used to only apply computation to the lower half of the + # matrix, which reduce the arithmetic intensity by up to 50% + multiplication_required_selection = (q_tile_idx * B_P_SIZE + >= k_i * B_F_SIZE) + else: + multiplication_required_selection = True + + if multiplication_required_selection: + qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), + dtype=np.float32, + buffer=nl.psum) # (128, 512) + qk_psum[:, :] = nl.matmul(q_local_tile, + k[:, k_i_b_f_slice], + transpose_x=True) # (p(128), 512) + qk_res_buf[:, k_i_b_f_slice] = nl.where( + tile_mask[:, k_i_b_f_slice], + qk_psum[:, nl.ds(0, B_F_SIZE)], + -9984.0, + dtype=acc_type, + ) + else: + qk_res_buf[:, k_i_b_f_slice] = -9984.0 + + # Calculate max of the current tile + max_local[:, k_i] = nisa.tensor_reduce( + np.max, + qk_res_buf[:, k_i_b_f_slice], + axis=(1, ), + dtype=acc_type, + negate=False, + ) + + if qk_res_buffer is not None: + qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :]) + + max_ = nisa.tensor_reduce( + np.max, + max_local[:, :], + axis=(1, ), + dtype=acc_type, + negate=False, + ) + + o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), + dtype=o_buffer.dtype) + + if initialize: + m_buffer[:, 0] = nl.copy(max_) + m_current = max_ + else: + m_previous = nl.copy(m_buffer[:, 0]) + m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1) + + m_current = m_buffer[:, 0] + # Compute scaling factor + alpha = nisa.activation( + np.exp, + m_previous, + bias=-1 * m_current, + scale=1.0, + ) + o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha) + + p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) + + p_partial_sum = nl.ndarray( + (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), + dtype=acc_type, + ) + + for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): + k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) + + # compute exp(qk - max) + # Compute partial row - tile sum of exp(qk - max)) + # FIXME : Use activation accumulate to accumulate over k_r_i loop ? + p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce( + np.exp, + qk_res_buf[:, k_r_i_reduce_slice], + bias=-1 * m_current, + scale=1.0, + reduce_op=nl.add, + reduce_res=p_partial_sum[:, k_r_i], + dtype=kernel_dtype, + ) + + ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type) + + p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + transpose_p_local( + p_local_transposed=p_local_transposed, + p_local=p_local, + LARGE_TILE_SZ=LARGE_TILE_SZ, + B_F_SIZE=B_F_SIZE, + ) + + pv_psum = nl.zeros( + (par_dim(B_P_SIZE), B_D_SIZE), + dtype=np.float32, + buffer=nl.psum, + ) + for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): + pv_psum[:, :] += nl.matmul( + p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], + v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)], + transpose_x=True, + ) # (128, 128) (p(Br), d) + + if initialize: + o_buffer[:, :] = nl.copy(pv_psum[:, :]) + l_buffer[:, 0] = nl.add(nl.log(ps), max_) + else: + o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum) + + l_prev = l_buffer[:, 0] + l_exp = nl.add( + nl.exp(nl.subtract(l_prev, m_current)), + ps, + ) + l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp)) + + +@nki.jit +def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ): + B_P_SIZE = 128 + B_D_SIZE = v_hbm_tile.shape[-1] + loaded = nl.load(v_hbm_tile[ + nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), + :, + ]) + if cur_v_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) + cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded + + +@nki.jit +def flash_paged_attention( + query, + key, + value, + kv_cache, + block_tables, + mask, + softmax_scale=None, + mixed_precision=True, + LARGE_TILE_SZ=2048, + return_debug_tensors=False, +): + """ + Flash PagedAttention Forward Kernel. + + IO tensor layouts: + - query: shape (1, n_heads, d, seq_q) + - key: shape (1, n_kv_heads, d, seq_k) + - value: shape (1, n_kv_heads, seq_v, d) + - kv_cache: (2, num_blocks, n_kv_heads, block_size, d) + - block_tables: (num_active_blocks, ) + - mask: (seq_q, num_active_blocks * block_size + seq_q) + - o: shape (1, n_heads, seq_q, d) + + - This kernel requires seq_k == seq_v + - We use continuous batching by default, so the batch dimension is + always 1, and different requests are concatenated along sequence + dimension. + - We use paged cache blocks (kv_cache) to store KV cache. + + IO tensor dtypes: + - This kernel assumes all IO tensors have the same dtype except for + block_tables (int32) and mask (int32) + - If mixed_precision is True, then all Tensor Engine operation will be + performed in bfloat16 and accumulation will be performed in float32. + Otherwise the intermediates will be in the same type as the inputs. + + Compile-time Constants: + - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` + - mixed_precision: flag to set non-matmul ops in fp32 precision, default + is set to `true`, if false, we use same precision as input types + - LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention + computation reduction + + GQA support Notes: + the spmd kernel for launching kernel should be on kv_heads instead of + nheads + + Example usage: + MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d] + usage: `flash_fwd[b, h](q, k, v, ...)` + GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] + usage: `flash_fwd[b, kv_h](q, k, v, ...)` + """ + B_F_SIZE = 512 + B_P_SIZE = 128 + b, h, d, seqlen_q = query.shape + B_D_SIZE = d + n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine + _, num_blocks, k_h, block_size, _ = kv_cache.shape + q_h_per_k_h = h // k_h + assert b == 1, f"invalid batch size {b=}" + assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}" + cache_shape = (2, num_blocks, k_h, block_size, d) + assert (tuple(kv_cache.shape) == cache_shape + ), f"{kv_cache.shape=} mismatch, expect {cache_shape}" + assert key is None or tuple(key.shape) == ( + 1, + k_h, + d, + seqlen_q, + ), f"key shape {key.shape} mismatch!" + assert value is None or tuple(value.shape) == ( + 1, + k_h, + seqlen_q, + d, + ), f"value shape {value.shape} mismatch!" + + assert ( + nl.program_ndim() == 2 + ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" + batch_id = nl.program_id(axis=0) + head_id = nl.program_id(axis=1) + + (num_active_blocks, ) = block_tables.shape + context_kv_len = num_active_blocks * block_size + assert ( + LARGE_TILE_SZ % B_F_SIZE == 0 + ), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p" + assert (context_kv_len % LARGE_TILE_SZ == 0 + ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" + + num_blocks_per_large_tile = LARGE_TILE_SZ // block_size + assert is_power_of_2( + num_blocks_per_large_tile + ), f"{num_blocks_per_large_tile=} is expected of be power of 2" + if seqlen_q > B_F_SIZE: + MAX_REDUCTION_TILE = 2048 + if seqlen_q // 2 > MAX_REDUCTION_TILE: + assert ( + seqlen_q % MAX_REDUCTION_TILE == 0 + ), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}" + else: + assert (seqlen_q % B_F_SIZE == 0 + ), f"{seqlen_q=} should be divisible by {B_F_SIZE=})" + + kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype + acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype + softmax_scale = softmax_scale or (1.0 / (d**0.5)) + num_large_k_tile = context_kv_len // LARGE_TILE_SZ + + o = nl.ndarray((b, h, seqlen_q, d), + dtype=query.dtype, + buffer=nl.shared_hbm) + hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = ( + None, + None, + None, + None, + ) + if return_debug_tensors: + hbm_l_buffer = nl.ndarray((b, h, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + hbm_m_buffer = nl.ndarray((b, h, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + qk_res_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + block_tables_sbuf = load_block_tables( + block_tables_hbm=block_tables, + num_tiles=num_large_k_tile, + num_blocks_per_tile=num_blocks_per_large_tile, + ) + + # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient + if num_blocks_per_large_tile < B_P_SIZE: + # we checked num_blocks_per_tile is a power of 2 + assert B_P_SIZE % num_blocks_per_large_tile == 0 + block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile + # We assume block_size >= block_size_tiling_factor + assert block_size % block_size_tiling_factor == 0 + else: + block_size_tiling_factor = 1 + tiled_block_size = block_size // block_size_tiling_factor + + # Indirect DMA load must be placed along Partition Dimension + block_tables_sbuf = transform_block_tables_for_indirect_load( + block_tables_sbuf, + block_size_tiling_factor=block_size_tiling_factor, + num_head=k_h, + head_id=head_id, + ) + + # Flatten KV cache to be 3D for loading into SBUF + new_cache_shape = ( + 2, + num_blocks * k_h * block_size_tiling_factor, + tiled_block_size * d, + ) + kv_cache = kv_cache.reshape(new_cache_shape) + + # Global Flash Attention accumulators + o_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + l_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + m_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + + for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile): + num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) + cur_k_tile = nl.ndarray( + (par_dim(B_D_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype, + ) + cur_v_tile = nl.ndarray( + (par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE), + dtype=kernel_dtype, + ) + load_kv_tile_from_cache( + cur_k_tile=cur_k_tile, + cur_v_tile=cur_v_tile, + kv_cache=kv_cache, + block_tables=block_tables_sbuf, + large_k_tile_idx=large_k_tile_idx, + num_blocks_per_large_tile=num_blocks_per_large_tile, + tiled_block_size=tiled_block_size, + B_P_SIZE=B_P_SIZE, + B_D_SIZE=B_D_SIZE, + ) + + for i in nl.affine_range(n_tile_q): + cur_mask = nl.load(mask[ + nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ), + ]) + for i_q_h in nl.affine_range(q_h_per_k_h): + q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) + q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] + q_sbuf_tile = nl.load(q_hbm_tile[:, + nl.ds(i * + B_P_SIZE, B_P_SIZE)]) + if q_sbuf_tile.dtype != kernel_dtype: + q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) + q_tile[:, :] = q_sbuf_tile * softmax_scale + + _flash_attention_core( + q_local_tile=q_tile, + k=cur_k_tile, + v=cur_v_tile, + o_buffer=o_buffer[i, i_q_h], + l_buffer=l_buffer[i, i_q_h], + m_buffer=m_buffer[i, i_q_h], + kernel_dtype=kernel_dtype, + acc_type=acc_type, + tile_mask=cur_mask, + use_causal_mask=False, + q_tile_idx=i, + initialize=large_k_tile_idx == 0, + LARGE_TILE_SZ=LARGE_TILE_SZ, + B_P_SIZE=B_P_SIZE, + B_F_SIZE=B_F_SIZE, + B_D_SIZE=B_D_SIZE, + ) + + # compute attention between input query, key and value + if key is not None and value is not None: + B_F_SIZE = min(seqlen_q, B_F_SIZE) + LARGE_TILE_SZ = seqlen_q + + cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + cur_v_tile = nl.ndarray( + (par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE), + dtype=kernel_dtype, + ) + + loaded = nl.load(key[batch_id, head_id, :, :]) + if loaded.dtype != kernel_dtype: + loaded = nl.copy(loaded, dtype=kernel_dtype) + cur_k_tile[:, :] = loaded + + v_hbm_tile = value[batch_id, head_id] + for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): + load_v_tile( + v_hbm_tile=v_hbm_tile, + cur_v_tile=cur_v_tile, + large_tile_idx=0, + v_i=v_i, + LARGE_TILE_SZ=LARGE_TILE_SZ, + ) + + for i in nl.affine_range(n_tile_q): + cur_mask = nl.load(mask[ + nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(context_kv_len, LARGE_TILE_SZ), + ]) + for i_q_h in nl.affine_range(q_h_per_k_h): + + q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) + q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] + q_sbuf_tile = nl.load(q_hbm_tile[:, + nl.ds(i * + B_P_SIZE, B_P_SIZE)]) + if q_sbuf_tile.dtype != kernel_dtype: + q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) + q_tile[:, :] = q_sbuf_tile * softmax_scale + _flash_attention_core( + q_local_tile=q_tile, + k=cur_k_tile, + v=cur_v_tile, + o_buffer=o_buffer[i, i_q_h], + l_buffer=l_buffer[i, i_q_h], + m_buffer=m_buffer[i, i_q_h], + kernel_dtype=kernel_dtype, + acc_type=acc_type, + tile_mask=cur_mask, + use_causal_mask=True, + q_tile_idx=i, + initialize=False, + LARGE_TILE_SZ=LARGE_TILE_SZ, + B_P_SIZE=B_P_SIZE, + B_F_SIZE=B_F_SIZE, + B_D_SIZE=B_D_SIZE, + qk_res_buffer=(qk_res_buffer[i, i_q_h] + if qk_res_buffer is not None else None), + ) + + # -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- # + for i_q_h in nl.affine_range(q_h_per_k_h): + for i in nl.affine_range(n_tile_q): + out = nl.multiply( + o_buffer[i, i_q_h], + nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]), + dtype=kernel_dtype, + ) + + nl.store( + o[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + :, + ], + out, + ) + # maximum and summation statistics + if return_debug_tensors: + nl.store( + hbm_m_buffer[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + ], + m_buffer[i, i_q_h, :, :], + ) + nl.store( + hbm_l_buffer[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + ], + l_buffer[i, i_q_h], + ) + nl.store( + hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], + qk_res_buffer[batch_id, i_q_h, :, :], + ) + + if return_debug_tensors: + return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res + return o + + +def reorder_context_mask(mask, LARGE_TILE_SZ, block_size): + """ + Reorder the mask to make it compatible with the flash attention kernel. + + We vectorize KV cache read to improve DMA utilization. However, the layout + that maximizes DMA bandwidth changes the order tokens are consumed. + + The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE, + tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And + each step the engine consumes a column (rather than a row) of B_P_SIZE + tokens. Therefore, the tokens are visited in a strided way. + + To make sure mask matches the order tokens are consumed, we need to properly + transpose mask. + """ + total_query_len, total_seq_len = mask.shape + context_kv_len = total_seq_len - total_query_len + + B_P_SIZE = 128 + assert (LARGE_TILE_SZ + >= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}" + num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size) + tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks + if tiled_block_size > 1: + # Mask reordering is needed when tiled_block_size > 1 + device = mask.device + mask = mask.cpu() + context_mask = mask[:, :context_kv_len] + context_mask = context_mask.view( + total_query_len, + context_kv_len // LARGE_TILE_SZ, + num_tiled_blocks // B_P_SIZE, + B_P_SIZE, + tiled_block_size, + ) + context_mask = context_mask.transpose(3, 4).reshape( + total_query_len, context_kv_len) + new_mask = mask[:, context_kv_len:] + return torch.concat([context_mask, new_mask], dim=1).to(device) + else: + return mask + + +def flash_attn_varlen_nkifunc( + query, + key, + value, + kv_cache, + block_table, + attn_mask, + n_kv_head=None, + head_size=None, + LARGE_TILE_SZ=2048, + mixed_precision=True, +): + """ + Compute flash paged attention for variable length sequences. + + This function is a wrapper around the flash attention NKI kernel. It takes + in the following arguments: + - query: (1, n_heads, d, seq_q) + - key: (1, n_kv_heads, d, seq_k) + - value: (1, n_kv_heads, seq_v, d) + - kv_cache: (2, n_blocks, n_kv_heads, block_size, d) + - block_tables: (n_active_blocks, ) + - attn_mask: (seq_q, n_active_blocks * block_size + seq_q) + + Notes: + - attn_mask must be reordered outside using `reorder_context_mask` + - Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d) + for better DMA throughput + """ + if n_kv_head is None: + n_kv_head = kv_cache.shape[2] + assert kv_cache.shape[0] == 2 + assert kv_cache.shape[2] == n_kv_head + if head_size is None: + head_size = kv_cache.shape[-1] + + kwargs = dict( + query=query, + key=key, + value=value, + kv_cache=kv_cache, + block_tables=block_table, + mask=attn_mask, + softmax_scale=1.0 / (head_size**0.5), + mixed_precision=mixed_precision, + LARGE_TILE_SZ=LARGE_TILE_SZ, + ) + + o = flash_paged_attention[1, n_kv_head](**kwargs) + return o + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + """ + Writes key-value pairs to the KV cache at specified positions. + + Args: + key (torch.Tensor): Key tensor with shape + (num_tokens, n_kv_head, d_head) + value (torch.Tensor): Value tensor with shape + (num_tokens, n_kv_head, d_head) + kv_cache (torch.Tensor): Key/value cache tensor with shape + (2, num_blocks, n_kv_head, block_size, d_head) + slot_mapping (torch.Tensor): Mapping tensor indicating cache positions + with shape (num_tokens) + + Returns: + None: Updates the kv_cache tensor in-place + """ + block_size = kv_cache.size(3) + n_kv_head = key.size(1) + + # Calculate indices with explicit floor division + block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + block_offsets = slot_mapping % block_size + + # Create the head indices tensor + head_indices = torch.arange(n_kv_head, device=key.device) + + # Update caches using index_put_ + kv_cache.index_put_( + (torch.tensor([0], device=key.device), block_indices[:, None], + head_indices[None, :], block_offsets[:, None]), key) + + kv_cache.index_put_( + (torch.tensor([1], device=key.device), block_indices[:, None], + head_indices[None, :], block_offsets[:, None]), value) diff --git a/attention/ops/paged_attn.py b/attention/ops/paged_attn.py new file mode 100644 index 0000000..c6d1501 --- /dev/null +++ b/attention/ops/paged_attn.py @@ -0,0 +1,256 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.attention.ops.prefix_prefill import context_attention_fwd + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +@dataclass +class PagedAttentionMetadata: + """Metadata for PagedAttention.""" + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + +class PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 80, 96, 112, 120, 128, 192, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory shortage. + use_v1 = (max_seq_len <= 8192 + and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + + if use_v1: + # Run PagedAttention V1. + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + return output + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache_dtype: str, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens_tensor: torch.Tensor, + max_query_len: int, + alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> torch.Tensor: + output = torch.empty_like(query) + max_seq_len = None + context_attention_fwd( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_tables, + # query_start_loc is (batch_size + 1,) + query_start_loc, + seq_lens_tensor, + max_seq_len, + max_query_len, + k_scale, + v_scale, + alibi_slopes, + sliding_window, + ) + return output + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/attention/ops/prefix_prefill.py b/attention/ops/prefix_prefill.py new file mode 100644 index 0000000..13bef96 --- /dev/null +++ b/attention/ops/prefix_prefill.py @@ -0,0 +1,902 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# The kernels in this file are adapted from LightLLM's context_attention_fwd: +# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py + +import torch + +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton + +# Static kernels parameters +BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 +NUM_WARPS = 4 if current_platform.is_rocm() else 8 + +# To check compatibility +IS_TURING = current_platform.get_device_capability() == (7, 5) + + +# Here's an example autotuner config for this kernel. This config does provide +# a performance improvement, but dramatically increases first call latency in +# triton 3.2. Because of this tradeoff, it's currently commented out. +# @triton.autotune( +# configs=[ +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ +# "num_unroll_cache": 4, \ +# "num_unroll_request": 1 } | \ +# ({"kpack": 2, "waves_per_eu": 2} \ +# if current_platform.is_rocm() else {}), \ +# num_warps=4, \ +# num_stages=1) +# ], +# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] +# ) +@triton.jit +def _fwd_kernel(Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + x: tl.constexpr, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl: tl.constexpr, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: tl.constexpr, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_PADDED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + num_unroll_cache: tl.constexpr, + num_unroll_request: tl.constexpr, + SKIP_DECODE: tl.constexpr, + MAX_Q_LEN: tl.constexpr = 0, + MAX_CTX_LEN: tl.constexpr = 0): + + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: + return + + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [BLOCK_SIZE]; starts at 0 + offs_bs_n = tl.arange(0, BLOCK_SIZE) + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ + loop_unroll_factor=num_unroll_cache): + start_n = tl.multiple_of(start_n, BLOCK_SIZE) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + (start_n // BLOCK_SIZE) * stride_b_loc_s) + # [D,BLOCK_SIZE] + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + + # [BLOCK_SIZE,D] + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + offs_bs_n[:, None] * stride_v_cache_bl) + + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + k_load = tl.load( + K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + else: + k_load = tl.load(K_cache + off_k) + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_bs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + v_load = tl.load( + V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + else: + v_load = tl.load(V_cache + off_v) + + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in tl.range(0, \ + block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ + loop_unroll_factor=num_unroll_request): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, + qk, -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) + return + + +@triton.jit +def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + +@triton.jit +def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SKIP_DECODE: tl.constexpr, +): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: + return + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision='ieee') + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + return + + +@torch.inference_mode() +def context_attention_fwd(q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + max_seq_len, + max_input_len, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, + skip_decode=False): + + q_dtype_is_f32 = q.dtype is torch.float32 + + # Turing does have tensor core for float32 multiplication + # use ieee as fallback for triton kernels work. There is also + # warning on vllm/config.py to inform users this fallback + # implementation + IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = current_platform.fp8_dtype() + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + if (k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): + raise ValueError("kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel") + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + assert batch + 1 == len(b_start_loc) + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if alibi_slopes is not None: + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + # if q.dtype is torch.float32: + BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK + # batch, head, + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + _fwd_kernel_alibi[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + alibi_slopes, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + SKIP_DECODE=skip_decode, + num_warps=NUM_WARPS, + num_stages=1, + ) + return + + max_seq_len = 0 if max_seq_len is None else max_seq_len + extra_kargs = {} + if current_platform.is_rocm(): + extra_kargs = {"kpack": 2, "waves_per_eu": 2} + + grid = lambda META: (batch, head, + triton.cdiv(max_input_len, META["BLOCK_M"])) + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_SIZE=v_cache.shape[3], + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + SLIDING_WINDOW=sliding_window, + SKIP_DECODE=skip_decode, + BLOCK_M=128, + BLOCK_N=64, + num_unroll_cache=4, + num_unroll_request=1, + num_warps=4, + num_stages=1, + **extra_kargs) + return diff --git a/attention/ops/rocm_aiter_mla.py b/attention/ops/rocm_aiter_mla.py new file mode 100644 index 0000000..cce6b46 --- /dev/null +++ b/attention/ops/rocm_aiter_mla.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + + +def get_aiter_mla_metadata(max_batch_size: int, block_size: int, + max_block_per_batch: int, + device: torch.device) -> tuple[torch.Tensor, ...]: + paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch, + dtype=torch.int32, + device=device) + paged_kv_indptr = torch.zeros(max_batch_size + 1, + dtype=torch.int32, + device=device) + paged_kv_last_page_lens = torch.full((max_batch_size, ), + block_size, + dtype=torch.int32) + qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) + return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr + + +def aiter_mla_decode_fwd( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + sm_scale: float, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + logit_cap: float = 0.0, +): + + torch.ops.vllm.rocm_aiter_mla_decode_fwd(q, + kv_buffer.view( + -1, 1, 1, q.shape[-1]), + o, + qo_indptr, + max_seqlen_qo, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap) + + +def mla_decode_fwd_impl( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + from aiter.mla import mla_decode_fwd + + mla_decode_fwd(q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + sm_scale=sm_scale, + logit_cap=logit_cap) + + +def mla_decode_fwd_fake( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + pass + + +if current_platform.is_rocm(): + direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", + op_func=mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=mla_decode_fwd_fake, + tags=[torch.Tag.needs_fixed_stride_order]) diff --git a/attention/ops/rocm_aiter_paged_attn.py b/attention/ops/rocm_aiter_paged_attn.py new file mode 100644 index 0000000..ad97152 --- /dev/null +++ b/attention/ops/rocm_aiter_paged_attn.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import aiter as rocm_aiter +import torch + +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.platforms import current_platform +from vllm.utils import cdiv + +FP8_DTYPE = current_platform.fp8_dtype() + + +class AITERPagedAttention(PagedAttention): + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> None: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, + v_scale) + else: + kv_cache_torch_dtype = (FP8_DTYPE + if "fp8" in kv_cache_dtype else torch.int8) + key_cache = key_cache.view(kv_cache_torch_dtype) + value_cache = value_cache.view(kv_cache_torch_dtype) + + rocm_aiter.reshape_and_cache_with_pertoken_quant( + key, value, key_cache, value_cache, k_scale, v_scale, + slot_mapping.flatten(), True) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + return PagedAttention.forward_decode( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + kv_cache_dtype=kv_cache_dtype, + num_kv_heads=num_kv_heads, + scale=scale, + alibi_slopes=alibi_slopes, + k_scale=k_scale, + v_scale=v_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step) + + if "fp8" in kv_cache_dtype: + key_cache = key_cache.view(torch.float8_e4m3fnuz) + value_cache = value_cache.view(torch.float8_e4m3fnuz) + + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + max_num_blocks_per_seq = cdiv(max_seq_len, block_size) + + rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, + seq_lens, max_num_blocks_per_seq, k_scale, + v_scale, output) + return output diff --git a/attention/ops/triton_decode_attention.py b/attention/ops/triton_decode_attention.py new file mode 100644 index 0000000..38bc59b --- /dev/null +++ b/attention/ops/triton_decode_attention.py @@ -0,0 +1,685 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +# which was originally adapted from +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + +# Changes: +# - Add support for page size >= 1. + +# Copyright 2025 vLLM Team +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Memory-efficient attention for decoding. +It supports page size >= 1. +""" + +import logging + +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton + +# is_hip_ = current_platform.is_rocm() + +logger = logging.getLogger(__name__) + +# Only print the following warnings when triton version < 3.2.0. +# The issue won't affect performance or accuracy. +if triton.__version__ < '3.2.0': + logger.warning( + "The following error message 'operation scheduled before its operands' " + "can be ignored.") + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + q = tl.load(Q + off_q, mask=mask_d, other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = (kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + offs_d[None, :]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + offs_dv) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +def _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + sm_scale, + page_size, + logit_cap, +): + # BLOCK = 64 if not is_hip_ else 8 + BLOCK = 8 + + NUM_KV_SPLITS = num_kv_splits + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + batch, head_num = q.shape[0], q.shape[1] + + grid = (batch, head_num, NUM_KV_SPLITS) + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + num_warps = 4 + if kv_group_num != 1: + # num_warps = 1 if is_hip_ else 2 + num_warps = 1 + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + + _fwd_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, + Lv=Lv, + ) + + +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if kv_group_num > BLOCK_H: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[ + None, :] + q = tl.load(Q + offs_q, + mask=(mask_h[:, None]) & (mask_d[None, :]), + other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + + offs_dpe[None, :]) + qpe = tl.load(Q + off_qpe, + mask=(mask_h[:, None]) & (mask_dpe[None, :]), + other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = (kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + offs_d[:, None]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None]) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & + (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), + qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + offs_dv[None, :]) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, + ) + + +def _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + num_stages, + sm_scale, + page_size, + logit_cap, +): + BLOCK = 16 + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + # [TODO] work around shmem limit on MI3xx + # if is_hip_ and Lk >= 576: + # BLOCK = 16 + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + batch, head_num = q.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits + grid = ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + NUM_KV_SPLITS, + ) + + if num_stages == 1: + extra_kargs = {"scenario":"mla"} + elif num_stages == 2: + extra_kargs = {"scenario" : "mla", "pipeline" : "cpasync"} + else: + KeyError("num_stages should be 1 or 2") + # if is_hip_: + # # https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization + # # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + # extra_kargs = { + # "waves_per_eu": 1, + # "matrix_instr_nonkdim": 16, + # "kpack": 2 + # } + # num_stages = 1 + + _fwd_grouped_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=4, + num_stages=num_stages, + Lk=Lk, + Lv=Lv, + **extra_kargs, + ) + + +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + o, + B_Seqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, + mask=mask_d, + other=0.0) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def _decode_softmax_reducev_fwd( + logits, + q, + o, + v_buffer, + b_seq_len, + num_kv_splits, +): + batch, head_num = q.shape[0], q.shape[1] + Lv = v_buffer.shape[-1] + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits + + extra_kargs = {} + # if is_hip_: + # # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + # extra_kargs = { + # "waves_per_eu": 4, + # "matrix_instr_nonkdim": 16, + # "kpack": 2 + # } + + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( + logits, + o, + b_seq_len, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + num_warps=4, + num_stages=2, + **extra_kargs, + ) + + +def decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + num_kv_splits) + + +def decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + num_stages, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + num_stages, + sm_scale, + page_size, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + num_kv_splits) + + +def decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + num_stages, + sm_scale, + page_size=1, + logit_cap=0.0, +): + assert num_kv_splits == attn_logits.shape[2] + kv_group_num = q.shape[1] // v_buffer.shape[-2] + + if kv_group_num == 1: + # MHA + decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + else: + # GQA/MQA/MLA + decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + num_stages, + sm_scale, + page_size, + logit_cap, + ) diff --git a/attention/ops/triton_flash_attention.py b/attention/ops/triton_flash_attention.py new file mode 100644 index 0000000..a26e713 --- /dev/null +++ b/attention/ops/triton_flash_attention.py @@ -0,0 +1,979 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch + +from vllm.platforms import current_platform +from vllm.platforms.rocm import on_gfx1x +from vllm.triton_utils import tl, triton + +torch_dtype: tl.constexpr = torch.float16 + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, + stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, + stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, + USE_FP8: tl.constexpr, + qk_scale, + p_descale, +): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) + if PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: # noqa: SIM102 + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. + # check if this masking works for that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], + actual_seqlen_k, + dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if USE_FP8: + qk *= qk_scale + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS + and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += bias * 1.44269504089 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) + if RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, + -p).to(encoded_softmax_block_ptr.type.element_ty), + ) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + + if USE_FP8: + p *= p_descale + + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, BLOCK_N)) + return acc, l_i, m_i + + +def get_cdna_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 256, + 'BLOCK_N': 64, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=8), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 256, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=8), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 1, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': True + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 64, + 'BLOCK_N': 64, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=8), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # "BLOCK_M": 16, + # "BLOCK_N": 16, + # "waves_per_eu": 1, + # "PRE_LOAD_V": False, + # }, + # num_stages=1, + # num_warps=4, + # ), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + + +def get_rdna_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 4, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 2, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # # Fall-back config. + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 1, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + + +def get_autotune_configs(): + if on_gfx1x(): + return get_rdna_autotune_configs() + else: + return get_cdna_autotune_configs() + + +autotune_configs, autotune_keys = get_autotune_configs() + +float8_info = torch.finfo(current_platform.fp8_dtype()) + + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, + L, + Out, + stride_qz: tl.int64, + stride_qh: tl.int64, + stride_qm: tl.int64, + stride_qk: tl.int64, + stride_kz: tl.int64, + stride_kh: tl.int64, + stride_kn: tl.int64, + stride_kk: tl.int64, + stride_vz: tl.int64, + stride_vh: tl.int64, + stride_vk: tl.int64, + stride_vn: tl.int64, + stride_oz: tl.int64, + stride_oh: tl.int64, + stride_om: tl.int64, + stride_on: tl.int64, + stride_bz: tl.int64, + stride_bh: tl.int64, + stride_bm: tl.int64, + stride_bn: tl.int64, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + USE_FP8: tl.constexpr, + USE_FP8_OUT: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + + # Compute pointers for all the tensors used in this kernel. + q_offset = (off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = (off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = (off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base \ + + (off_z * HQ + off_h_q) \ + * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + if not USE_FP8: + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + acc_scale = 1.0 + else: + qk_scale *= q_scale * k_scale + acc_scale = p_scale * v_scale + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + USE_FP8, + qk_scale, + p_descale, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + USE_FP8, + qk_scale, + p_descale, + ) + # epilogue + + if USE_FP8: + acc *= acc_scale + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + if USE_FP8_OUT: + acc *= o_descale + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] + >= out_mask_boundary[None, :]) + z = tl.zeros((1, ), tl.float32) + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): + assert q.dim() == k.dim() and q.dim() == v.dim() + if varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + fp8_scales=None, + fp8_out_scale=None, + ): + if fp8_scales is not None: + use_fp8 = True + (q_scale, k_scale, v_scale, p_scale) = fp8_scales + float8 = current_platform.fp8_dtype() + + def check_and_convert(t, scale): + if t.dtype != float8: + descale = 1.0 / scale + ts = (t * descale).clamp(min=float8_info.min, + max=float8_info.max) + return ts.to(float8) + else: + return t + + q = check_and_convert(q, q_scale) + k = check_and_convert(k, k_scale) + v = check_and_convert(v, v_scale) + else: + use_fp8 = False + q_scale = k_scale = v_scale = p_scale = 1.0 + + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + unpadded_head_dims = {32, 64, 128, 256} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None + else: + padded_d_model = head_size + + grid = lambda META: ( + triton.cdiv(max_seqlens_q, META["BLOCK_M"]), + nheads_q, + batch, + ) + + encoded_softmax = None + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) + else: + bias_strides = (0, 0, 0, 0) + + p_descale = 1.0 / p_scale + o_descale = 1.0 / fp8_out_scale.item( + ) if fp8_out_scale is not None else 1.0 + + arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q + arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k + + attn_fwd[grid]( + q, + k, + v, + bias, + sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, + None, + o, + *q_strides, + *k_strides, + *v_strides, + *o_strides, + *bias_strides, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + HQ=nheads_q, + HK=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=arg_max_seqlens_q, + MAX_SEQLENS_K=arg_max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False, + USE_FP8=use_fp8, + USE_FP8_OUT=fp8_out_scale is not None, + ) + + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + + +triton_attention = _attention.apply diff --git a/attention/ops/triton_merge_attn_states.py b/attention/ops/triton_merge_attn_states.py new file mode 100644 index 0000000..56d78ed --- /dev/null +++ b/attention/ops/triton_merge_attn_states.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + + +# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +# can be used to combine partial attention results (in the split-KV case) +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + num_tokens = output.shape[0] + num_query_heads = output.shape[1] + head_size = output.shape[2] + padded_head_size = triton.next_power_of_2(head_size) + + # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. + merge_attn_states_kernel[(num_tokens, num_query_heads)]( + output, + output_lse, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + head_size, + padded_head_size, + output_lse is not None, + ) + + +@triton.jit +def merge_attn_states_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + output_lse, # [NUM_HEADS, NUM_TOKENS] + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse, # [NUM_HEADS, NUM_TOKENS] + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse, # [NUM_HEADS, NUM_TOKENS] + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) + s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) + + # FA2 and FA3 have different behavior for when the sum-exp is 0, this namely + # arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf. + # If we see an inf assume FA2 and convert inf to -inf for consistency + # and correctness. Inf generally doesn't make sense in this context outside + # of undefined-behavior/FA2-case, so I think this a safe assumption. + p_lse = float('-inf') if p_lse == float('inf') else p_lse + s_lse = float('-inf') if s_lse == float('inf') else s_lse + + max_lse = tl.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + # Will reuse precomputed Exp values for scale factor computation. + p_se = tl.exp(p_lse) + s_se = tl.exp(s_lse) + out_se = (p_se + s_se) + + if OUTPUT_LSE: + out_lse = tl.log(out_se) + max_lse + tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + + # NOTE(woosuk): Be careful with the numerical stability. + # We should compute the scale first, and then multiply it with the output. + # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. + p_scale = p_se / out_se + s_scale = s_se / out_se + out = p_out * p_scale + s_out * s_scale + tl.store(output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask) diff --git a/attention/ops/triton_unified_attention.py b/attention/ops/triton_unified_attention.py new file mode 100644 index 0000000..92c09e6 --- /dev/null +++ b/attention/ops/triton_unified_attention.py @@ -0,0 +1,334 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import triton +import triton.language as tl + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.exp(Sdiv) + p2 = tl.exp(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@triton.jit +def kernel_unified_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int +): + + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + mid_val = tl.load(query_start_len_ptr + mid) // BLOCK_Q + mid + if mid_val <= q_block_global_idx: + left = mid + 1 + else: + right = mid + + seq_idx = left - 1 + q_block_start_idx = tl.load(query_start_len_ptr + + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, + other=0.0) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles + for j in range(0, num_blocks): + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_n[:, None] * stride_v_cache_1) + + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = j * BLOCK_SIZE + offs_n + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, BLOCK_SIZE) + S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) + + if SLIDING_WINDOW > 0: + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, BLOCK_SIZE) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + + output_offset = (query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :]) + + tl.store( + output_ptr + output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + + +def unified_attention( + q, + k, + v, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + causal, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + alibi_slopes=None, +): + assert causal, "Only causal attention is supported" + assert q_descale is None, "Q scales not supported" + + block_size = v.shape[1] + assert q.element_size() >= 2 or block_size >= 32, \ + "Block size must be at least 32 for fp8" + + use_alibi_slopes = alibi_slopes is not None + + block_size = v.shape[1] + num_seqs = len(seqused_k) + num_query_heads = q.shape[1] + num_kv_heads = k.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] + + BLOCK_M = 16 + BLOCK_Q = BLOCK_M // num_queries_per_kv + + # Ideally we would launch with kernel with: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # However, it is slow to realize the query_lens on cpu. + # Instead we use upper-bound: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # = floor(q.shape[0] / BLOCK_Q) + num_seqs + total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + + kernel_unified_attention_2d[( + total_num_q_blocks, + num_kv_heads, + )]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + ) diff --git a/attention/selector.py b/attention/selector.py new file mode 100644 index 0000000..cb577fa --- /dev/null +++ b/attention/selector.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from contextlib import contextmanager +from functools import cache +from typing import Generator, Optional, Type + +import torch + +import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionBackend +from vllm.logger import init_logger +from vllm.platforms import _Backend, current_platform +from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname + +logger = init_logger(__name__) + + +def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: + """ + Convert a string backend name to a _Backend enum value. + + Returns: + * _Backend: enum value if backend_name is a valid in-tree type + * None: otherwise it's an invalid in-tree type or an out-of-tree platform is + loaded. + """ + assert backend_name is not None + return _Backend[backend_name] if backend_name in _Backend.__members__ else \ + None + + +def get_env_variable_attn_backend() -> Optional[_Backend]: + ''' + Get the backend override specified by the vLLM attention + backend environment variable, if one is specified. + + Returns: + + * _Backend enum value if an override is specified + * None otherwise + ''' + backend_name = os.environ.get(STR_BACKEND_ENV_VAR) + return (None + if backend_name is None else backend_name_to_enum(backend_name)) + + +# Global state allows a particular choice of backend +# to be forced, overriding the logic which auto-selects +# a backend based on system & workload configuration +# (default behavior if this variable is None) +# +# THIS SELECTION TAKES PRECEDENCE OVER THE +# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE +forced_attn_backend: Optional[_Backend] = None + + +def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: + ''' + Force all attention operations to use a specified backend. + + Passing `None` for the argument re-enables automatic + backend selection., + + Arguments: + + * attn_backend: backend selection (None to revert to auto) + ''' + global forced_attn_backend + forced_attn_backend = attn_backend + + +def get_global_forced_attn_backend() -> Optional[_Backend]: + ''' + Get the currently-forced choice of attention backend, + or None if auto-selection is currently enabled. + ''' + return forced_attn_backend + + +def get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, + use_mla: bool = False, +) -> Type[AttentionBackend]: + """Selects which attention backend to use and lazily imports it.""" + # Accessing envs.* behind an @lru_cache decorator can cause the wrong + # value to be returned from the cache if the value changes between calls. + # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the + # private function. + return _cached_get_attn_backend( + head_size=head_size, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + block_size=block_size, + is_attention_free=is_attention_free, + is_blocksparse=is_blocksparse, + use_v1=envs.VLLM_USE_V1, + use_mla=use_mla, + ) + + +@cache +def _cached_get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, + use_v1: bool = False, + use_mla: bool = False, +) -> Type[AttentionBackend]: + if is_blocksparse: + logger.info("Using BlocksparseFlashAttention backend.") + from vllm.attention.backends.blocksparse_attn import ( + BlocksparseFlashAttentionBackend) + return BlocksparseFlashAttentionBackend + + # If there are no attention layers (e.g. we are running Mamba), + # use the placeholder NO_ATTENTION + if is_attention_free: + from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionBackend) + return PlaceholderAttentionBackend + + # Check whether a particular choice of backend was + # previously forced. + # + # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND + # ENVIRONMENT VARIABLE. + selected_backend = None + backend_by_global_setting: Optional[_Backend] = ( + get_global_forced_attn_backend()) + if backend_by_global_setting is not None: + selected_backend = backend_by_global_setting + else: + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + + # get device-specific attn_backend + attention_cls = current_platform.get_attn_backend_cls( + selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, + use_mla) + if not attention_cls: + raise ValueError( + f"Invalid attention backend for {current_platform.device_name}") + return resolve_obj_by_qualname(attention_cls) + + +@contextmanager +def global_force_attn_backend_context_manager( + attn_backend: _Backend) -> Generator[None, None, None]: + ''' + Globally force a vLLM attention backend override within a + context manager, reverting the global attention backend + override to its prior state upon exiting the context + manager. + + Arguments: + + * attn_backend: attention backend to force + + Returns: + + * Generator + ''' + + # Save the current state of the global backend override (if any) + original_value = get_global_forced_attn_backend() + + # Globally force the new backend override + global_force_attn_backend(attn_backend) + + # Yield control back to the enclosed code block + try: + yield + finally: + # Revert the original global backend override, if any + global_force_attn_backend(original_value) diff --git a/attention/utils/fa_utils.py b/attention/utils/fa_utils.py new file mode 100644 index 0000000..69cde06 --- /dev/null +++ b/attention/utils/fa_utils.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +from vllm import envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: + # import here to avoid circular dependencies + from vllm.platforms import current_platform + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + fa_version_unsupported_reason, is_fa_version_supported) + device_capability = current_platform.get_device_capability() + + assert device_capability is not None + + # 1. default version depending on platform + fa_version = 3 if (device_capability.major == 9 + and is_fa_version_supported(3)) else 2 + + # 2. override if passed by environment + if envs.VLLM_FLASH_ATTN_VERSION is not None: + assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] + fa_version = envs.VLLM_FLASH_ATTN_VERSION + + # 3. fallback for unsupported combinations + if device_capability.major == 10 and fa_version == 3: + logger.warning_once( + "Cannot use FA version 3 on Blackwell platform " + "defaulting to FA version 2.") + fa_version = 2 + + if requires_alibi and fa_version == 3: + logger.warning_once("Cannot use FA version 3 with ALiBi, " + "defaulting to FA version 2.") + fa_version = 2 + + if not is_fa_version_supported(fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + fa_version, fa_version_unsupported_reason(fa_version)) + + assert is_fa_version_supported(fa_version) + return fa_version + except (ImportError, AssertionError): + return None + + +def flash_attn_supports_fp8() -> bool: + from vllm.platforms import current_platform + return get_flash_attn_version() == 3 and \ + current_platform.get_device_capability().major == 9 diff --git a/beam_search.py b/beam_search.py new file mode 100644 index 0000000..f3bc421 --- /dev/null +++ b/beam_search.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Union + +from vllm.lora.request import LoRARequest +from vllm.sequence import Logprob + +if TYPE_CHECKING: + from vllm.multimodal import MultiModalDataDict + + +@dataclass +class BeamSearchSequence: + """A sequence for beam search. + It keeps track of the tokens and the log probability of the sequence. + The text field is optional and will only be filled when the sequence is + about to be returned to the user. + """ + # The tokens includes the prompt. + tokens: list[int] + logprobs: list[dict[int, Logprob]] + lora_request: Optional[LoRARequest] = None + cum_logprob: float = 0.0 + text: Optional[str] = None + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + multi_modal_data: Optional["MultiModalDataDict"] = None + mm_processor_kwargs: Optional[dict[str, Any]] = None + + +@dataclass +class BeamSearchOutput: + """The output of beam search. + It contains the list of the best beam search sequences. + The length of the list is equal to the beam width. + """ + sequences: list[BeamSearchSequence] + + +class BeamSearchInstance: + + def __init__( + self, + prompt_tokens: list[int], + lora_request: Optional[LoRARequest] = None, + logprobs: Optional[list[dict[int, Logprob]]] = None, + **kwargs, + ): + self.beams: list[BeamSearchSequence] = [ + BeamSearchSequence( + tokens=prompt_tokens, + logprobs=[] if logprobs is None else list(logprobs), + lora_request=lora_request, + **kwargs, + ) + ] + self.completed: list[BeamSearchSequence] = [] + + +def get_beam_search_score( + tokens: list[int], + cumulative_logprob: float, + eos_token_id: int, + length_penalty: float = 1.0, +) -> float: + """Calculate the beam search score with length penalty. + + Adapted from + + https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 + """ + seq_len = len(tokens) + if tokens[-1] == eos_token_id: + seq_len -= 1 + + return cumulative_logprob / (seq_len**length_penalty) + + +def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, + length_penalty) + + return sort_beams_key diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/datasets.py b/benchmarks/datasets.py new file mode 100644 index 0000000..4da9f73 --- /dev/null +++ b/benchmarks/datasets.py @@ -0,0 +1,1185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This module defines a framework for sampling benchmark requests from various +datasets. Each dataset subclass of BenchmarkDataset must implement sample +generation. Supported dataset types include: + - ShareGPT + - Random (synthetic) + - Sonnet + - BurstGPT + - HuggingFace + - VisionArena +""" +import base64 +import io +import json +import logging +import random +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass +from functools import cache +from io import BytesIO +from typing import Any, Callable, Optional, Union + +import numpy as np +from PIL import Image +from transformers import PreTrainedTokenizerBase + +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path +from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.image import convert_image_mode +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer +from vllm.utils import PlaceholderModule + +try: + from datasets import load_dataset +except ImportError: + datasets = PlaceholderModule("datasets") + load_dataset = datasets.placeholder_attr("load_dataset") + +try: + import pandas as pd +except ImportError: + pd = PlaceholderModule("pandas") + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") + +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Data Classes +# ----------------------------------------------------------------------------- + + +@dataclass +class SampleRequest: + """ + Represents a single inference request for benchmarking. + """ + + prompt: Union[str, Any] + prompt_len: int + expected_output_len: int + multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None + lora_request: Optional[LoRARequest] = None + + +# ----------------------------------------------------------------------------- +# Benchmark Dataset Base Class +# ----------------------------------------------------------------------------- + + +class BenchmarkDataset(ABC): + DEFAULT_SEED = 0 + IS_MULTIMODAL = False + + def __init__( + self, + dataset_path: Optional[str] = None, + random_seed: int = DEFAULT_SEED, + ) -> None: + """ + Initialize the BenchmarkDataset with an optional dataset path and random + seed. + + Args: + dataset_path (Optional[str]): Path to the dataset. If None, it + indicates that a default or random dataset might be used. + random_seed (int): Seed value for reproducible shuffling or + sampling. Defaults to DEFAULT_SEED. + """ + self.dataset_path = dataset_path + # Set the random seed, ensuring that a None value is replaced with the + # default seed. + self.random_seed = (random_seed + if random_seed is not None else self.DEFAULT_SEED) + self.data = None + + def apply_multimodal_chat_transformation( + self, + prompt: str, + mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + """ + Transform a prompt and optional multimodal content into a chat format. + This method is used for chat models that expect a specific conversation + format. + """ + content = [{"text": prompt, "type": "text"}] + if mm_content is not None: + content.append(mm_content) + return [{"role": "user", "content": content}] + + def load_data(self) -> None: + """ + Load data from the dataset path into self.data. + + This method must be overridden by subclasses since the method to load + data will vary depending on the dataset format and source. + + Raises: + NotImplementedError: If a subclass does not implement this method. + """ + # TODO (jenniferzhao): add support for downloading data + raise NotImplementedError( + "load_data must be implemented in subclasses.") + + def get_random_lora_request( + self, + tokenizer: PreTrainedTokenizerBase, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + ) -> tuple[Optional[LoRARequest], AnyTokenizer]: + """ + Optionally select a random LoRA request and return its associated + tokenizer. + + This method is used when LoRA parameters are provided. It randomly + selects a LoRA based on max_loras and retrieves a cached tokenizer for + that LoRA if available. Otherwise, it returns the base tokenizer. + + Args: + tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no + LoRA is selected. + max_loras (Optional[int]): The maximum number of LoRAs available. + If `None`, LoRA is not used. + lora_path (Optional[str]): Path to the LoRA parameters on disk. + If `None`, LoRA is not used. + + Returns: + A tuple with the following elements: + - A new [LoRARequest][] (or `None` if not applicable). + - The tokenizer associated with the LoRA request + (or the base tokenizer). + """ + if max_loras is None or lora_path is None: + return None, tokenizer + + # Generate a random LoRA ID in the range [1, max_loras]. + lora_id = random.randint(1, max_loras) + lora_request = LoRARequest( + lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(lora_path), + ) + if lora_id not in lora_tokenizer_cache: + lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) + # Return lora_request and the cached tokenizer if available; otherwise, + # return the base tokenizer + return lora_request, lora_tokenizer_cache[lora_id] or tokenizer + + @abstractmethod + def sample(self, tokenizer: PreTrainedTokenizerBase, + num_requests: int) -> list[SampleRequest]: + """ + Abstract method to generate sample requests from the dataset. + + Subclasses must override this method to implement dataset-specific logic + for generating a list of SampleRequest objects. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used + for processing the dataset's text. + num_requests (int): The number of sample requests to generate. + + Returns: + list[SampleRequest]: A list of sample requests generated from the + dataset. + """ + raise NotImplementedError("sample must be implemented in subclasses.") + + def maybe_oversample_requests(self, requests: list[SampleRequest], + num_requests: int) -> None: + """ + Oversamples the list of requests if its size is less than the desired + number. + + Args: + requests (List[SampleRequest]): The current list of sampled + requests. + num_requests (int): The target number of requests. + """ + if len(requests) < num_requests: + random.seed(self.random_seed) + additional = random.choices(requests, + k=num_requests - len(requests)) + requests.extend(additional) + logger.info("Oversampled requests to reach %d total samples.", + num_requests) + + +# ----------------------------------------------------------------------------- +# Utility Functions and Global Caches +# ----------------------------------------------------------------------------- + + +def is_valid_sequence( + prompt_len: int, + output_len: int, + min_len: int = 4, + max_prompt_len: int = 1024, + max_total_len: int = 2048, + skip_min_output_len_check: bool = False, +) -> bool: + """ + Validate a sequence based on prompt and output lengths. + + Default pruning criteria are copied from the original `sample_hf_requests` + and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as + from `sample_requests` in benchmark_throughput.py. + """ + # Check for invalid conditions + prompt_too_short = prompt_len < min_len + output_too_short = (not skip_min_output_len_check) and (output_len + < min_len) + prompt_too_long = prompt_len > max_prompt_len + combined_too_long = (prompt_len + output_len) > max_total_len + + # Return True if none of the invalid conditions are met + return not (prompt_too_short or output_too_short or prompt_too_long + or combined_too_long) + + +@cache +def lora_path_on_disk(lora_path: str) -> str: + return get_adapter_absolute_path(lora_path) + + +# Global cache for LoRA tokenizers. +lora_tokenizer_cache: dict[int, AnyTokenizer] = {} + + +def process_image(image: Any) -> Mapping[str, Any]: + """ + Process a single image input and return a multimedia content dictionary. + + Supports three input types: + + 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key + containing raw image data. - Loads the bytes as a PIL.Image.Image. + + 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as + a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns + a dictionary with the image as a base64 data URL. + + 3. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(image, dict) and 'bytes' in image: + image = Image.open(BytesIO(image['bytes'])) + if isinstance(image, Image.Image): + image = convert_image_mode(image, "RGB") + with io.BytesIO() as image_data: + image.save(image_data, format="JPEG") + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + if isinstance(image, str): + image_url = (image if image.startswith( + ("http://", "file://")) else f"file://{image}") + return {"type": "image_url", "image_url": {"url": image_url}} + + raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes.") + + +# ----------------------------------------------------------------------------- +# Random Dataset Implementation (Synthetic Data) +# ----------------------------------------------------------------------------- + + +class RandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the random dataset. + DEFAULT_PREFIX_LEN = 0 + DEFAULT_RANGE_RATIO = 0.0 + DEFAULT_INPUT_LEN = 1024 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + range_ratio: float = DEFAULT_RANGE_RATIO, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + **kwargs, + ) -> list[SampleRequest]: + # Enforce range_ratio < 1 + assert range_ratio < 1.0, ( + "random_range_ratio must be < 1.0 to ensure a valid sampling range" + ) + + vocab_size = tokenizer.vocab_size + num_special_tokens = tokenizer.num_special_tokens_to_add() + real_input_len = input_len - num_special_tokens + + prefix_token_ids = (np.random.randint( + 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + + # New sampling logic: [X * (1 - b), X * (1 + b)] + input_low = int(real_input_len * (1 - range_ratio)) + input_high = int(real_input_len * (1 + range_ratio)) + output_low = int(output_len * (1 - range_ratio)) + output_high = int(output_len * (1 + range_ratio)) + + # Add logging for debugging + logger.info( + "Sampling input_len from [%s, %s] and output_len from [%s, %s]", + input_low, input_high, output_low, output_high) + + input_lens = np.random.randint(input_low, + input_high + 1, + size=num_requests) + output_lens = np.random.randint(output_low, + output_high + 1, + size=num_requests) + offsets = np.random.randint(0, vocab_size, size=num_requests) + + requests = [] + for i in range(num_requests): + inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % + vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + prompt = tokenizer.decode(token_sequence) + # After decoding the prompt we have to encode and decode it again. + # This is done because in some cases N consecutive tokens + # give a string tokenized into != N number of tokens. + # For example for GPT2Tokenizer: + # [6880, 6881] -> ['Ġcalls', 'here'] -> + # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] + # To avoid uncontrolled change of the prompt length, + # the encoded sequence is truncated before being decode again. + re_encoded_sequence = tokenizer.encode( + prompt, add_special_tokens=False)[:input_lens[i]] + prompt = tokenizer.decode(re_encoded_sequence) + total_input_len = prefix_len + int(input_lens[i]) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + )) + return requests + + +# ----------------------------------------------------------------------------- +# ShareGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ShareGPTDataset(BenchmarkDataset): + """ + Implements the ShareGPT dataset. Loads data from a JSON file and generates + sample requests based on conversation turns. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + with open(self.dataset_path, encoding="utf-8") as f: + self.data = json.load(f) + # Filter entries with at least two conversation turns. + self.data = [ + entry for entry in self.data + if "conversations" in entry and len(entry["conversations"]) >= 2 + ] + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + samples: list = [] + for entry in self.data: + if len(samples) >= num_requests: + break + prompt, completion = ( + entry["conversations"][0]["value"], + entry["conversations"][1]["value"], + ) + + lora_request, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + new_output_len = (len(completion_ids) + if output_len is None else output_len) + if not is_valid_sequence(prompt_len, + new_output_len, + skip_min_output_len_check=output_len + is not None): + continue + if enable_multimodal_chat: + prompt = self.apply_multimodal_chat_transformation( + prompt, None) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=new_output_len, + lora_request=lora_request, + )) + self.maybe_oversample_requests(samples, num_requests) + return samples + + +# ----------------------------------------------------------------------------- +# Custom Dataset Implementation +# ----------------------------------------------------------------------------- + + +class CustomDataset(BenchmarkDataset): + """ + Implements the Custom dataset. Loads data from a JSONL file and generates + sample requests based on conversation turns. E.g., + ``` + {"prompt": "What is the capital of India?"} + {"prompt": "What is the capital of Iran?"} + {"prompt": "What is the capital of China?"} + ``` + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + # self.data will be a list of dictionaries + # e.g., [{"prompt": "What is the capital of India?"}, ...] + # This will be the standardized format which load_data() + # has to convert into depending on the filetype of dataset_path. + # sample() will assume this standardized format of self.data + self.data = [] + + # Load the JSONL file + if self.dataset_path.endswith(".jsonl"): + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, + lines=True) + + # check if the JSONL file has a 'prompt' column + if "prompt" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'prompt' column.") + + # Convert each row to a dictionary and append to self.data + # This will convert the DataFrame to a list of dictionaries + # where each dictionary corresponds to a row in the DataFrame. + # This is the standardized format we want for self.data + for _, row in jsonl_data.iterrows(): + self.data.append(row.to_dict()) + else: + raise NotImplementedError( + "Only JSONL format is supported for CustomDataset.") + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + skip_chat_template: bool = False, + **kwargs, + ) -> list: + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["prompt"] + + # apply template + if not skip_chat_template: + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Sonnet Dataset Implementation +# ----------------------------------------------------------------------------- + + +class SonnetDataset(BenchmarkDataset): + """ + Simplified implementation of the Sonnet dataset. Loads poem lines from a + text file and generates sample requests. Default values here copied from + `benchmark_serving.py` for the sonnet dataset. + """ + + DEFAULT_PREFIX_LEN = 200 + DEFAULT_INPUT_LEN = 550 + DEFAULT_OUTPUT_LEN = 150 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if not self.dataset_path: + raise ValueError("dataset_path must be provided.") + with open(self.dataset_path, encoding="utf-8") as f: + self.data = f.readlines() + + def sample( + self, + tokenizer, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + return_prompt_formatted: bool = False, + **kwargs, + ) -> list: + # Calculate average token length for a poem line. + tokenized_lines = [tokenizer(line).input_ids for line in self.data] + avg_len = sum(len(tokens) + for tokens in tokenized_lines) / len(tokenized_lines) + + # Build the base prompt. + base_prompt = "Pick as many lines as you can from these poem lines:\n" + base_msg = [{"role": "user", "content": base_prompt}] + base_fmt = tokenizer.apply_chat_template(base_msg, + add_generation_prompt=True, + tokenize=False) + base_offset = len(tokenizer(base_fmt).input_ids) + if input_len <= base_offset: + raise ValueError( + f"'input_len' must be higher than the base prompt length " + f"({base_offset}).") + + # Determine how many poem lines to use. + num_input_lines = round((input_len - base_offset) / avg_len) + num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0) + prefix_lines = self.data[:num_prefix_lines] + + samples = [] + while len(samples) < num_requests: + extra_lines = random.choices(self.data, + k=num_input_lines - num_prefix_lines) + prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" + msg = [{"role": "user", "content": prompt}] + prompt_formatted = tokenizer.apply_chat_template( + msg, add_generation_prompt=True, tokenize=False) + prompt_len = len(tokenizer(prompt_formatted).input_ids) + if prompt_len <= input_len: + samples.append( + SampleRequest( + prompt=prompt_formatted + if return_prompt_formatted else prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + return samples + + +# ----------------------------------------------------------------------------- +# BurstGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class BurstGPTDataset(BenchmarkDataset): + """ + Implements the BurstGPT dataset. Loads data from a CSV file and generates + sample requests based on synthetic prompt generation. Only rows with Model + "GPT-4" and positive response tokens are used. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self, ): + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + df = pd.read_csv(self.dataset_path) + # Filter to keep only GPT-4 rows. + gpt4_df = df[df["Model"] == "GPT-4"] + # Remove failed requests (where Response tokens is 0 or less). + gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] + # Sample the desired number of rows. + self.data = gpt4_df + + def _sample_loaded_data(self, num_requests: int) -> list: + if num_requests <= len(self.data): + data = self.data.sample(n=num_requests, + random_state=self.random_seed) + else: + data = self.data.sample( + n=num_requests, + random_state=self.random_seed, + replace=True, + ) + # Convert the dataframe to a list of lists. + return data.values.tolist() + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + **kwargs, + ) -> list[SampleRequest]: + samples = [] + data = self._sample_loaded_data(num_requests=num_requests) + for i in range(num_requests): + input_len = int(data[i][2]) + output_len = int(data[i][3]) + lora_req, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + vocab_size = tokenizer.vocab_size + # Generate a synthetic prompt: a list of token IDs computed as (i + + # j) modulo vocab_size. + token_ids = [(i + j) % vocab_size for j in range(input_len)] + prompt = tokenizer.decode(token_ids) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=output_len, + lora_request=lora_req, + )) + return samples + + +# ----------------------------------------------------------------------------- +# HuggingFace Dataset Base Implementation +# ----------------------------------------------------------------------------- +class HuggingFaceDataset(BenchmarkDataset): + """Base class for datasets hosted on HuggingFace.""" + + SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set() + + def __init__( + self, + dataset_path: str, + dataset_split: str, + dataset_subset: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(dataset_path=dataset_path, **kwargs) + + self.dataset_split = dataset_split + self.dataset_subset = dataset_subset + self.load_data() + + def load_data(self) -> None: + """Load data from HuggingFace datasets.""" + self.data = load_dataset( + self.dataset_path, + name=self.dataset_subset, + split=self.dataset_split, + streaming=True, + ) + self.data = self.data.shuffle(seed=self.random_seed) + + +# ----------------------------------------------------------------------------- +# Conversation Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ConversationDataset(HuggingFaceDataset): + """Dataset for conversation data with multimodal support.""" + SUPPORTED_DATASET_PATHS = { + 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' + } + IS_MULTIMODAL = True + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + # Filter examples with at least 2 conversations + filtered_data = self.data.filter( + lambda x: len(x["conversations"]) >= 2) + sampled_requests = [] + dynamic_output = output_len is None + + for item in filtered_data: + if len(sampled_requests) >= num_requests: + break + conv = item["conversations"] + prompt, completion = conv[0]["value"], conv[1]["value"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len): + continue + mm_content = process_image( + item["image"]) if "image" in item else None + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len and output len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Vision Arena Dataset Implementation +# ----------------------------------------------------------------------------- + + +class VisionArenaDataset(HuggingFaceDataset): + """ + Vision Arena Dataset. + """ + + DEFAULT_OUTPUT_LEN = 128 + SUPPORTED_DATASET_PATHS = { + "lmarena-ai/VisionArena-Chat": + lambda x: x["conversation"][0][0]["content"], + "lmarena-ai/vision-arena-bench-v0.1": + lambda x: x["turns"][0][0]["content"] + } + IS_MULTIMODAL = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) + if parser_fn is None: + raise ValueError( + f"Unsupported dataset path: {self.dataset_path}") + prompt = parser_fn(item) + mm_content = process_image(item["images"][0]) + prompt_len = len(tokenizer(prompt).input_ids) + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Instruct Coder Dataset Implementation +# ----------------------------------------------------------------------------- + + +class InstructCoderDataset(HuggingFaceDataset): + """ + InstructCoder Dataset. + https://huggingface.co/datasets/likaixin/InstructCoder + + InstructCoder is the dataset designed for general code editing. It consists + of 114,239 instruction-input-output triplets, and covers multiple distinct + code editing scenario. + """ + + DEFAULT_OUTPUT_LEN = 200 # this is the average default output length + SUPPORTED_DATASET_PATHS = { + "likaixin/InstructCoder", + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = f"{item['input']}\n\n{item['instruction']} Just output \ + the code, do not include any explanation." + + # apply template + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# MT-Bench Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MTBenchDataset(HuggingFaceDataset): + """ + MT-Bench Dataset. + https://huggingface.co/datasets/philschmid/mt-bench + + We create a single turn dataset for MT-Bench. + This is similar to Spec decoding benchmark setup in vLLM + https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 + """ # noqa: E501 + + DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM + SUPPORTED_DATASET_PATHS = { + "philschmid/mt-bench", + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["turns"][0] + + # apply template + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# AIMO Dataset Implementation +# ----------------------------------------------------------------------------- + + +class AIMODataset(HuggingFaceDataset): + """ + Dataset class for processing a AIMO dataset with reasoning questions. + """ + SUPPORTED_DATASET_PATHS = { + "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", + "AI-MO/NuminaMath-CoT" + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs) -> list: + sampled_requests = [] + dynamic_output = output_len is None + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt, completion = item['problem'], item["solution"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence(prompt_len, + completion_len, + max_prompt_len=2048, + max_total_len=32000): + continue + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=None, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Next Edit Prediction Dataset Implementation +# ----------------------------------------------------------------------------- + + +zeta_prompt = """### Instruction: +You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. + +### User Edits: + +{} + +### User Excerpt: + +{} + +### Response: + +""" # noqa: E501 + + +def _format_zeta_prompt( + sample: dict, + original_start_marker: str = "<|editable_region_start|>") -> dict: + """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. + + This function formats examples from the NEP dataset + into prompts and expected outputs. It could be + further extended to support more NEP datasets. + + Args: + sample: The dataset sample containing events, + inputs, and outputs. + original_start_marker: The marker indicating the + start of the editable region. Defaults to + "<|editable_region_start|>". + + Returns: + A dictionary with the formatted prompts and expected outputs. + """ + events = sample["events"] + input = sample["input"] + output = sample["output"] + prompt = zeta_prompt.format(events, input) + + # following the original implementation, extract the focused region + # from the raw output + output_start_index = output.find(original_start_marker) + output_focused_region = output[output_start_index:] + expected_output = output_focused_region + + return {"prompt": prompt, "expected_output": expected_output} + + +class NextEditPredictionDataset(HuggingFaceDataset): + """ + Dataset class for processing a Next Edit Prediction dataset. + """ + + SUPPORTED_DATASET_PATHS = { + "zed-industries/zeta", + } + MAPPING_PROMPT_FUNCS = { + "zed-industries/zeta": _format_zeta_prompt, + } + + def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + **kwargs): + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( + self.dataset_path) + if formatting_prompt_func is None: + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + samples = [] + for sample in self.data: + sample = formatting_prompt_func(sample) + samples.append( + SampleRequest( + prompt=sample["prompt"], + prompt_len=len(tokenizer(sample["prompt"]).input_ids), + expected_output_len=len( + tokenizer(sample["expected_output"]).input_ids), + )) + if len(samples) >= num_requests: + break + self.maybe_oversample_requests(samples, num_requests) + return samples + + +# ----------------------------------------------------------------------------- +# ASR Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ASRDataset(HuggingFaceDataset): + """ + Dataset class for processing a ASR dataset for transcription. + Tested on the following set: + + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | Dataset | Domain | Speaking Style | hf-subset | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | TED-LIUM | TED talks | Oratory | release1, release2, release3| + | | | | release3-speaker-adaptation | + | VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... | + | LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" | + | GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test | + | SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test | + | AMI | Meetings | Spontaneous | ihm, sdm | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + + """ # noqa: E501 + + SUPPORTED_DATASET_PATHS = { + "openslr/librispeech_asr", + "facebook/voxpopuli", + "LIUM/tedlium", + "edinburghcstr/ami", + "speechcolab/gigaspeech", + "kensho/spgispeech", + } + + DEFAULT_OUTPUT_LEN = 128 + IS_MULTIMODAL = True + + # TODO Whisper-specific. Abstract interface when more models are supported. + TRANSCRIPTION_PREAMBLE = ( + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>") + skip_long_audios: bool = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + prompt = ASRDataset.TRANSCRIPTION_PREAMBLE + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests = [] + skipped = 0 + for item in self.data: + if len(sampled_requests) >= num_requests: + break + audio = item["audio"] + y, sr = audio["array"], audio["sampling_rate"] + duration_s = librosa.get_duration(y=y, sr=sr) + # Whisper max supported duration + if self.skip_long_audios and duration_s > 30: + skipped += 1 + continue + + mm_content = {"audio": (y, sr)} + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + if skipped: + logger.warning( + "%d samples discarded from dataset due to" + " their length being greater than" + " what Whisper supports.", + skipped, + ) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/benchmarks/endpoint_request_func.py b/benchmarks/endpoint_request_func.py new file mode 100644 index 0000000..aba60ed --- /dev/null +++ b/benchmarks/endpoint_request_func.py @@ -0,0 +1,381 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""The request function for API endpoints.""" + +import io +import json +import os +import sys +import time +import traceback +from dataclasses import dataclass, field +from typing import Optional + +import aiohttp +from tqdm.asyncio import tqdm + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + +@dataclass +class RequestFuncInput: + """The input for the request function.""" + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + model_name: Optional[str] = None + logprobs: Optional[int] = None + extra_body: Optional[dict] = None + multi_modal_content: Optional[dict] = None + ignore_eos: bool = False + language: Optional[str] = None + + +@dataclass +class RequestFuncOutput: + """The output of the request function including metrics.""" + generated_text: str = "" + success: bool = False + latency: float = 0.0 + output_tokens: int = 0 + ttft: float = 0.0 # Time to first token + itl: list[float] = field( + default_factory=list) # list of inter-token latencies + tpot: float = 0.0 # avg next-token latencies + prompt_len: int = 0 + error: str = "" + + +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """The async request function for the OpenAI Completions API. + + Args: + request_func_input: The input for the request function. + pbar: The progress bar to display the progress. + + Returns: + The output of the request function. + """ + api_url = request_func_input.api_url + assert api_url.endswith( + ("completions", "profile") + ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "repetition_penalty": 1.0, + "max_tokens": request_func_input.output_len, + "logprobs": request_func_input.logprobs, + "stream": True, + "stream_options": { + "include_usage": True, + }, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + first_chunk_received = False + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + if first_chunk_received: + output.success = True + else: + output.success = False + output.error = ( + "Never received a valid chunk to calculate TTFT." + "This response will be marked as failed!") + output.generated_text = generated_text + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_chat_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith(("chat/completions", "profile")), ( + "OpenAI Chat Completions API URL must end with 'chat/completions'.") + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) + payload = { + "model": + request_func_input.model_name + if request_func_input.model_name else request_func_input.model, + "messages": [ + { + "role": "user", + "content": content + }, + ], + "temperature": + 0.0, + "max_completion_tokens": + request_func_input.output_len, + "stream": + True, + "stream_options": { + "include_usage": True, + }, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_audio( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + # Lazy import without PlaceholderModule to avoid vllm dep. + import soundfile + + api_url = request_func_input.api_url + assert api_url.endswith(("transcriptions", "translations")), ( + "OpenAI Chat Completions API URL must end with 'transcriptions' ") + "or `translations`." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + payload = { + "model": + request_func_input.model_name + if request_func_input.model_name else request_func_input.model, + "temperature": + 0.0, + "max_completion_tokens": + request_func_input.output_len, + "stream": + True, + "language": + "en", + # Flattened due to multipart/form-data + "stream_include_usage": + True, + "stream_continuous_usage_stats": + True, + } + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + # Send audio file + def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + with to_bytes(*request_func_input.multi_modal_content["audio"]) as f: + form = aiohttp.FormData() + form.add_field("file", f, content_type="audio/wav") + for key, value in payload.items(): + form.add_field(key, str(value)) + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, + data=form, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get( + "content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +# TODO: Add more request functions for different API protocols. +ASYNC_REQUEST_FUNCS = { + "vllm": async_request_openai_completions, + "openai": async_request_openai_completions, + "openai-chat": async_request_openai_chat_completions, + "openai-audio": async_request_openai_audio, +} + +OPENAI_COMPATIBLE_BACKENDS = [ + k for k, v in ASYNC_REQUEST_FUNCS.items() + if v in (async_request_openai_completions, + async_request_openai_chat_completions) +] diff --git a/benchmarks/latency.py b/benchmarks/latency.py new file mode 100644 index 0000000..5c6124d --- /dev/null +++ b/benchmarks/latency.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark the latency of processing a single batch of requests.""" + +import argparse +import dataclasses +import json +import os +import time +from typing import Any, Optional + +import numpy as np +from tqdm import tqdm + +import vllm.envs as envs +from vllm import LLM, SamplingParams +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.engine.arg_utils import EngineArgs +from vllm.inputs import PromptType +from vllm.sampling_params import BeamSearchParams + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any]) -> None: + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={"latency": results["latencies"]}, + extra_info={k: results[k] + for k in ["avg_latency", "percentiles"]}) + if pt_records: + pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--input-len", type=int, default=32) + parser.add_argument("--output-len", type=int, default=128) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument( + "--n", + type=int, + default=1, + help="Number of generated sequences per prompt.", + ) + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument( + "--num-iters-warmup", + type=int, + default=10, + help="Number of iterations to run for warmup.", + ) + parser.add_argument("--num-iters", + type=int, + default=30, + help="Number of iterations to run.") + parser.add_argument( + "--profile", + action="store_true", + help="profile the generation process of a single batch", + ) + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Path to save the latency results in JSON format.", + ) + parser.add_argument( + "--disable-detokenize", + action="store_true", + help=("Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)"), + ) + + parser = EngineArgs.add_cli_args(parser) + # V1 enables prefix caching by default which skews the latency + # numbers. We need to disable prefix caching by default. + parser.set_defaults(enable_prefix_caching=False) + + +def main(args: argparse.Namespace): + if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: + raise OSError( + "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. " + "Please set it to a valid path to use torch profiler.") + engine_args = EngineArgs.from_cli_args(args) + + # NOTE(woosuk): If the request cannot be processed in a single batch, + # the engine will automatically process the request in multiple batches. + llm = LLM(**dataclasses.asdict(engine_args)) + assert llm.llm_engine.model_config.max_model_len >= ( + args.input_len + + args.output_len), ("Please ensure that max_model_len is greater than" + " the sum of input_len and output_len.") + + sampling_params = SamplingParams( + n=args.n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=args.output_len, + detokenize=not args.disable_detokenize, + ) + dummy_prompt_token_ids = np.random.randint(10000, + size=(args.batch_size, + args.input_len)) + dummy_prompts: list[PromptType] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] + + def llm_generate(): + if not args.use_beam_search: + llm.generate(dummy_prompts, + sampling_params=sampling_params, + use_tqdm=False) + else: + llm.beam_search( + dummy_prompts, + BeamSearchParams( + beam_width=args.n, + max_tokens=args.output_len, + ignore_eos=True, + ), + ) + + def run_to_completion(profile_dir: Optional[str] = None): + if profile_dir: + llm.start_profile() + llm_generate() + llm.stop_profile() + else: + start_time = time.perf_counter() + llm_generate() + end_time = time.perf_counter() + latency = end_time - start_time + return latency + + print("Warming up...") + for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): + run_to_completion(profile_dir=None) + + if args.profile: + profile_dir = envs.VLLM_TORCH_PROFILER_DIR + print(f"Profiling (results will be saved to '{profile_dir}')...") + run_to_completion(profile_dir=profile_dir) + return + + # Benchmark. + latencies = [] + for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): + latencies.append(run_to_completion(profile_dir=None)) + latencies = np.array(latencies) + percentages = [10, 25, 50, 75, 90, 99] + percentiles = np.percentile(latencies, percentages) + print(f"Avg latency: {np.mean(latencies)} seconds") + for percentage, percentile in zip(percentages, percentiles): + print(f"{percentage}% percentile latency: {percentile} seconds") + + # Output JSON results if specified + if args.output_json: + results = { + "avg_latency": np.mean(latencies), + "latencies": latencies.tolist(), + "percentiles": dict(zip(percentages, percentiles.tolist())), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) diff --git a/benchmarks/serve.py b/benchmarks/serve.py new file mode 100644 index 0000000..019ebcf --- /dev/null +++ b/benchmarks/serve.py @@ -0,0 +1,1135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +r"""Benchmark online serving throughput. + +On the server side, run one of the following commands +to launch the vLLM OpenAI API server: + vllm serve + +On the client side, run: + vllm bench serve \ + --endpoint-type \ + --label \ + --model \ + --dataset-name \ + --request-rate \ + --num-prompts +""" +import argparse +import asyncio +import gc +import json +import os +import random +import time +import warnings +from collections.abc import AsyncGenerator, Iterable +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Optional + +import numpy as np +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase + +from vllm.benchmarks.datasets import (AIMODataset, ASRDataset, BurstGPTDataset, + ConversationDataset, HuggingFaceDataset, + InstructCoderDataset, MTBenchDataset, + NextEditPredictionDataset, RandomDataset, + SampleRequest, ShareGPTDataset, + SonnetDataset, VisionArenaDataset) +from vllm.benchmarks.endpoint_request_func import (ASYNC_REQUEST_FUNCS, + OPENAI_COMPATIBLE_BACKENDS, + RequestFuncInput, + RequestFuncOutput) +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.transformers_utils.tokenizer import get_tokenizer + +MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + request_throughput: float + request_goodput: float + output_throughput: float + total_token_throughput: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + percentiles_ttft_ms: list[tuple[float, float]] + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + percentiles_tpot_ms: list[tuple[float, float]] + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + percentiles_itl_ms: list[tuple[float, float]] + # E2EL stands for end-to-end latency per request. + # It is the time taken on the client side from sending + # a request to receiving a complete response. + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: list[tuple[float, float]] + + +async def get_request( + input_requests: list[SampleRequest], + request_rate: float, + burstiness: float = 1.0, +) -> AsyncGenerator[SampleRequest, None]: + """ + Asynchronously generates requests at a specified rate + with OPTIONAL burstiness. + + Args: + input_requests: + A list of input requests, each represented as a SampleRequest. + request_rate: + The rate at which requests are generated (requests/s). + burstiness (optional): + The burstiness factor of the request generation. + Only takes effect when request_rate is not inf. + Default value is 1, which follows a Poisson process. + Otherwise, the request intervals follow a gamma distribution. + A lower burstiness value (0 < burstiness < 1) results + in more bursty requests, while a higher burstiness value + (burstiness > 1) results in a more uniform arrival of requests. + """ + input_requests: Iterable[SampleRequest] = iter(input_requests) + + # Calculate scale parameter theta to maintain the desired request_rate. + assert burstiness > 0, ( + f"A positive burstiness factor is expected, but given {burstiness}.") + theta = 1.0 / (request_rate * burstiness) + + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the gamma distribution. + # If burstiness is 1, it follows exponential distribution. + interval = np.random.gamma(shape=burstiness, scale=theta) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: list[SampleRequest], + outputs: list[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + selected_percentiles: list[float], + goodput_config_dict: dict[str, float], +) -> tuple[BenchmarkMetrics, list[int]]: + """Calculate the metrics for the benchmark. + + Args: + input_requests: The input requests. + outputs: The outputs of the requests. + dur_s: The duration of the benchmark. + tokenizer: The tokenizer to use. + selected_percentiles: The percentiles to select. + goodput_config_dict: The goodput configuration. + + Returns: + A tuple of the benchmark metrics and the actual output lengths. + """ + actual_output_lens: list[int] = [] + total_input = 0 + completed = 0 + good_completed = 0 + itls: list[float] = [] + tpots: list[float] = [] + all_tpots: list[float] = [] + ttfts: list[float] = [] + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_tokens + + if not output_len: + # We use the tokenizer to count the number of output tokens + # for some serving backends instead of looking at + # len(outputs[i].itl) since multiple output tokens may be + # bundled together + # Note : this may inflate the output token count slightly + output_len = len( + tokenizer(outputs[i].generated_text, + add_special_tokens=False).input_ids) + actual_output_lens.append(output_len) + total_input += input_requests[i].prompt_len + tpot = 0 + if output_len > 1: + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) + tpots.append(tpot) + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) + completed += 1 + else: + actual_output_lens.append(0) + + if goodput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in goodput_config_dict: + valid_metrics.append(ttfts) + slo_values.append(goodput_config_dict["ttft"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "tpot" in goodput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append(goodput_config_dict["tpot"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "e2el" in goodput_config_dict: + valid_metrics.append(e2els) + slo_values.append(goodput_config_dict["e2el"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(actual_output_lens), + request_throughput=completed / dur_s, + request_goodput=good_completed / dur_s, + output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by the endpoint + std_ttft_ms=np.std(ttfts or 0) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], + mean_tpot_ms=np.mean(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], + mean_itl_ms=np.mean(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], + ) + + return metrics, actual_output_lens + + +async def benchmark( + endpoint_type: str, + api_url: str, + base_url: str, + model_id: str, + model_name: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: list[SampleRequest], + logprobs: Optional[int], + request_rate: float, + burstiness: float, + disable_tqdm: bool, + profile: bool, + selected_percentile_metrics: list[str], + selected_percentiles: list[float], + ignore_eos: bool, + goodput_config_dict: dict[str, float], + max_concurrency: Optional[int], + lora_modules: Optional[Iterable[str]], + extra_body: Optional[dict], +): + if endpoint_type in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[endpoint_type] + else: + raise ValueError(f"Unknown endpoint_type: {endpoint_type}") + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0].prompt, + input_requests[0].prompt_len, + input_requests[0].expected_output_len, + input_requests[0].multi_modal_data, + ) + + assert test_mm_content is None or isinstance(test_mm_content, dict) + test_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) + + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") + + if lora_modules: + # For each input request, choose a LoRA module at random. + lora_modules = iter( + [random.choice(lora_modules) for _ in range(len(input_requests))]) + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput(model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler started") + + if burstiness == 1.0: + distribution = "Poisson process" + else: + distribution = "Gamma distribution" + + print(f"Traffic request rate: {request_rate}") + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + # This can be used once the minimum Python version is 3.10 or higher, + # and it will simplify the code in limited_request_func. + # semaphore = (asyncio.Semaphore(max_concurrency) + # if max_concurrency else contextlib.nullcontext()) + semaphore = (asyncio.Semaphore(max_concurrency) + if max_concurrency else None) + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + + benchmark_start_time = time.perf_counter() + tasks: list[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate, burstiness): + prompt, prompt_len, output_len, mm_content = ( + request.prompt, + request.prompt_len, + request.expected_output_len, + request.multi_modal_data, + ) + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput(model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body) + tasks.append( + asyncio.create_task( + limited_request_func(request_func_input=request_func_input, + pbar=pbar))) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) + + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler stopped") + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + ) + + print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", + benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", + metrics.request_throughput)) + if goodput_config_dict: + print("{:<40} {:<10.2f}".format("Request goodput (req/s):", + metrics.request_goodput)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "request_goodput:": + metrics.request_goodput if goodput_config_dict else None, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"))) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms") + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms") + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + + print("=" * 50) + + return result + + +def check_goodput_args(args): + # Check and parse goodput arguments + goodput_config_dict = {} + VALID_NAMES = ["ttft", "tpot", "e2el"] + if args.goodput: + goodput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in goodput_config_dict.items(): + if slo_name not in VALID_NAMES: + raise ValueError( + f"Invalid metric name found, {slo_name}: {slo_val}. " + "The service level objective name should be one of " + f"{str(VALID_NAMES)}. ") + if slo_val < 0: + raise ValueError( + f"Invalid value found, {slo_name}: {slo_val}. " + "The service level objective value should be " + "non-negative.") + return goodput_config_dict + + +def parse_goodput(slo_pairs): + goodput_config_dict = {} + try: + for slo_pair in slo_pairs: + slo_name, slo_val = slo_pair.split(":") + goodput_config_dict[slo_name] = float(slo_val) + except ValueError as err: + raise argparse.ArgumentTypeError( + "Invalid format found for service level objectives. " + "Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is a " + "number in milliseconds.") from err + return goodput_config_dict + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any], + file_name: str) -> None: + metrics = [ + "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", + "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", + "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" + ] + # These raw data might be useful, but they are rather big. They can be added + # later if needed + ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={k: [results[k]] + for k in metrics}, + extra_info={ + k: results[k] + for k in results if k not in metrics and k not in ignored_metrics + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--endpoint-type", + type=str, + default="openai", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + ) + parser.add_argument( + "--label", + type=str, + default=None, + help="The label (prefix) of the benchmark results. If not specified, " + "the endpoint type will be used as the label.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + # Use 127.0.0.1 here instead of localhost to force the use of ipv4 + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--endpoint", + type=str, + default="/v1/completions", + help="API endpoint.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="random", + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", + type=str, + default=None, + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.", + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.") + + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the model.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help= + "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + ) + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.", + ) + parser.add_argument( + "--logprobs", + type=int, + default=None, + help=("Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed"), + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times.", + ) + parser.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor of the request generation. " + "Only take effect when request_rate is not inf. " + "Default value is 1, which follows Poisson process. " + "Otherwise, the request intervals follow a gamma distribution. " + "A lower burstiness value (0 < burstiness < 1) results in more " + "bursty requests. A higher burstiness value (burstiness > 1) " + "results in a more uniform arrival of requests.", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from huggingface", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "VLLM_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--save-result", + action="store_true", + help="Specify to save benchmark results to a json file", + ) + parser.add_argument( + "--save-detailed", + action="store_true", + help="When saving the results, whether to include per request " + "information such as response, error, ttfs, tpots, etc.", + ) + parser.add_argument( + "--append-result", + action="store_true", + help="Append the benchmark result to the existing json file.", + ) + parser.add_argument( + "--metadata", + metavar="KEY=VALUE", + nargs="*", + help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) " + "for metadata of this run to be saved in the result JSON file " + "for record keeping purposes.", + ) + parser.add_argument( + "--result-dir", + type=str, + default=None, + help="Specify directory to save benchmark json results." + "If not specified, results are saved in the current directory.", + ) + parser.add_argument( + "--result-filename", + type=str, + default=None, + help="Specify the filename to save benchmark json results." + "If not specified, results will be saved in " + "{label}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" # noqa + " format.", + ) + parser.add_argument( + "--ignore-eos", + action="store_true", + help="Set ignore_eos flag when sending the benchmark request." + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-separated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". ") + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-separated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\"." + "Use \"--percentile-metrics\" to select metrics.", + ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help="Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is in " + "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + "separated by spaces. Allowed request level metric names are " + "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) + + # group for dataset specific arguments + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.", + ) + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help= + "Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help= + "Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range ratio for sampling input/output length, " + "used only for random sampling. Must be in the range [0, 1) to define " + "a symmetric sampling range" + "[length * (1 - range_ratio), length * (1 + range_ratio)].", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before random " + " context. The length range of context in a random " + " request is [random-prefix-len, " + " random-prefix-len + random-prefix-len * random-range-ratio).") + + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + + sampling_group = parser.add_argument_group("sampling parameters") + sampling_group.add_argument( + "--top-p", + type=float, + default=None, + help="Top-p sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--top-k", + type=int, + default=None, + help="Top-k sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--min-p", + type=float, + default=None, + help="Min-p sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--temperature", + type=float, + default=None, + help="Temperature sampling parameter. Only has effect on " + "openai-compatible backends. If not specified, default to greedy " + "decoding (i.e. temperature==0.0).", + ) + + parser.add_argument( + '--tokenizer-mode', + type=str, + default="auto", + choices=['auto', 'slow', 'mistral', 'custom'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer. \n*' + '"custom" will use --tokenizer to select the preregistered tokenizer.') + + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ") + + parser.add_argument("--lora-modules", + nargs='+', + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.") + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + endpoint_type = args.endpoint_type + label = args.label + model_id = args.model + model_name = args.served_model_name + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + tokenizer_mode = args.tokenizer_mode + + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" + + tokenizer = get_tokenizer(tokenizer_id, + tokenizer_mode=tokenizer_mode, + trust_remote_code=args.trust_remote_code) + + if args.dataset_name is None: + raise ValueError( + "Please specify '--dataset-name' and the corresponding " + "'--dataset-path' if required.") + + if args.dataset_name == "sonnet": + dataset = SonnetDataset(dataset_path=args.dataset_path) + # For the "sonnet" dataset, formatting depends on the backend. + if args.backend == "openai-chat": + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=False, + ) + else: + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=True, + ) + + elif args.dataset_name == "hf": + # all following datasets are implemented from the + # HuggingFaceDataset base class + if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + dataset_class = VisionArenaDataset + args.hf_split = "train" + args.hf_subset = None + elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + dataset_class = InstructCoderDataset + args.hf_split = "train" + elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: + dataset_class = MTBenchDataset + args.hf_split = "train" + elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ConversationDataset + args.hf_split = "train" + elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + dataset_class = AIMODataset + args.hf_split = "train" + elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 + dataset_class = NextEditPredictionDataset + args.hf_split = "train" + elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ASRDataset + args.hf_split = "train" + else: + supported_datasets = set([ + dataset_name for cls in HuggingFaceDataset.__subclasses__() + for dataset_name in cls.SUPPORTED_DATASET_PATHS + ]) + raise ValueError( + f"Unsupported dataset path: {args.dataset_path}. " + "Huggingface dataset only supports dataset_path" + f" from one of following: {supported_datasets}. " + "Please consider contributing if you would " + "like to add support for additional dataset formats.") + + if dataset_class.IS_MULTIMODAL and endpoint_type not in [ + "openai-chat", + "openai-audio", + ]: + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' and " + "'openai-audio' backend.") + input_requests = dataset_class( + dataset_path=args.dataset_path, + dataset_subset=args.hf_subset, + dataset_split=args.hf_split, + random_seed=args.seed, + ).sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.hf_output_len, + ) + + else: + # For datasets that follow a similar structure, use a mapping. + dataset_mapping = { + "sharegpt": + lambda: ShareGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), + "burstgpt": + lambda: BurstGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path). + sample(tokenizer=tokenizer, num_requests=args.num_prompts), + "random": + lambda: RandomDataset(dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + input_len=args.random_input_len, + output_len=args.random_output_len, + range_ratio=args.random_range_ratio, + ), + } + + try: + input_requests = dataset_mapping[args.dataset_name]() + except KeyError as err: + raise ValueError(f"Unknown dataset: {args.dataset_name}") from err + goodput_config_dict = check_goodput_args(args) + + # Collect the sampling parameters. + sampling_params = { + k: v + for k, v in { + "top_p": args.top_p, + "top_k": args.top_k, + "min_p": args.min_p, + "temperature": args.temperature, + }.items() if v is not None + } + + # Sampling parameters are only supported by openai-compatible backend. + if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: + raise ValueError("Sampling parameters are only supported by " + "openai-compatible backends.") + + if "temperature" not in sampling_params: + sampling_params["temperature"] = 0.0 # Default to greedy decoding. + + # Avoid GC processing "static" data - reduce pause times. + gc.collect() + gc.freeze() + + benchmark_result = asyncio.run( + benchmark( + endpoint_type=endpoint_type, + api_url=api_url, + base_url=base_url, + model_id=model_id, + model_name=model_name, + tokenizer=tokenizer, + input_requests=input_requests, + logprobs=args.logprobs, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, + extra_body=sampling_params, + )) + + # Save config and results to json + if args.save_result or args.append_result: + result_json: dict[str, Any] = {} + + # Setup + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") + result_json["date"] = current_dt + result_json["endpoint_type"] = endpoint_type + result_json["label"] = label + result_json["model_id"] = model_id + result_json["tokenizer_id"] = tokenizer_id + result_json["num_prompts"] = args.num_prompts + + # Metadata + if args.metadata: + for item in args.metadata: + if "=" in item: + kvstring = item.split("=") + result_json[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid metadata format. Please use KEY=VALUE format." + ) + + # Traffic + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency + + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} + + if not args.save_detailed: + # Remove fields with too many data points + for field in [ + "input_lens", + "output_lens", + "ttfts", + "itls", + "generated_texts", + "errors", + ]: + if field in result_json: + del result_json[field] + if field in benchmark_result: + del benchmark_result[field] + + # Save to file + base_model_id = model_id.split("/")[-1] + max_concurrency_str = (f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None else "") + label = label or endpoint_type + file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa + if args.result_filename: + file_name = args.result_filename + if args.result_dir: + os.makedirs(args.result_dir, exist_ok=True) + file_name = os.path.join(args.result_dir, file_name) + with open(file_name, + mode="a+" if args.append_result else "w", + encoding="utf-8") as outfile: + # Append a newline. + if args.append_result and outfile.tell() != 0: + outfile.write("\n") + json.dump(result_json, outfile) + save_to_pytorch_benchmark_format(args, result_json, file_name) diff --git a/benchmarks/throughput.py b/benchmarks/throughput.py new file mode 100644 index 0000000..be9ea39 --- /dev/null +++ b/benchmarks/throughput.py @@ -0,0 +1,609 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark offline inference throughput.""" +import argparse +import dataclasses +import json +import os +import random +import time +import warnings +from typing import Any, Optional, Union + +import torch +import uvloop +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset, + ConversationDataset, + InstructCoderDataset, RandomDataset, + SampleRequest, ShareGPTDataset, + SonnetDataset, VisionArenaDataset) +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams +from vllm.utils import merge_async_iterators + + +def run_vllm( + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False, +) -> tuple[float, Optional[list[RequestOutput]]]: + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") + # Add the requests to the engine. + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + sampling_params: list[SamplingParams] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + lora_requests: Optional[list[LoRARequest]] = None + if engine_args.enable_lora: + lora_requests = [request.lora_request for request in requests] + + use_beam_search = False + + outputs = None + if not use_beam_search: + start = time.perf_counter() + outputs = llm.generate(prompts, + sampling_params, + lora_request=lora_requests, + use_tqdm=True) + end = time.perf_counter() + else: + assert lora_requests is None, "BeamSearch API does not support LoRA" + prompts = [request.prompt for request in requests] + # output_len should be the same for all requests. + output_len = requests[0][2] + for request in requests: + assert request.expected_output_len == output_len + start = time.perf_counter() + llm.beam_search( + prompts, + BeamSearchParams( + beam_width=n, + max_tokens=output_len, + ignore_eos=True, + )) + end = time.perf_counter() + return end - start, outputs + + +def run_vllm_chat( + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: + """ + Run vLLM chat benchmark. This function is recommended ONLY for benchmarking + multimodal models as it properly handles multimodal inputs and chat + formatting. For non-multimodal models, use run_vllm() instead. + """ + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of " + "prompt_len and expected_output_len for all requests.") + + prompts = [] + sampling_params: list[SamplingParams] = [] + for request in requests: + prompts.append(request.prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + start = time.perf_counter() + outputs = llm.chat(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() + return end - start, outputs + + +async def run_vllm_async( + requests: list[SampleRequest], + n: int, + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, + disable_detokenize: bool = False, +) -> float: + from vllm import SamplingParams + + async with build_async_engine_client_from_engine_args( + engine_args, disable_frontend_multiprocessing) as llm: + model_config = await llm.get_model_config() + assert all( + model_config.max_model_len >= (request.prompt_len + + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") + + # Add the requests to the engine. + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + sampling_params: list[SamplingParams] = [] + lora_requests: list[Optional[LoRARequest]] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + lora_requests.append(request.lora_request) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp, + lr) in enumerate(zip(prompts, sampling_params, lora_requests)): + generator = llm.generate(prompt, + sp, + lora_request=lr, + request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + +def run_hf( + requests: list[SampleRequest], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, + max_batch_size: int, + trust_remote_code: bool, + disable_detokenize: bool = False, +) -> float: + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() + + pbar = tqdm(total=len(requests)) + start = time.perf_counter() + batch: list[str] = [] + max_prompt_len = 0 + max_output_len = 0 + for i in range(len(requests)): + prompt = requests[i].prompt + prompt_len = requests[i].prompt_len + output_len = requests[i].expected_output_len + # Add the prompt to the batch. + batch.append(prompt) + max_prompt_len = max(max_prompt_len, prompt_len) + max_output_len = max(max_output_len, output_len) + if len(batch) < max_batch_size and i != len(requests) - 1: + # Check if we can add more requests to the batch. + next_prompt_len = requests[i + 1].prompt_len + next_output_len = requests[i + 1].expected_output_len + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: + # We can add more requests to the batch. + continue + + # Generate the sequences. + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), + do_sample=True, + num_return_sequences=n, + temperature=1.0, + top_p=1.0, + use_cache=True, + max_new_tokens=max_output_len, + ) + if not disable_detokenize: + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + pbar.update(len(batch)) + + # Clear the batch. + batch = [] + max_prompt_len = 0 + max_output_len = 0 + end = time.perf_counter() + return end - start + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any]) -> None: + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "requests_per_second": [results["requests_per_second"]], + "tokens_per_second": [results["tokens_per_second"]], + }, + extra_info={ + k: results[k] + for k in ["elapsed_time", "num_requests", "total_num_tokens"] + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def get_requests(args, tokenizer): + # Common parameters for all dataset types. + common_kwargs = { + "dataset_path": args.dataset_path, + "random_seed": args.seed, + } + sample_kwargs = { + "tokenizer": tokenizer, + "lora_path": args.lora_path, + "max_loras": args.max_loras, + "num_requests": args.num_prompts, + "input_len": args.input_len, + "output_len": args.output_len, + } + + if args.dataset_path is None or args.dataset_name == "random": + sample_kwargs["range_ratio"] = args.random_range_ratio + sample_kwargs["prefix_len"] = args.prefix_len + dataset_cls = RandomDataset + elif args.dataset_name == "sharegpt": + dataset_cls = ShareGPTDataset + if args.backend == "vllm-chat": + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_name == "sonnet": + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + dataset_cls = SonnetDataset + sample_kwargs["prefix_len"] = args.prefix_len + sample_kwargs["return_prompt_formatted"] = True + elif args.dataset_name == "burstgpt": + dataset_cls = BurstGPTDataset + elif args.dataset_name == "hf": + if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = VisionArenaDataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = "train" + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = InstructCoderDataset + common_kwargs['dataset_split'] = "train" + elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = ConversationDataset + common_kwargs['dataset_subset'] = args.hf_subset + common_kwargs['dataset_split'] = args.hf_split + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + dataset_cls = AIMODataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = "train" + else: + raise ValueError(f"Unknown dataset name: {args.dataset_name}") + # Remove None values + sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} + return dataset_cls(**common_kwargs).sample(**sample_kwargs) + + +def validate_args(args): + """ + Validate command-line arguments. + """ + + # === Deprecation and Defaulting === + if args.dataset is not None: + warnings.warn( + "The '--dataset' argument will be deprecated in the next release. " + "Please use '--dataset-name' and '--dataset-path' instead.", + stacklevel=2) + args.dataset_path = args.dataset + + if not getattr(args, "tokenizer", None): + args.tokenizer = args.model + + # === Backend Validation === + valid_backends = {"vllm", "hf", "mii", "vllm-chat"} + if args.backend not in valid_backends: + raise ValueError(f"Unsupported backend: {args.backend}") + + # === Dataset Configuration === + if not args.dataset and not args.dataset_path: + print( + "When dataset path is not set, it will default to random dataset") + args.dataset_name = 'random' + if args.input_len is None: + raise ValueError("input_len must be provided for a random dataset") + + # === Dataset Name Specific Checks === + # --hf-subset and --hf-split: only used + # when dataset_name is 'hf' + if args.dataset_name != "hf" and ( + getattr(args, "hf_subset", None) is not None + or getattr(args, "hf_split", None) is not None): + warnings.warn("--hf-subset and --hf-split will be ignored \ + since --dataset-name is not 'hf'.", + stacklevel=2) + elif args.dataset_name == "hf": + if args.dataset_path in ( + VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() + | ConversationDataset.SUPPORTED_DATASET_PATHS): + assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 + elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS + | AIMODataset.SUPPORTED_DATASET_PATHS): + assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 + else: + raise ValueError( + f"{args.dataset_path} is not supported by hf dataset.") + + # --random-range-ratio: only used when dataset_name is 'random' + if args.dataset_name != 'random' and args.random_range_ratio is not None: + warnings.warn("--random-range-ratio will be ignored since \ + --dataset-name is not 'random'.", + stacklevel=2) + + # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not + # set. + if args.dataset_name not in {"random", "sonnet", None + } and args.prefix_len is not None: + warnings.warn("--prefix-len will be ignored since --dataset-name\ + is not 'random', 'sonnet', or not set.", + stacklevel=2) + + # === LoRA Settings === + if getattr(args, "enable_lora", False) and args.backend != "vllm": + raise ValueError( + "LoRA benchmarking is only supported for vLLM backend") + if getattr(args, "enable_lora", False) and args.lora_path is None: + raise ValueError("LoRA path must be provided when enable_lora is True") + + # === Backend-specific Validations === + if args.backend == "hf" and args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend") + if args.backend != "hf" and args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + + if args.backend in {"hf", "mii"} and getattr(args, "quantization", + None) is not None: + raise ValueError("Quantization is only for vLLM backend.") + + if args.backend == "mii" and args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.backend == "mii" and args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.backend == "mii" and args.tokenizer != args.model: + raise ValueError( + "Tokenizer must be the same as the model for MII backend.") + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii", "vllm-chat"], + default="vllm") + parser.add_argument( + "--dataset-name", + type=str, + choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], + help="Name of the dataset to benchmark on.", + default="sharegpt") + parser.add_argument( + "--dataset", + type=str, + default=None, + help="Path to the ShareGPT dataset, will be deprecated in\ + the next release. The dataset is expected to " + "be a json in form of list[dict[..., conversations: " + "list[dict[..., value: ]]]]") + parser.add_argument("--dataset-path", + type=str, + default=None, + help="Path to the dataset") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument("--async-engine", + action='store_true', + default=False, + help="Use vLLM async engine rather than LLM class.") + parser.add_argument("--disable-frontend-multiprocessing", + action='store_true', + default=False, + help="Disable decoupled async engine frontend.") + parser.add_argument( + "--disable-detokenize", + action="store_true", + help=("Do not detokenize the response (i.e. do not include " + "detokenization time in the measurement)")) + # LoRA + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to the lora adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.") + parser.add_argument( + "--prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before the random " + "context in a request (default: 0).", + ) + # random dataset + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range ratio for sampling input/output length, " + "used only for RandomDataset. Must be in the range [0, 1) to define " + "a symmetric sampling range " + "[length * (1 - range_ratio), length * (1 + range_ratio)].", + ) + + # hf dtaset + parser.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + parser.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + + parser = AsyncEngineArgs.add_cli_args(parser) + + +def main(args: argparse.Namespace): + if args.tokenizer is None: + args.tokenizer = args.model + validate_args(args) + if args.seed is None: + args.seed = 0 + random.seed(args.seed) + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + requests = get_requests(args, tokenizer) + is_multi_modal = any(request.multi_modal_data is not None + for request in requests) + request_outputs: Optional[list[RequestOutput]] = None + if args.backend == "vllm": + if args.async_engine: + elapsed_time = uvloop.run( + run_vllm_async( + requests, + args.n, + AsyncEngineArgs.from_cli_args(args), + args.disable_frontend_multiprocessing, + args.disable_detokenize, + )) + else: + elapsed_time, request_outputs = run_vllm( + requests, args.n, EngineArgs.from_cli_args(args), + args.disable_detokenize) + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.hf_max_batch_size, args.trust_remote_code, + args.disable_detokenize) + elif args.backend == "vllm-chat": + elapsed_time, request_outputs = run_vllm_chat( + requests, args.n, EngineArgs.from_cli_args(args), + args.disable_detokenize) + else: + raise ValueError(f"Unknown backend: {args.backend}") + + if request_outputs: + # Note: with the vllm and vllm-chat backends, + # we have request_outputs, which we use to count tokens. + total_prompt_tokens = 0 + total_output_tokens = 0 + for ro in request_outputs: + if not isinstance(ro, RequestOutput): + continue + total_prompt_tokens += len( + ro.prompt_token_ids) if ro.prompt_token_ids else 0 + total_output_tokens += sum( + len(o.token_ids) for o in ro.outputs if o) + total_num_tokens = total_prompt_tokens + total_output_tokens + else: + total_num_tokens = sum(r.prompt_len + r.expected_output_len + for r in requests) + total_output_tokens = sum(r.expected_output_len for r in requests) + total_prompt_tokens = total_num_tokens - total_output_tokens + + if is_multi_modal and args.backend != "vllm-chat": + print("\033[91mWARNING\033[0m: Multi-modal request with " + f"{args.backend} backend detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details.") + # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. + # vllm-chat backend counts the image tokens now + + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + print(f"Total num prompt tokens: {total_prompt_tokens}") + print(f"Total num output tokens: {total_output_tokens}") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) diff --git a/benchmarks/utils.py b/benchmarks/utils.py new file mode 100644 index 0000000..f0bb993 --- /dev/null +++ b/benchmarks/utils.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import json +import math +import os +from typing import Any + + +def convert_to_pytorch_benchmark_format(args: argparse.Namespace, + metrics: dict[str, list], + extra_info: dict[str, Any]) -> list: + """ + Save the benchmark results in the format used by PyTorch OSS benchmark with + on metric per record + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + """ + records = [] + if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False): + return records + + for name, benchmark_values in metrics.items(): + record = { + "benchmark": { + "name": "vLLM benchmark", + "extra_info": { + "args": vars(args), + }, + }, + "model": { + "name": args.model, + }, + "metric": { + "name": name, + "benchmark_values": benchmark_values, + "extra_info": extra_info, + }, + } + + tp = record["benchmark"]["extra_info"]["args"].get( + "tensor_parallel_size") + # Save tensor_parallel_size parameter if it's part of the metadata + if not tp and "tensor_parallel_size" in extra_info: + record["benchmark"]["extra_info"]["args"][ + "tensor_parallel_size"] = extra_info["tensor_parallel_size"] + + records.append(record) + + return records + + +class InfEncoder(json.JSONEncoder): + + def clear_inf(self, o: Any): + if isinstance(o, dict): + return {k: self.clear_inf(v) for k, v in o.items()} + elif isinstance(o, list): + return [self.clear_inf(v) for v in o] + elif isinstance(o, float) and math.isinf(o): + return "inf" + return o + + def iterencode(self, o: Any, *args, **kwargs) -> Any: + return super().iterencode(self.clear_inf(o), *args, **kwargs) + + +def write_to_json(filename: str, records: list) -> None: + with open(filename, "w") as f: + json.dump(records, f, cls=InfEncoder) diff --git a/collect_env.py b/collect_env.py new file mode 100644 index 0000000..64172a9 --- /dev/null +++ b/collect_env.py @@ -0,0 +1,820 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# ruff: noqa +# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py + +import datetime +import locale +import os +import subprocess +import sys +# Unlike the rest of the PyTorch this file must be python2 compliant. +# This script outputs relevant system environment info +# Run it with `python collect_env.py` or `python -m torch.utils.collect_env` +from collections import namedtuple + +import regex as re + +from vllm.envs import environment_variables + +try: + import torch + TORCH_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TORCH_AVAILABLE = False + +# System Environment Information +SystemEnv = namedtuple( + 'SystemEnv', + [ + 'torch_version', + 'is_debug_build', + 'cuda_compiled_version', + 'gcc_version', + 'clang_version', + 'cmake_version', + 'os', + 'libc_version', + 'python_version', + 'python_platform', + 'is_cuda_available', + 'cuda_runtime_version', + 'cuda_module_loading', + 'nvidia_driver_version', + 'nvidia_gpu_models', + 'cudnn_version', + 'pip_version', # 'pip' or 'pip3' + 'pip_packages', + 'conda_packages', + 'hip_compiled_version', + 'hip_runtime_version', + 'miopen_runtime_version', + 'caching_allocator_config', + 'is_xnnpack_available', + 'cpu_info', + 'rocm_version', # vllm specific field + 'neuron_sdk_version', # vllm specific field + 'vllm_version', # vllm specific field + 'vllm_build_flags', # vllm specific field + 'gpu_topo', # vllm specific field + 'env_vars', + ]) + +DEFAULT_CONDA_PATTERNS = { + "torch", + "numpy", + "cudatoolkit", + "soumith", + "mkl", + "magma", + "triton", + "optree", + "nccl", + "transformers", + "zmq", + "nvidia", + "pynvml", +} + +DEFAULT_PIP_PATTERNS = { + "torch", + "numpy", + "mypy", + "flake8", + "triton", + "optree", + "onnx", + "nccl", + "transformers", + "zmq", + "nvidia", + "pynvml", +} + + +def run(command): + """Return (return-code, stdout, stderr).""" + shell = True if type(command) is str else False + p = subprocess.Popen(command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell) + raw_output, raw_err = p.communicate() + rc = p.returncode + if get_platform() == 'win32': + enc = 'oem' + else: + enc = locale.getpreferredencoding() + output = raw_output.decode(enc) + if command == 'nvidia-smi topo -m': + # don't remove the leading whitespace of `nvidia-smi topo -m` + # because they are meaningful + output = output.rstrip() + else: + output = output.strip() + err = raw_err.decode(enc) + return rc, output, err.strip() + + +def run_and_read_all(run_lambda, command): + """Run command using run_lambda; reads and returns entire output if rc is 0.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out + + +def run_and_parse_first_match(run_lambda, command, regex): + """Run command using run_lambda, returns the first regex match if it exists.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + match = re.search(regex, out) + if match is None: + return None + return match.group(1) + + +def run_and_return_first_line(run_lambda, command): + """Run command using run_lambda and returns first line if output is not empty.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out.split('\n')[0] + + +def get_conda_packages(run_lambda, patterns=None): + if patterns is None: + patterns = DEFAULT_CONDA_PATTERNS + conda = os.environ.get('CONDA_EXE', 'conda') + out = run_and_read_all(run_lambda, "{} list".format(conda)) + if out is None: + return out + + return "\n".join(line for line in out.splitlines() + if not line.startswith("#") and any(name in line + for name in patterns)) + + +def get_gcc_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + + +def get_clang_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'clang --version', + r'clang version (.*)') + + +def get_cmake_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'cmake --version', + r'cmake (.*)') + + +def get_nvidia_driver_version(run_lambda): + if get_platform() == 'darwin': + cmd = 'kextstat | grep -i cuda' + return run_and_parse_first_match(run_lambda, cmd, + r'com[.]nvidia[.]CUDA [(](.*?)[)]') + smi = get_nvidia_smi() + return run_and_parse_first_match(run_lambda, smi, + r'Driver Version: (.*?) ') + + +def get_gpu_info(run_lambda): + if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr( + torch.version, 'hip') and torch.version.hip is not None): + if TORCH_AVAILABLE and torch.cuda.is_available(): + if torch.version.hip is not None: + prop = torch.cuda.get_device_properties(0) + if hasattr(prop, "gcnArchName"): + gcnArch = " ({})".format(prop.gcnArchName) + else: + gcnArch = "NoGCNArchNameOnOldPyTorch" + else: + gcnArch = "" + return torch.cuda.get_device_name(None) + gcnArch + return None + smi = get_nvidia_smi() + uuid_regex = re.compile(r' \(UUID: .+?\)') + rc, out, _ = run_lambda(smi + ' -L') + if rc != 0: + return None + # Anonymize GPUs by removing their UUID + return re.sub(uuid_regex, '', out) + + +def get_running_cuda_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'nvcc --version', + r'release .+ V(.*)') + + +def get_cudnn_version(run_lambda): + """Return a list of libcudnn.so; it's hard to tell which one is being used.""" + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") + where_cmd = os.path.join(system_root, 'System32', 'where') + cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) + elif get_platform() == 'darwin': + # CUDA libraries and drivers can be found in /usr/local/cuda/. See + # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install + # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac + # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. + cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + else: + cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' + rc, out, _ = run_lambda(cudnn_cmd) + # find will return 1 if there are permission errors or if not found + if len(out) == 0 or (rc != 1 and rc != 0): + l = os.environ.get('CUDNN_LIBRARY') + if l is not None and os.path.isfile(l): + return os.path.realpath(l) + return None + files_set = set() + for fn in out.split('\n'): + fn = os.path.realpath(fn) # eliminate symbolic links + if os.path.isfile(fn): + files_set.add(fn) + if not files_set: + return None + # Alphabetize the result because the order is non-deterministic otherwise + files = sorted(files_set) + if len(files) == 1: + return files[0] + result = '\n'.join(files) + return 'Probably one of the following:\n{}'.format(result) + + +def get_nvidia_smi(): + # Note: nvidia-smi is currently available only on Windows and Linux + smi = 'nvidia-smi' + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + program_files_root = os.environ.get('PROGRAMFILES', + 'C:\\Program Files') + legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', + 'NVSMI', smi) + new_path = os.path.join(system_root, 'System32', smi) + smis = [new_path, legacy_path] + for candidate_smi in smis: + if os.path.exists(candidate_smi): + smi = '"{}"'.format(candidate_smi) + break + return smi + + +def get_rocm_version(run_lambda): + """Returns the ROCm version if available, otherwise 'N/A'.""" + return run_and_parse_first_match(run_lambda, 'hipcc --version', + r'HIP version: (\S+)') + + +def get_neuron_sdk_version(run_lambda): + # Adapted from your install script + try: + result = run_lambda(["neuron-ls"]) + return result if result[0] == 0 else 'N/A' + except Exception: + return 'N/A' + + +def get_vllm_version(): + from vllm import __version__, __version_tuple__ + + if __version__ == "dev": + return "N/A (dev)" + version_str = __version_tuple__[-1] + if isinstance(version_str, str) and version_str.startswith('g'): + # it's a dev build + if '.' in version_str: + # it's a dev build containing local changes + git_sha = version_str.split('.')[0][1:] + date = version_str.split('.')[-1][1:] + return f"{__version__} (git sha: {git_sha}, date: {date})" + else: + # it's a dev build without local changes + git_sha = version_str[1:] # type: ignore + return f"{__version__} (git sha: {git_sha})" + return __version__ + + +def summarize_vllm_build_flags(): + # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. + return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format( + os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'), + 'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled', + 'Enabled' if os.environ.get('NEURON_CORES') else 'Disabled', + ) + + +def get_gpu_topo(run_lambda): + output = None + + if get_platform() == 'linux': + output = run_and_read_all(run_lambda, 'nvidia-smi topo -m') + if output is None: + output = run_and_read_all(run_lambda, 'rocm-smi --showtopo') + + return output + + +# example outputs of CPU infos +# * linux +# Architecture: x86_64 +# CPU op-mode(s): 32-bit, 64-bit +# Address sizes: 46 bits physical, 48 bits virtual +# Byte Order: Little Endian +# CPU(s): 128 +# On-line CPU(s) list: 0-127 +# Vendor ID: GenuineIntel +# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# CPU family: 6 +# Model: 106 +# Thread(s) per core: 2 +# Core(s) per socket: 32 +# Socket(s): 2 +# Stepping: 6 +# BogoMIPS: 5799.78 +# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr +# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl +# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 +# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand +# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced +# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap +# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 +# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq +# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities +# Virtualization features: +# Hypervisor vendor: KVM +# Virtualization type: full +# Caches (sum of all): +# L1d: 3 MiB (64 instances) +# L1i: 2 MiB (64 instances) +# L2: 80 MiB (64 instances) +# L3: 108 MiB (2 instances) +# NUMA: +# NUMA node(s): 2 +# NUMA node0 CPU(s): 0-31,64-95 +# NUMA node1 CPU(s): 32-63,96-127 +# Vulnerabilities: +# Itlb multihit: Not affected +# L1tf: Not affected +# Mds: Not affected +# Meltdown: Not affected +# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown +# Retbleed: Not affected +# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp +# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization +# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence +# Srbds: Not affected +# Tsx async abort: Not affected +# * win32 +# Architecture=9 +# CurrentClockSpeed=2900 +# DeviceID=CPU0 +# Family=179 +# L2CacheSize=40960 +# L2CacheSpeed= +# Manufacturer=GenuineIntel +# MaxClockSpeed=2900 +# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# ProcessorType=3 +# Revision=27142 +# +# Architecture=9 +# CurrentClockSpeed=2900 +# DeviceID=CPU1 +# Family=179 +# L2CacheSize=40960 +# L2CacheSpeed= +# Manufacturer=GenuineIntel +# MaxClockSpeed=2900 +# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# ProcessorType=3 +# Revision=27142 + + +def get_cpu_info(run_lambda): + rc, out, err = 0, '', '' + if get_platform() == 'linux': + rc, out, err = run_lambda('lscpu') + elif get_platform() == 'win32': + rc, out, err = run_lambda( + 'wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ + CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE' + ) + elif get_platform() == 'darwin': + rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") + cpu_info = 'None' + if rc == 0: + cpu_info = out + else: + cpu_info = err + return cpu_info + + +def get_platform(): + if sys.platform.startswith('linux'): + return 'linux' + elif sys.platform.startswith('win32'): + return 'win32' + elif sys.platform.startswith('cygwin'): + return 'cygwin' + elif sys.platform.startswith('darwin'): + return 'darwin' + else: + return sys.platform + + +def get_mac_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', + r'(.*)') + + +def get_windows_version(run_lambda): + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic') + findstr_cmd = os.path.join(system_root, 'System32', 'findstr') + return run_and_read_all( + run_lambda, + '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) + + +def get_lsb_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'lsb_release -a', + r'Description:\t(.*)') + + +def check_release_file(run_lambda): + return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', + r'PRETTY_NAME="(.*)"') + + +def get_os(run_lambda): + from platform import machine + platform = get_platform() + + if platform == 'win32' or platform == 'cygwin': + return get_windows_version(run_lambda) + + if platform == 'darwin': + version = get_mac_version(run_lambda) + if version is None: + return None + return 'macOS {} ({})'.format(version, machine()) + + if platform == 'linux': + # Ubuntu/Debian based + desc = get_lsb_version(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) + + # Try reading /etc/*-release + desc = check_release_file(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) + + return '{} ({})'.format(platform, machine()) + + # Unknown platform + return platform + + +def get_python_platform(): + import platform + return platform.platform() + + +def get_libc_version(): + import platform + if get_platform() != 'linux': + return 'N/A' + return '-'.join(platform.libc_ver()) + + +def get_pip_packages(run_lambda, patterns=None): + """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages.""" + if patterns is None: + patterns = DEFAULT_PIP_PATTERNS + + def run_with_pip(): + try: + import importlib.util + pip_spec = importlib.util.find_spec('pip') + pip_available = pip_spec is not None + except ImportError: + pip_available = False + + if pip_available: + cmd = [sys.executable, '-mpip', 'list', '--format=freeze'] + elif os.environ.get("UV") is not None: + print("uv is set") + cmd = ["uv", "pip", "list", "--format=freeze"] + else: + raise RuntimeError( + "Could not collect pip list output (pip or uv module not available)" + ) + + out = run_and_read_all(run_lambda, cmd) + return "\n".join(line for line in out.splitlines() + if any(name in line for name in patterns)) + + pip_version = 'pip3' if sys.version[0] == '3' else 'pip' + out = run_with_pip() + return pip_version, out + + +def get_cachingallocator_config(): + ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') + return ca_config + + +def get_cuda_module_loading_config(): + if TORCH_AVAILABLE and torch.cuda.is_available(): + torch.cuda.init() + config = os.environ.get('CUDA_MODULE_LOADING', '') + return config + else: + return "N/A" + + +def is_xnnpack_available(): + if TORCH_AVAILABLE: + import torch.backends.xnnpack + return str( + torch.backends.xnnpack.enabled) # type: ignore[attr-defined] + else: + return "N/A" + + +def get_env_vars(): + env_vars = '' + secret_terms = ('secret', 'token', 'api', 'access', 'password') + report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN", + "OMP_", "MKL_", "NVIDIA") + for k, v in os.environ.items(): + if any(term in k.lower() for term in secret_terms): + continue + if k in environment_variables: + env_vars = env_vars + "{}={}".format(k, v) + "\n" + if k.startswith(report_prefix): + env_vars = env_vars + "{}={}".format(k, v) + "\n" + + return env_vars + + +def get_env_info(): + run_lambda = run + pip_version, pip_list_output = get_pip_packages(run_lambda) + + if TORCH_AVAILABLE: + version_str = torch.__version__ + debug_mode_str = str(torch.version.debug) + cuda_available_str = str(torch.cuda.is_available()) + cuda_version_str = torch.version.cuda + if not hasattr(torch.version, + 'hip') or torch.version.hip is None: # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + else: # HIP version + + def get_version_or_na(cfg, prefix): + _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] + return _lst[0] if _lst else 'N/A' + + cfg = torch._C._show_config().split('\n') + hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') + miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') + cuda_version_str = 'N/A' + hip_compiled_version = torch.version.hip + else: + version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + + sys_version = sys.version.replace("\n", " ") + + conda_packages = get_conda_packages(run_lambda) + + rocm_version = get_rocm_version(run_lambda) + neuron_sdk_version = get_neuron_sdk_version(run_lambda) + vllm_version = get_vllm_version() + vllm_build_flags = summarize_vllm_build_flags() + gpu_topo = get_gpu_topo(run_lambda) + + return SystemEnv( + torch_version=version_str, + is_debug_build=debug_mode_str, + python_version='{} ({}-bit runtime)'.format( + sys_version, + sys.maxsize.bit_length() + 1), + python_platform=get_python_platform(), + is_cuda_available=cuda_available_str, + cuda_compiled_version=cuda_version_str, + cuda_runtime_version=get_running_cuda_version(run_lambda), + cuda_module_loading=get_cuda_module_loading_config(), + nvidia_gpu_models=get_gpu_info(run_lambda), + nvidia_driver_version=get_nvidia_driver_version(run_lambda), + cudnn_version=get_cudnn_version(run_lambda), + hip_compiled_version=hip_compiled_version, + hip_runtime_version=hip_runtime_version, + miopen_runtime_version=miopen_runtime_version, + pip_version=pip_version, + pip_packages=pip_list_output, + conda_packages=conda_packages, + os=get_os(run_lambda), + libc_version=get_libc_version(), + gcc_version=get_gcc_version(run_lambda), + clang_version=get_clang_version(run_lambda), + cmake_version=get_cmake_version(run_lambda), + caching_allocator_config=get_cachingallocator_config(), + is_xnnpack_available=is_xnnpack_available(), + cpu_info=get_cpu_info(run_lambda), + rocm_version=rocm_version, + neuron_sdk_version=neuron_sdk_version, + vllm_version=vllm_version, + vllm_build_flags=vllm_build_flags, + gpu_topo=gpu_topo, + env_vars=get_env_vars(), + ) + + +env_info_fmt = """ +============================== + System Info +============================== +OS : {os} +GCC version : {gcc_version} +Clang version : {clang_version} +CMake version : {cmake_version} +Libc version : {libc_version} + +============================== + PyTorch Info +============================== +PyTorch version : {torch_version} +Is debug build : {is_debug_build} +CUDA used to build PyTorch : {cuda_compiled_version} +ROCM used to build PyTorch : {hip_compiled_version} + +============================== + Python Environment +============================== +Python version : {python_version} +Python platform : {python_platform} + +============================== + CUDA / GPU Info +============================== +Is CUDA available : {is_cuda_available} +CUDA runtime version : {cuda_runtime_version} +CUDA_MODULE_LOADING set to : {cuda_module_loading} +GPU models and configuration : {nvidia_gpu_models} +Nvidia driver version : {nvidia_driver_version} +cuDNN version : {cudnn_version} +HIP runtime version : {hip_runtime_version} +MIOpen runtime version : {miopen_runtime_version} +Is XNNPACK available : {is_xnnpack_available} + +============================== + CPU Info +============================== +{cpu_info} + +============================== +Versions of relevant libraries +============================== +{pip_packages} +{conda_packages} +""".strip() + +# both the above code and the following code use `strip()` to +# remove leading/trailing whitespaces, so we need to add a newline +# in between to separate the two sections +env_info_fmt += "\n\n" + +env_info_fmt += """ +============================== + vLLM Info +============================== +ROCM Version : {rocm_version} +Neuron SDK Version : {neuron_sdk_version} +vLLM Version : {vllm_version} +vLLM Build Flags: + {vllm_build_flags} +GPU Topology: + {gpu_topo} + +============================== + Environment Variables +============================== +{env_vars} +""".strip() + + +def pretty_str(envinfo): + + def replace_nones(dct, replacement='Could not collect'): + for key in dct.keys(): + if dct[key] is not None: + continue + dct[key] = replacement + return dct + + def replace_bools(dct, true='Yes', false='No'): + for key in dct.keys(): + if dct[key] is True: + dct[key] = true + elif dct[key] is False: + dct[key] = false + return dct + + def prepend(text, tag='[prepend]'): + lines = text.split('\n') + updated_lines = [tag + line for line in lines] + return '\n'.join(updated_lines) + + def replace_if_empty(text, replacement='No relevant packages'): + if text is not None and len(text) == 0: + return replacement + return text + + def maybe_start_on_next_line(string): + # If `string` is multiline, prepend a \n to it. + if string is not None and len(string.split('\n')) > 1: + return '\n{}\n'.format(string) + return string + + mutable_dict = envinfo._asdict() + + # If nvidia_gpu_models is multiline, start on the next line + mutable_dict['nvidia_gpu_models'] = \ + maybe_start_on_next_line(envinfo.nvidia_gpu_models) + + # If the machine doesn't have CUDA, report some fields as 'No CUDA' + dynamic_cuda_fields = [ + 'cuda_runtime_version', + 'nvidia_gpu_models', + 'nvidia_driver_version', + ] + all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] + all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None + for field in dynamic_cuda_fields) + if TORCH_AVAILABLE and not torch.cuda.is_available( + ) and all_dynamic_cuda_fields_missing: + for field in all_cuda_fields: + mutable_dict[field] = 'No CUDA' + if envinfo.cuda_compiled_version is None: + mutable_dict['cuda_compiled_version'] = 'None' + + # Replace True with Yes, False with No + mutable_dict = replace_bools(mutable_dict) + + # Replace all None objects with 'Could not collect' + mutable_dict = replace_nones(mutable_dict) + + # If either of these are '', replace with 'No relevant packages' + mutable_dict['pip_packages'] = replace_if_empty( + mutable_dict['pip_packages']) + mutable_dict['conda_packages'] = replace_if_empty( + mutable_dict['conda_packages']) + + # Tag conda and pip packages with a prefix + # If they were previously None, they'll show up as ie '[conda] Could not collect' + if mutable_dict['pip_packages']: + mutable_dict['pip_packages'] = prepend( + mutable_dict['pip_packages'], '[{}] '.format(envinfo.pip_version)) + if mutable_dict['conda_packages']: + mutable_dict['conda_packages'] = prepend( + mutable_dict['conda_packages'], '[conda] ') + mutable_dict['cpu_info'] = envinfo.cpu_info + return env_info_fmt.format(**mutable_dict) + + +def get_pretty_env_info(): + return pretty_str(get_env_info()) + + +def main(): + print("Collecting environment information...") + output = get_pretty_env_info() + print(output) + + if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr( + torch.utils, '_crash_handler'): + minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR + if sys.platform == "linux" and os.path.exists(minidump_dir): + dumps = [ + os.path.join(minidump_dir, dump) + for dump in os.listdir(minidump_dir) + ] + latest = max(dumps, key=os.path.getctime) + ctime = os.path.getctime(latest) + creation_time = datetime.datetime.fromtimestamp(ctime).strftime( + '%Y-%m-%d %H:%M:%S') + msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ + "if this is related to your bug please include it when you file a report ***" + print(msg, file=sys.stderr) + + +if __name__ == '__main__': + main() diff --git a/compilation/__init__.py b/compilation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/compilation/activation_quant_fusion.py b/compilation/activation_quant_fusion.py new file mode 100644 index 0000000..ce4e50a --- /dev/null +++ b/compilation/activation_quant_fusion.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only, + register_replacement) + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +def silu_mul_pattern_static(result: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at1[1], + scale=scale) + return at2[1] + + +def silu_mul_replacement_static(result: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, + result=result, + input=input, + scale=scale) + return at[1] + + +def empty_bf16(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + + +def empty_fp8(*args, **kwargs): + fp8 = current_platform.fp8_dtype() + return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") + + +def empty_fp32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + + +class ActivationQuantFusionPass(VllmInductorPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ + + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="activation_quant_fusion_pass") + + inputs = [ + empty_fp8(5, 4), # Quant output + empty_bf16(5, 4), # Silu_and_mul output + empty_bf16(5, 4), # Input + empty_fp32(1, 1) # Scale + ] + register_replacement(silu_mul_pattern_static, + silu_mul_replacement_static, inputs, fwd_only, + self.patterns) + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_act_quant_fusion") + + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns in ActivationQuantFusionPass", + count) + + self.dump_graph(graph, "after_act_quant_fusion") + self.end_and_log() diff --git a/compilation/backends.py b/compilation/backends.py new file mode 100644 index 0000000..5af3b7e --- /dev/null +++ b/compilation/backends.py @@ -0,0 +1,563 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast +import dataclasses +import os +import pprint +import time +from collections.abc import Sequence +from typing import Any, Callable, Optional + +import torch +import torch.fx as fx +from torch._dispatch.python import enable_python_dispatcher + +import vllm.envs as envs +from vllm.config import CompilationConfig, VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname + +from .compiler_interface import (CompilerInterface, EagerAdaptor, + InductorAdaptor, InductorStandaloneAdaptor) +from .counter import compilation_counter +from .inductor_pass import InductorPass +from .pass_manager import PostGradPassManager + +logger = init_logger(__name__) + + +def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: + if compilation_config.use_inductor: + if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer( + "2.8.0"): + logger.debug("Using InductorStandaloneAdaptor") + return InductorStandaloneAdaptor() + else: + logger.debug("Using InductorAdaptor") + return InductorAdaptor() + else: + logger.debug("Using EagerAdaptor") + return EagerAdaptor() + + +class CompilerManager: + """ + A manager to manage the compilation process, including + caching the compiled graph, loading the compiled graph, + and compiling the graph. + + The cache is a dict mapping + `(runtime_shape, graph_index, backend_name)` + to `any_data` returned from the compiler. + + When serializing the cache, we save it to a Python file + for readability. We don't use json here because json doesn't + support int as key. + """ + + def __init__(self, compilation_config: CompilationConfig): + self.cache: dict[tuple[Optional[int], int, str], Any] = dict() + self.is_cache_updated = False + self.compilation_config = compilation_config + self.compiler = make_compiler(compilation_config) + + def compute_hash(self, vllm_config: VllmConfig) -> str: + return self.compiler.compute_hash(vllm_config) + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + self.disable_cache = disable_cache + self.cache_dir = cache_dir + self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") + + if not disable_cache and os.path.exists(self.cache_file_path): + # load the cache from the file + with open(self.cache_file_path) as f: + # we use ast.literal_eval to parse the data + # because it is a safe way to parse Python literals. + # do not use eval(), it is unsafe. + self.cache = ast.literal_eval(f.read()) + + self.compiler.initialize_cache(cache_dir=cache_dir, + disable_cache=disable_cache) + + def save_to_file(self): + if self.disable_cache or not self.is_cache_updated: + return + printer = pprint.PrettyPrinter(indent=4) + data = printer.pformat(self.cache) + with open(self.cache_file_path, "w") as f: + f.write(data) + + def load(self, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Optional[Callable]: + if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + return None + handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + compiled_graph = self.compiler.load(handle, graph, example_inputs, + graph_index, runtime_shape) + logger.debug( + "Directly load the %s-th graph for shape %s from %s via " + "handle %s", graph_index, str(runtime_shape), self.compiler.name, + handle) + return compiled_graph + + def compile(self, + graph: fx.GraphModule, + example_inputs, + additional_inductor_config, + compilation_config: CompilationConfig, + graph_index: int = 0, + num_graphs: int = 1, + runtime_shape: Optional[int] = None) -> Any: + if graph_index == 0: + # before compiling the first graph, record the start time + global compilation_start_time + compilation_start_time = time.time() + + compilation_counter.num_backend_compilations += 1 + + compiled_graph = None + + # try to load from the cache + compiled_graph = self.load(graph, example_inputs, graph_index, + runtime_shape) + if compiled_graph is not None: + if graph_index == num_graphs - 1: + # after loading the last graph for this shape, record the time. + # there can be multiple graphs due to piecewise compilation. + now = time.time() + elapsed = now - compilation_start_time + logger.info( + "Directly load the compiled graph(s) for shape %s " + "from the cache, took %.3f s", str(runtime_shape), elapsed) + return compiled_graph + + # no compiler cached the graph, or the cache is disabled, + # we need to compile it + if isinstance(self.compiler, InductorAdaptor): + # Let compile_fx generate a key for us + maybe_key = None + else: + maybe_key = \ + f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" + compiled_graph, handle = self.compiler.compile( + graph, example_inputs, additional_inductor_config, runtime_shape, + maybe_key) + + assert compiled_graph is not None, "Failed to compile the graph" + + # store the artifact in the cache + if handle is not None: + self.cache[(runtime_shape, graph_index, + self.compiler.name)] = handle + self.is_cache_updated = True + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Cache the graph of shape %s for later use", + str(runtime_shape)) + logger.debug( + "store the %s-th graph for shape %s from %s via handle %s", + graph_index, str(runtime_shape), self.compiler.name, handle) + + # after compiling the last graph, record the end time + if graph_index == num_graphs - 1: + now = time.time() + elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed + if runtime_shape is None: + logger.info("Compiling a graph for general shape takes %.2f s", + elapsed) + else: + logger.info("Compiling a graph for shape %s takes %.2f s", + runtime_shape, elapsed) + + return compiled_graph + + +@dataclasses.dataclass +class SplitItem: + submod_name: str + graph_id: int + is_splitting_graph: bool + graph: fx.GraphModule + + +def split_graph(graph: fx.GraphModule, + ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]: + # split graph by ops + subgraph_id = 0 + node_to_subgraph_id = {} + split_op_graphs = [] + for node in graph.graph.nodes: + if node.op in ("output", "placeholder"): + continue + if node.op == 'call_function' and str(node.target) in ops: + subgraph_id += 1 + node_to_subgraph_id[node] = subgraph_id + split_op_graphs.append(subgraph_id) + subgraph_id += 1 + else: + node_to_subgraph_id[node] = subgraph_id + + # `keep_original_order` is important! + # otherwise pytorch might reorder the nodes and + # the semantics of the graph will change when we + # have mutations in the graph + split_gm = torch.fx.passes.split_module.split_module( + graph, + None, + lambda node: node_to_subgraph_id[node], + keep_original_order=True) + + outputs = [] + + names = [name for (name, module) in split_gm.named_modules()] + + for name in names: + if "." in name or name == "": + # recursive child module or the root module + continue + + module = getattr(split_gm, name) + + graph_id = int(name.replace("submod_", "")) + outputs.append( + SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) + + # sort by intetger graph_id, rather than string name + outputs.sort(key=lambda x: x.graph_id) + + return split_gm, outputs + + +# we share the global graph pool among all the backends +global_graph_pool = None + +compilation_start_time = 0.0 + + +class PiecewiseCompileInterpreter(torch.fx.Interpreter): + """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. + It runs the given graph with fake inputs, and compile some + submodules specified by `compile_submod_names` with the given + compilation configs. + + NOTE: the order in `compile_submod_names` matters, because + it will be used to determine the order of the compiled piecewise + graphs. The first graph will handle logging, and the last graph + has some special cudagraph output handling. + """ + + def __init__(self, module: torch.fx.GraphModule, + compile_submod_names: list[str], vllm_config: VllmConfig, + graph_pool, vllm_backend: "VllmBackend"): + super().__init__(module) + from torch._guards import detect_fake_mode + self.fake_mode = detect_fake_mode() + self.compile_submod_names = compile_submod_names + self.compilation_config = vllm_config.compilation_config + self.graph_pool = graph_pool + self.vllm_config = vllm_config + self.vllm_backend = vllm_backend + # When True, it annoyingly dumps the torch.fx.Graph on errors. + self.extra_traceback = False + + def run(self, *args): + fake_args = [ + self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] + with self.fake_mode, enable_python_dispatcher(): + return super().run(*fake_args) + + def call_module(self, target: torch.fx.node.Target, + args: tuple[torch.fx.node.Argument, + ...], kwargs: dict[str, Any]) -> Any: + assert isinstance(target, str) + output = super().call_module(target, args, kwargs) + + if target in self.compile_submod_names: + index = self.compile_submod_names.index(target) + submod = self.fetch_attr(target) + sym_shape_indices = [ + i for i, x in enumerate(args) if isinstance(x, torch.SymInt) + ] + global compilation_start_time + compiled_graph_for_general_shape = self.vllm_backend.\ + compiler_manager.compile( + submod, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=index, + num_graphs=len(self.compile_submod_names), + runtime_shape=None) + + piecewise_backend = resolve_obj_by_qualname( + current_platform.get_piecewise_backend_cls()) + self.module.__dict__[target] = piecewise_backend( + submod, self.vllm_config, self.graph_pool, index, + len(self.compile_submod_names), sym_shape_indices, + compiled_graph_for_general_shape, self.vllm_backend) + + compilation_counter.num_piecewise_capturable_graphs_seen += 1 + + return output + + +class VllmBackend: + """The compilation backend for `torch.compile` with vLLM. + It is used for compilation level of `CompilationLevel.PIECEWISE`, + where we customize the compilation. + + The major work of this backend is to split the graph into + piecewise graphs, and pass them to the piecewise backend. + + This backend also adds the PostGradPassManager to Inductor config, + which handles the post-grad passes. + """ + + vllm_config: VllmConfig + compilation_config: CompilationConfig + graph_pool: Any + _called: bool = False + # the graph we compiled + graph: fx.GraphModule + # the stiching graph module for all the piecewise graphs + split_gm: fx.GraphModule + piecewise_graphs: list[SplitItem] + returned_callable: Callable + # Inductor passes to run on the graph pre-defunctionalization + post_grad_passes: Sequence[Callable] + sym_tensor_indices: list[int] + input_buffers: list[torch.Tensor] + compiler_manager: CompilerManager + + def __init__( + self, + vllm_config: VllmConfig, + ): + global global_graph_pool + if global_graph_pool is None: + global_graph_pool = current_platform.graph_pool_handle() + + # TODO: in the future, if we want to use multiple + # streams, it might not be safe to share a global pool. + # only investigate this when we use multiple streams + self.graph_pool = global_graph_pool + + # Passes to run on the graph post-grad. + self.post_grad_pass_manager = PostGradPassManager() + + self.sym_tensor_indices = [] + self.input_buffers = [] + + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + + self.compiler_manager: CompilerManager = CompilerManager( + self.compilation_config) + + # `torch.compile` is JIT compiled, so we don't need to + # do anything here + + def configure_post_pass(self): + config = self.compilation_config + self.post_grad_pass_manager.configure(self.vllm_config) + + # Post-grad custom passes are run using the post_grad_custom_post_pass + # hook. If a pass for that hook exists, add it to the pass manager. + inductor_config = config.inductor_compile_config + PASS_KEY = "post_grad_custom_post_pass" + if PASS_KEY in inductor_config: + # Config should automatically wrap all inductor passes + if isinstance(inductor_config[PASS_KEY], PostGradPassManager): + assert (inductor_config[PASS_KEY].uuid() == + self.post_grad_pass_manager.uuid()) + else: + assert isinstance(inductor_config[PASS_KEY], InductorPass) + self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) + inductor_config[PASS_KEY] = self.post_grad_pass_manager + + def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + + vllm_config = self.vllm_config + if not self.compilation_config.cache_dir: + # no provided cache dir, generate one based on the known factors + # that affects the compilation. if none of the factors change, + # the cache dir will be the same so that we can reuse the compiled + # graph. + + factors = [] + # 0. factors come from the env, for example, The values of + # VLLM_PP_LAYER_PARTITION will affects the computation graph. + env_hash = envs.compute_hash() + factors.append(env_hash) + + # 1. factors come from the vllm_config (it mainly summarizes how the + # model is created) + config_hash = vllm_config.compute_hash() + factors.append(config_hash) + + # 2. factors come from the code files that are traced by Dynamo ( + # it mainly summarizes how the model is used in forward pass) + forward_code_files = list( + sorted(self.compilation_config.traced_files)) + self.compilation_config.traced_files.clear() + logger.debug( + "Traced files (to be considered for compilation cache):\n%s", + "\n".join(forward_code_files)) + hash_content = [] + for filepath in forward_code_files: + hash_content.append(filepath) + if filepath == "": + # This means the function was dynamically generated, with + # e.g. exec(). We can't actually check these. + continue + with open(filepath) as f: + hash_content.append(f.read()) + import hashlib + code_hash = hashlib.md5("\n".join(hash_content).encode(), + usedforsecurity=False).hexdigest() + factors.append(code_hash) + + # 3. compiler hash + compiler_hash = self.compiler_manager.compute_hash(vllm_config) + factors.append(compiler_hash) + + # combine all factors to generate the cache dir + hash_key = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + + cache_dir = os.path.join( + envs.VLLM_CACHE_ROOT, + "torch_compile_cache", + hash_key, + ) + self.compilation_config.cache_dir = cache_dir + + if compilation_counter.num_graphs_seen > 0: + cache_dir = self.compilation_config.cache_dir + \ + f'-{compilation_counter.num_graphs_seen}' + else: + cache_dir = self.compilation_config.cache_dir + os.makedirs(cache_dir, exist_ok=True) + self.compilation_config.cache_dir = cache_dir + rank = vllm_config.parallel_config.rank + dp_rank = vllm_config.parallel_config.data_parallel_rank + local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") + os.makedirs(local_cache_dir, exist_ok=True) + self.compilation_config.local_cache_dir = local_cache_dir + + disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE + + if disable_cache: + logger.info("vLLM's torch.compile cache is disabled.") + else: + logger.info("Using cache directory: %s for vLLM's torch.compile", + local_cache_dir) + + self.compiler_manager.initialize_cache(local_cache_dir, disable_cache) + + # when dynamo calls the backend, it means the bytecode + # transform and analysis are done + compilation_counter.num_graphs_seen += 1 + from .monitor import torch_compile_start_time + dynamo_time = time.time() - torch_compile_start_time + logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) + self.compilation_config.compilation_time += dynamo_time + + # we control the compilation process, each instance can only be + # called once + assert not self._called, "VllmBackend can only be called once" + + self.graph = graph + self.configure_post_pass() + + self.split_gm, self.piecewise_graphs = split_graph( + graph, self.compilation_config.splitting_ops) + + from torch._dynamo.utils import lazy_format_graph_code + + # depyf will hook lazy_format_graph_code and dump the graph + # for debugging, no need to print the graph here + lazy_format_graph_code("before split", self.graph) + lazy_format_graph_code("after split", self.split_gm) + + compilation_counter.num_piecewise_graphs_seen += len( + self.piecewise_graphs) + submod_names_to_compile = [ + item.submod_name for item in self.piecewise_graphs + if not item.is_splitting_graph + ] + + # propagate the split graph to the piecewise backend, + # compile submodules with symbolic shapes + PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, + self.vllm_config, self.graph_pool, + self).run(*example_inputs) + + graph_path = os.path.join(local_cache_dir, "computation_graph.py") + if not os.path.exists(graph_path): + # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa + # use `print_readable` because it can include submodules + src = "from __future__ import annotations\nimport torch\n" + \ + self.split_gm.print_readable(print_output=False) + src = src.replace("", "GraphModule") + with open(graph_path, "w") as f: + f.write(src) + + logger.debug("Computation graph saved to %s", graph_path) + + self._called = True + + if not self.compilation_config.use_cudagraph or \ + not self.compilation_config.cudagraph_copy_inputs: + return self.split_gm + + # if we need to copy input buffers for cudagraph + from torch._guards import detect_fake_mode + fake_mode = detect_fake_mode() + fake_args = [ + fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in example_inputs + ] + + # index of tensors that have symbolic shapes (batch size) + # for weights and static buffers, they will have concrete shapes. + # symbolic shape only happens for input tensors. + from torch.fx.experimental.symbolic_shapes import is_symbolic + self.sym_tensor_indices = [ + i for i, x in enumerate(fake_args) + if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \ + any(is_symbolic(d) for d in x.size()) + ] + + # compiler managed cudagraph input buffers + # we assume the first run with symbolic shapes + # has the maximum size among all the tensors + self.input_buffers = [ + example_inputs[x].clone() for x in self.sym_tensor_indices + ] + + # this is the callable we return to Dynamo to run + def copy_and_call(*args): + list_args = list(args) + for i, index in enumerate(self.sym_tensor_indices): + runtime_tensor = list_args[index] + runtime_shape = runtime_tensor.shape[0] + static_tensor = self.input_buffers[i][:runtime_shape] + + # copy the tensor to the static buffer + static_tensor.copy_(runtime_tensor) + + # replace the tensor in the list_args to the static buffer + list_args[index] = static_tensor + return self.split_gm(*list_args) + + return copy_and_call diff --git a/compilation/base_piecewise_backend.py b/compilation/base_piecewise_backend.py new file mode 100644 index 0000000..4d7aeeb --- /dev/null +++ b/compilation/base_piecewise_backend.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Protocol + +import torch.fx as fx + +from vllm.compilation.backends import VllmBackend +from vllm.config import VllmConfig + + +class AbstractPiecewiseBackend(Protocol): + """ + PiecewiseBackend interface that allows platforms to extend + piecewise static graph. + """ + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend, **kwargs): + """ + Initializes the PiecewiseBackend class with compilation and + execution-related configurations. + + This class handles piecewise compilation, graph capturing, + and dispatching for specific input shapes. + + Args: + graph (fx.GraphModule): The graph represented in fx. + vllm_config (VllmConfig): Global configuration for vLLM. + graph_pool (Any): + Graph memory pool handle, e.g., + `torch.cuda.graph_pool_handle()`. + piecewise_compile_index (int): + Index of the current piecewise subgraph. + total_piecewise_compiles (int): + Total number of piecewise-compiled graphs. + sym_shape_indices (list[int]): + Indices of symbolic shape. + compiled_graph_for_general_shape (Callable): + Callable that executes the graph compiled for general shapes. + vllm_backend (VllmBackend): + Backend compiler that manages compilation and graph runtime + for vLLM. + + Keyword Args: + kwargs: Additional keyword arguments reserved for future + extensions or custom platforms. + """ + raise NotImplementedError + + def __call__(self, *args) -> Any: + """Executes the compiled graph for given input args. + + If this is the first invocation, executes the general compiled graph + and initiates the compilation process tracking. For subsequent calls, + dynamically dispatches execution to either a compiled graph or a static + graph based on the input shape. + + Args: + *args: Variable length input arguments to be passed into the + graph. The symbolic shape is expected to be in position + `sym_shape_indices[0]`. + + Returns: + Any: Output of the executed graph. This can be from the general + compiled graph, a specialized compiled version for the given shape, + or a replayed static graph. + """ + raise NotImplementedError diff --git a/compilation/collective_fusion.py b/compilation/collective_fusion.py new file mode 100644 index 0000000..f754fc2 --- /dev/null +++ b/compilation/collective_fusion.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch.distributed._symmetric_memory import enable_symm_mem_for_group + +from vllm.config import VllmConfig +from vllm.distributed import get_tp_group +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class BasePattern: + + def __init__(self, dtype: torch.dtype, device: str): + self.dtype = dtype + self.device = device + self.tp = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + +class GEMMReduceScatterPattern(BasePattern): + + def get_inputs(self): + mul = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + return [mul, mm_weight] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(mul: torch.Tensor, mm_weight: torch.Tensor): + mm = torch.ops.aten.mm.default(mul, mm_weight) + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + return reduce_scatter + + def replacement(mul: torch.Tensor, mm_weight: torch.Tensor): + gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter( + mul, + mm_weight, + "avg", + scatter_dim=0, + group_name=self.tp.device_group.group_name, + ) + + return gemm_rs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllGatherGEMMPattern(BasePattern): + + def get_inputs(self): + x = torch.empty([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + return [x, weight] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + x: torch.Tensor, + weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_gather = torch.ops.vllm.all_gather.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + + return torch.ops.aten.mm.default(all_gather, weight) + + def replacement( + x: torch.Tensor, + weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( + x, + [weight], + gather_dim=0, + group_name=self.tp.device_group.group_name, + ) + return mm_outputs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AsyncTPPass(VllmInductorPass): + + def __init__(self, config: VllmConfig): + super().__init__(config) + + # Enable symmetric memory for the TP process group + enable_symm_mem_for_group(get_tp_group().device_group.group_name) + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="async_tp_pass") + GEMMReduceScatterPattern(self.model_dtype, + self.device).register(self.patterns) + + AllGatherGEMMPattern(self.model_dtype, + self.device).register(self.patterns) + + def is_applicable_for_shape(self, shape: Optional[int]) -> bool: + # only do replace for specific shapes + tp_size = get_tensor_model_parallel_world_size() + return shape is not None and shape % tp_size == 0 + + def __call__(self, graph: fx.Graph): + self.begin() + self.dump_graph(graph, "before_async_tp_pass") + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_async_tp_pass") + self.end_and_log() diff --git a/compilation/compiler_interface.py b/compilation/compiler_interface.py new file mode 100644 index 0000000..36c810e --- /dev/null +++ b/compilation/compiler_interface.py @@ -0,0 +1,544 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import copy +import hashlib +import os +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch +import torch._inductor.compile_fx +import torch.fx as fx + +import vllm.envs as envs +from vllm.compilation.counter import compilation_counter +from vllm.config import VllmConfig +from vllm.utils import is_torch_equal_or_newer + +from .inductor_pass import pass_context + + +class CompilerInterface: + """ + The interface for a compiler that can be used by vLLM. + """ + # The name of the compiler, e.g. inductor. + # This is a class-level attribute. + name: str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + """ + when the vLLM process uses `cache_dir` as the cache directory, + the compiler should initialize itself with the cache directory, + e.g. by re-directing its own cache directory to a sub-directory. + """ + pass + + def compute_hash(self, vllm_config: VllmConfig) -> str: + """ + Gather all the relevant information from the vLLM config, + to compute a hash so that we can cache the compiled model. + + See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash] + to check what information + is already considered by default. This function should only + consider the information that is specific to the compiler. + """ + return "" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + """ + Compile the graph with the given example inputs and compiler config, + with a runtime shape. If the `runtime_shape` is None, it means + the `example_inputs` have a dynamic shape. Otherwise, the + `runtime_shape` specifies the shape of the inputs. Right now we only + support one variable shape for all inputs, which is the batchsize + (number of tokens) during inference. + + Dynamo will make sure `graph(*example_inputs)` is valid. + + The function should return a compiled callable function, as well as + a handle that can be used to directly load the compiled function. + + The handle should be a plain Python object, preferably a string or a + file path for readability. + + If the compiler doesn't support caching, it should return None for the + handle. If the compiler fails to compile the graph, it should return + None for the compiled function as well. + + `key` is required for StandaloneInductorAdapter, it specifies where to + save the compiled artifact. The compiled artifact gets saved to + `cache_dir/key`. + """ + return None, None + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + """ + Load the compiled function from the handle. + Raises an error if the handle is invalid. + + The handle is the second return value of the `compile` function. + """ + raise NotImplementedError("caching is not supported") + + +class AlwaysHitShapeEnv: + """ + Why do we need this class: + + For normal `torch.compile` usage, every compilation will have + one Dynamo bytecode compilation and one Inductor compilation. + The Inductor compilation happens under the context of the + Dynamo bytecode compilation, and that context is used to + determine the dynamic shape information, etc. + + For our use case, we only run Dynamo bytecode compilation once, + and run Inductor compilation multiple times with different shapes + plus a general shape. The compilation for specific shapes happens + outside of the context of the Dynamo bytecode compilation. At that + time, we don't have shape environment to provide to Inductor, and + it will fail the Inductor code cache lookup. + + By providing a dummy shape environment that always hits, we can + make the Inductor code cache lookup always hit, and we can + compile the graph for different shapes as needed. + + The following dummy methods are obtained by trial-and-error + until it works. + """ + + def __init__(self) -> None: + self.guards: list[Any] = [] + + def evaluate_guards_expression(self, *args, **kwargs): + return True + + def get_pruned_guards(self, *args, **kwargs): + return [] + + def produce_guards_expression(self, *args, **kwargs): + return "" + + +def get_inductor_factors() -> list[Any]: + factors: list[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + torch_factors = torch_key() + factors.append(torch_factors) + return factors + + +class InductorStandaloneAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler. + Requires PyTorch 2.8+. + This is not on by default yet, but we plan to turn it on by default for + PyTorch 2.8. + + Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off. + """ + name = "inductor_standalone" + + def compute_hash(self, vllm_config: VllmConfig) -> str: + factors = get_inductor_factors() + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + self.cache_dir = cache_dir + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_inductor_compiles += 1 + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) + set_inductor_config(current_config, runtime_shape) + + if isinstance(runtime_shape, int): + dynamic_shapes = "from_example_inputs" + else: + dynamic_shapes = "from_tracing_context" + + from torch._inductor import standalone_compile + with pass_context(runtime_shape): + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}) + + # Save the compiled artifact to disk in the specified path + assert key is not None + path = os.path.join(self.cache_dir, key) + compiled_graph.save(path=path, format="unpacked") + return compiled_graph, (key, path) + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + path = handle[1] + inductor_compiled_graph = torch._inductor.CompiledArtifact.load( + path=path, format="unpacked") + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + def compiled_graph_wrapper(*args): + graph_output = inductor_compiled_graph(*args) + # unpack the tuple if needed + # TODO(rzou): the implication is that we're not + # reading the python bytecode correctly in vLLM? + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph_wrapper + + +class InductorAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7. + """ + name = "inductor" + + def compute_hash(self, vllm_config: VllmConfig) -> str: + factors = get_inductor_factors() + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + self.cache_dir = cache_dir + if disable_cache: + return + # redirect the cache directory to a sub-directory + # set flags so that Inductor and Triton store their cache + # in the cache_dir, then users only need to copy the cache_dir + # to another machine to reuse the cache. + inductor_cache = os.path.join(cache_dir, "inductor_cache") + os.makedirs(inductor_cache, exist_ok=True) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache + triton_cache = os.path.join(cache_dir, "triton_cache") + os.makedirs(triton_cache, exist_ok=True) + os.environ["TRITON_CACHE_DIR"] = triton_cache + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_inductor_compiles += 1 + from torch._inductor.compile_fx import compile_fx + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) + + # disable remote cache + current_config["fx_graph_cache"] = True + current_config["fx_graph_remote_cache"] = False + + set_inductor_config(current_config, runtime_shape) + + # inductor can inplace modify the graph, so we need to copy it + # see https://github.com/pytorch/pytorch/issues/138980 + graph = copy.deepcopy(graph) + + # it's the first time we compile this graph + # the assumption is that we don't have nested Inductor compilation. + # compiled_fx_graph_hash will only be called once, and we can hook + # it to get the hash of the compiled graph directly. + + hash_str, file_path = None, None + from torch._inductor.codecache import (FxGraphCache, + compiled_fx_graph_hash) + if torch.__version__.startswith("2.5"): + original_load = FxGraphCache.load + original_load_name = "torch._inductor.codecache.FxGraphCache.load" + + def hijack_load(*args, **kwargs): + inductor_compiled_graph = original_load(*args, **kwargs) + nonlocal file_path + compiled_fn = inductor_compiled_graph.current_callable + file_path = compiled_fn.__code__.co_filename # noqa + if not file_path.startswith(self.cache_dir): + # hooked in the align_inputs_from_check_idxs function + # in torch/_inductor/utils.py + for cell in compiled_fn.__closure__: + if not callable(cell.cell_contents): + continue + if cell.cell_contents.__code__.co_filename.startswith( + self.cache_dir): + # this is the real file path compiled from Inductor + file_path = cell.cell_contents.__code__.co_filename + break + return inductor_compiled_graph + + hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa + elif torch.__version__ >= "2.6": + # function renamed in 2.6 + original_load_name = None + + def hijacked_compile_fx_inner(*args, **kwargs): + output = torch._inductor.compile_fx.compile_fx_inner( + *args, **kwargs) + nonlocal hash_str + inductor_compiled_graph = output + if inductor_compiled_graph is not None: + nonlocal file_path + compiled_fn = inductor_compiled_graph.current_callable + file_path = compiled_fn.__code__.co_filename # noqa + if not file_path.startswith(self.cache_dir): + # hooked in the align_inputs_from_check_idxs function + # in torch/_inductor/utils.py + for cell in compiled_fn.__closure__: + if not callable(cell.cell_contents): + continue + code = cell.cell_contents.__code__ + if code.co_filename.startswith(self.cache_dir): + # this is the real file path + # compiled from Inductor + file_path = code.co_filename + break + hash_str = inductor_compiled_graph._fx_graph_cache_key + return output + + def hijack_compiled_fx_graph_hash(*args, **kwargs): + out = compiled_fx_graph_hash(*args, **kwargs) + nonlocal hash_str + hash_str = out[0] + return out + + def _check_can_cache(*args, **kwargs): + # no error means it can be cached. + # Inductor refuses to cache the graph outside of Dynamo + # tracing context, and also disables caching for graphs + # with high-order ops. + # For vLLM, in either case, we want to cache the graph. + # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa + return + + def _get_shape_env() -> AlwaysHitShapeEnv: + return AlwaysHitShapeEnv() + + with ExitStack() as stack: + # hijack to get the compiled graph itself + if original_load_name is not None: + stack.enter_context(patch(original_load_name, hijack_load)) + + # for hijacking the hash of the compiled graph + stack.enter_context( + patch("torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash)) + + # for providing a dummy shape environment + stack.enter_context( + patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env)) + + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache) + + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + _get_shape_env)) + + # for forcing the graph to be cached + stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._check_can_cache", + _check_can_cache)) + + # Dynamo metrics context, see method for more details. + stack.enter_context(self.metrics_context()) + + # Disable remote caching. When these are on, on remote cache-hit, + # the monkey-patched functions never actually get called. + # vLLM today assumes and requires the monkey-patched functions to + # get hit. + # TODO(zou3519): we're going to replace this all with + # standalone_compile sometime. + if is_torch_equal_or_newer("2.6"): + stack.enter_context( + torch._inductor.config.patch(fx_graph_remote_cache=False)) + stack.enter_context( + torch._functorch.config.patch( + enable_remote_autograd_cache=False)) + + with pass_context(runtime_shape): + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config) + + # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch + # compilation cache. So turn off the checks if we disable the + # compilation cache. + if not envs.VLLM_DISABLE_COMPILE_CACHE: + if hash_str is None: + raise RuntimeError( + "vLLM failed to compile the model. The most " + "likely reason for this is that a previous compilation " + "failed, leading to a corrupted compilation artifact. " + "We recommend trying to " + "remove ~/.cache/vllm/torch_compile_cache and try again " + "to see the real issue. ") + assert file_path is not None, ( + "failed to get the file path of the compiled graph") + return compiled_graph, (hash_str, file_path) + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + hash_str = handle[0] + + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache) + from torch._inductor.codecache import FxGraphCache + with ExitStack() as exit_stack: + exit_stack.enter_context( + patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv())) + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + exit_stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv())) + + # Dynamo metrics context, see method for more details. + exit_stack.enter_context(self.metrics_context()) + + if torch.__version__.startswith("2.5"): + inductor_compiled_graph = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, False) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + elif torch.__version__ >= "2.6": + from torch._inductor.output_code import ( + CompiledFxGraphConstantsWithGm) + constants = CompiledFxGraphConstantsWithGm(graph) + inductor_compiled_graph, _ = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, None, constants) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + + # Inductor calling convention (function signature): + # f(list) -> tuple + # Dynamo calling convention (function signature): + # f(*args) -> Any + + # need to know if the graph returns a tuple + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + # this is the callable we return to Dynamo to run + def compiled_graph(*args): + # convert args to list + list_args = list(args) + graph_output = inductor_compiled_graph(list_args) + # unpack the tuple if needed + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph + + def metrics_context(self) -> contextlib.AbstractContextManager: + """ + This method returns the Dynamo metrics context (if it exists, + otherwise a null context). It is used by various compile components. + Present in torch>=2.6, it's used inside FxGraphCache in + torch==2.6 (but not after). It might also be used in various other + torch.compile internal functions. + + Because it is re-entrant, we always set it (even if entering via Dynamo + and the context was already entered). We might want to revisit if it + should be set at a different level of compilation. + + This is likely a bug in PyTorch: public APIs should not rely on + manually setting up internal contexts. But we also rely on non-public + APIs which might not provide these guarantees. + """ + if is_torch_equal_or_newer("2.6"): + import torch._dynamo.utils + return torch._dynamo.utils.get_metrics_context() + else: + return contextlib.nullcontext() + + +def set_inductor_config(config, runtime_shape): + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + config["max_autotune"] = True + config["coordinate_descent_tuning"] = True + + +class EagerAdaptor(CompilerInterface): + name = "eager" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_eager_compiles += 1 + # we don't need to compile the graph, just return the graph itself. + # It does not support caching, return None for the handle. + return graph, None diff --git a/compilation/counter.py b/compilation/counter.py new file mode 100644 index 0000000..165347c --- /dev/null +++ b/compilation/counter.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +import dataclasses +from contextlib import contextmanager + + +@dataclasses.dataclass +class CompilationCounter: + num_models_seen: int = 0 + num_graphs_seen: int = 0 + # including the splitting ops + num_piecewise_graphs_seen: int = 0 + # not including the splitting ops + num_piecewise_capturable_graphs_seen: int = 0 + num_backend_compilations: int = 0 + num_cudagraph_captured: int = 0 + # InductorAdapter.compile calls + num_inductor_compiles: int = 0 + # EagerAdapter.compile calls + num_eager_compiles: int = 0 + + def clone(self) -> "CompilationCounter": + return copy.deepcopy(self) + + @contextmanager + def expect(self, **kwargs): + old = self.clone() + yield + for k, v in kwargs.items(): + assert getattr(self, k) - getattr(old, k) == v, ( + f"{k} not as expected, before it is {getattr(old, k)}" + f", after it is {getattr(self, k)}, " + f"expected diff is {v}") + + +compilation_counter = CompilationCounter() diff --git a/compilation/cuda_piecewise_backend.py b/compilation/cuda_piecewise_backend.py new file mode 100644 index 0000000..8c49ea6 --- /dev/null +++ b/compilation/cuda_piecewise_backend.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch +import torch.fx as fx + +import vllm.envs as envs +from vllm.compilation.backends import VllmBackend +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import end_monitoring_torch_compile +from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.utils import weak_ref_tensors + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class ConcreteSizeEntry: + runtime_shape: int + need_to_compile: bool # the size is in compile_sizes + use_cudagraph: bool # the size is in cudagraph_capture_sizes + + compiled: bool = False + runnable: Callable = None # type: ignore + num_finished_warmup: int = 0 + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[list[int]] = None + + +class CUDAPiecewiseBackend: + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend): + """ + The backend for piecewise compilation. + It mainly handles the compilation and cudagraph capturing. + + We will compile `self.graph` once for the general shape, + and then compile for different shapes specified in + `compilation_config.compile_sizes`. + + Independently, we will capture cudagraph for different shapes. + + If a shape needs both compilation and cudagraph, we will + compile it first, and then capture cudagraph. + """ + self.graph = graph + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.graph_pool = graph_pool + self.piecewise_compile_index = piecewise_compile_index + self.total_piecewise_compiles = total_piecewise_compiles + self.vllm_backend = vllm_backend + + self.is_first_graph = piecewise_compile_index == 0 + self.is_last_graph = ( + piecewise_compile_index == total_piecewise_compiles - 1) + + self.compile_sizes: set[int] = set( + self.compilation_config.compile_sizes) + self.cudagraph_capture_sizes: set[int] = set( + self.compilation_config.cudagraph_capture_sizes + ) if self.compilation_config.use_cudagraph else set() + + self.first_run_finished = False + + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + + self.sym_shape_indices = sym_shape_indices + + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + # the entries for different shapes that we need to either + # compile or capture cudagraph + self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + + # to_be_compiled_sizes tracks the remaining sizes to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + self.concrete_size_entries[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=shape in self.compile_sizes, + use_cudagraph=shape in self.cudagraph_capture_sizes, + ) + + def check_for_ending_compilation(self): + if self.is_last_graph and not self.to_be_compiled_sizes: + # no specific sizes to compile + # save the hash of the inductor graph for the next run + self.vllm_backend.compiler_manager.save_to_file() + end_monitoring_torch_compile(self.vllm_config) + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + self.check_for_ending_compilation() + return self.compiled_graph_for_general_shape(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.compiled_graph_for_general_shape + + if entry.need_to_compile and not entry.compiled: + entry.compiled = True + self.to_be_compiled_sizes.remove(runtime_shape) + # args are real arguments + entry.runnable = self.vllm_backend.compiler_manager.compile( + self.graph, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + runtime_shape=runtime_shape) + + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + self.check_for_ending_compilation() + + # Skip CUDA graphs if this entry doesn't use them OR + # if we're supposed to skip them globally + skip_cuda_graphs = get_forward_context().skip_cuda_graphs + if not entry.use_cudagraph or skip_cuda_graphs: + return entry.runnable(*args) + + if entry.cudagraph is None: + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa + entry.num_finished_warmup += 1 + if self.is_first_graph: + logger.debug( + "Warming up %s/%s for shape %s", + entry.num_finished_warmup, + self.compilation_config.cudagraph_num_of_warmups, + runtime_shape) + return entry.runnable(*args) + + if self.is_first_graph: + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every shape. + # We only log it in the debug mode. + logger.debug("Capturing a cudagraph for shape %s", + runtime_shape) + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if not self.is_first_graph: + # during every model forward, we will capture + # many pieces of cudagraphs (roughly one per layer). + # running gc again and again across layers will + # make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) + + entry.cudagraph.replay() + return entry.output diff --git a/compilation/decorators.py b/compilation/decorators.py new file mode 100644 index 0000000..05e4ca9 --- /dev/null +++ b/compilation/decorators.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import inspect +from typing import Callable, Optional, TypeVar, Union, overload +from unittest.mock import patch + +import torch +import torch.nn as nn +from torch._dynamo.symbolic_convert import InliningInstructionTranslator + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import CompilationLevel, VllmConfig +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors +from vllm.utils import supports_dynamo + +from .monitor import start_monitoring_torch_compile + +logger = init_logger(__name__) + +_T = TypeVar("_T", bound=type[nn.Module]) + + +@overload +def support_torch_compile( + *, + dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]], +) -> Callable[[_T], _T]: + ... + + +@overload +def support_torch_compile(cls: _T) -> _T: + ... + + +def support_torch_compile( + cls: Optional[_T] = None, + *, + dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None, +) -> Union[Callable[[_T], _T], _T]: + """ + A decorator to add support for compiling the forward method of a class. + + Usage 1: use directly as a decorator without arguments: + + ```python + @support_torch_compile + class MyModel(nn.Module): + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): + ... + ``` + + Usage 2: use as a decorator with arguments: + + ```python + @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0}) + class MyModel(nn.Module): + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): + ... + ``` + + `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic + dimensions of the argument. The dynamic dimensions can be either a single + integer or a list of integers. + + if `dynamic_arg_dims` is `None`, it is inferred from the type annotation + of the `forward` method, based on the following default rules: + + - if the argument is annotated as `torch.Tensor` or + `Optional[torch.Tensor]`, the first dimension will be + marked as dynamic. + - if the argument is annotated as `IntermediateTensors`, the first + dimension of all the tensors in the intermediate tensors + will be marked as dynamic. + + During runtime, when we actually mark dimensions of tensors, + it depends on the value of arguments: + + - if it is a single integer (can be negative), the corresponding dimension + of the argument will be marked as dynamic. + - if it is `None`, ignored. + - if it is `IntermediateTensors`, all the tensors in the intermediate + tensors will be marked as dynamic. + - otherwise, it will raise an error. + + NOTE: if an argument is `None`, it should always be passed as `None` during + the lifetime of the model, otherwise, it cannot be captured as a single + computation graph. + """ + + def cls_decorator_helper(cls: _T) -> _T: + # helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` + # to avoid too much indentation for `_support_torch_compile`` + if not hasattr(cls, 'forward'): + raise TypeError("decorated class should have a forward method.") + sig = inspect.signature(cls.forward) + inferred_dynamic_arg_dims = dynamic_arg_dims + if inferred_dynamic_arg_dims is None: + inferred_dynamic_arg_dims = {} + for k, v in sig.parameters.items(): + if v.annotation in [ + torch.Tensor, Optional[torch.Tensor], + IntermediateTensors, Optional[IntermediateTensors] + ]: + inferred_dynamic_arg_dims[k] = 0 + + logger.debug(("Inferred dynamic dimensions for " + "forward method of %s: %s"), cls, + list(inferred_dynamic_arg_dims.keys())) + + if len(inferred_dynamic_arg_dims) == 0: + raise ValueError( + "No dynamic dimensions found in the forward method of " + f"{cls}. Please provide dynamic_arg_dims explicitly.") + + for k in inferred_dynamic_arg_dims: + if k not in sig.parameters: + raise ValueError( + f"Argument {k} not found in the forward method of {cls}") + return _support_torch_compile(cls, inferred_dynamic_arg_dims) + + if cls is not None: + # use `support_torch_compile` as a decorator without arguments + assert isinstance(cls, type) + return cls_decorator_helper(cls) + + return cls_decorator_helper + + +def _support_torch_compile( + cls: _T, + dynamic_arg_dims: dict[str, Union[int, list[int]]], +) -> _T: + """ + A decorator to add support for compiling the forward method of a class. + """ + if TorchCompileWrapperWithCustomDispatcher in cls.__bases__: + # support decorating multiple times + return cls + + # take care of method resolution order + # make sure super().__init__ is called on the base class + # other than TorchCompileWrapperWithCustomDispatcher + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) + + old_init = cls.__init__ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): + old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) + self.vllm_config = vllm_config + # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner + # will handle the compilation, so we don't need to do anything here. + self.do_not_compile = \ + vllm_config.compilation_config.level in [ + CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS + ] or not supports_dynamo() + if self.do_not_compile: + return + compilation_counter.num_models_seen += 1 + TorchCompileWrapperWithCustomDispatcher.__init__( + self, compilation_level=vllm_config.compilation_config.level) + + cls.__init__ = __init__ + + def __call__(self, *args, **kwargs): + # torch.compiler.is_compiling() means we are inside the compilation + # e.g. TPU has the compilation logic in model runner, so we don't + # need to compile the model inside. + if self.do_not_compile or torch.compiler.is_compiling(): + return self.forward(*args, **kwargs) + + # the first compilation needs to have dynamic shapes marked + if len(self.compiled_codes) < 1: + sig = inspect.signature(self.__class__.forward) + bound_args = sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + for k, dims in dynamic_arg_dims.items(): + arg = bound_args.arguments.get(k) + if arg is not None: + dims = [dims] if isinstance(dims, int) else dims + if isinstance(arg, torch.Tensor): + # In case dims is specified with negative indexing + dims = [ + arg.ndim + dim if dim < 0 else dim for dim in dims + ] + torch._dynamo.mark_dynamic(arg, dims) + elif isinstance(arg, IntermediateTensors): + for tensor in arg.tensors.values(): + # In case dims is specified with negative indexing + dims = [ + tensor.ndim + dim if dim < 0 else dim + for dim in dims + ] + torch._dynamo.mark_dynamic(tensor, dims) + else: + raise ValueError( + "Unsupported dynamic dimensions" + f" {dims} for argument {k} with type {type(arg)}.") + # here, it is the starting point of the `torch.compile` process + start_monitoring_torch_compile(self.vllm_config) + logger.debug("Start compiling function %s", + self.original_code_object) + + # if we don't use custom dispatcher, we can directly call the + # compiled function and let torch.compile handle the dispatching, + # with the overhead of guard evaluation and recompilation. + if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: + # it seems Dynamo reuse the compilation across instances, + # while we need to make sure the compiled code is not reused. + # we need to control all the compilation of the model. + torch._dynamo.eval_frame.remove_from_cache( + self.original_code_object) + + # collect all relevant files traced by Dynamo, + # so that the compilation cache can trigger re-compilation + # properly when any of these files change. + + # 1. the file containing the top-level forward function + self.vllm_config.compilation_config.traced_files.add( + self.original_code_object.co_filename) + + # 2. every time Dynamo sees a function call, it will inline + # the function by calling InliningInstructionTranslator.inline_call + # we hijack this function to know all the functions called + # during Dynamo tracing, and their corresponding files + inline_call = InliningInstructionTranslator.inline_call + + def patched_inline_call(parent, func, args, kwargs): + code = func.get_code() + self.vllm_config.compilation_config.traced_files.add( + code.co_filename) + return inline_call(parent, func, args, kwargs) + + with patch.object(InliningInstructionTranslator, 'inline_call', + patched_inline_call): + output = self.compiled_callable(*args, **kwargs) + return output + + # usually, capturing the model once is enough, and then we can + # dispatch to the compiled code directly, without going through + # the Dynamo guard mechanism. + with self.dispatch_to_code(0): + model_output = self.forward(*args, **kwargs) + return model_output + + cls.__call__ = __call__ + return cls diff --git a/compilation/fix_functionalization.py b/compilation/fix_functionalization.py new file mode 100644 index 0000000..286221d --- /dev/null +++ b/compilation/fix_functionalization.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import operator +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +from vllm.logger import init_logger + +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class FixFunctionalizationPass(VllmInductorPass): + """ + This pass defunctionalizes certain nodes to avoid redundant tensor copies. + After this pass, DCE (dead-code elimination) should never be run, + as de-functionalized nodes may appear as dead code. + + To add new nodes to defunctionalize, add to the if-elif chain in __call__. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_fix_functionalization") + + self.nodes_to_remove: list[torch.fx.Node] = [] + count = 0 + for node in graph.nodes: + if not is_func(node, auto_functionalized): + continue # Avoid deep if-elif nesting + + kwargs = node.kwargs + at_target = node.args[0] + + if at_target == torch.ops._C.rotary_embedding.default: + query = kwargs['query'] + mm_node = query.args[0].args[0] + + # rotary_embedding is a special case: the two mutating inputs + # are query and key, which are slices of mm_node. + # While functionalized, results at[1] and at[2] are scattered + # back into mm_node. After de-functionalization, we can just + # use mm_node directly. + for idx, user in self.getitem_users(node).items(): + for user_of_getitem in user.users: + if is_func(user_of_getitem, + torch.ops.aten.slice_scatter.default): + user_of_getitem.replace_all_uses_with(mm_node) + self._remove(user_of_getitem) + self._remove(user) + + self.insert_defunctionalized(graph, node) + self._remove(node) + + # rms_norm replacements avoid the most copies for LLaMa. + elif at_target == torch.ops._C.fused_add_rms_norm.default: + mutated_args = {1: 'input', 2: 'residual'} + self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 + mutated_args = {1: 'result', 2: 'residual'} + self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 + mutated_args = {1: 'result', 2: 'scale', 3: 'residual'} + self.defunctionalize(graph, node, mutated_args) + elif at_target in [ + torch.ops._C.rms_norm.default, + torch.ops._C.rms_norm_static_fp8_quant.default, + ]: + mutated_args = {1: 'result'} + self.defunctionalize(graph, node, mutated_args) + # For some reason we need to specify the args for both + # silu_and_mul and silu_and_mul_quant. The kwargs + # pathway gets the wrong answer. + elif at_target == torch.ops._C.silu_and_mul.default: + mutated_args = {1: 'result'} + self.defunctionalize(graph, + node, + mutated_args, + args=('result', 'input')) + elif at_target == torch.ops._C.silu_and_mul_quant.default: + mutated_args = {1: 'result'} + self.defunctionalize(graph, + node, + mutated_args, + args=('result', 'input', 'scale')) + else: + continue # skip the count + + count += 1 + + self.dump_graph(graph, "before_fix_functionalization_cleanup") + + # Remove the nodes all at once + count_removed = len(self.nodes_to_remove) + for node in self.nodes_to_remove: + graph.erase_node(node) + + logger.debug("De-functionalized %s nodes, removed %s nodes", count, + count_removed) + self.dump_graph(graph, "after_fix_functionalization") + self.end_and_log() + + def _remove(self, node_or_nodes: Union[torch.fx.Node, + Iterable[torch.fx.Node]]): + """ + Stage a node (or nodes) for removal at the end of the pass. + """ + if isinstance(node_or_nodes, torch.fx.Node): + self.nodes_to_remove.append(node_or_nodes) + else: + self.nodes_to_remove.extend(node_or_nodes) + + def defunctionalize(self, + graph: torch.fx.Graph, + node: torch.fx.Node, + mutated_args: dict[int, Union[torch.fx.Node, str]], + args: Optional[tuple[Union[torch.fx.Node, str], + ...]] = None): + """ + De-functionalize a node by replacing it with a call to the original. + It also replaces the getitem users with the mutated arguments. + See replace_users_with_mutated_args and insert_defunctionalized. + """ + self.replace_users_with_mutated_args(node, mutated_args) + self.insert_defunctionalized(graph, node, args=args) + self._remove(node) + + def replace_users_with_mutated_args(self, node: torch.fx.Node, + mutated_args: dict[int, + Union[torch.fx.Node, + str]]): + """ + Replace all getitem users of the auto-functionalized node with the + mutated arguments. + :param node: The auto-functionalized node + :param mutated_args: The mutated arguments, indexed by getitem index. + If the value of an arg is a string, `node.kwargs[arg]` is used. + """ + for idx, user in self.getitem_users(node).items(): + arg = mutated_args[idx] + arg = node.kwargs[arg] if isinstance(arg, str) else arg + user.replace_all_uses_with(arg) + self._remove(user) + + def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]: + """ + Returns the operator.getitem users of the auto-functionalized node, + indexed by the index they are getting. + """ + users = {} + for user in node.users: + if is_func(user, operator.getitem): + idx = user.args[1] + users[idx] = user + return users + + def insert_defunctionalized(self, + graph: torch.fx.Graph, + node: torch.fx.Node, + args: Optional[tuple[Union[torch.fx.Node, str], + ...]] = None): + """ + Insert a new defunctionalized node into the graph before node. + If one of the kwargs is 'out', provide args directly, + as node.kwargs cannot be used. + See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 + + :param graph: Graph to insert the defunctionalized node into + :param node: The auto-functionalized node to defunctionalize + :param args: If we cannot use kwargs, specify args directly. + If an arg is a string, `node.kwargs[arg]` is used. + """ # noqa: E501 + assert is_func(node, auto_functionalized), \ + f"node must be auto-functionalized, is {node} instead" + + # Create a new call to the original function + with graph.inserting_before(node): + function = node.args[0] + if args is None: + graph.call_function(function, kwargs=node.kwargs) + else: + # Args passed as strings refer to items in node.kwargs + args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg + for arg in args) + graph.call_function(function, args=args) diff --git a/compilation/fusion.py b/compilation/fusion.py new file mode 100644 index 0000000..7e2c5b4 --- /dev/null +++ b/compilation/fusion.py @@ -0,0 +1,618 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, NamedTuple, Optional + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._ops import OpOverload + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .fx_utils import find_getitem_maybe +from .multi_output_match import MultiOutputMatch +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() + + +def empty_bf16(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + + +def empty_fp32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + + +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + + +class QuantKey(NamedTuple): + """ + Named tuple for identifying the type of quantization. + dtype: quantized data type + static: static quantization if True, dynamic if False + per_tensor: per-tensor quantization if True, per-token if False + symmetric: symmetric if True, asymmetric if False + """ + dtype: torch.dtype + static: bool + per_tensor: bool = True + symmetric: bool = True + + def __str__(self): + return (f"QuantKey({'static' if self.static else 'dynamic'}," + f"{fx.graph.dtype_abbrs[self.dtype]}," + f"{'per_tensor' if self.per_tensor else 'per_token'}," + f"{'a' if not self.symmetric else ''}symmetric)") + + +kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True) +kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True) +kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True) + +QUANT_OPS: dict[QuantKey, OpOverload] = { + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa + kFp8DynamicTensorSym: + torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa + kFp8DynamicTokenSym: + torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa +} + + +class FusedRMSQuantKey(NamedTuple): + """ + Named tuple for identifying the type of RMSNorm + quant fusion. + quant: type of quantization + fused_add: does the op also perform the residual add + """ + quant: QuantKey + fused_add: bool + + def __str__(self): + return (f"FusedQuantKey({self.quant}, with" + f"{'' if self.fused_add else 'out'} residual)") + + +FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { + FusedRMSQuantKey(kFp8StaticTensorSym, False): + torch.ops._C.rms_norm_static_fp8_quant.default, # noqa + FusedRMSQuantKey(kFp8StaticTensorSym, True): + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa + FusedRMSQuantKey(kFp8DynamicTokenSym, False): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa + FusedRMSQuantKey(kFp8DynamicTokenSym, True): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa +} + + +class QuantMultiOutputMatch(MultiOutputMatch): + + def __init__(self, match: pm.Match, quant_op, fused_op): + super().__init__(match) + assert isinstance(quant_op, OpOverload) + assert isinstance(fused_op, OpOverload) + self.QUANT_OP = quant_op # in-place quant op + self.FUSED_OP = fused_op # in-place fused quant op + + def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node, + int]], + **kwargs): + """ + This utility function inserts an auto-functionalized node for FUSED_OP. + It also correctly sets its meta value and rebinds the users of the + unfused nodes to use the fused node instead. + + :param fused_return_mapping: A dictionary, mapping from getitem indices + of the fused node result to a tuple of the old node and a getitem index. + :param kwargs: kwargs that get directly forwarded to the auto_fn node + + Example: + If we want to replace this graph: + _, x1, x2 = auto_fn(op1) + _, y1, y2 = auto_fn(op2) + + with + _, x1, y2, x2 = auto_fn(FUSED_OP) + + we would call: + insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)} + + Note that the 0th element is None for auto-functionalized in-place ops. + Hence, others appear 1-indexed. + """ + fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) + indices = fused_return_mapping.keys() + getitem_nodes = self.insert_getitems(fused_node, indices) + + # Prepare the meta value, use a list so it's mutable + meta_val = [None] * (max(indices) + 1) + + # Iterate through elements of the tuple produced by fused_node + for idx, getitem_node in zip(indices, getitem_nodes): + old_node, old_idx = fused_return_mapping[idx] + + # If the old value was never used, the old_getitem might not exist + old_getitem = find_getitem_maybe(old_node, old_idx) + if old_getitem is not None: + # Rebind the users of match getitem nodes to use the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + old_getitem.replace_all_uses_with(getitem_node) + getitem_node.meta["val"] = old_getitem.meta["val"] + + # Extract the appropriate meta value + # It is present even if the getitem node does not exist + meta_val[idx] = old_node.meta["val"][old_idx] + + # Fix the meta value on the new fused node + fused_node.meta["val"] = tuple(meta_val) + + +class RMSNormQuantPattern: + + def __init__(self, epsilon: float, key: FusedRMSQuantKey): + self.epsilon = epsilon + self.quant_dtype = key.quant.dtype + + assert key.quant in QUANT_OPS, \ + f"unsupported quantization scheme {key.quant}" + self.QUANT_OP = QUANT_OPS[key.quant] + + assert key in FUSED_OPS, \ + f"unsupported fused rmsnorm+quant op for {key}" + self.FUSED_OP = FUSED_OPS[key] + + +class RMSNormStaticQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + fused_key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + static=True, + per_tensor=True, + symmetric=symmetric)) + super().__init__(epsilon, fused_key) + + def register(self, pm_pass: PatternMatcherPass): + # Cannot use methods, as the self argument affects tracing + def pattern(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon) + at2 = auto_functionalized(self.QUANT_OP, + result=result, + input=at1[1], + scale=scale) + + # result + return at2[1] + + def replacement(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result + return at[1] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, + pm_pass) + + +class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + static=True, + per_tensor=True, + symmetric=symmetric)) + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + at1 = auto_functionalized(self.QUANT_OP, + result=result, + input=at[1], + scale=scale) + + # result, residual + return at1[1], at[2] + + def replacement(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result, residual + return at[1], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_ADD_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 1 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract the result and residual. + # The auto_fn node returns a tuple of (None, result, residual). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa + # result_node_new = at[1] + # residual_node_new = at[2] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + + # 0 is always None + fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} + self.insert_fused_node(fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + **kwargs) + + +class RMSNormDynamicQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + per_tensor: bool, + symmetric=True): + key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + static=False, + per_tensor=per_tensor, + symmetric=symmetric)) + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon) + at2 = auto_functionalized(self.QUANT_OP, + result=result, + input=at1[1], + scale=scale, + scale_ub=None) + + # result, scale + return at2[1], at2[2] + + def replacement(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None) + + # result, scale + return at[1], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) + + assert len(rms_node.users) == 1 + assert len(quant_node.users) == 2 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract the result and scale. + # The auto_fn node returns a tuple of (None, result, scale). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa + # result_node_new = at[1] + # scale_node_new = at[2] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + del kwargs["result_rms"] # not used in the fused op + + fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)} + self.insert_fused_node( + fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + scale_ub=None, # not used but required + residual=None, # not used but required + **kwargs) + + +class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + per_tensor: bool = True, + symmetric=True): + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + static=False, + per_tensor=per_tensor, + symmetric=symmetric)) + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + at1 = auto_functionalized(self.QUANT_OP, + result=result, + input=at[1], + scale=scale, + scale_ub=None) + + # result, residual, scale + return at1[1], at[2], at1[2] + + def replacement(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual) + + # result, residual, scale + return at[1], at[3], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_ADD_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 2 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract result, scale, and residual. + # The auto_fn node returns a tuple (None, result, scale, residual). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa + # result_node_new = at[1] + # scale_node_new = at[2] + # residual_node_new = at[3] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + + fused_return_mapping = { + 1: (quant_node, 1), # result + 2: (quant_node, 2), # scale + 3: (rms_node, 2), # residual + } + self.insert_fused_node( + fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + scale_ub=None, # not used but required + **kwargs) + + +class FusionPass(VllmInductorPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + It also manually processes multi-output matches, as those are broken in + the torch pattern matcher. + + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ + + _instance: 'Optional[FusionPass]' = None + + @classmethod + def instance(cls, config: VllmConfig): + """ + Get the singleton instance of the FusionPass. + If the instance exists, the config is updated but + initialization is not repeated. + """ + if cls._instance is None: + cls._instance = FusionPass(config) + else: + cls._instance.pass_config = config.compilation_config.pass_config + return cls._instance + + def __init__(self, config: VllmConfig): + assert self.__class__._instance is None, \ + "FusionPass singleton instance already exists" + super().__init__(config) + + self.matches: list[MultiOutputMatch] = [] + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="fusion_pass") + + for epsilon in [1e-5, 1e-6]: + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) + + # Matches for patterns below have 2 or more outputs, + # so we need to process them manually (see process_matches) + + # Fuse rms_norm + static fp8 quant + FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns, self.record_match) + + # Fuse rms_norm + dynamic per-token fp8 quant + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE, + per_tensor=False).register( + self.patterns, self.record_match) + + # Fuse fused_add_rms_norm + dynamic per-token fp8 quant + FusedAddRMSNormDynamicQuantPattern(epsilon, + FP8_DTYPE, + per_tensor=False).register( + self.patterns, + self.record_match) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + def record_match(self, match: MultiOutputMatch) -> bool: + # Hijack the extra_check to record the match and + # save it for post-processing. + self.matches.append(match) + + # Return False to prevent automatic replacement. + return False + + def process_matches(self, graph: fx.Graph): + """ + Manually process multi-output matches and replace them with fused nodes. + See MultiOutputMatch for more details. + """ + for match in self.matches: + match.process() + + # Finally, remove matched nodes + graph.eliminate_dead_code() + assert all(node not in graph.nodes for match in self.matches + for node in match.match.nodes) + + def __call__(self, graph: fx.Graph): + self.begin() + self.dump_graph(graph, "before_fusion") + + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_pattern_match") + + # Manually process multi-output matches (and run DCE) + self.process_matches(graph) + logger.debug("Post-processed %s matches", len(self.matches)) + self.dump_graph(graph, "after_fusion") + self.matches.clear() + self.end_and_log() diff --git a/compilation/fx_utils.py b/compilation/fx_utils.py new file mode 100644 index 0000000..9ef3889 --- /dev/null +++ b/compilation/fx_utils.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import operator +from collections.abc import Iterable +from typing import Optional + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._ops import OpOverload + + +def is_func(node: fx.Node, target) -> bool: + return node.op == "call_function" and node.target == target + + +# Returns the first specified node with the given op (if it exists) +def find_specified_fn_maybe(nodes: Iterable[fx.Node], + op: OpOverload) -> Optional[fx.Node]: + for node in nodes: + if node.target == op: + return node + return None + + +# Returns the first specified node with the given op +def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_specified_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the first auto_functionalized node with the given op (if it exists) +def find_auto_fn_maybe(nodes: Iterable[fx.Node], + op: OpOverload) -> Optional[fx.Node]: + for node in nodes: + if is_func(node, auto_functionalized) and node.args[0] == op: # noqa + return node + return None + + +# Returns the first auto_functionalized node with the given op +def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_auto_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the getitem node that extracts the idx-th element from node +# (if it exists) +def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]: + for user in node.users: + if is_func(user, operator.getitem) and user.args[1] == idx: + return user + return None + + +# Returns the getitem node that extracts the idx-th element from node +def find_getitem(node: fx.Node, idx: int) -> fx.Node: + ret = find_getitem_maybe(node, idx) + assert ret is not None, f"Could not find getitem {idx} in node {node}" + return ret diff --git a/compilation/inductor_pass.py b/compilation/inductor_pass.py new file mode 100644 index 0000000..810d080 --- /dev/null +++ b/compilation/inductor_pass.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import inspect +import json +import types +from contextlib import contextmanager +from typing import Any, Callable, Optional, Union + +import torch +from torch import fx + +from vllm.utils import is_torch_equal_or_newer + +if is_torch_equal_or_newer("2.6"): + from torch._inductor.custom_graph_pass import CustomGraphPass +else: + # CustomGraphPass is not present in 2.5 or lower, import our version + from .torch25_custom_graph_pass import ( # noqa: E501 + Torch25CustomGraphPass as CustomGraphPass) + +_pass_context = None + + +class PassContext: + + def __init__(self, runtime_shape: Optional[int]): + self.runtime_shape = runtime_shape + + +def get_pass_context() -> PassContext: + """Get the current pass context.""" + assert _pass_context is not None + return _pass_context + + +@contextmanager +def pass_context(runtime_shape: Optional[int]): + """A context manager that stores the current pass context, + usually it is a list of sizes to specialize. + """ + global _pass_context + prev_context = _pass_context + _pass_context = PassContext(runtime_shape) + try: + yield + finally: + _pass_context = prev_context + + +class InductorPass(CustomGraphPass): + """ + A custom graph pass that uses a hash of its source as the UUID. + This is defined as a convenience and should work in most cases. + """ + + def uuid(self) -> Any: + """ + Provide a unique identifier for the pass, used in Inductor code cache. + This should depend on the pass implementation, so that changes to the + pass result in recompilation. + By default, the object source is hashed. + """ + return InductorPass.hash_source(self) + + @staticmethod + def hash_source(*srcs: Union[str, Any]): + """ + Utility method to hash the sources of functions or objects. + :param srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. + :return: + """ + hasher = hashlib.sha256() + for src in srcs: + if isinstance(src, str): + src_str = src + elif isinstance(src, types.FunctionType): + src_str = inspect.getsource(src) + else: + src_str = inspect.getsource(src.__class__) + hasher.update(src_str.encode("utf-8")) + return hasher.hexdigest() + + @staticmethod + def hash_dict(dict_: dict[Any, Any]): + """ + Utility method to hash a dictionary, can alternatively be used for uuid. + :return: A sha256 hash of the json rep of the dictionary. + """ + encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + def is_applicable_for_shape(self, shape: Optional[int]): + return True + + +class CallableInductorPass(InductorPass): + """ + This class is a wrapper for a callable that automatically provides an + implementation of the UUID. + """ + + def __init__(self, + callable: Callable[[fx.Graph], None], + uuid: Optional[Any] = None): + self.callable = callable + self._uuid = self.hash_source(callable) if uuid is None else uuid + + def __call__(self, graph: torch.fx.Graph): + self.callable(graph) + + def uuid(self) -> Any: + return self._uuid diff --git a/compilation/monitor.py b/compilation/monitor.py new file mode 100644 index 0000000..1e059b5 --- /dev/null +++ b/compilation/monitor.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import time + +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + +context_manager = None +torch_compile_start_time: float = 0.0 + + +def start_monitoring_torch_compile(vllm_config: VllmConfig): + global torch_compile_start_time + torch_compile_start_time = time.time() + + compilation_config: CompilationConfig = vllm_config.compilation_config + if compilation_config.level == CompilationLevel.PIECEWISE and \ + compilation_config.debug_dump_path: + import depyf + path = os.path.join(compilation_config.debug_dump_path, + f"rank_{vllm_config.parallel_config.rank}") + global context_manager + context_manager = depyf.prepare_debug(path) + context_manager.__enter__() + + +def end_monitoring_torch_compile(vllm_config: VllmConfig): + compilation_config: CompilationConfig = vllm_config.compilation_config + if compilation_config.level == CompilationLevel.PIECEWISE: + logger.info("torch.compile takes %.2f s in total", + compilation_config.compilation_time) + global context_manager + if context_manager is not None: + context_manager.__exit__(None, None, None) + context_manager = None diff --git a/compilation/multi_output_match.py b/compilation/multi_output_match.py new file mode 100644 index 0000000..6d18937 --- /dev/null +++ b/compilation/multi_output_match.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import abc +import operator +from abc import abstractmethod +from collections.abc import Iterable + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor import pattern_matcher as pm +from torch._ops import OpOverload +from torch.fx import Node + +from vllm.compilation.fx_utils import find_auto_fn + + +class MultiOutputMatch(abc.ABC): + """ + This class provides utilities to process multi-output matches and + manually insert replacements. + + This is necessary because the automatic replacement for multi-output + matches is broken: https://github.com/pytorch/pytorch/issues/137280 + """ + + def __init__(self, match: pm.Match): + self.match = match + + @abstractmethod + def process(self): + """ + Process a multi-output match and manually insert the replacement. + + This method should: + 1. Insert the replacement nodes after the last node in the match. + 2. Rebind the users of nodes in the match to use the new nodes. + 3. Set meta["val"] for de-functionalization. + + The result of an auto-functionalized node is a tuple of tensors. + The first element is the return value of the function, usually None. + The remaining elements are the mutated args of the function. + + All auto-functionalized nodes must contain a proper meta["val"], + as it is used by de-functionalization. meta["val"] has to contain the + value of the node (tuple of tensors) that would be returned by the + functionalized node during tracing. + + Existing nodes in the graph all have this property set, but we have + to set it manually for new nodes we insert. + + Example: + # op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None + at = auto_functionalized(torch.ops._C.foo.default, a, b, c) + # at.meta["val"] = (None, a, c) + """ + raise NotImplementedError + + @property + def nodes(self) -> list[fx.Node]: + return self.match.nodes + + @property + def graph(self) -> fx.Graph: + return self.match.graph + + def find_auto_fn(self, op) -> fx.Node: + """ + Find the first auto_functionalized node with the given op in the match. + """ + return find_auto_fn(self.nodes, op) + + def inserting_after_match(self): + """ + Insert nodes after the last node in the match. + This is done to avoid use-before-definition errors after inserting + replacement nodes. + """ + + # match.nodes is not guaranteed to be sorted. + # Find the last node in the match. + for last_node_in_match in reversed(self.graph.nodes): + if last_node_in_match in self.match.nodes: + break + else: + raise ValueError("No nodes in graph") + + return self.graph.inserting_after(last_node_in_match) + + def insert_getitems(self, tuple_node: fx.Node, + indices: Iterable[int]) -> tuple[fx.Node, ...]: + """ + Insert operator.getitem nodes to extract elements from a tuple node. + + :param tuple_node: The tuple node to extract elements from. + :param indices: The indices of the elements to extract. + :return: Tuple of the new getitem nodes, corresponding to the indices. + """ + with self.graph.inserting_after(tuple_node): + return tuple( + self.graph.call_function(operator.getitem, (tuple_node, idx)) + for idx in indices) + + def insert_auto_fn(self, op: OpOverload, kwargs) -> Node: + """ + Insert an auto_functionalized node with the given op and kwargs. + """ + return self.graph.call_function(auto_functionalized, (op, ), + kwargs=kwargs) diff --git a/compilation/noop_elimination.py b/compilation/noop_elimination.py new file mode 100644 index 0000000..46f70dc --- /dev/null +++ b/compilation/noop_elimination.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Union + +import torch.fx +from torch import SymInt + +from vllm.logger import init_logger + +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class NoOpEliminationPass(VllmInductorPass): + """ + This is an inductor pass that removes redundant reshape/slice operations. + It is required for RMSNorm-quant fusion to work properly. + That's because apply_fp8_linear adds a reshape, which is redundant + in the 2D-case. Additionally, torch internal no-op elimination pass does + not handle certain slice variants. + + Example graph 1: + getitem_1: "f16[s0, 4096]" = ... + view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) + at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Can be replaced with: + getitem_1: "f16[s0, 4096]" = ... + at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Example graph 2: + arg0: "s0" = SymInt(s0) + scaled_mm: "f16[s0, 4096]" = ... + slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0) + at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...) + out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0) + + Can be replaced with: + arg0: "s0" = SymInt(s0) + scaled_mm: "f16[s0, 4096]" = ... + at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...) + out: "f16[s0, 4096]" = at[1] + + TODO(luka): This is currently tested in test_fusion, + but separate tests could be good. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_noop_elimination") + count = 0 + # Remove no-op reshapes/views: + for node in graph.nodes: + if is_func(node, torch.ops.aten.reshape.default): + input, shape = node.args[:2] + input_shape = input.meta["val"].shape + if len(shape) != len(input_shape): + # Reshape changing rank, skip + continue + + if shape.count(-1) > 1: + # Invalid reshape args, skip + continue + + if self.all_dims_equivalent(shape, input_shape): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + elif is_func(node, torch.ops.aten.slice.Tensor): + input, dim_index, start, end = node.args[:4] + input_shape = input.meta["val"].shape + i_dim = input_shape[dim_index] + + if start == 0 and self.dims_equivalent(end, i_dim): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + elif is_func(node, torch.ops.aten.slice_scatter.default): + base, view, dim_index, start, end = node.args[:5] + base_shape = base.meta["val"].shape + view_shape = view.meta["val"].shape + + view_dim = view_shape[dim_index] + + # Check that view fully covers base and the full view is used + # (if the view fully covered the base after slicing but was not + # fully used, we could replace slice_scatter with a simple slice + # but that's a niche case). + if (base_shape == view_shape and start == 0 + and self.dims_equivalent(end, view_dim)): + node.replace_all_uses_with(view) + graph.erase_node(node) + count += 1 + + logger.debug("Removed %s no-op reshapes and slices", count) + self.dump_graph(graph, "after_noop_elimination") + self.end_and_log() + + def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]], + i_dims: Iterable[Union[int, SymInt]]): + return all( + self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) + + def dims_equivalent(self, dim: Union[int, torch.fx.Node], + i_dim: Union[int, SymInt]) -> bool: + """ + This function checks if two dimensions are equivalent. + :param dim: The dimension arg to reshape/slice + :param i_dim: The corresponding dimension in the input tensor + :return: Are the dimensions equivalent? + + There are three cases in which the dimensions are equivalent: + 1. The dimensions are equal (both integers) + 2. The reshape dimension is -1 (i.e. inferred) + 3. The dimensions both correspond to the same SymInt + + While case 2 does not guarantee the dimensions are equal, + they are equal if all other dimensions are equal. + + In case 3, the reshape dimension is a torch.fx.Node, + and its value is a SymInt. That value is equal to the + input dimension. + + """ + # Case 1 and 2 + if dim == i_dim or dim == -1: + return True + # Case 3 + return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim diff --git a/compilation/pass_manager.py b/compilation/pass_manager.py new file mode 100644 index 0000000..621c89a --- /dev/null +++ b/compilation/pass_manager.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from torch import fx as fx + +from vllm.config import VllmConfig +from vllm.logger import init_logger + +from .activation_quant_fusion import ActivationQuantFusionPass +from .collective_fusion import AsyncTPPass +from .fix_functionalization import FixFunctionalizationPass +from .fusion import FusionPass +from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context +from .noop_elimination import NoOpEliminationPass +from .sequence_parallelism import SequenceParallelismPass +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class PostGradPassManager(CustomGraphPass): + """ + The pass manager for post-grad passes. + It handles configuration, adding custom passes, and running passes. + It supports uuid for the Inductor code cache. That includes torch<2.6 + support using pickling (in .inductor_pass.CustomGraphPass). + + The order of the post-grad post-passes is: + 1. passes (constructor parameter) + 2. default passes (NoopEliminationPass, FusionPass) + 3. config["post_grad_custom_post_pass"] (if it exists) + 4. fix_functionalization + This way, all passes operate on a functionalized graph. + """ + + def __init__(self): + self.passes: list[VllmInductorPass] = [] + + def __call__(self, graph: fx.Graph): + shape = get_pass_context().runtime_shape + for pass_ in self.passes: + if pass_.is_applicable_for_shape(shape): + pass_(graph) + + # always run fix_functionalization last + self.fix_functionalization(graph) + + def configure(self, config: VllmConfig): + self.pass_config = config.compilation_config.pass_config + if self.pass_config.enable_noop: + self.passes += [NoOpEliminationPass(config)] + + if self.pass_config.enable_fusion: + self.passes += [FusionPass.instance(config)] + self.passes += [ActivationQuantFusionPass(config)] + + if self.pass_config.enable_sequence_parallelism: + self.passes += [SequenceParallelismPass(config)] + if self.pass_config.enable_async_tp: + self.passes += [AsyncTPPass(config)] + + self.fix_functionalization = FixFunctionalizationPass(config) + + def add(self, pass_: InductorPass): + assert isinstance(pass_, InductorPass) + self.passes.append(pass_) + + def uuid(self): + """ + The PostGradPassManager is set as a custom pass in the Inductor and + affects compilation caching. Its uuid depends on the UUIDs of all + dependent passes and the pass config. See InductorPass for more info. + """ + state = {"pass_config": self.pass_config.uuid(), "passes": []} + for pass_ in self.passes: + state["passes"].append(pass_.uuid()) + state["passes"].append(self.fix_functionalization.uuid()) + return InductorPass.hash_dict(state) diff --git a/compilation/sequence_parallelism.py b/compilation/sequence_parallelism.py new file mode 100644 index 0000000..d410939 --- /dev/null +++ b/compilation/sequence_parallelism.py @@ -0,0 +1,268 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm.config import VllmConfig +from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class AllReduceRMSNormPattern: + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + self.epsilon = epsilon + self.dtype = dtype + self.device = device + + +class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern): + + def get_inputs(self): + arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], + device=self.device, + dtype=torch.long) + unsqueeze = torch.rand([1, 8, 1], device=self.device, \ + dtype=self.dtype) > 0.5 + full_default = torch.zeros([1, 8, 4], device=self.device, \ + dtype=self.dtype) + permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) + arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) + + return [arg2_1, mul_6, unsqueeze, full_default, permute, arg3_1] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + permute: torch.Tensor, + arg3_1: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + all_reduce = tensor_model_parallel_all_reduce(where) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=permute, + input=all_reduce, + weight=arg3_1, + epsilon=self.epsilon, + ) + + return rmsnorm[1], all_reduce + + def replacement( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + permute: torch.Tensor, + arg3_1: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + where, dim=0, world_size=tp_size, group_name=tp.unique_name) + + rmsnorm_result = torch.empty_like(reduce_scatter) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=rmsnorm_result, + input=reduce_scatter, + weight=arg3_1, + epsilon=self.epsilon, + ) + + all_gather = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + + return all_gather, reduce_scatter + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + + return [ + residual, + mm_1, + rms_norm_weights, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + return rmsnorm[1], rmsnorm[2] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) + + # TODO is it possible to extract epsilon from somewhere + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=reduce_scatter, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + all_gather = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + return all_gather, rmsnorm[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + + return [ + residual, + mm_1, + rms_norm_weights, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + return rmsnorm[1] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) + + # TODO is it possible to extract epsilon from somewhere + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=reduce_scatter, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + normalized = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + + return normalized + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class SequenceParallelismPass(VllmInductorPass): + + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="sequence_parallelism_pass") + for epsilon in [1e-5, 1e-6]: + EmbeddingAllReduceRMSNormPattern( + epsilon, self.model_dtype, self.device).register(self.patterns) + + MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype, + self.device).register(self.patterns) + + LastAllReduceRMSNormPattern(epsilon, self.model_dtype, + self.device).register(self.patterns) + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + def is_applicable_for_shape(self, shape: Optional[int]) -> bool: + tp_size = get_tensor_model_parallel_world_size() + return shape is not None and shape % tp_size == 0 + + def __call__(self, graph: fx.Graph): + self.begin() + self.dump_graph(graph, "before_sequence_parallelism_pass") + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_sequence_parallelism_pass") + self.end_and_log() diff --git a/compilation/torch25_custom_graph_pass.py b/compilation/torch25_custom_graph_pass.py new file mode 100644 index 0000000..cd39706 --- /dev/null +++ b/compilation/torch25_custom_graph_pass.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch + + +class Torch25CustomGraphPass(ABC): # noqa (redefinition) + """ + This class replaces CustomGraphPass from torch==2.6 when using torch<2.6. + It conforms to the 2.6 interface but also supports pickling, as that's what + the inductor code cache uses to determine the cache key before 2.6. + (in 2.6 and above, uuid() is used.) + + Subclasses can just "pretend" that uuid is used. + """ + + @abstractmethod + def __call__(self, graph: torch.fx.graph.Graph) -> None: + """ + Implementation of the custom pass. + """ + + @abstractmethod + def uuid(self) -> Optional[Any]: + """ + Return an ID to uniquely identify your custom pass implementation. + Return None to skip inductor code caching entirely. + """ + + def __getstate__(self): + """ + Pickling is used instead of uuid() in torch<2.6. Just return uuid() + to enable subclasses to only have to implement uuid. + """ + return self.uuid() + + def __setstate__(self, state): + raise ValueError("Cannot unpickle CustomGraphPass because pickling" + " is used for cache key uuid. Use torch>=2.6 with" + " native uuid support for custom passes.") diff --git a/compilation/vllm_inductor_pass.py b/compilation/vllm_inductor_pass.py new file mode 100644 index 0000000..3ccbf52 --- /dev/null +++ b/compilation/vllm_inductor_pass.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time + +import torch + +from vllm.config import PassConfig, VllmConfig +# yapf: disable +from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +from vllm.distributed import ( + get_tensor_model_parallel_world_size as get_tp_world_size) +from vllm.distributed import model_parallel_is_initialized as p_is_init +# yapf: enable +from vllm.logger import init_logger + +from .inductor_pass import InductorPass + +logger = init_logger(__name__) + + +class VllmInductorPass(InductorPass): + """ + An inductor pass with access to vLLM PassConfig. + It provides timing, logging, and dumping utilities. + """ + + def __init__(self, config: VllmConfig): + self.pass_config = config.compilation_config.pass_config + self.model_dtype = config.model_config.dtype if config.model_config \ + else None + self.device = config.device_config.device if config.device_config \ + else None + self.pass_name = self.__class__.__name__ + + def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): + if stage in self.pass_config.dump_graph_stages or always: + # Make sure filename includes rank in the distributed setting + parallel = p_is_init() and get_tp_world_size() > 1 + rank = f"-{get_tp_rank()}" if parallel else "" + filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py" + + logger.info("%s printing graph to %s", self.pass_name, filepath) + with open(filepath, "w") as f: + src = graph.python_code(root_module="self", verbose=True).src + # Add imports so it's not full of errors + print("import torch; from torch import device", file=f) + print(src, file=f) + + def begin(self): + self._start_time = time.perf_counter_ns() + + def end_and_log(self): + self._end_time = time.perf_counter_ns() + duration_ms = float(self._end_time - self._start_time) / 1.0e6 + logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) + + +class PrinterInductorPass(VllmInductorPass): + + def __init__(self, name: str, config: PassConfig, always=False): + super().__init__(config) + self.name = name + self.always = always + + def __call__(self, graph: torch.fx.Graph): + self.dump_graph(graph, self.name, always=self.always) diff --git a/compilation/wrapper.py b/compilation/wrapper.py new file mode 100644 index 0000000..2a261c8 --- /dev/null +++ b/compilation/wrapper.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import sys +from abc import abstractmethod +from contextlib import contextmanager +from types import CodeType +from typing import Callable, Optional + +import torch + +import vllm.envs as envs +from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class TorchCompileWrapperWithCustomDispatcher: + """ + A wrapper class for torch.compile, with a custom dispatch logic. + Subclasses should: + 1. Implement the forward method + 2. Implement the dispatch logic in the __call__ method + It can use `self.compiled_codes` to access the compiled bytecode, + and `with self.dispatch_to_code(index):` to dispatch to + the compiled code. + 3. Implement the `__init__` method to determine how to call + `torch.compile` over the forward method. + """ + + def __init__(self, + compiled_callable: Optional[Callable] = None, + compilation_level: int = 0): + + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config + if compiled_callable is None: + # default compilation settings + # compiling the forward method + + backend = vllm_config.compilation_config.init_backend(vllm_config) + options = None + if isinstance(backend, str) and backend == "inductor": + options = get_current_vllm_config( + ).compilation_config.inductor_compile_config + + compiled_callable = torch.compile( + self.forward, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend, + options=options) + + self.compiled_callable = compiled_callable + self.original_code_object = self.__class__.forward.__code__ + self.compiled_codes: list[CodeType] = [] + torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + + # read the env var to determine whether to use the custom dispatcher + # subclasses can use this to switch between the custom dispatcher + # and the default Dynamo guard mechanism. + self.use_custom_dispatcher: bool = \ + compilation_level >= CompilationLevel.DYNAMO_ONCE + + def __call__(self, *args, **kwargs): + """Implement the dispatch logic here, beyond the torch.compile level. + NOTE: this function can have additional arguments beyond the forward + method, for directly dispatching to the compiled code. + """ + return self.compiled_callable(*args, **kwargs) + + @abstractmethod + def forward(self, *args, **kwargs): + ... + + def bytecode_hook(self, old_code: CodeType, new_code: CodeType): + """Hook to save the compiled bytecode for direct execution.""" + if old_code is not self.original_code_object: + return + # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 + frame = sys._getframe() + while frame and frame.f_back: + frame = frame.f_back + code_name = frame.f_code.co_name + file_name = frame.f_code.co_filename.split(os.path.sep)[-1] + if code_name == "_compile" and file_name == "convert_frame.py": + break + frame = frame.f_locals["frame"] + assert frame.f_code == old_code + + if frame.f_locals["self"] is not self: + return + + self.compiled_codes.append(new_code) + local_cache_dir = self.vllm_config.compilation_config.local_cache_dir + if isinstance(local_cache_dir, str): + decompiled_file = os.path.join(local_cache_dir, + "transformed_code.py") + if not os.path.exists(decompiled_file): + try: + # usually the decompilation will succeed for most models, + # as we guarantee a full-graph compilation in Dynamo. + # but there's no 100% guarantee, since decompliation is + # not a reversible process. + import depyf + src = depyf.decompile(new_code) + with open(decompiled_file, "w") as f: + f.write(src) + + logger.debug("Dynamo transformed code saved to %s", + decompiled_file) + except Exception: + pass + + if self.vllm_config.compilation_config.use_cudagraph and \ + "update" in new_code.co_names: + import depyf + src = depyf.decompile(new_code) + msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa + raise RuntimeError(msg) + + @contextmanager + def dispatch_to_code(self, index: int): + """Context manager to dispatch to the compiled code. + Why does this work? Because Dynamo guarantees that the compiled + bytecode has exactly the same arguments, cell variables, and free + variables as the original code. Therefore we can directly switch + the code object in the function and call it. + + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. + """ # noqa + self.__class__.forward.__code__ = self.compiled_codes[index] + yield + self.__class__.forward.__code__ = self.original_code_object diff --git a/config.py b/config.py new file mode 100644 index 0000000..aa13f22 --- /dev/null +++ b/config.py @@ -0,0 +1,4756 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast +import copy +import enum +import hashlib +import inspect +import json +import textwrap +import uuid +import warnings +from collections import Counter +from contextlib import contextmanager +from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, + replace) +from functools import cached_property +from importlib.util import find_spec +from pathlib import Path +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, + Protocol, TypeVar, Union, cast, get_args, get_origin) + +import regex as re +import torch +from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator, + model_validator) +from pydantic.dataclasses import dataclass +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE +from torch.distributed import ProcessGroup, ReduceOp +from transformers import PretrainedConfig +from typing_extensions import deprecated, runtime_checkable + +import vllm.envs as envs +from vllm import version +from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, + QuantizationMethods, + get_quantization_config) +from vllm.model_executor.models import ModelRegistry +from vllm.platforms import current_platform +from vllm.tracing import is_otel_available, otel_import_error_traceback +from vllm.transformers_utils.config import ( + ConfigFormat, get_config, get_hf_image_processor_config, + get_hf_text_config, get_pooling_config, + get_sentence_transformer_tokenizer_config, is_encoder_decoder, + try_get_generation_config, try_get_safetensors_metadata, + try_get_tokenizer_config, uses_mrope) +from vllm.transformers_utils.s3_utils import S3Model +from vllm.transformers_utils.utils import is_s3, maybe_model_redirect +from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes, + LayerBlockType, common_broadcastable_dtype, + cuda_device_count_stateless, get_cpu_memory, + get_open_port, is_torch_equal_or_newer, random_uuid, + resolve_obj_by_qualname) + +if TYPE_CHECKING: + from _typeshed import DataclassInstance + from ray.util.placement_group import PlacementGroup + + from vllm.executor.executor_base import ExecutorBase + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + from vllm.model_executor.model_loader import BaseModelLoader + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + + ConfigType = type[DataclassInstance] +else: + PlacementGroup = Any + ExecutorBase = Any + QuantizationConfig = Any + BaseModelLoader = Any + TensorizerConfig = Any + ConfigType = type + +logger = init_logger(__name__) + +ConfigT = TypeVar("ConfigT", bound=ConfigType) + +TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", + "score", "reward", "transcription"] + +_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward", + "draft", "transcription"] + +RunnerType = Literal["generate", "pooling", "draft", "transcription"] + +_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = { + "generate": ["generate"], + "pooling": ["embed", "classify", "score", "reward"], + "draft": ["draft"], + "transcription": ["transcription"], +} + +_TASK_RUNNER: dict[_ResolvedTask, RunnerType] = { + task: runner + for runner, tasks in _RUNNER_TASKS.items() + for task in tasks +} + +HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig], + PretrainedConfig]] + + +@runtime_checkable +class SupportsHash(Protocol): + + def compute_hash(self) -> str: + ... + + +class SupportsMetricsInfo(Protocol): + + def metrics_info(self) -> dict[str, str]: + ... + + +class ModelImpl(str, enum.Enum): + AUTO = "auto" + VLLM = "vllm" + TRANSFORMERS = "transformers" + + +def get_attr_docs(cls: type[Any]) -> dict[str, str]: + """ + Get any docstrings placed after attribute assignments in a class body. + + https://davidism.com/mit-license/ + """ + + def pairwise(iterable): + """ + Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise + + Can be removed when Python 3.9 support is dropped. + """ + iterator = iter(iterable) + a = next(iterator, None) + + for b in iterator: + yield a, b + a = b + + cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] + + if not isinstance(cls_node, ast.ClassDef): + raise TypeError("Given object was not a class.") + + out = {} + + # Consider each pair of nodes. + for a, b in pairwise(cls_node.body): + # Must be an assignment then a constant string. + if (not isinstance(a, (ast.Assign, ast.AnnAssign)) + or not isinstance(b, ast.Expr) + or not isinstance(b.value, ast.Constant) + or not isinstance(b.value.value, str)): + continue + + doc = inspect.cleandoc(b.value.value) + + # An assignment can have multiple targets (a = b = v), but an + # annotated assignment only has one target. + targets = a.targets if isinstance(a, ast.Assign) else [a.target] + + for target in targets: + # Must be assigning to a plain name. + if not isinstance(target, ast.Name): + continue + + out[target.id] = doc + + return out + + +def config(cls: ConfigT) -> ConfigT: + """ + A decorator that ensures all fields in a dataclass have default values + and that each field has a docstring. + + If a `ConfigT` is used as a CLI argument itself, the default value provided + by `get_kwargs` will be the result parsing a JSON string as the kwargs + (i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT` + requires custom construction from CLI (i.e. `CompilationConfig`), it can + have a `from_cli` method, which will be called instead. + """ + if not is_dataclass(cls): + raise TypeError("The decorated class must be a dataclass.") + attr_docs = get_attr_docs(cls) + for f in fields(cls): + if f.init and f.default is MISSING and f.default_factory is MISSING: + raise ValueError( + f"Field '{f.name}' in {cls.__name__} must have a default value." + ) + + if f.name not in attr_docs: + raise ValueError( + f"Field '{f.name}' in {cls.__name__} must have a docstring.") + + if get_origin(f.type) is Union: + args = get_args(f.type) + literal_args = [arg for arg in args if get_origin(arg) is Literal] + if len(literal_args) > 1: + raise ValueError( + f"Field '{f.name}' in {cls.__name__} must use a single " + "Literal type. Please use 'Literal[Literal1, Literal2]' " + "instead of 'Union[Literal1, Literal2]'.") + return cls + + +def get_field(cls: ConfigType, name: str) -> Field: + """Get the default factory field of a dataclass by name. Used for getting + default factory fields in `EngineArgs`.""" + if not is_dataclass(cls): + raise TypeError("The given class is not a dataclass.") + cls_fields = {f.name: f for f in fields(cls)} + if name not in cls_fields: + raise ValueError(f"Field '{name}' not found in {cls.__name__}.") + named_field: Field = cls_fields[name] + if (default_factory := named_field.default_factory) is not MISSING: + return field(default_factory=default_factory) + if (default := named_field.default) is not MISSING: + return field(default=default) + raise ValueError( + f"{cls.__name__}.{name} must have a default value or default factory.") + + +def is_init_field(cls: ConfigType, name: str) -> bool: + return next(f for f in fields(cls) if f.name == name).init + + +TokenizerMode = Literal["auto", "slow", "mistral", "custom"] +ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class ModelConfig: + """Configuration for the model.""" + + model: str = "facebook/opt-125m" + """Name or path of the Hugging Face model to use. It is also used as the + content for `model_name` tag in metrics output when `served_model_name` is + not specified.""" + task: Literal[TaskOption, Literal["draft"]] = "auto" + """The task to use the model for. Each vLLM instance only supports one + task, even if the same model can be used for multiple tasks. When the model + only supports one task, "auto" can be used to select it; otherwise, you + must specify explicitly which task to use.""" + tokenizer: SkipValidation[str] = None # type: ignore + """Name or path of the Hugging Face tokenizer to use. If unspecified, model + name or path will be used.""" + tokenizer_mode: TokenizerMode = "auto" + """Tokenizer mode:\n + - "auto" will use the fast tokenizer if available.\n + - "slow" will always use the slow tokenizer.\n + - "mistral" will always use the tokenizer from `mistral_common`.\n + - "custom" will use --tokenizer to select the preregistered tokenizer.""" + trust_remote_code: bool = False + """Trust remote code (e.g., from HuggingFace) when downloading the model + and tokenizer.""" + dtype: Union[ModelDType, torch.dtype] = "auto" + """Data type for model weights and activations:\n + - "auto" will use FP16 precision for FP32 and FP16 models, and BF16 + precision for BF16 models.\n + - "half" for FP16. Recommended for AWQ quantization.\n + - "float16" is the same as "half".\n + - "bfloat16" for a balance between precision and range.\n + - "float" is shorthand for FP32 precision.\n + - "float32" for FP32 precision.""" + seed: Optional[int] = None + """Random seed for reproducibility. Initialized to None in V0, but + initialized to 0 in V1.""" + hf_config_path: Optional[str] = None + """Name or path of the Hugging Face config to use. If unspecified, model + name or path will be used.""" + allowed_local_media_path: str = "" + """Allowing API requests to read local images or videos from directories + specified by the server file system. This is a security risk. Should only + be enabled in trusted environments.""" + revision: Optional[str] = None + """The specific model version to use. It can be a branch name, a tag name, + or a commit id. If unspecified, will use the default version.""" + code_revision: Optional[str] = None + """The specific revision to use for the model code on the Hugging Face Hub. + It can be a branch name, a tag name, or a commit id. If unspecified, will + use the default version.""" + rope_scaling: dict[str, Any] = field(default_factory=dict) + """RoPE scaling configuration. For example, + `{"rope_type":"dynamic","factor":2.0}`.""" + rope_theta: Optional[float] = None + """RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE + theta improves the performance of the scaled model.""" + tokenizer_revision: Optional[str] = None + """The specific revision to use for the tokenizer on the Hugging Face Hub. + It can be a branch name, a tag name, or a commit id. If unspecified, will + use the default version.""" + max_model_len: SkipValidation[int] = None # type: ignore + """Model context length (prompt and output). If unspecified, will be + automatically derived from the model config. + + When passing via `--max-model-len`, supports k/m/g/K/M/G in human-readable + format. Examples:\n + - 1k -> 1000\n + - 1K -> 1024\n + - 25.6k -> 25,600""" + spec_target_max_model_len: Optional[int] = None + """Specify the maximum length for spec decoding draft models.""" + quantization: SkipValidation[Optional[QuantizationMethods]] = None + """Method used to quantize the weights. If `None`, we first check the + `quantization_config` attribute in the model config file. If that is + `None`, we assume the model weights are not quantized and use `dtype` to + determine the data type of the weights.""" + enforce_eager: bool = False + """Whether to always use eager-mode PyTorch. If True, we will disable CUDA + graph and always execute the model in eager mode. If False, we will use + CUDA graph and eager execution in hybrid for maximal performance and + flexibility.""" + max_seq_len_to_capture: int = 8192 + """Maximum sequence len covered by CUDA graphs. When a sequence has context + length larger than this, we fall back to eager mode. Additionally for + encoder-decoder models, if the sequence length of the encoder input is + larger than this, we fall back to the eager mode.""" + max_logprobs: int = 20 + """Maximum number of log probabilities to return when `logprobs` is + specified in `SamplingParams`. The default value comes the default for the + OpenAI Chat Completions API.""" + disable_sliding_window: bool = False + """Whether to disable sliding window. If True, we will disable the sliding + window functionality of the model, capping to sliding window size. If the + model does not support sliding window, this argument is ignored.""" + disable_cascade_attn: bool = True + """Disable cascade attention for V1. While cascade attention does not + change the mathematical correctness, disabling it could be useful for + preventing potential numerical issues. Note that even if this is set to + False, cascade attention will be only used when the heuristic tells that + it's beneficial.""" + skip_tokenizer_init: bool = False + """Skip initialization of tokenizer and detokenizer. Expects valid + `prompt_token_ids` and `None` for prompt from the input. The generated + output will contain token ids.""" + enable_prompt_embeds: bool = False + """If `True`, enables passing text embeddings as inputs via the + `prompt_embeds` key. Note that enabling this will double the time required + for graph compilation.""" + served_model_name: Optional[Union[str, list[str]]] = None + """The model name(s) used in the API. If multiple names are provided, the + server will respond to any of the provided names. The model name in the + model field of a response will be the first name in this list. If not + specified, the model name will be the same as the `--model` argument. Noted + that this name(s) will also be used in `model_name` tag content of + prometheus metrics, if multiple names provided, metrics tag will take the + first one.""" + limit_mm_per_prompt: dict[str, int] = field(default_factory=dict) + """Maximum number of data items per modality per prompt. Only applicable + for multimodal models.""" + use_async_output_proc: bool = True + """Whether to use async output processor.""" + config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value + """The format of the model config to load:\n + - "auto" will try to load the config in hf format if available else it + will try to load in mistral format.\n + - "hf" will load the config in hf format.\n + - "mistral" will load the config in mistral format.""" + hf_token: Optional[Union[bool, str]] = None + """The token to use as HTTP bearer authorization for remote files . If + `True`, will use the token generated when running `huggingface-cli login` + (stored in `~/.huggingface`).""" + hf_overrides: HfOverrides = field(default_factory=dict) + """If a dictionary, contains arguments to be forwarded to the Hugging Face + config. If a callable, it is called to update the HuggingFace config.""" + mm_processor_kwargs: Optional[dict[str, Any]] = None + """Arguments to be forwarded to the model's processor for multi-modal data, + e.g., image processor. Overrides for the multi-modal processor obtained + from `AutoProcessor.from_pretrained`. The available overrides depend on the + model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. + """ + disable_mm_preprocessor_cache: bool = False + """If `True`, disable caching of the multi-modal preprocessor/mapper (not + recommended).""" + override_neuron_config: dict[str, Any] = field(default_factory=dict) + """Initialize non-default neuron config or override default neuron config + that are specific to Neuron devices, this argument will be used to + configure the neuron config that can not be gathered from the vllm + arguments. e.g. `{"cast_logits_dtype": "bfloat16"}`.""" + pooler_config: Optional["PoolerConfig"] = field(init=False) + """Pooler config which controls the behaviour of output pooling in pooling + models.""" + override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None + """Initialize non-default pooling config or override default pooling config + for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`. + """ + logits_processor_pattern: Optional[str] = None + """Optional regex pattern specifying valid logits processor qualified names + that can be passed with the `logits_processors` extra completion argument. + Defaults to `None`, which allows no processors.""" + generation_config: str = "auto" + """The folder path to the generation config. Defaults to `"auto"`, the + generation config will be loaded from model path. If set to `"vllm"`, no + generation config is loaded, vLLM defaults will be used. If set to a folder + path, the generation config will be loaded from the specified folder path. + If `max_new_tokens` is specified in generation config, then it sets a + server-wide limit on the number of output tokens for all requests.""" + override_generation_config: dict[str, Any] = field(default_factory=dict) + """Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If + used with `--generation-config auto`, the override parameters will be + merged with the default config from the model. If used with + `--generation-config vllm`, only the override parameters are used.""" + enable_sleep_mode: bool = False + """Enable sleep mode for the engine (only cuda platform is supported).""" + model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value + """Which implementation of the model to use:\n + - "auto" will try to use the vLLM implementation, if it exists, and fall + back to the Transformers implementation if no vLLM implementation is + available.\n + - "vllm" will use the vLLM model implementation.\n + - "transformers" will use the Transformers model implementation.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.model) + factors.append(self.dtype) + factors.append(self.quantization) + factors.append(self.revision) + factors.append(self.code_revision) + factors.append(self.max_model_len) + factors.append(self.max_logprobs) + factors.append(self.disable_sliding_window) + factors.append(self.trust_remote_code) + factors.append(self.generation_config) + factors.append(self.model_impl) + factors.append(self.override_generation_config) + factors.append(self.rope_scaling) + factors.append(self.rope_theta) + # hf_config can control how the model looks! + factors.append(self.hf_config.to_json_string()) + str_factors = str(factors) + assert_hashable(str_factors) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def __post_init__(self) -> None: + # Set the default seed to 0 in V1. + # NOTE(woosuk): In V0, we set the default seed to None because the + # driver worker shares the same process as the user process, and thus + # setting a seed affects the user process as well. + # In V1, we use separate processes for workers (unless + # VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here + # doesn't affect the user process. However, without a consistent seed, + # different tensor parallel workers would sample different tokens, + # leading to inconsistent results. + if envs.VLLM_USE_V1 and self.seed is None: + self.seed = 0 + if not envs.VLLM_ENABLE_V1_MULTIPROCESSING: + logger.warning( + "The global random seed is set to %d. Since " + "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may " + "affect the random state of the Python process that " + "launched vLLM.", self.seed) + + self.model = maybe_model_redirect(self.model) + # The tokenizer is consistent with the model by default. + if self.tokenizer is None: + self.tokenizer = self.model + if self.tokenizer_revision is None: + self.tokenizer_revision = self.revision + self.tokenizer = maybe_model_redirect(self.tokenizer) + + if isinstance(self.hf_config_path, str): + self.hf_config_path = maybe_model_redirect(self.hf_config_path) + + if callable(self.hf_overrides): + hf_overrides_kw = {} + hf_overrides_fn = self.hf_overrides + else: + hf_overrides_kw = self.hf_overrides + hf_overrides_fn = None + + if self.rope_scaling: + hf_override: dict[str, Any] = {"rope_scaling": self.rope_scaling} + hf_overrides_kw.update(hf_override) + hf_overrides_str = json.dumps(hf_overrides_kw) + msg = ( + "`--rope-scaling` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") + warnings.warn(DeprecationWarning(msg), stacklevel=2) + if self.rope_theta is not None: + hf_override = {"rope_theta": self.rope_theta} + hf_overrides_kw.update(hf_override) + hf_overrides_str = json.dumps(hf_overrides_kw) + msg = ( + "`--rope-theta` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer) + + if (backend := envs.VLLM_ATTENTION_BACKEND + ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: + raise ValueError( + "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " + "module was not found. See " + "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 + "for instructions on how to install it.") + + from vllm.platforms import current_platform + + if (self.enable_sleep_mode + and not current_platform.is_sleep_mode_available()): + raise ValueError( + "Sleep mode is not supported on current platform.") + + if isinstance(self.config_format, str): + self.config_format = ConfigFormat(self.config_format) + + hf_config = get_config(self.hf_config_path or self.model, + self.trust_remote_code, self.revision, + self.code_revision, self.config_format) + + if hf_overrides_kw: + logger.info("Overriding HF config with %s", hf_overrides_kw) + hf_config.update(hf_overrides_kw) + if hf_overrides_fn: + logger.info("Overriding HF config with %s", hf_overrides_fn) + hf_config = hf_overrides_fn(hf_config) + + self.hf_config = hf_config + + self.hf_text_config = get_hf_text_config(self.hf_config) + self.attention_chunk_size = getattr(self.hf_text_config, + "attention_chunk_size", None) + self.encoder_config = self._get_encoder_config() + self.hf_image_processor_config = get_hf_image_processor_config( + self.model, hf_token=self.hf_token, revision=self.revision) + + supported_tasks, task = self._resolve_task(self.task) + self.supported_tasks = supported_tasks + self.task = task + if self.task in ("draft", "generate"): + self.truncation_side = "left" + else: + self.truncation_side = "right" + + self.pooler_config = self._init_pooler_config() + + self.dtype = _get_and_verify_dtype( + self.model, + self.hf_config, + self.dtype, + is_pooling_model=self.runner_type == "pooling", + revision=self.revision, + ) + + # Workaround for Gemma 2 which uses interleaved sliding window + # attention, but it's not specified in its config. TODO: remove this + # when Gemma 2 is fixed in Transformers. + if self.hf_text_config.model_type == "gemma2": + self.hf_text_config.sliding_window_pattern = 2 + + sliding_window = getattr(self.hf_text_config, "sliding_window", None) + sliding_window_pattern = getattr(self.hf_text_config, + "sliding_window_pattern", None) + has_interleaved_attention = sliding_window_pattern is not None or ( + isinstance(sliding_window, list)) + + if not self.disable_sliding_window and has_interleaved_attention: + if (backend := + envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"): + sliding_window_len_min = get_min_sliding_window( + self.hf_text_config.sliding_window) + + logger.warning_once( + "%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501 + self.hf_text_config.model_type, + backend, + sliding_window_len_min, + ) + self.disable_sliding_window = True + else: + # for a model with interleaved attention, + # the scheduler and the model treat it as full attention + # (i.e., not dropping any tokens outside the window). + # only the attention layer itself is aware of the sliding + # window, and use the window size to compute the attention. + self.hf_text_config.interleaved_sliding_window = sliding_window + + if hasattr(self.hf_text_config, "sliding_window"): + delattr(self.hf_text_config, "sliding_window") + + sliding_window = None + + self.original_max_model_len = self.max_model_len + self.max_model_len = self.get_and_verify_max_len(self.max_model_len) + self.served_model_name = get_served_model_name(self.model, + self.served_model_name) + self.multimodal_config = self._init_multimodal_config() + if not self.skip_tokenizer_init: + self._verify_tokenizer_mode() + + self.is_attention_free = self._init_attention_free() + self.is_hybrid = self._init_is_hybrid() + self.has_noops = self._init_has_noops() + self.has_inner_state = self._init_has_inner_state() + + if (not current_platform.is_neuron() and self.override_neuron_config): + raise ValueError( + "`override_neuron_config` is only supported on Neuron.") + + self._verify_quantization() + self._verify_cuda_graph() + self._verify_bnb_config() + + @field_validator("quantization", mode="before") + @classmethod + def validate_quantization_before(cls, value: Any) -> Any: + if isinstance(value, str): + return value.lower() + return value + + @model_validator(mode="after") + def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": + if not isinstance(self.tokenizer, str): + raise ValueError("tokenizer must be a string after __post_init__.") + if not isinstance(self.max_model_len, int): + raise ValueError( + "max_model_len must be an integer after __post_init__.") + return self + + @property + def registry(self): + return ModelRegistry + + @property + def architectures(self) -> list[str]: + return getattr(self.hf_config, "architectures", []) + + def maybe_pull_model_tokenizer_for_s3(self, model: str, + tokenizer: str) -> None: + """Pull model/tokenizer from S3 to temporary directory when needed. + + Args: + model: Model name or path + tokenizer: Tokenizer name or path + """ + if not (is_s3(model) or is_s3(tokenizer)): + return + + if is_s3(model): + s3_model = S3Model() + s3_model.pull_files(model, + allow_pattern=["*.model", "*.py", "*.json"]) + self.model_weights = model + self.model = s3_model.dir + + # If tokenizer is same as model, download to same directory + if model == tokenizer: + s3_model.pull_files( + model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + self.tokenizer = s3_model.dir + return + + # Only download tokenizer if needed and not already handled + if is_s3(tokenizer): + s3_tokenizer = S3Model() + s3_tokenizer.pull_files( + model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + self.tokenizer = s3_tokenizer.dir + + def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: + if self.registry.is_multimodal_model(self.architectures): + return MultiModalConfig( + limit_per_prompt=self.limit_mm_per_prompt, + mm_processor_kwargs=self.mm_processor_kwargs, + disable_mm_preprocessor_cache=self. + disable_mm_preprocessor_cache) + + if self.limit_mm_per_prompt: + raise ValueError("`limit_mm_per_prompt` is only supported for " + "multimodal models.") + if self.mm_processor_kwargs: + raise ValueError("`mm_processor_kwargs` is only supported for " + "multimodal models.") + if self.disable_mm_preprocessor_cache: + raise ValueError("`disable_mm_preprocessor_cache` is only " + "supported for multimodal models.") + + return None + + def _get_encoder_config(self): + return get_sentence_transformer_tokenizer_config( + self.model, self.revision) + + def _init_pooler_config(self) -> Optional["PoolerConfig"]: + if self.runner_type == "pooling": + if isinstance(self.override_pooler_config, dict): + self.override_pooler_config = PoolerConfig( + **self.override_pooler_config) + + pooler_config = self.override_pooler_config or PoolerConfig() + + base_config = get_pooling_config(self.model, self.revision) + if base_config is not None: + # Only set values that are not overridden by the user + for k, v in base_config.items(): + if getattr(pooler_config, k) is None: + setattr(pooler_config, k, v) + + if self.is_matryoshka: + if pooler_config.normalize is None: + pooler_config.normalize = True + elif not pooler_config.normalize: + raise ValueError( + "`normalize` must be enabled (set to True) " + "for models that are compatible with " + "Matryoshka Representation.") + + return pooler_config + + return None + + def _init_attention_free(self) -> bool: + return self.registry.is_attention_free_model(self.architectures) + + def _init_is_hybrid(self) -> bool: + return self.registry.is_hybrid_model(self.architectures) + + def _init_has_noops(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return self.registry.is_noops_model(architectures) + + def _init_has_inner_state(self) -> bool: + return self.registry.model_has_inner_state(self.architectures) + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower()) + if tokenizer_mode not in get_args(TokenizerMode): + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + f"one of {get_args(TokenizerMode)}.") + self.tokenizer_mode = tokenizer_mode + + def _get_preferred_task( + self, + architectures: list[str], + supported_tasks: set[_ResolvedTask], + ) -> Optional[_ResolvedTask]: + model_id = self.model + if get_pooling_config(model_id, self.revision): + return "embed" + if self.registry.is_cross_encoder_model(architectures): + return "score" + if self.registry.is_transcription_model(architectures): + return "transcription" + + suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [ + # Other models follow this pattern + ("ForCausalLM", "generate"), + ("ForConditionalGeneration", "generate"), + ("ForSequenceClassification", "classify"), + ("ChatModel", "generate"), + ("LMHeadModel", "generate"), + ("EmbeddingModel", "embed"), + ("RewardModel", "reward"), + ] + _, arch = self.registry.inspect_model_cls(architectures) + + for suffix, pref_task in suffix_to_preferred_task: + if arch.endswith(suffix) and pref_task in supported_tasks: + return pref_task + + return None + + def _resolve_task( + self, + task_option: Literal[TaskOption, Literal["draft"]], + ) -> tuple[set[_ResolvedTask], _ResolvedTask]: + if task_option == "draft": + return {"draft"}, "draft" + + registry = self.registry + architectures = self.architectures + + runner_support: dict[RunnerType, bool] = { + # NOTE: Listed from highest to lowest priority, + # in case the model supports multiple of them + "transcription": registry.is_transcription_model(architectures), + "generate": registry.is_text_generation_model(architectures), + "pooling": registry.is_pooling_model(architectures), + } + supported_runner_types_lst: list[RunnerType] = [ + runner_type + for runner_type, is_supported in runner_support.items() + if is_supported + ] + + supported_tasks_lst: list[_ResolvedTask] = [ + task for runner_type in supported_runner_types_lst + for task in _RUNNER_TASKS[runner_type] + ] + supported_tasks = set(supported_tasks_lst) + + if task_option == "auto": + selected_task = next(iter(supported_tasks_lst)) + + if len(supported_tasks_lst) > 1: + preferred_task = self._get_preferred_task( + architectures, supported_tasks) + if preferred_task is not None: + selected_task = preferred_task + + logger.info( + "This model supports multiple tasks: %s. " + "Defaulting to '%s'.", supported_tasks, selected_task) + else: + # Aliases + if task_option == "embedding": + msg = ("The 'embedding' task has been renamed to " + "'embed', please use the new name. The old name " + "will be removed in v1.0.") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + task_option = "embed" + + if task_option not in supported_tasks: + msg = ( + f"This model does not support the '{task_option}' task. " + f"Supported tasks: {supported_tasks}") + raise ValueError(msg) + + selected_task = task_option + + return supported_tasks, selected_task + + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(self.hf_config, "compression_config", None) + return quant_cfg + + def _verify_quantization(self) -> None: + supported_quantization = QUANTIZATION_METHODS + optimized_quantization_methods = [ + "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", + "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", + "quark", "modelopt_fp4", "bitblas", "gptq_bitblas" + ] + if self.quantization is not None: + self.quantization = cast(QuantizationMethods, self.quantization) + + # Parse quantization method from the HF model config, if available. + quant_cfg = self._parse_quant_hf_config() + + if quant_cfg is not None: + quant_method = quant_cfg.get("quant_method", "").lower() + quant_method = quant_method.replace("compressed_tensors", + "compressed-tensors") + quant_cfg["quant_method"] = quant_method + + # Quantization methods which are overrides (i.e. they have a + # `override_quantization_method` method) must be checked in order + # of preference (this is particularly important for GPTQ). + overrides = [ + "marlin", + "bitblas", + "gptq_marlin_24", + "gptq_marlin", + "gptq_bitblas", + "awq_marlin", + "ipex", + "moe_wna16", + ] + quantization_methods = [ + q for q in supported_quantization if q not in overrides + ] + # Any custom overrides will be in quantization_methods so we place + # them at the start of the list so custom overrides have preference + # over the built in ones. + quantization_methods = quantization_methods + overrides + + # Detect which checkpoint is it + for name in quantization_methods: + method = get_quantization_config(name) + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization) + if quantization_override is not None: + # Raise error if the override is not custom (custom would + # be in QUANTIZATION_METHODS but not QuantizationMethods) + # and hasn't been added to the overrides list. + # if (name in get_args(QuantizationMethods) + # and name not in overrides): + # raise ValueError( + # f"Quantization method {name} is an override but " + # "is has not been added to the `overrides` list " + # "above. This is necessary to ensure that the " + # "overrides are checked in order of preference.") + quant_method = quantization_override + self.quantization = quantization_override + break + + # Verify quantization configurations. + if self.quantization is None: + self.quantization = quant_method + elif self.quantization != quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization}).") + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}.") + from vllm.platforms import current_platform + current_platform.verify_quantization(self.quantization) + if self.quantization not in optimized_quantization_methods: + logger.warning( + "%s quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.", self.quantization) + + def _verify_cuda_graph(self) -> None: + self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, + self.max_model_len) + # CUDAGraph capture not supported for enc-dec models and mllama on ROCm + ROCM_UNSUPPORTED_MODELS = ['mllama'] + unsupported_rocm = (self.hf_config.model_type + in ROCM_UNSUPPORTED_MODELS + or self.is_encoder_decoder) + + if (unsupported_rocm and not self.enforce_eager + and current_platform.is_rocm()): + logger.warning( + "CUDA graph is not supported for %s on ROCm yet, fallback " + "to eager mode.", self.hf_config.model_type) + self.enforce_eager = True + + def _verify_bnb_config(self) -> None: + """ + The current version of bitsandbytes (0.45.3) with 8-bit models does not + yet support CUDA graph. + # TODO Remove this when bitsandbytes supports. + """ + is_bitsandbytes = self.quantization == "bitsandbytes" + has_quantization_config = (getattr(self.hf_config, + "quantization_config", None) + is not None) + is_8bit = (self.hf_config.quantization_config.get( + "load_in_8bit", False) if has_quantization_config else False) + if all([ + is_bitsandbytes, + has_quantization_config, + is_8bit, + not self.enforce_eager, + ]): + logger.warning( + "CUDA graph is not supported on BitsAndBytes 8bit yet, " + "fallback to the eager mode.") + + self.enforce_eager = True + + def _verify_with_expert_parallelism(self) -> None: + num_expert_names = [ + "moe_num_experts", # Dbrx + "num_experts", # Jamba + "n_routed_experts", # DeepSeek + "num_local_experts", # Mixtral + ] + num_experts = 0 + for name in num_expert_names: + num_experts = getattr(self.hf_text_config, name, 0) + if num_experts > 0: + break + if num_experts < 1: + raise ValueError( + "Number of experts in the model must be greater than 0 " + "when expert parallelism is enabled.") + + def verify_dual_chunk_attention_config( + self, + load_config: "LoadConfig", + ) -> None: + if hasattr(self.hf_config, "dual_chunk_attention_config"): + # Try loading the sparse attention config + from vllm.model_executor.model_loader.weight_utils import ( + get_sparse_attention_config) + sparse_attn_config = get_sparse_attention_config(self, load_config) + if sparse_attn_config: + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_config"] = sparse_attn_config + if "sparse_attention_enabled" not in \ + self.hf_config.dual_chunk_attention_config: + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_enabled"] = True + + def verify_async_output_proc(self, parallel_config, speculative_config, + device_config) -> None: + if not self.use_async_output_proc: + # Nothing to check + return + + if parallel_config.pipeline_parallel_size > 1: + self.use_async_output_proc = False + return + + # Reminder: Please update docs/features/compatibility_matrix.md + # If the feature combo become valid + from vllm.platforms import current_platform + if not current_platform.is_async_output_supported(self.enforce_eager): + self.use_async_output_proc = False + return + + if envs.VLLM_USE_RAY_SPMD_WORKER: + self.use_async_output_proc = False + return + + # Async postprocessor is not necessary for pooling models + # since there is no token generation + if self.runner_type == "pooling": + self.use_async_output_proc = False + + # Reminder: Please update docs/features/compatibility_matrix.md + # If the feature combo become valid + if speculative_config: + self.use_async_output_proc = False + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + + if parallel_config.distributed_executor_backend == "external_launcher": + assert self.seed is not None, ( + "Seed must be set when using external launcher backend to " + "make sure sampling results are the same across workers.") + + total_num_attention_heads = getattr(self.hf_text_config, + "num_attention_heads", 0) + tensor_parallel_size = parallel_config.tensor_parallel_size + if total_num_attention_heads % tensor_parallel_size != 0: + raise ValueError( + f"Total number of attention heads ({total_num_attention_heads})" + " must be divisible by tensor parallel size " + f"({tensor_parallel_size}).") + + if parallel_config.enable_expert_parallel: + self._verify_with_expert_parallelism() + + pipeline_parallel_size = parallel_config.pipeline_parallel_size + if pipeline_parallel_size > 1: + if not self.registry.is_pp_supported_model(self.architectures): + raise NotImplementedError( + "Pipeline parallelism is not supported for this model. " + "Supported models implement the `SupportsPP` interface.") + + if self.use_async_output_proc: + self.use_async_output_proc = False + + def get_hf_config_sliding_window( + self) -> Union[Optional[int], list[Optional[int]]]: + """Get the sliding window size, or None if disabled.""" + + # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in + # addition to sliding window size. We check if that field is present + # and if it's False, return None. + if (hasattr(self.hf_text_config, "use_sliding_window") + and not self.hf_text_config.use_sliding_window): + return None + return getattr(self.hf_text_config, "sliding_window", None) + + def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: + """Get the sliding window size, or None if disabled. + """ + # If user disables sliding window, return None. + if self.disable_sliding_window: + return None + # Otherwise get the value from the hf config. + return self.get_hf_config_sliding_window() + + def get_vocab_size(self) -> int: + return self.hf_text_config.vocab_size + + def get_hidden_size(self) -> int: + return self.hf_text_config.hidden_size + + @property + def is_deepseek_mla(self) -> bool: + if not hasattr(self.hf_text_config, "model_type"): + return False + elif self.hf_text_config.model_type in \ + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'): + return self.hf_text_config.kv_lora_rank is not None + elif self.hf_text_config.model_type == 'eagle': + # if the model is an EAGLE module, check for the + # underlying architecture + return self.hf_text_config.model.model_type in \ + ('deepseek_v2', 'deepseek_v3') \ + and self.hf_text_config.kv_lora_rank is not None + return False + + def get_head_size(self) -> int: + # TODO remove hard code + if self.is_deepseek_mla: + qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", + 0) + if self.use_mla: + return self.hf_text_config.kv_lora_rank + qk_rope_head_dim + else: + qk_nope_head_dim = getattr(self.hf_text_config, + "qk_nope_head_dim", 0) + if qk_rope_head_dim and qk_nope_head_dim: + return qk_rope_head_dim + qk_nope_head_dim + + if hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + == "zamba2"): + return self.hf_text_config.attention_head_dim + + if self.is_attention_free: + return 0 + + # NOTE: Some configs may set head_dim=None in the config + if getattr(self.hf_text_config, "head_dim", None) is not None: + return self.hf_text_config.head_dim + + # FIXME(woosuk): This may not be true for all models. + return (self.hf_text_config.hidden_size // + self.hf_text_config.num_attention_heads) + + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False)) + if not new_decoder_arch_falcon and getattr(self.hf_text_config, + "multi_query", False): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + # For DBRX and MPT + if self.hf_config.model_type == "mpt": + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type == "dbrx": + return getattr(self.hf_config.attn_config, "kv_n_heads", + self.hf_config.num_attention_heads) + + if self.hf_config.model_type == "nemotron-nas": + for block in self.hf_config.block_configs: + if not block.attention.no_op: + return self.hf_config.num_attention_heads \ + // block.attention.n_heads_in_group + + raise RuntimeError("Couldn't determine number of kv heads") + + if self.is_attention_free: + return 0 + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_text_config.num_attention_heads + + def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: + """Returns the number of KV heads per GPU.""" + if self.use_mla: + # When using MLA during decode it becomes MQA + return 1 + + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, + total_num_kv_heads // parallel_config.tensor_parallel_size) + + def get_num_attention_heads(self, + parallel_config: "ParallelConfig") -> int: + num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) + return num_heads // parallel_config.tensor_parallel_size + + def get_layers_start_end_indices( + self, parallel_config: "ParallelConfig") -> tuple[int, int]: + from vllm.distributed.utils import get_pp_indices + if (self.hf_text_config.model_type == "deepseek_mtp" + or self.hf_config.model_type == "mimo_mtp"): + total_num_hidden_layers = getattr(self.hf_text_config, + "num_nextn_predict_layers", 0) + else: + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) + # the layout order is: DP x PP x TP + pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size + ) % parallel_config.pipeline_parallel_size + pp_size = parallel_config.pipeline_parallel_size + start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) + return start, end + + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + start, end = self.get_layers_start_end_indices(parallel_config) + return end - start + + def get_num_layers_by_block_type( + self, + parallel_config: "ParallelConfig", + block_type: LayerBlockType = LayerBlockType.attention, + ) -> int: + # This function relies on 'layers_block_type' in hf_config, + # for w/o this attribute, we will need to have workarounds like so + attn_block_type = block_type == LayerBlockType.attention + is_transformer = not self.is_hybrid and \ + not self.has_noops and \ + not self.is_attention_free + start, end = self.get_layers_start_end_indices(parallel_config) + + if is_transformer: + # Handle the basic case first + return end - start if attn_block_type else 0 + elif self.is_attention_free: + # Attention free + # Note that this code assumes there + # is only one type of attention-free block type. + return 0 if attn_block_type else end - start + elif self.has_noops: + block_configs = self.hf_config.block_configs + return sum(not bc.attention.no_op + for bc in block_configs[start:end]) + else: + # Hybrid model Jamba + layers_block_type_value = getattr(self.hf_config, + "layers_block_type", None) + if layers_block_type_value is not None: + if hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + == "zamba2"): + if attn_block_type: + return sum(t == "hybrid" + for t in layers_block_type_value[start:end]) + else: + return self.get_num_layers(parallel_config) + return sum(t == block_type.value + for t in layers_block_type_value[start:end]) + + # Hybrid model Minimax + attn_type_list = getattr(self.hf_config, "attn_type_list", None) + if attn_type_list: + return sum(t == 1 for t in attn_type_list[start:end]) + + if layers_block_type_value is None and attn_type_list is None: + raise ValueError( + "The model is an hybrid without a" + "layers_block_type or an attn_type_list in the hf_config," + "cannot determine the num of " + f"{block_type.value} layers") + + return sum(t == 1 for t in attn_type_list[start:end]) + + def get_multimodal_config(self) -> "MultiModalConfig": + """ + Get the multimodal configuration of the model. + + Raises: + ValueError: If the model is not multimodal. + """ + if self.multimodal_config is None: + raise ValueError("The model is not multimodal.") + + return self.multimodal_config + + def try_get_generation_config(self) -> dict[str, Any]: + if self.generation_config in ("auto", "vllm"): + config = try_get_generation_config( + self.hf_config_path or self.model, + trust_remote_code=self.trust_remote_code, + revision=self.revision, + ) + else: + config = try_get_generation_config( + self.generation_config, + trust_remote_code=self.trust_remote_code, + ) + + if config is None: + return {} + + return config.to_diff_dict() + + def get_diff_sampling_param(self) -> dict[str, Any]: + """ + This method returns a dictionary containing the parameters + that differ from the default sampling parameters. If + `generation_config` is `"vllm"`, an empty dictionary is returned. + + Returns: + dict[str, Any]: A dictionary with the differing sampling + parameters, if `generation_config` is `"vllm"` an empty dictionary. + """ + if self.generation_config == "vllm": + config = {} + else: + config = self.try_get_generation_config() + + # Overriding with given generation config + config.update(self.override_generation_config) + + available_params = [ + "repetition_penalty", + "temperature", + "top_k", + "top_p", + "min_p", + "max_new_tokens", + ] + if any(p in config for p in available_params): + diff_sampling_param = { + p: config.get(p) + for p in available_params if config.get(p) is not None + } + # Huggingface definition of max_new_tokens is equivalent + # to vLLM's max_tokens + if "max_new_tokens" in diff_sampling_param: + diff_sampling_param["max_tokens"] = diff_sampling_param.pop( + "max_new_tokens") + else: + diff_sampling_param = {} + + if diff_sampling_param: + logger.warning_once( + "Default sampling parameters have been overridden by the " + "model's Hugging Face generation config recommended from the " + "model creator. If this is not intended, please relaunch " + "vLLM instance with `--generation-config vllm`.") + return diff_sampling_param + + @property + def is_encoder_decoder(self) -> bool: + """Extract the HF encoder/decoder model flag.""" + """ + For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to + True to enable cross-attention + Neuron needs all multimodal data to be in the decoder and does not + need to explicitly enable cross-attention + """ + if (current_platform.is_neuron() + and self.hf_config.model_type == "mllama"): + return False + + return is_encoder_decoder(self.hf_config) + + @property + def uses_mrope(self) -> bool: + return uses_mrope(self.hf_config) + + @property + def is_multimodal_model(self) -> bool: + return self.multimodal_config is not None + + @property + def is_cross_encoder(self) -> bool: + return self.registry.is_cross_encoder_model(self.architectures) + + @property + def use_mla(self) -> bool: + return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE + + @property + def supported_runner_types(self) -> set[RunnerType]: + return {_TASK_RUNNER[task] for task in self.supported_tasks} + + @property + def runner_type(self) -> RunnerType: + return _TASK_RUNNER[cast(_ResolvedTask, self.task)] + + @property + def is_v1_compatible(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_v1_compatible(architectures) + + @property + def is_matryoshka(self) -> bool: + return (hasattr(self.hf_config, "matryoshka_dimensions") + or getattr(self.hf_config, "is_matryoshka", False)) + + @property + def matryoshka_dimensions(self): + return getattr(self.hf_config, "matryoshka_dimensions", None) + + def get_and_verify_max_len(self, max_model_len: int): + max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window(), + spec_target_max_model_len=self.spec_target_max_model_len, + encoder_config=self.encoder_config) + + # For pooling models, the tokenizer's `model_max_length` is often a + # reliable source for the maximum sequence length. However, for + # generative models, this can be incorrect and unduly limit the + # context window (e.g., DeepSeek-R1). Therefore, we only consider + # tokenizer_config for pooling models. + tokenizer_config = None + if self.runner_type == "pooling": + tokenizer_config = try_get_tokenizer_config( + self.tokenizer, + trust_remote_code=self.trust_remote_code, + revision=self.tokenizer_revision) + + if tokenizer_config is None: + return max_model_len + + model_max_length = tokenizer_config.get("model_max_length", + max_model_len) + max_model_len = min(max_model_len, model_max_length) + return max_model_len + + +BlockSize = Literal[1, 8, 16, 32, 64, 128] +CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] +PrefixCachingHashAlgo = Literal["builtin", "sha256"] + + +@config +@dataclass +class CacheConfig: + """Configuration for the KV cache.""" + + block_size: SkipValidation[BlockSize] = None # type: ignore + """Size of a contiguous cache block in number of tokens. This is ignored on + neuron devices and set to `--max-model-len`. On CUDA devices, only block + sizes up to 32 are supported. On HPU devices, block size defaults to 128. + + This config has no static default. If left unspecified by the user, it will + be set in `Platform.check_and_update_configs()` based on the current + platform.""" + gpu_memory_utilization: float = 0.9 + """The fraction of GPU memory to be used for the model executor, which can + range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory + utilization. If unspecified, will use the default value of 0.9. This is a + per-instance limit, and only applies to the current vLLM instance. It does + not matter if you have another vLLM instance running on the same GPU. For + example, if you have two vLLM instances running on the same GPU, you can + set the GPU memory utilization to 0.5 for each instance.""" + swap_space: float = 4 + """Size of the CPU swap space per GPU (in GiB).""" + cache_dtype: CacheDType = "auto" + """Data type for kv cache storage. If "auto", will use model data type. + CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports + fp8 (=fp8_e4m3).""" + is_attention_free: bool = False + """Whether the model is attention-free. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + num_gpu_blocks_override: Optional[int] = None + """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks` + if specified. Does nothing if `None`. Used for testing preemption.""" + sliding_window: Optional[int] = None + """Sliding window size for the KV cache. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + enable_prefix_caching: Optional[bool] = None + """Whether to enable prefix caching. Disabled by default for V0. Enabled by + default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Set the hash algorithm for prefix caching:\n + - "builtin" is Python's built-in hash.\n + - "sha256" is collision resistant but with certain overheads.""" + cpu_offload_gb: float = 0 + """The space in GiB to offload to CPU, per GPU. Default is 0, which means + no offloading. Intuitively, this argument can be seen as a virtual way to + increase the GPU memory size. For example, if you have one 24 GB GPU and + set this to 10, virtually you can think of it as a 34 GB GPU. Then you can + load a 13B model with BF16 weight, which requires at least 26GB GPU memory. + Note that this requires fast CPU-GPU interconnect, as part of the model is + loaded from CPU memory to GPU memory on the fly in each model forward pass. + """ + calculate_kv_scales: bool = False + """This enables dynamic calculation of `k_scale` and `v_scale` when + kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model + checkpoint if available. Otherwise, the scales will default to 1.0.""" + + # Will be set after profiling. + num_gpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for GPU memory.""" + num_cpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for CPU memory.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.cache_dtype) + # `cpu_offload_gb` does not use `torch.compile` yet. + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + self.swap_space_bytes = self.swap_space * GiB_bytes + + self._verify_args() + self._verify_cache_dtype() + self._verify_prefix_caching() + + def metrics_info(self): + # convert cache_config to dict(key: str, value: str) for prometheus + # metrics info + return {key: str(value) for key, value in self.__dict__.items()} + + def _verify_args(self) -> None: + if self.cpu_offload_gb < 0: + raise ValueError("CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype in get_args(CacheDType): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None and not envs.VLLM_USE_V1: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + + if (self.enable_prefix_caching and self.prefix_caching_hash_algo + not in get_args(PrefixCachingHashAlgo)): + raise ValueError( + "Unknown prefix caching hash algorithm: " + f"{self.prefix_caching_hash_algo}. Must be one of " + f"{get_args(PrefixCachingHashAlgo)}.") + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_cpu_memory = get_cpu_memory() + # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel + # group are in the same node. However, the GPUs may span multiple nodes. + num_gpus_per_node = parallel_config.tensor_parallel_size + cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + + msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " + f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " + "is allocated for the swap space.") + if cpu_memory_usage > 0.7 * total_cpu_memory: + raise ValueError("Too large swap space. " + msg) + elif cpu_memory_usage > 0.4 * total_cpu_memory: + logger.warning("Possibly too large swap space. %s", msg) + + +@config +@dataclass +class TokenizerPoolConfig: + """This config is deprecated and will be removed in a future release. + + Passing these parameters will have no effect. Please remove them from your + configurations. + """ + + pool_size: int = 0 + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" + pool_type: str = "ray" + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" + extra_config: dict = field(default_factory=dict) + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" + + def __post_init__(self) -> None: + logger.warning_once( + "TokenizerPoolConfig is deprecated and will be removed in a " + "future release. Passing this parameter will have no effect. " + "Please remove it from your configurations.") + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + PT = "pt" + SAFETENSORS = "safetensors" + NPCACHE = "npcache" + DUMMY = "dummy" + TENSORIZER = "tensorizer" + SHARDED_STATE = "sharded_state" + GGUF = "gguf" + BITSANDBYTES = "bitsandbytes" + MISTRAL = "mistral" + RUNAI_STREAMER = "runai_streamer" + RUNAI_STREAMER_SHARDED = "runai_streamer_sharded" + FASTSAFETENSORS = "fastsafetensors" + + +@config +@dataclass +class LoadConfig: + """Configuration for loading the model weights.""" + + load_format: Union[str, LoadFormat, + "BaseModelLoader"] = LoadFormat.AUTO.value + """The format of the model weights to load:\n + - "auto" will try to load the weights in the safetensors format and fall + back to the pytorch bin format if safetensors format is not available.\n + - "pt" will load the weights in the pytorch bin format.\n + - "safetensors" will load the weights in the safetensors format.\n + - "npcache" will load the weights in pytorch format and store a numpy cache + to speed up the loading.\n + - "dummy" will initialize the weights with random values, which is mainly + for profiling.\n + - "tensorizer" will use CoreWeave's tensorizer library for fast weight + loading. See the Tensorize vLLM Model script in the Examples section for + more information.\n + - "runai_streamer" will load the Safetensors weights using Run:ai Model + Streamer.\n + - "bitsandbytes" will load the weights using bitsandbytes quantization.\n + - "sharded_state" will load weights from pre-sharded checkpoint files, + supporting efficient loading of tensor-parallel models.\n + - "gguf" will load weights from GGUF format files (details specified in + https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n + - "mistral" will load weights from consolidated safetensors files used by + Mistral models.""" + download_dir: Optional[str] = None + """Directory to download and load the weights, default to the default + cache directory of Hugging Face.""" + model_loader_extra_config: Union[dict, TensorizerConfig] = field( + default_factory=dict) + """Extra config for model loader. This will be passed to the model loader + corresponding to the chosen load_format.""" + ignore_patterns: Optional[Union[list[str], str]] = None + """The list of patterns to ignore when loading the model. Default to + "original/**/*" to avoid repeated loading of llama's checkpoints.""" + use_tqdm_on_load: bool = True + """Whether to enable tqdm for showing progress bar when loading model + weights.""" + pt_load_map_location: Union[str, dict[str, str]] = "cpu" + """ + pt_load_map_location: the map location for loading pytorch checkpoint, to + support loading checkpoints can only be loaded on certain devices like + "cuda", this is equivalent to {"": "cuda"}. Another supported format is + mapping from different devices like from GPU 1 to GPU 0: + {"cuda:1": "cuda:0"}. Note that when passed from command line, the strings + in dictionary needs to be double quoted for json parsing. For more details, + see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if isinstance(self.load_format, str): + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + + +DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] + + +@config +@dataclass +class ParallelConfig: + """Configuration for the distributed execution.""" + + pipeline_parallel_size: int = 1 + """Number of pipeline parallel groups.""" + tensor_parallel_size: int = 1 + """Number of tensor parallel groups.""" + data_parallel_size: int = 1 + """Number of data parallel groups. MoE layers will be sharded according to + the product of the tensor parallel size and data parallel size.""" + data_parallel_size_local: int = 1 + """Number of local data parallel groups.""" + data_parallel_rank: int = 0 + """Rank of the data parallel group.""" + data_parallel_rank_local: Optional[int] = None + """Local rank of the data parallel group, + set only in SPMD mode.""" + data_parallel_master_ip: str = "127.0.0.1" + """IP of the data parallel master.""" + data_parallel_rpc_port: int = 29550 + """Port for data parallel messaging.""" + data_parallel_master_port: int = 29500 + """Port of the data parallel master.""" + data_parallel_backend: str = "mp" + """Backend to use for data parallel, either "mp" or "ray".""" + enable_expert_parallel: bool = False + """Use expert parallelism instead of tensor parallelism for MoE layers.""" + max_parallel_loading_workers: Optional[int] = None + """Maximum number of parallel loading workers when loading model + sequentially in multiple batches. To avoid RAM OOM when using tensor + parallel and large models.""" + + disable_custom_all_reduce: bool = True + """Disable the custom all-reduce kernel and fall back to NCCL.""" + + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None + """This parameter is deprecated and will be removed in a future release. + Please remove it from your configs""" + + ray_workers_use_nsight: bool = False + """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" + + placement_group: Optional["PlacementGroup"] = None + """ray distributed model workers placement group.""" + + distributed_executor_backend: Optional[Union[DistributedExecutorBackend, + type["ExecutorBase"]]] = None + """Backend to use for distributed model + workers, either "ray" or "mp" (multiprocessing). If the product + of pipeline_parallel_size and tensor_parallel_size is less than + or equal to the number of GPUs available, "mp" will be used to + keep processing on a single host. Otherwise, this will default + to "ray" if Ray is installed and fail otherwise. Note that tpu + and hpu only support Ray for distributed inference.""" + + worker_cls: str = "auto" + """The full name of the worker class to use. If "auto", the worker class + will be determined based on the platform.""" + sd_worker_cls: str = "auto" + """The full name of the worker class to use for speculative decofing. + If "auto", the worker class will be determined based on the platform.""" + worker_extension_cls: str = "" + """The full name of the worker extension class to use. The worker extension + class is dynamically inherited by the worker class. This is used to inject + new attributes and methods to the worker class for use in collective_rpc + calls.""" + + world_size: int = field(init=False) + """world_size is TPxPP, it affects the number of workers we create.""" + + rank: int = 0 + """Global rank in distributed setup.""" + + enable_multimodal_encoder_data_parallel: bool = False + """ Use data parallelism instead of tensor parallelism for vision encoder. + Only support LLama4 for now""" + + @property + def world_size_across_dp(self) -> int: + """world_size_across_dp is TPxPPxDP, it is the size of the world + including data parallelism.""" + return self.world_size * self.data_parallel_size + + def get_next_dp_init_port(self) -> int: + """ + We might need to initialize process groups in multiple + processes that is related to data parallelism, + e.g. both in the worker and in the engine, which + can live in different processes. To avoid port conflicts, we + increment the port number each time we need to initialize a + new process group related to data parallelism. + """ + answer = self.data_parallel_master_port + self.data_parallel_master_port += 1 + return answer + + def stateless_init_dp_group(self) -> "ProcessGroup": + from vllm.distributed.utils import ( + stateless_init_torch_distributed_process_group) + + # use gloo since the engine process might not have cuda device + dp_group = stateless_init_torch_distributed_process_group( + self.data_parallel_master_ip, + self.get_next_dp_init_port(), + self.data_parallel_rank, + self.data_parallel_size, + backend="gloo") + + return dp_group + + @staticmethod + def has_unfinished_dp(dp_group: "ProcessGroup", + has_unfinished: bool) -> bool: + tensor = torch.tensor([has_unfinished], + dtype=torch.int32, + device="cpu") + # dp rank 0: has_unfinished_seqs=True + # dp rank 1: has_unfinished_seqs=False + # aggregated: has_unfinished_seqs=True + # so this is an OR operation, i.e. MAX in integers + torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group) + aggregated_has_unfinished = bool(tensor.item()) + return aggregated_has_unfinished + + def compute_hash(self): + """ + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.pipeline_parallel_size) + factors.append(self.tensor_parallel_size) + factors.append(self.enable_expert_parallel) + factors.append(self.data_parallel_size) + factors.append(envs.VLLM_ALL2ALL_BACKEND) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def __post_init__(self) -> None: + self.world_size = self.pipeline_parallel_size * \ + self.tensor_parallel_size + + if self.data_parallel_size_local > self.data_parallel_size: + raise ValueError( + f"data_parallel_size_local ({self.data_parallel_size_local}) " + f"must be <= data_parallel_size ({self.data_parallel_size})") + + if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: + # Data parallel was specified in the engine args. + self.data_parallel_master_port = get_open_port() + else: + # Otherwise fall back to env vars (e.g. for offline SPMD case). + self.data_parallel_size = envs.VLLM_DP_SIZE + self.data_parallel_rank = envs.VLLM_DP_RANK + self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL + self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP + self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT + + if self.distributed_executor_backend == "external_launcher": + import os + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + logger.info("Disabling V1 multiprocessing for external launcher.") + + ray_only_devices: list[str] = [] + from vllm.platforms import current_platform + if (current_platform.device_type in ray_only_devices + and self.world_size > 1): + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "ray" + if self.distributed_executor_backend != "ray": + raise ValueError( + f"{current_platform.device_type.upper()} backend only " + "supports Ray for distributed inference.") + + if self.distributed_executor_backend is None and self.world_size > 1: + # We use multiprocessing by default if world_size fits on the + # current node and we aren't in a ray placement group. + + from vllm.executor import ray_utils + backend: DistributedExecutorBackend = "mp" + ray_found = ray_utils.ray_is_available() + if current_platform.is_neuron(): + # neuron uses single process to control multiple devices + backend = "uni" + elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: + backend = "uni" + elif (current_platform.is_cuda() + and cuda_device_count_stateless() < self.world_size): + if not ray_found: + raise ValueError("Unable to load Ray which is " + "required for multi-node inference, " + "please install Ray with `pip install " + "ray`.") from ray_utils.ray_import_err + backend = "ray" + elif self.data_parallel_backend == "ray": + logger.info("Using ray distributed inference because " + "data_parallel_backend is ray") + backend = "ray" + elif ray_found: + if self.placement_group: + backend = "ray" + else: + from ray import is_initialized as ray_is_initialized + if ray_is_initialized(): + from ray.util import get_current_placement_group + if get_current_placement_group(): + backend = "ray" + self.distributed_executor_backend = backend + logger.info("Defaulting to use %s for distributed inference", + backend) + + if self.distributed_executor_backend is None and self.world_size == 1: + self.distributed_executor_backend = "uni" + + self._verify_args() + + @property + def use_ray(self) -> bool: + return self.distributed_executor_backend == "ray" or ( + isinstance(self.distributed_executor_backend, type) + and self.distributed_executor_backend.uses_ray) + + def _verify_args(self) -> None: + # Lazy import to avoid circular import + from vllm.executor.executor_base import ExecutorBase + from vllm.platforms import current_platform + if self.distributed_executor_backend not in ( + "ray", "mp", "uni", + "external_launcher", None) and not (isinstance( + self.distributed_executor_backend, type) and issubclass( + self.distributed_executor_backend, ExecutorBase)): + raise ValueError( + "Unrecognized distributed executor backend " + f"{self.distributed_executor_backend}. Supported " + "values are 'ray', 'mp' 'uni', 'external_launcher' or" + " custom ExecutorBase subclass.") + if self.use_ray: + from vllm.executor import ray_utils + ray_utils.assert_ray_available() + + if not current_platform.use_custom_allreduce(): + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "supported on current platform.") + if self.ray_workers_use_nsight and not self.use_ray: + raise ValueError("Unable to use nsight profiling unless workers " + "run with Ray.") + + assert isinstance(self.worker_extension_cls, str), ( + "worker_extension_cls must be a string (qualified class name).") + + +PreemptionMode = Literal["swap", "recompute"] +SchedulerPolicy = Literal["fcfs", "priority"] + + +@config +@dataclass +class SchedulerConfig: + """Scheduler configuration.""" + + runner_type: RunnerType = "generate" + """The runner type to launch for the model.""" + + max_num_batched_tokens: SkipValidation[int] = None # type: ignore + """Maximum number of tokens to be processed in a single iteration. + + This config has no static default. If left unspecified by the user, it will + be set in `EngineArgs.create_engine_config` based on the usage context.""" + + max_num_seqs: SkipValidation[int] = None # type: ignore + """Maximum number of sequences to be processed in a single iteration. + + This config has no static default. If left unspecified by the user, it will + be set in `EngineArgs.create_engine_config` based on the usage context.""" + + max_model_len: SkipValidation[int] = None # type: ignore + """Maximum length of a sequence (including prompt and generated text). This + is primarily set in `ModelConfig` and that value should be manually + duplicated here.""" + + max_num_partial_prefills: int = 1 + """For chunked prefill, the maximum number of sequences that can be + partially prefilled concurrently.""" + + max_long_partial_prefills: int = 1 + """For chunked prefill, the maximum number of prompts longer than + long_prefill_token_threshold that will be prefilled concurrently. Setting + this less than max_num_partial_prefills will allow shorter prompts to jump + the queue in front of longer prompts in some cases, improving latency.""" + + long_prefill_token_threshold: int = 0 + """For chunked prefill, a request is considered long if the prompt is + longer than this number of tokens.""" + + num_lookahead_slots: int = 0 + """The number of slots to allocate per sequence per + step, beyond the known token ids. This is used in speculative + decoding to store KV activations of tokens which may or may not be + accepted. + + NOTE: This will be replaced by speculative config in the future; it is + present to enable correctness tests until then.""" + + cuda_graph_sizes: list[int] = field(default_factory=lambda: [256]) + """Cuda graph capture sizes, default is 256. + 1. if one value is provided, then the capture list would follow the + pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] + 2. more than one value (e.g. 1 2 128) is provided, then the capture list + will follow the provided list.""" + + delay_factor: float = 0.0 + """Apply a delay (of delay factor multiplied by previous + prompt latency) before scheduling next prompt.""" + + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore + """If True, prefill requests can be chunked based + on the remaining max_num_batched_tokens.""" + + is_multimodal_model: bool = False + """True if the model is multimodal.""" + + # TODO (ywang96): Make this configurable. + max_num_encoder_input_tokens: int = field(init=False) + """Multimodal encoder compute budget, only used in V1. + + NOTE: This is not currently configurable. It will be overridden by + max_num_batched_tokens in case max multimodal embedding size is larger.""" + + # TODO (ywang96): Make this configurable. + encoder_cache_size: int = field(init=False) + """Multimodal encoder cache size, only used in V1. + + NOTE: This is not currently configurable. It will be overridden by + max_num_batched_tokens in case max multimodal embedding size is larger.""" + + preemption_mode: Optional[PreemptionMode] = None + """Whether to perform preemption by swapping or + recomputation. If not specified, we determine the mode as follows: + We use recomputation by default since it incurs lower overhead than + swapping. However, when the sequence group has multiple sequences + (e.g., beam search), recomputation is not currently supported. In + such a case, we use swapping instead.""" + + num_scheduler_steps: int = 1 + """Maximum number of forward steps per scheduler call.""" + + multi_step_stream_outputs: bool = True + """If False, then multi-step will stream outputs at the end of all steps""" + + send_delta_data: bool = False + """Private API. If used, scheduler sends delta data to + workers instead of an entire data. It should be enabled only + when SPMD worker architecture is enabled. I.e., + VLLM_USE_RAY_SPMD_WORKER=1""" + + policy: SchedulerPolicy = "fcfs" + """The scheduling policy to use:\n + - "fcfs" means first come first served, i.e. requests are handled in order + of arrival.\n + - "priority" means requests are handled based on given priority (lower + value means earlier handling) and time of arrival deciding any ties).""" + + chunked_prefill_enabled: bool = field(init=False) + """True if chunked prefill is enabled.""" + + disable_chunked_mm_input: bool = False + """If set to true and chunked prefill is enabled, we do not want to + partially schedule a multimodal item. Only used in V1 + This ensures that if a request has a mixed prompt + (like text tokens TTTT followed by image tokens IIIIIIIIII) where only + some image tokens can be scheduled (like TTTTIIIII, leaving IIIII), + it will be scheduled as TTTT in one step and IIIIIIIIII in the next.""" + + # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) + # or "mod.custom_class". + scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler" + """The scheduler class to use. "vllm.core.scheduler.Scheduler" is the + default scheduler. Can be a class directly or the path to a class of form + "mod.custom_class".""" + + disable_hybrid_kv_cache_manager: bool = False + """If set to True, KV cache manager will allocate the same size of KV cache + for all attention layers even if there are multiple type of attention layers + like full attention and sliding window attention. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + if self.max_model_len is None: + self.max_model_len = 8192 + + if self.max_num_seqs is None: + self.max_num_seqs = 128 + + if self.max_num_batched_tokens is None: + if self.enable_chunked_prefill: + if self.num_scheduler_steps > 1: + # Multi-step Chunked-Prefill doesn't allow prompt-chunking + # for now. Have max_num_batched_tokens set to max_model_len + # so we don't reject sequences on account of a short + # max_num_batched_tokens. + self.max_num_batched_tokens = max( + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) + else: + self.max_num_batched_tokens = ( + DEFAULT_MAX_NUM_BATCHED_TOKENS) + else: + # If max_model_len is too short, use + # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value + # for higher throughput. + self.max_num_batched_tokens = max( + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) + + if self.runner_type == "pooling": + # Choose specific value for higher throughput + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + if self.is_multimodal_model: + # The value needs to be at least the number of multimodal tokens + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + + # When using default settings, + # Ensure max_num_batched_tokens does not exceed model limit. + # Some models (e.g., Whisper) have embeddings tied to max length. + self.max_num_batched_tokens = min( + self.max_num_seqs * self.max_model_len, + self.max_num_batched_tokens) + + self.max_num_encoder_input_tokens = self.max_num_batched_tokens + self.encoder_cache_size = self.max_num_batched_tokens + + if self.enable_chunked_prefill: + logger.info( + "Chunked prefill is enabled with max_num_batched_tokens=%d.", + self.max_num_batched_tokens) + + self.chunked_prefill_enabled = self.enable_chunked_prefill + if self.max_num_partial_prefills > 1: + if self.long_prefill_token_threshold == 0: + self.long_prefill_token_threshold = int(self.max_model_len * + 0.04) + + logger.info( + "Concurrent partial prefills enabled with " + "max_num_partial_prefills=%d, max_long_partial_prefills=%d, " + "long_prefill_token_threshold=%d", + self.max_num_partial_prefills, self.max_long_partial_prefills, + self.long_prefill_token_threshold) + + self._verify_args() + + def _verify_args(self) -> None: + if (self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled): + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len.") + + if self.max_num_batched_tokens < self.max_num_seqs: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_num_seqs " + f"({self.max_num_seqs}).") + + if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: + logger.warning( + "max_num_batched_tokens (%d) exceeds max_num_seqs" + "* max_model_len (%d). This may lead to unexpected behavior.", + self.max_num_batched_tokens, + self.max_num_seqs * self.max_model_len) + + if self.num_lookahead_slots < 0: + raise ValueError( + "num_lookahead_slots " + f"({self.num_lookahead_slots}) must be greater than or " + "equal to 0.") + + if self.num_scheduler_steps < 1: + raise ValueError( + "num_scheduler_steps " + f"({self.num_scheduler_steps}) must be greater than or " + "equal to 1.") + + if self.max_num_partial_prefills < 1: + raise ValueError( + f"max_num_partial_prefills ({self.max_num_partial_prefills}) " + "must be greater than or equal to 1.") + elif self.max_num_partial_prefills > 1: + if not self.chunked_prefill_enabled: + raise ValueError("Chunked prefill must be enabled to set " + "max_num_partial_prefills > 1.") + + if self.long_prefill_token_threshold > self.max_model_len: + raise ValueError( + "long_prefill_token_threshold " + f"({self.long_prefill_token_threshold}) cannot be greater " + f"than the max_model_len ({self.max_model_len}).") + + if (self.max_long_partial_prefills + < 1) or (self.max_long_partial_prefills + > self.max_num_partial_prefills): + raise ValueError( + f"max_long_partial_prefills ({self.max_long_partial_prefills}) " + "must be greater than or equal to 1 and less than or equal to " + f"max_num_partial_prefills ({self.max_num_partial_prefills}).") + + @property + def is_multi_step(self) -> bool: + return self.num_scheduler_steps > 1 + + +Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class DeviceConfig: + """Configuration for the device to use for vLLM execution.""" + + device: SkipValidation[Union[Device, torch.device]] = "auto" + """Device type for vLLM execution. + This parameter is deprecated and will be + removed in a future release. + It will now be set automatically based + on the current platform.""" + device_type: str = field(init=False) + """Device type from the current platform. This is set in + `__post_init__`.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # the device/platform information will be summarized + # by torch/vllm automatically. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if self.device == "auto": + # Automated device type detection + from vllm.platforms import current_platform + self.device_type = current_platform.device_type + if not self.device_type: + raise RuntimeError( + "Failed to infer device type, please set " + "the environment variable `VLLM_LOGGING_LEVEL=DEBUG` " + "to turn on verbose logging to help debug the issue.") + else: + # Device type is assigned explicitly + self.device_type = self.device + + # Some device types require processing inputs on CPU + if self.device_type in ["neuron"]: + self.device = torch.device("cpu") + elif self.device_type in ["tpu"]: + self.device = None + else: + # Set device with device type + self.device = torch.device(self.device_type) + + +SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", + "mlp_speculator", "draft_model", "deepseek_mtp"] +SpeculativeAcceptanceMethod = Literal["rejection_sampler", + "typical_acceptance_sampler"] + + +@config +@dataclass +class SpeculativeConfig: + """Configuration for speculative decoding.""" + + # General speculative decoding control + num_speculative_tokens: SkipValidation[int] = None # type: ignore + """The number of speculative tokens, if provided. It will default to the + number in the draft model config if present, otherwise, it is required.""" + model: Optional[str] = None + """The name of the draft model, eagle head, or additional weights, if + provided.""" + method: Optional[SpeculativeMethod] = None + """The name of the speculative method to use. If users provide and set the + `model` param, the speculative method type will be detected automatically + if possible, if `model` param is not provided, the method name must be + provided. + + If using `ngram` method, the related configuration `prompt_lookup_max` and + `prompt_lookup_min` should be considered.""" + acceptance_method: SpeculativeAcceptanceMethod = "rejection_sampler" + """The method to use for accepting draft tokens:\n + - "rejection_sampler" maps to `RejectionSampler`.\n + - "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`. + + If using `typical_acceptance_sampler`, the related configuration + `posterior_threshold` and `posterior_alpha` should be considered.""" + draft_tensor_parallel_size: Optional[int] = None + """The degree of the tensor parallelism for the draft model. Can only be 1 + or the same as the target model's tensor parallel size.""" + disable_logprobs: bool = True + """If set to True, token log probabilities are not returned during + speculative decoding. If set to False, token log probabilities are returned + according to the log probability settings in SamplingParams.""" + + # Draft model configuration + quantization: Optional[QuantizationMethods] = None + """Quantization method that was used to quantize the draft model weights. + If `None`, we assume the model weights are not quantized. Note that it only + takes effect when using the draft model-based speculative method.""" + max_model_len: Optional[int] = None + """The maximum model length of the draft model. Used when testing the + ability to skip speculation for some sequences.""" + revision: Optional[str] = None + """The specific model version to use for the draft model. It can be a + branch name, a tag name, or a commit id. If unspecified, will use the + default version.""" + code_revision: Optional[str] = None + """The specific revision to use for the draft model code on Hugging Face + Hub. It can be a branch name, a tag name, or a commit id. If unspecified, + will use the default version.""" + + # Advanced control + disable_mqa_scorer: bool = False + """Disable the MQA scorer and fall back to batch expansion for scoring + proposals.""" + disable_by_batch_size: Optional[int] = None + """Disable speculative decoding for new incoming requests when the number + of enqueued requests is larger than this value, if provided.""" + + # Ngram proposer configuration + prompt_lookup_max: Optional[int] = None + """Maximum size of ngram token window when using Ngram proposer, required + when method is set to ngram.""" + prompt_lookup_min: Optional[int] = None + """Minimum size of ngram token window when using Ngram proposer, if + provided. Defaults to 1.""" + + # Typical acceptance sampler configuration + posterior_threshold: Optional[float] = None + """A threshold value that sets a lower bound on the posterior probability + of a token in the target model for it to be accepted. This threshold is + used only when we use the `TypicalAcceptanceSampler` for token acceptance. + """ + posterior_alpha: Optional[float] = None + """Scaling factor for entropy-based threshold, applied when using + `TypicalAcceptanceSampler`.""" + + speculative_token_tree: Optional[str] = None + """Specifies the tree structure for speculative token generation. + """ + # required configuration params passed from engine + target_model_config: SkipValidation[ModelConfig] = None # type: ignore + """The configuration of the target model.""" + target_parallel_config: SkipValidation[ + ParallelConfig] = None # type: ignore + """The parallel configuration for the target model.""" + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore + """Whether vLLM is configured to use chunked prefill or not. Used for + raising an error since it's not yet compatible with speculative decode.""" + disable_log_stats: SkipValidation[bool] = None # type: ignore + """Whether to disable the periodic printing of stage times in speculative + decoding.""" + + # params generated in the post-init stage + draft_model_config: SkipValidation[ModelConfig] = None # type: ignore + """The configuration of the draft model initialized internal.""" + draft_parallel_config: SkipValidation[ + ParallelConfig] = None # type: ignore + """The parallel configuration for the draft model initialized internal.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + # Eagle3 affects the computation graph because it returns intermediate + # hidden states in addition to the final hidden state. + factors.append(self.method == "eagle3") + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + @classmethod + def from_dict(cls, dict_value: dict) -> "SpeculativeConfig": + """Parse the CLI value for the speculative config.""" + return cls(**dict_value) + + @staticmethod + def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: + if hf_config.model_type == "deepseek_v3": + hf_config.model_type = "deepseek_mtp" + if hf_config.model_type == "deepseek_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["DeepSeekMTPModel"] + }) + + if hf_config.architectures[0] == "MiMoForCausalLM": + hf_config.model_type = "mimo_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["MiMoMTPModel"] + }) + return hf_config + + return hf_config + + def __post_init__(self): + + # Note: "method" is a new parameter that helps to extend the + # configuration of non-model-based proposers, and the "model" parameter + # will be used to set the draft model, eagle head, or additional weight + # when needed. If users do not specify "method", the speculative method + # will be detected automatically if possible. If the speculative method + # can not be detected, it will be considered as the "draft_model" by + # default. + + if self.model is None and self.num_speculative_tokens is not None: + # TODO(Shangming): Refactor mtp configuration logic when supporting + # mtp acceleration for more models besides deepseek_v3 + if self.target_model_config and \ + (self.target_model_config.hf_text_config.model_type \ + == "deepseek_v3" or + self.target_model_config.hf_text_config.model_type \ + == "mimo"): + # use the draft model from the same model: + self.model = self.target_model_config.model + elif self.method in ("ngram", "[ngram]"): + self.model = "ngram" + else: + raise ValueError("num_speculative_tokens was provided without " + "speculative model.") + + # Automatically configure the method for ngram when "model" is used + # instead of "method" + if self.method is None and (self.model is not None + and self.model in ("ngram", "[ngram]")): + self.method = "ngram" + + if self.method in ("ngram", "[ngram]"): + # Unified to "ngram" internally + self.method = "ngram" + # Set default values if not provided + if (self.prompt_lookup_min is None + and self.prompt_lookup_max is None): + # TODO(woosuk): Tune these values. They are arbitrarily chosen. + self.prompt_lookup_min = 5 + self.prompt_lookup_max = 5 + elif self.prompt_lookup_min is None: + assert self.prompt_lookup_max is not None + self.prompt_lookup_min = self.prompt_lookup_max + elif self.prompt_lookup_max is None: + assert self.prompt_lookup_min is not None + self.prompt_lookup_max = self.prompt_lookup_min + + # Validate values + if self.prompt_lookup_min < 1: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must be > 0") + if self.prompt_lookup_max < 1: + raise ValueError( + f"prompt_lookup_max={self.prompt_lookup_max} must be > 0") + if self.prompt_lookup_min > self.prompt_lookup_max: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must " + f"be <= prompt_lookup_max={self.prompt_lookup_max}") + + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config + else: + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 + + if self.model is not None: + self.draft_model_config = ModelConfig( + model=self.model, + task="draft", + tokenizer=self.target_model_config.tokenizer, + tokenizer_mode=self.target_model_config.tokenizer_mode, + trust_remote_code=self.target_model_config. + trust_remote_code, + allowed_local_media_path=self.target_model_config. + allowed_local_media_path, + dtype=self.target_model_config.dtype, + seed=self.target_model_config.seed, + revision=self.revision, + code_revision=self.code_revision, + tokenizer_revision=self.target_model_config. + tokenizer_revision, + spec_target_max_model_len=self.target_model_config. + max_model_len, + quantization=self.quantization, + enforce_eager=self.target_model_config.enforce_eager, + max_seq_len_to_capture=self.target_model_config. + max_seq_len_to_capture, + max_logprobs=self.target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, + ) + + # Automatically detect the method + if self.method in ('eagle', 'eagle3'): + pass + elif "eagle-" in self.draft_model_config.model.lower() or \ + "eagle3-" in self.draft_model_config.model.lower(): + self.method = "eagle" + elif self.draft_model_config.hf_config.model_type == "medusa": + self.method = "medusa" + elif (self.draft_model_config.hf_config.model_type == + "mlp_speculator"): + self.method = "mlp_speculator" + elif (self.draft_model_config.hf_config.model_type == + "deepseek_mtp"): + self.method = "deepseek_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Deepseek MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + else: + self.method = "draft_model" + + # Replace hf_config for EAGLE draft_model + if self.method in ("eagle", "eagle3"): + if self.enable_chunked_prefill and not envs.VLLM_USE_V1: + raise ValueError( + "Chunked prefill and EAGLE are not compatible " + "when using V0.") + + from vllm.transformers_utils.configs.eagle import ( + EAGLEConfig) + if isinstance(self.draft_model_config.hf_config, + EAGLEConfig): + pass + else: + eagle_config = EAGLEConfig( + self.draft_model_config.hf_config, + method=self.method, + model_type="eagle") + self.draft_model_config.hf_config = eagle_config + + if (self.num_speculative_tokens is not None + and hasattr(self.draft_model_config.hf_config, + "num_lookahead_tokens")): + self.draft_model_config.hf_config.num_lookahead_tokens = \ + self.num_speculative_tokens + + n_predict = getattr(self.draft_model_config.hf_config, + "n_predict", None) + if n_predict is not None: + if self.num_speculative_tokens is None: + # Default to max value defined in draft model config. + self.num_speculative_tokens = n_predict + elif self.num_speculative_tokens > n_predict and \ + self.num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. + raise ValueError( + f"num_speculative_tokens:{self.num_speculative_tokens}" + f" must be divisible by {n_predict=}") + + self.draft_tensor_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_tp( + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_model_config.hf_config + ) + + self.draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + self.max_model_len, + self.draft_model_config.max_model_len, + self.target_model_config.max_model_len, + )) + + self.draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + self.target_parallel_config, + self.draft_tensor_parallel_size)) + + if self.acceptance_method == "typical_acceptance_sampler": + if self.posterior_threshold is None: + self.posterior_threshold = 0.09 + if self.posterior_alpha is None: + self.posterior_alpha = 0.3 + + self._verify_args() + + @staticmethod + def _maybe_override_draft_max_model_len( + speculative_max_model_len: Optional[int], + draft_max_model_len: int, + target_max_model_len: int, + ) -> int: + """Determine the max sequence len for the draft model. This is usually + the draft_max_model_len, but may be the target_max_model_len if it is + less than the draft_max_model_len, or may be speculative_max_model_len + if it is specified. + + This is necessary so that sequences do not exceed the capacity of the + draft model or the target model. + + speculative_max_model_len is mainly used for testing that sequences can + skip speculation. + """ + + if speculative_max_model_len is not None: + + if speculative_max_model_len > draft_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {draft_max_model_len=}") + + if speculative_max_model_len > target_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {target_max_model_len=}") + + return speculative_max_model_len + + return min( + draft_max_model_len, + target_max_model_len, + ) + + @staticmethod + def _verify_and_get_draft_tp( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig) -> int: + """ + Verifies and adjusts the tensor parallel size for a draft model + specified using speculative_draft_tensor_parallel_size. + """ + # If speculative_draft_tensor_parallel_size is unset then set it + # appropriately else verify that it is set correctly. + if speculative_draft_tensor_parallel_size is None: + if draft_hf_config.model_type == "mlp_speculator": + speculative_draft_tensor_parallel_size = 1 + if target_parallel_config.tensor_parallel_size > 1: + logger.warning( + "%s cannot currently be run with tp>1; " + "setting speculative_draft_tensor_parallel_size=1", + draft_hf_config.model_type) + else: + speculative_draft_tensor_parallel_size = \ + target_parallel_config.tensor_parallel_size + elif speculative_draft_tensor_parallel_size not in ( + 1, target_parallel_config.tensor_parallel_size): + raise ValueError( + f"{speculative_draft_tensor_parallel_size=} cannot be " + f"other value than 1 or target model tensor_parallel_size") + return speculative_draft_tensor_parallel_size + + @staticmethod + def create_draft_parallel_config( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: int, + ) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + + This is mostly a copy of the target parallel config, except the tp_size. + """ + draft_parallel_config = ParallelConfig( + pipeline_parallel_size=target_parallel_config. + pipeline_parallel_size, + tensor_parallel_size=speculative_draft_tensor_parallel_size, + distributed_executor_backend=target_parallel_config. + distributed_executor_backend, + max_parallel_loading_workers=target_parallel_config. + max_parallel_loading_workers, + disable_custom_all_reduce=target_parallel_config. + disable_custom_all_reduce, + ray_workers_use_nsight=target_parallel_config. + ray_workers_use_nsight, + placement_group=target_parallel_config.placement_group, + ) + + return draft_parallel_config + + def _verify_args(self) -> None: + if self.num_speculative_tokens is None: + raise ValueError( + "num_speculative_tokens must be provided with " + "speculative model unless the draft model config contains an " + "n_predict parameter.") + + if self.num_speculative_tokens <= 0: + raise ValueError("Expected num_speculative_tokens to be greater " + f"than zero ({self.num_speculative_tokens}).") + + if self.draft_model_config: + self.draft_model_config.verify_with_parallel_config( + self.draft_parallel_config) + # Validate and set draft token acceptance related settings. + + if self.acceptance_method is None: + raise ValueError("acceptance_method is not set. " + "Expected values are rejection_sampler or " + "typical_acceptance_sampler.") + + if (self.acceptance_method != 'rejection_sampler' + and self.acceptance_method != 'typical_acceptance_sampler'): + raise ValueError( + "Expected acceptance_method to be either " + "rejection_sampler or typical_acceptance_sampler. Instead it " + f"is {self.acceptance_method}") + + if self.acceptance_method == "typical_acceptance_sampler" and ( + (self.posterior_threshold is not None + and self.posterior_threshold < 0) or + (self.posterior_alpha is not None and self.posterior_alpha < 0)): + raise ValueError( + "Expected the posterior_threshold and posterior_alpha of " + "typical_acceptance_sampler to be > 0. " + "Instead found posterior_threshold = " + f"{self.posterior_threshold} and posterior_alpha = " + f"{self.posterior_alpha}") + + if (self.disable_by_batch_size is not None + and self.disable_by_batch_size < 2): + raise ValueError("Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{self.disable_by_batch_size=}") + + if self.method == "eagle3" and self.target_model_config and \ + "llama" not in self.target_model_config.hf_text_config.model_type: + raise ValueError( + "Eagle3 is only supported for Llama models. " + f"Got {self.target_model_config.hf_text_config.model_type=}") + + @property + def num_lookahead_slots(self) -> int: + """The number of additional slots the scheduler should allocate per + step, in addition to the slots allocated for each known token. + + This is equal to the number of speculative tokens, as each speculative + token must be scored. + """ + return self.num_speculative_tokens + + def use_eagle(self) -> bool: + return self.method in ("eagle", "eagle3", "deepseek_mtp") + + def __repr__(self) -> str: + method = self.method + model = None if method == "ngram" else self.draft_model_config.model + num_spec_tokens = self.num_speculative_tokens + return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" + + +LoRADType = Literal["auto", "float16", "bfloat16"] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class LoRAConfig: + """Configuration for LoRA.""" + + max_lora_rank: int = 16 + """Max LoRA rank.""" + max_loras: int = 1 + """Max number of LoRAs in a single batch.""" + fully_sharded_loras: bool = False + """By default, only half of the LoRA computation is sharded with tensor + parallelism. Enabling this will use the fully sharded layers. At high + sequence length, max rank or tensor parallel size, this is likely faster. + """ + max_cpu_loras: Optional[int] = None + """Maximum number of LoRAs to store in CPU memory. Must be >= than + `max_loras`.""" + lora_dtype: Union[torch.dtype, LoRADType] = "auto" + """Data type for LoRA. If auto, will default to base model dtype.""" + lora_extra_vocab_size: int = 256 + """Maximum size of extra vocabulary that can be present in a LoRA adapter + (added to the base model vocabulary).""" + lora_vocab_padding_size: ClassVar[int] = current_platform\ + .get_lora_vocab_padding_size() + long_lora_scaling_factors: Optional[tuple[float, ...]] = None + """Specify multiple scaling factors (which can be different from base model + scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters + trained with those scaling factors to be used at the same time. If not + specified, only adapters trained with the base model scaling factor are + allowed.""" + bias_enabled: bool = False + """Enable bias for LoRA adapters.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.max_lora_rank) + factors.append(self.max_loras) + factors.append(self.fully_sharded_loras) + factors.append(self.lora_dtype) + factors.append(self.lora_extra_vocab_size) + factors.append(self.lora_vocab_padding_size) + factors.append(self.long_lora_scaling_factors) + factors.append(self.bias_enabled) + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + # Setting the maximum rank to 512 should be able to satisfy the vast + # majority of applications. + possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) + possible_lora_extra_vocab_size = (256, 512) + if self.max_lora_rank not in possible_max_ranks: + raise ValueError( + f"max_lora_rank ({self.max_lora_rank}) must be one of " + f"{possible_max_ranks}.") + if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: + raise ValueError( + f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " + f"must be one of {possible_lora_extra_vocab_size}.") + if self.max_loras < 1: + raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_loras ({self.max_loras})") + + def verify_with_cache_config(self, cache_config: CacheConfig): + if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1: + raise ValueError( + "V0 LoRA does not support CPU offload, please use V1.") + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) + + def verify_lora_support(self): + if self.long_lora_scaling_factors is not None and envs.VLLM_USE_V1: + raise ValueError( + "V1 LoRA does not support long LoRA, please use V0.") + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class PromptAdapterConfig: + """Configuration for PromptAdapters.""" + + max_prompt_adapters: int = 1 + """Max number of PromptAdapters in a batch.""" + max_prompt_adapter_token: int = 0 + """Max number of PromptAdapters tokens.""" + max_cpu_prompt_adapters: Optional[int] = None + """Maximum number of PromptAdapters to store in CPU memory. Must be >= than + `max_prompt_adapters`.""" + prompt_adapter_dtype: Union[torch.dtype, str] = "auto" + """Data type for PromptAdapter. If auto, will default to base model dtype. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + + if self.max_prompt_adapters < 1: + raise ValueError(f"max_prompt_adapters " + f"({self.max_prompt_adapters}) must be >= 1.") + if self.max_prompt_adapter_token == 0: + raise ValueError("max_prompt_adapter_token must be set.") + if self.max_cpu_prompt_adapters is None: + self.max_cpu_prompt_adapters = self.max_prompt_adapters + + def verify_with_model_config(self, model_config: ModelConfig): + if self.prompt_adapter_dtype == "auto": + self.prompt_adapter_dtype = model_config.dtype + elif isinstance(self.prompt_adapter_dtype, str): + self.prompt_adapter_dtype = getattr(torch, + self.prompt_adapter_dtype) + + +@config +@dataclass +class MultiModalConfig: + """Controls the behavior of multimodal models.""" + + limit_per_prompt: dict[str, int] = \ + cast(dict[str, int], get_field(ModelConfig, "limit_mm_per_prompt")) + """ + The maximum number of input items allowed per prompt for each modality. + Defaults to 1 (V0) or 999 (V1) for each modality. + + For example, to allow up to 16 images and 2 videos per prompt: + `{"images": 16, "videos": 2}` + """ + + mm_processor_kwargs: Optional[dict[str, object]] = None + """ + Overrides for the multi-modal processor obtained from + `transformers.AutoProcessor.from_pretrained`. + + The available overrides depend on the model that is being run. + + For example, for Phi-3-Vision: + `{"num_crops": 4}`. + """ + + disable_mm_preprocessor_cache: bool = False + """ + If `True`, disable caching of the processed multi-modal inputs. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def get_limit_per_prompt(self, modality: str) -> int: + """ + Get the maximum number of input items allowed per prompt + for the given modality. + """ + return self.limit_per_prompt.get( + modality, + 999 if envs.VLLM_USE_V1 else 1, + ) + + # TODO: Add configs to init vision tower or not. + + +@config +@dataclass +class PoolerConfig: + """Controls the behavior of output pooling in pooling models.""" + + pooling_type: Optional[str] = None + """ + The pooling method of the pooling model. This should be a key in + [`vllm.model_executor.layers.pooler.PoolingType`][]. + """ + + normalize: Optional[bool] = None + """ + Whether to normalize the pooled outputs. Usually, this should be set to + ``True`` for embedding outputs. + """ + + softmax: Optional[bool] = None + """ + Whether to apply softmax to the pooled outputs. Usually, this should be set + to ``True`` for classification outputs. + """ + + step_tag_id: Optional[int] = None + """ + If set, only the score corresponding to the ``step_tag_id`` in the + generated sentence should be returned. Otherwise, the scores for all tokens + are returned. + """ + + returned_token_ids: Optional[list[int]] = None + """ + A list of indices for the vocabulary dimensions to be extracted, + such as the token IDs of ``good_token`` and ``bad_token`` in the + ``math-shepherd-mistral-7b-prm`` model. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + +# model_type -> reason +_FLOAT16_NOT_SUPPORTED_MODELS = { + "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", + "gemma3": "Numerical instability. Please use bfloat16 or float32 instead.", + "plamo2": "Numerical instability. Please use bfloat16 or float32 instead.", + "glm4": "Numerical instability. Please use bfloat16 or float32 instead.", +} + + +def _is_valid_dtype(model_type: str, dtype: torch.dtype): + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103 + return False + + return True + + +def _check_valid_dtype(model_type: str, dtype: torch.dtype): + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: + reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type] + raise ValueError(f"The model type {model_type!r} " + f"does not support float16. Reason: {reason}") + + return True + + +def _find_dtype( + model_id: str, + config: PretrainedConfig, + *, + revision: Optional[str], +): + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + + # Fallbacks for multi-modal models if the root config + # does not define torch_dtype + if config_dtype is None: + config_dtype = getattr(config.get_text_config(), "torch_dtype", None) + if config_dtype is None and hasattr(config, "vision_config"): + config_dtype = getattr(config.vision_config, "torch_dtype", None) + if config_dtype is None and hasattr(config, "encoder_config"): + config_dtype = getattr(config.encoder_config, "torch_dtype", None) + + # Try to read the dtype of the weights if they are in safetensors format + if config_dtype is None: + repo_mt = try_get_safetensors_metadata(model_id, revision=revision) + + if repo_mt and (files_mt := repo_mt.files_metadata): + param_dtypes: set[torch.dtype] = { + _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] + for file_mt in files_mt.values() + for dtype_str in file_mt.parameter_count + if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE + } + + if param_dtypes: + return common_broadcastable_dtype(param_dtypes) + + if config_dtype is None: + config_dtype = torch.float32 + + return config_dtype + + +def _resolve_auto_dtype( + model_type: str, + config_dtype: torch.dtype, + *, + is_pooling_model: bool, +): + from vllm.platforms import current_platform + + supported_dtypes = [ + dtype for dtype in current_platform.supported_dtypes + if _is_valid_dtype(model_type, dtype) + ] + + if is_pooling_model and torch.float16 in supported_dtypes: + preferred_dtype = torch.float16 + else: + preferred_dtype = supported_dtypes[0] + + # Downcast for float32 models + if config_dtype == torch.float32: + config_dtype = preferred_dtype + + if config_dtype in supported_dtypes: + return config_dtype + + # Ensure device compatibility + device_name = current_platform.get_device_name() + device_capability = current_platform.get_device_capability() + + if device_capability is None: + device_str = f"{device_name!r}" + else: + version_str = device_capability.as_version_str() + device_str = f"{device_name!r} (with compute capability {version_str})" + + logger.warning( + "Your device %s doesn't support %s. " + "Falling back to %s for compatibility.", + device_str, + config_dtype, + preferred_dtype, + ) + + return preferred_dtype + + +def _get_and_verify_dtype( + model_id: str, + config: PretrainedConfig, + dtype: Union[str, torch.dtype], + *, + is_pooling_model: bool, + revision: Optional[str] = None, +) -> torch.dtype: + config_dtype = _find_dtype(model_id, config, revision=revision) + model_type = config.model_type + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + # Set default dtype from model config + torch_dtype = _resolve_auto_dtype( + model_type, + config_dtype, + is_pooling_model=is_pooling_model, + ) + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype!r}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + _check_valid_dtype(model_type, torch_dtype) + + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) + + return torch_dtype + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + max_model_len: Optional[int], + disable_sliding_window: bool, + sliding_window_len: Optional[Union[int, list[Optional[int]]]], + spec_target_max_model_len: Optional[int] = None, + encoder_config: Optional[Any] = None, +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Whisper + "max_target_positions", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + # Choose the smallest "max_length" from the possible keys. + max_len_key = None + for key in possible_keys: + max_len = getattr(hf_config, key, None) + if max_len is not None: + max_len_key = key if max_len < derived_max_model_len \ + else max_len_key + derived_max_model_len = min(derived_max_model_len, max_len) + # For Command-R / Cohere, Cohere2 / Aya Vision models + if tmp_max_len := getattr(hf_config, "model_max_length", None): + max_len_key = "model_max_length" + derived_max_model_len = tmp_max_len + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + if disable_sliding_window and sliding_window_len is not None: + + sliding_window_len_min = get_min_sliding_window(sliding_window_len) + max_len_key = "sliding_window" \ + if sliding_window_len_min < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, + sliding_window_len_min) + + # If none of the keys were found in the config, use a default and + # log a warning. + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + if spec_target_max_model_len is not None: + # If this is a speculative draft model, we use the max model len + # from the target model. + return spec_target_max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + "%s. Assuming the model's maximum length is %d.", possible_keys, + default_max_len) + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + # NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE + # scaling, so we skip applying the scaling factor again. + if rope_scaling is not None and "gemma3" not in hf_config.model_type: + # No need to consider "type" key because of patch_rope_scaling when + # loading HF config + rope_type = rope_scaling["rope_type"] + + if rope_type not in ("su", "longrope", "llama3"): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") + + # NOTE: rope_type == "default" does not define factor + # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py + scaling_factor = rope_scaling.get("factor", 1.0) + + if rope_type == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + if encoder_config and "max_seq_length" in encoder_config: + derived_max_model_len = encoder_config["max_seq_length"] + + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. + if max_model_len is None: + max_model_len = int(derived_max_model_len) + if current_platform.is_tpu(): + logger.warning( + "--max-model-len is not specified, " + "it's currently using model's default length %s, " + "which might be too large." + "Please input with --max-model-len based on your " + "request input length and output length, to avoid " + "unnecessary degradation.", max_model_len) + elif max_model_len > derived_max_model_len: + # Some models might have a separate key for specifying model_max_length + # that will be bigger than derived_max_model_len. We compare user input + # with model_max_length and allow this override when it's smaller. + model_max_length = getattr(hf_config, "model_max_length", None) + if model_max_length is not None and max_model_len <= model_max_length: + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate.") + else: + msg = ( + f"User-specified max_model_len ({max_model_len}) is greater " + f"than the derived max_model_len ({max_len_key}=" + f"{derived_max_model_len} or model_max_length=" + f"{model_max_length} in model's config.json). This may lead " + "to incorrect model outputs or CUDA errors.") + if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: + logger.warning( + "%s Make sure the value is correct and within the " + "model context size.", msg) + else: + raise ValueError( + f"{msg} To allow overriding this maximum, set " + "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") + return int(max_model_len) + + +def get_min_sliding_window( + sliding_window: Union[int, list[Optional[int]]]) -> int: + if isinstance(sliding_window, list): + return min(s for s in sliding_window if s is not None) + + return sliding_window + + +def get_served_model_name(model: str, + served_model_name: Optional[Union[str, list[str]]]): + """ + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an + empty list, the fallback is to use `self.model`. + """ + if not served_model_name: + return model + if isinstance(served_model_name, list): + return served_model_name[0] + return served_model_name + + +GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer", + "xgrammar", "guidance"] +GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"] +GuidedDecodingBackend = Literal[GuidedDecodingBackendV0, + GuidedDecodingBackendV1] + + +@config +@dataclass +class DecodingConfig: + """Dataclass which contains the decoding strategy of the engine.""" + + @property + @deprecated( + "`guided_decoding_backend` is deprecated and has been renamed to " + "`backend`. This will be removed in v0.10.0. Please use the " + "`backend` argument instead.") + def guided_decoding_backend(self) -> GuidedDecodingBackend: + return self.backend + + @guided_decoding_backend.setter + def guided_decoding_backend(self, value: GuidedDecodingBackend): + self.backend = value + + backend: GuidedDecodingBackend = "auto" if envs.VLLM_USE_V1 else "xgrammar" + """Which engine will be used for guided decoding (JSON schema / regex etc) + by default. With "auto", we will make opinionated choices based on request + contents and what the backend libraries currently support, so the behavior + is subject to change in each release.""" + + disable_fallback: bool = False + """If `True`, vLLM will not fallback to a different backend on error.""" + + disable_any_whitespace: bool = False + """If `True`, the model will not generate any whitespace during guided + decoding. This is only supported for xgrammar and guidance backends.""" + + disable_additional_properties: bool = False + """If `True`, the `guidance` backend will not use `additionalProperties` + in the JSON schema. This is only supported for the `guidance` backend and + is used to better align its behaviour with `outlines` and `xgrammar`.""" + + reasoning_backend: str = "" + """Select the reasoning parser depending on the model that you're using. + This is used to parse the reasoning content into OpenAI API format.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if ":" in self.backend: + self._extract_backend_options() + + if envs.VLLM_USE_V1: + valid_guided_backends = get_args(GuidedDecodingBackendV1) + else: + valid_guided_backends = get_args(GuidedDecodingBackendV0) + if self.backend not in valid_guided_backends: + raise ValueError(f"Invalid backend '{self.backend}'," + f" must be one of {valid_guided_backends}") + if (self.disable_any_whitespace + and self.backend not in ("xgrammar", "guidance")): + raise ValueError("disable_any_whitespace is only supported for " + "xgrammar and guidance backends.") + if (self.disable_additional_properties and self.backend != "guidance"): + raise ValueError("disable_additional_properties is only supported " + "for the guidance backend.") + + @deprecated( + "Passing guided decoding backend options inside backend in the format " + "'backend:...' is deprecated. This will be removed in v0.10.0. Please " + "use the dedicated arguments '--disable-fallback', " + "'--disable-any-whitespace' and '--disable-additional-properties' " + "instead.") + def _extract_backend_options(self): + """Extract backend options from the backend string.""" + backend, options = self.backend.split(":") + self.backend = cast(GuidedDecodingBackend, backend) + options_set = set(options.strip().split(",")) + if "no-fallback" in options_set: + self.disable_fallback = True + if "disable-any-whitespace" in options_set: + self.disable_any_whitespace = True + if "no-additional-properties" in options_set: + self.disable_additional_properties = True + + +DetailedTraceModules = Literal["model", "worker", "all"] + + +@config +@dataclass +class ObservabilityConfig: + """Configuration for observability - metrics and tracing.""" + + show_hidden_metrics_for_version: Optional[str] = None + """Enable deprecated Prometheus metrics that have been hidden since the + specified version. For example, if a previously deprecated metric has been + hidden since the v0.7.0 release, you use + `--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while + you migrate to new metrics. The metric is likely to be removed completely + in an upcoming release.""" + + @cached_property + def show_hidden_metrics(self) -> bool: + """Check if the hidden metrics should be shown.""" + if self.show_hidden_metrics_for_version is None: + return False + return version._prev_minor_version_was( + self.show_hidden_metrics_for_version) + + otlp_traces_endpoint: Optional[str] = None + """Target URL to which OpenTelemetry traces will be sent.""" + + collect_detailed_traces: Optional[list[DetailedTraceModules]] = None + """It makes sense to set this only if `--otlp-traces-endpoint` is set. If + set, it will collect detailed traces for the specified modules. This + involves use of possibly costly and or blocking operations and hence might + have a performance impact. + + Note that collecting detailed timing information for each request can be + expensive.""" + + @cached_property + def collect_model_forward_time(self) -> bool: + """Whether to collect model forward time for the request.""" + return (self.collect_detailed_traces is not None + and ("model" in self.collect_detailed_traces + or "all" in self.collect_detailed_traces)) + + @cached_property + def collect_model_execute_time(self) -> bool: + """Whether to collect model execute time for the request.""" + return (self.collect_detailed_traces is not None + and ("worker" in self.collect_detailed_traces + or "all" in self.collect_detailed_traces)) + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if (self.collect_detailed_traces is not None + and len(self.collect_detailed_traces) == 1 + and "," in self.collect_detailed_traces[0]): + self._parse_collect_detailed_traces() + + if not is_otel_available() and self.otlp_traces_endpoint is not None: + raise ValueError( + "OpenTelemetry is not available. Unable to configure " + "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " + f"installed. Original error:\n{otel_import_error_traceback}") + + def _parse_collect_detailed_traces(self): + assert isinstance(self.collect_detailed_traces, list) + self.collect_detailed_traces = cast( + list[DetailedTraceModules], + self.collect_detailed_traces[0].split(",")) + + +KVProducer = Literal["kv_producer", "kv_both"] +KVConsumer = Literal["kv_consumer", "kv_both"] +KVRole = Literal[KVProducer, KVConsumer] + + +@config +@dataclass +class KVTransferConfig: + """Configuration for distributed KV cache transfer.""" + + kv_connector: Optional[str] = None + """The KV connector for vLLM to transmit KV caches between vLLM instances. + """ + + engine_id: Optional[str] = None + """The engine id for KV transfers.""" + + kv_buffer_device: Optional[str] = "cuda" + """The device used by kv connector to buffer the KV cache. + Currently only support 'cuda'.""" + + kv_buffer_size: float = 1e9 + """The buffer size for TorchDistributedConnector. Measured in number of + bytes. Recommended value: 1e9 (about 1GB).""" + + kv_role: Optional[KVRole] = None + """Whether this vLLM instance produces, consumes KV cache, or both. Choices + are 'kv_producer', 'kv_consumer', and 'kv_both'.""" + + kv_rank: Optional[int] = None + """The rank of this vLLM instance in the KV cache transfer. Typical value: + 0 for prefill instance, 1 for decode instance. + Currently only 1P1D is supported.""" + + kv_parallel_size: int = 1 + """The number of parallel instances for KV cache transfer. For + PyNcclConnector, this should be 2.""" + + kv_ip: str = "127.0.0.1" + """The KV connector ip, used to build distributed connection.""" + + kv_port: int = 14579 + """The KV connector port, used to build distributed connection.""" + + kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) + """any extra config that the connector may need.""" + + kv_connector_module_path: Optional[str] = None + """The Python module path to dynamically load the KV connector from. + Only supported in V1.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + if self.engine_id is None: + self.engine_id = str(uuid.uuid4()) + + if self.kv_role is not None and self.kv_role not in get_args(KVRole): + raise ValueError(f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are {get_args(KVRole)}") + + if self.kv_connector is not None and self.kv_role is None: + raise ValueError("Please specify kv_disagg_role when kv_connector " + f"is set, supported roles are {get_args(KVRole)}") + + @property + def is_kv_transfer_instance(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVRole) + + @property + def is_kv_producer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVProducer) + + @property + def is_kv_consumer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVConsumer) + + def get_from_extra_config(self, key, default) -> Any: + return self.kv_connector_extra_config.get(key, default) + + +@config +@dataclass +class KVEventsConfig: + """Configuration for KV event publishing.""" + + enable_kv_cache_events: bool = False + """If True, enable KV cache events for tracking block storage and removal. + Events can be published externally by zmq using the event publisher config. + """ + + publisher: str = "null" + """The publisher to use for publishing kv events. Can be "null", "zmq". + """ + + endpoint: str = "tcp://*:5557" + """The zmq endpoint to use for publishing kv events. + """ + + replay_endpoint: Optional[str] = None + """The zmq endpoint to use for replaying kv events. + """ + + buffer_steps: int = 10_000 + """The number of steps to cache for replay endpoint. Will only save + events from the last N steps for the replay endpoint. + """ + + hwm: int = 100_000 + """The zmq high water mark for the event publisher. After queueing N events, + events will start dropping if the consumer is not keeping up. + """ + + max_queue_size: int = 100_000 + """The maximum number of events to queue while waiting for publishing. + """ + + topic: str = "" + """The topic to use for the event publisher. Consumers can subscribe to + this topic to receive events. + """ + + +class CompilationLevel: + # constants for the levels of the compilation process + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + PIECEWISE = 3 + + +@config +@dataclass +class PassConfig: + """Configuration for custom Inductor passes. + + This is separate from general `CompilationConfig` so that inductor passes + don't all have access to full configuration - that would create a cycle as + the `PassManager` is set as a property of config.""" + + dump_graph_stages: list[str] = field(default_factory=list) + """List of stages for which we want to dump the graph. Each pass defines + its own stages (before, after, maybe in-between).""" + dump_graph_dir: Path = Path(".") + """Directory to dump the graphs.""" + # TODO(luka) better pass enabling system. + enable_fusion: bool = True + """Whether to enable the custom fusion pass.""" + enable_noop: bool = True + """Whether to enable the custom no-op elimination pass.""" + enable_sequence_parallelism: bool = False + """Whether to enable sequence parallelism.""" + enable_async_tp: bool = False + """Whether to enable async TP.""" + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Do not include dump_graph_* in the hash - they don't affect + compilation. + """ + include = { + "enable_fusion", "enable_noop", "enable_sequence_parallelism", + "enable_async_tp" + } + dict_ = {k: v for k, v in asdict(self).items() if k in include} + return InductorPass.hash_dict(dict_) + + def __post_init__(self) -> None: + if not self.enable_noop and self.enable_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "RMSNorm + quant (fp8) fusion might not work") + + +@config +@dataclass +class CompilationConfig: + """Configuration for compilation. It has three parts: + + - Top-level Compilation control: + - [`level`][vllm.config.CompilationConfig.level] + - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path] + - [`cache_dir`][vllm.config.CompilationConfig.cache_dir] + - [`backend`][vllm.config.CompilationConfig.backend] + - [`custom_ops`][vllm.config.CompilationConfig.custom_ops] + - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops] + - CudaGraph capture: + - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph] + - [`cudagraph_capture_sizes`] + [vllm.config.CompilationConfig.cudagraph_capture_sizes] + - [`cudagraph_num_of_warmups`] + [vllm.config.CompilationConfig.cudagraph_num_of_warmups] + - [`cudagraph_copy_inputs`] + [vllm.config.CompilationConfig.cudagraph_copy_inputs] + - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph] + - Inductor compilation: + - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] + - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] + - [`inductor_compile_config`] + [vllm.config.CompilationConfig.inductor_compile_config] + - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] + - custom inductor passes + + Why we have different sizes for cudagraph and inductor: + - cudagraph: a cudagraph captured for a specific size can only be used + for the same size. We need to capture all the sizes we want to use. + - inductor: a graph compiled by inductor for a general shape can be used + for different sizes. Inductor can also compile for specific sizes, + where it can have more information to optimize the graph with fully + static shapes. However, we find the general shape compilation is + sufficient for most cases. It might be beneficial to compile for + certain small batchsizes, where inductor is good at optimizing. + """ + # Top-level Compilation control + level: int = 0 + """The level of compilation: + + - 0: no compilation. + - 1: dynamo as is. + - 2: dynamo once. + - 3: piecewise compilation.""" + debug_dump_path: str = "" + """The path to dump the debug information.""" + cache_dir: str = "" + """The directory to store the compiled graph, to accelerate Inductor + compilation. By default, it will use model-related information to generate + a cache directory.""" + backend: str = "" + """The backend for compilation. It needs to be a string: + + - "" (empty string): use the default backend. + - "eager"/"openxla"/...: use the specified backend registered in PyTorch. + - "full.module.name": a qualified name which can be used to import the + + backend function. + We use string to avoid serialization issues when using compilation in a + distributed setting. When the compilation level is 1 or 2, the backend is + used for the compilation directly (it sees the whole graph). When the + compilation level is 3, the backend is used for the piecewise compilation + (it sees a part of the graph).""" + custom_ops: list[str] = field(default_factory=list) + """Fine-grained control over which custom ops to enable/disable. Use 'all' + to enable all, 'none' to disable all. Also specify a list of custom op + names to enable (prefixed with a '+'), or disable (prefixed with a '-'). + Examples: + + - 'all,-op1' to enable all except op1 + - 'none,+op1,+op2' to enable only op1 and op2 + + By default, all custom ops are enabled when running without Inductor and + disabled when running with Inductor (compile_level >= Inductor).""" + splitting_ops: list[str] = field(default_factory=list) + """A list of ops to split the full graph into subgraphs, used in piecewise + compilation.""" + + # Inductor capture + use_inductor: bool = True + """Whether to use inductor compilation: + + - False: inductor compilation is not used. graph runs in eager. + - True: inductor compilation is used. one graph for symbolic shape + is compiled. In addition, compile for compile_sizes, + using configurations in inductor_compile_config.""" + compile_sizes: Optional[list[Union[int, str]]] = None + """Sizes to compile for inductor. In addition + to integers, it also supports "cudagraph_capture_sizes" to + specify the sizes for cudagraph capture.""" + inductor_compile_config: dict = field(default_factory=dict) + """Additional configurations for inductor. + - None: use default configurations.""" + inductor_passes: dict[str, str] = field(default_factory=dict) + """Additional passes for inductor. It is a dictionary + from pass name to pass function qualified name. We use function + name because the config uses JSON format. If we pass the config + from Python, functions can also be passed directly via Python object + constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" + + # CudaGraph compilation + use_cudagraph: bool = envs.VLLM_USE_V1 + """Whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + In the vLLM V1 Engine, this flag only applies for + CompilationLevel.PIECEWISE (aka -O3). + Note that this is orthogonal to the cudagraph capture logic + outside of compilation. + TODO: move outside cudagraph logic into compilation. + torch.compile will handle cudagraph capture logic in the future.""" + cudagraph_num_of_warmups: int = 0 + """Number of warmup runs for cudagraph. + It means the first several runs will be treated as warmup runs. + Only after that, the execution will be recorded, and the recorded + cudagraph will be used for subsequent runs.""" + cudagraph_capture_sizes: Optional[list[int]] = None + """Sizes to capture cudagraph. + - None (default): capture sizes are inferred from vllm config. + - list[int]: capture sizes are specified as given.""" + cudagraph_copy_inputs: bool = False + """Whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. Default is False.""" + full_cuda_graph: bool = False + """whether to use a full cuda graph for the entire forward pass rather than + splitting certain operations such as attention into subgraphs. Thus this + flag cannot be used together with splitting_ops. This may provide + performance benefits for smaller models.""" + + pass_config: PassConfig = field(default_factory=PassConfig) + """Custom inductor passes, see PassConfig for more details""" + + max_capture_size: int = field(default=None, init=False) # type: ignore + """not configurable, computed after init""" + local_cache_dir: str = field(default=None, init=False) # type: ignore + """local cache dir for each rank""" + bs_to_padded_graph_size: list[int] = field( + default=None, # type: ignore + init=False) + """optimization: + Intuitively, bs_to_padded_graph_size should be dict[int, int]. + since we know all keys are in a range [0, max_capture_size], + we can optimize it to list[int] for better lookup performance.""" + + # keep track of enabled and disabled custom ops + enabled_custom_ops: Counter[str] = field(default_factory=Counter, + init=False) + """custom ops that are enabled""" + disabled_custom_ops: Counter[str] = field(default_factory=Counter, + init=False) + """custom ops that are disabled""" + traced_files: set[str] = field(default_factory=set, init=False) + """files that are traced for compilation""" + compilation_time: float = field(default=0.0, init=False) + """time taken for compilation""" + + static_forward_context: dict[str, Any] = field(default_factory=dict, + init=False) + """Per-model forward context + Map from layer name to layer objects that need to be accessed outside + model code, e.g., Attention, FusedMOE when dp_size>1.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.level) + factors.append(self.backend) + factors.append(self.custom_ops) + factors.append(self.splitting_ops) + factors.append(self.use_inductor) + factors.append(self.inductor_compile_config) + factors.append(self.inductor_passes) + factors.append(self.pass_config.uuid()) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def __repr__(self) -> str: + exclude = { + "static_forward_context": True, + "enabled_custom_ops": True, + "disabled_custom_ops": True, + "compilation_time": True, + "bs_to_padded_graph_size": True, + "pass_config": True, + "traced_files": True, + "inductor_compile_config": { + "post_grad_custom_post_pass": True, + }, + } + # The cast to string is necessary because Pydantic is mocked in docs + # builds and sphinx-argparse doesn't know the return type of decode() + return str( + TypeAdapter(CompilationConfig).dump_json( + self, + exclude=exclude, # type: ignore[arg-type] + exclude_unset=True).decode()) + + __str__ = __repr__ + + @classmethod + def from_cli(cls, cli_value: str) -> "CompilationConfig": + """Parse the CLI value for the compilation config.""" + if cli_value in ["0", "1", "2", "3"]: + return cls(level=int(cli_value)) + return TypeAdapter(CompilationConfig).validate_json(cli_value) + + def __post_init__(self) -> None: + count_none = self.custom_ops.count("none") + count_all = self.custom_ops.count("all") + assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + + # TODO(zou3519/luka): There are 2 issues with auto-functionalization V2: + # 1. A bug in PyTorch, fixed in 2.7: + # https://github.com/pytorch/pytorch/issues/147924 + # 2. Custom passes (fusion) rely on auto-functionalization V1 and don't + # work with V2. Addressing this will take extra engineering effort + # and it is not yet a priority. RFC here: + # https://github.com/vllm-project/vllm/issues/14703 + + if is_torch_equal_or_newer("2.6"): + KEY = 'enable_auto_functionalized_v2' + if KEY not in self.inductor_compile_config: + self.inductor_compile_config[KEY] = False + + for k, v in self.inductor_passes.items(): + if not isinstance(v, str): + assert callable(v), ( + f"pass {k} should be callable or a qualified name") + self.inductor_compile_config[k] = v if isinstance( + v, InductorPass) else CallableInductorPass(v) + continue + + # resolve function from qualified name + names = v.split(".") + module = ".".join(names[:-1]) + func_name = names[-1] + func = __import__(module).__dict__[func_name] + self.inductor_compile_config[k] = func if isinstance( + func, InductorPass) else CallableInductorPass(func) + + if isinstance(self.pass_config, dict): + self.pass_config = PassConfig(**self.pass_config) + + def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: + if self.level == CompilationLevel.NO_COMPILATION: + raise ValueError("No compilation level is set.") + + from torch._dynamo.backends.registry import list_backends + torch_backends = list_backends(exclude_tags=tuple()) + if self.level in [ + CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE + ]: + if self.backend == "": + return "eager" + if self.backend in torch_backends: + return self.backend + return resolve_obj_by_qualname(self.backend) + + # TODO: pass user-specified backend to piecewise compilation + # merge with the config use_inductor + assert self.level == CompilationLevel.PIECEWISE + + from vllm.compilation.backends import VllmBackend + return VllmBackend(vllm_config) + + def init_with_cudagraph_sizes(self, + cudagraph_capture_sizes: list[int]) -> None: + """To complete the initialization of config, + we need to know the cudagraph sizes.""" + + if self.cudagraph_capture_sizes is None: + self.cudagraph_capture_sizes = cudagraph_capture_sizes + else: + # de-duplicate the sizes provided by the config + dedup_sizes = list(set(self.cudagraph_capture_sizes)) + if len(dedup_sizes) < len(self.cudagraph_capture_sizes): + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + cudagraph_capture_sizes, dedup_sizes) + self.cudagraph_capture_sizes = dedup_sizes + + computed_compile_sizes = [] + if self.compile_sizes is not None: + # de-duplicate the sizes provided by the config + self.compile_sizes = list(set(self.compile_sizes)) + for x in self.compile_sizes: + if isinstance(x, str): + assert x == "cudagraph_capture_sizes", \ + "Unrecognized size type in compile_sizes, " \ + f"expect 'cudagraph_capture_sizes', got {x}" + computed_compile_sizes.extend(self.cudagraph_capture_sizes) + else: + assert isinstance(x, int) + computed_compile_sizes.append(x) + self.compile_sizes = computed_compile_sizes # type: ignore + + # sort to make sure cudagraph capture sizes are in descending order + self.cudagraph_capture_sizes.sort(reverse=True) + self.max_capture_size = self.cudagraph_capture_sizes[ + 0] if self.cudagraph_capture_sizes else 0 + + # pre-compute the mapping from batch size to padded graph size + self.bs_to_padded_graph_size = [ + 0 for i in range(self.max_capture_size + 1) + ] + for end, start in zip(self.cudagraph_capture_sizes, + self.cudagraph_capture_sizes[1:] + [0]): + for bs in range(start, end): + if bs == start: + self.bs_to_padded_graph_size[bs] = start + else: + self.bs_to_padded_graph_size[bs] = end + self.bs_to_padded_graph_size[ + self.max_capture_size] = self.max_capture_size + + def set_splitting_ops_for_v1(self): + # NOTE: this function needs to be called + if self.splitting_ops and self.full_cuda_graph: + raise ValueError("full_cuda_graph cannot be used together with " + "splitting_ops, as Full CUDA graph will override " + f"the splitting_ops: {self.splitting_ops}") + + if not self.splitting_ops: + self.splitting_ops = [] if self.full_cuda_graph else [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + ] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class VllmConfig: + """Dataclass which contains all vllm-related configuration. This + simplifies passing around the distinct configurations in the codebase. + """ + + # TODO: use default_factory once default constructing ModelConfig doesn't + # try to download a model + model_config: ModelConfig = None # type: ignore + """Model configuration.""" + cache_config: CacheConfig = field(default_factory=CacheConfig) + """Cache configuration.""" + parallel_config: ParallelConfig = field(default_factory=ParallelConfig) + """Parallel configuration.""" + scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig) + """Scheduler configuration.""" + device_config: DeviceConfig = field(default_factory=DeviceConfig) + """Device configuration.""" + load_config: LoadConfig = field(default_factory=LoadConfig) + """Load configuration.""" + lora_config: Optional[LoRAConfig] = None + """LoRA configuration.""" + speculative_config: Optional[SpeculativeConfig] = None + """Speculative decoding configuration.""" + decoding_config: DecodingConfig = field(default_factory=DecodingConfig) + """Decoding configuration.""" + observability_config: Optional[ObservabilityConfig] = None + """Observability configuration.""" + prompt_adapter_config: Optional[PromptAdapterConfig] = None + """Prompt adapter configuration.""" + quant_config: Optional[QuantizationConfig] = None + """Quantization configuration.""" + compilation_config: CompilationConfig = field( + default_factory=CompilationConfig) + """`torch.compile` configuration for the model. + + When it is a number (0, 1, 2, 3), it will be interpreted as the + optimization level. + + NOTE: level 0 is the default level without any optimization. level 1 and 2 + are for internal testing only. level 3 is the recommended level for + production. + + Following the convention of traditional compilers, using `-O` without space + is also supported. `-O3` is equivalent to `-O 3`. + + You can specify the full compilation config like so: + `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` + """ + kv_transfer_config: Optional[KVTransferConfig] = None + """The configurations for distributed KV cache transfer.""" + kv_events_config: Optional[KVEventsConfig] = None + """The configurations for event publishing.""" + # some opaque config, only used to provide additional information + # for the hash computation, mainly used for testing, debugging or out of + # tree config registration. + additional_config: Union[dict, SupportsHash] = field(default_factory=dict) + """Additional config for specified platform. Different platforms may + support different configs. Make sure the configs are valid for the platform + you are using. Contents must be hashable.""" + instance_id: str = "" + """The ID of the vLLM instance.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + + # summarize vllm config + vllm_factors: list[Any] = [] + from vllm import __version__ + vllm_factors.append(__version__) + vllm_factors.append(envs.VLLM_USE_V1) + vllm_factors.append(envs.MACA_VLLM_USE_TN_2_NN) + if self.model_config: + vllm_factors.append(self.model_config.compute_hash()) + else: + vllm_factors.append("None") + if self.cache_config: + vllm_factors.append(self.cache_config.compute_hash()) + else: + vllm_factors.append("None") + if self.parallel_config: + vllm_factors.append(self.parallel_config.compute_hash()) + else: + vllm_factors.append("None") + if self.scheduler_config: + vllm_factors.append(self.scheduler_config.compute_hash()) + else: + vllm_factors.append("None") + if self.device_config: + vllm_factors.append(self.device_config.compute_hash()) + else: + vllm_factors.append("None") + if self.load_config: + vllm_factors.append(self.load_config.compute_hash()) + else: + vllm_factors.append("None") + if self.lora_config: + vllm_factors.append(self.lora_config.compute_hash()) + # LoRA creates static buffers based on max_num_batched_tokens. + # The tensor sizes and strides get captured in the torch.compile + # graph explicitly. + vllm_factors.append( + str(self.scheduler_config.max_num_batched_tokens)) + else: + vllm_factors.append("None") + if self.speculative_config: + vllm_factors.append(self.speculative_config.compute_hash()) + else: + vllm_factors.append("None") + if self.decoding_config: + vllm_factors.append(self.decoding_config.compute_hash()) + else: + vllm_factors.append("None") + if self.observability_config: + vllm_factors.append(self.observability_config.compute_hash()) + else: + vllm_factors.append("None") + if self.prompt_adapter_config: + vllm_factors.append(self.prompt_adapter_config.compute_hash()) + else: + vllm_factors.append("None") + if self.quant_config: + pass # should be captured by model_config.quantization + if self.compilation_config: + vllm_factors.append(self.compilation_config.compute_hash()) + else: + vllm_factors.append("None") + if self.kv_transfer_config: + vllm_factors.append(self.kv_transfer_config.compute_hash()) + else: + vllm_factors.append("None") + if self.additional_config: + if isinstance(additional_config := self.additional_config, dict): + additional_config_hash = hashlib.md5( + json.dumps(additional_config, sort_keys=True).encode(), + usedforsecurity=False, + ).hexdigest() + else: + additional_config_hash = additional_config.compute_hash() + vllm_factors.append(additional_config_hash) + else: + vllm_factors.append("None") + factors.append(vllm_factors) + + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def pad_for_cudagraph(self, batch_size: int) -> int: + # if batch_size > self.compilation_config.max_capture_size, + # it should raise an IndexError. + # the caller should make sure the batch_size is within the range, + # i.e., batch_size <= self.compilation_config.max_capture_size + return self.compilation_config.bs_to_padded_graph_size[batch_size] + + @staticmethod + def _get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + from vllm.platforms import current_platform + if model_config.quantization is not None: + from vllm.model_executor.model_loader.weight_utils import ( + get_quant_config) + quant_config = get_quant_config(model_config, load_config) + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + """ + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. Minimum " + f"capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + """ + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + return quant_config + return None + + @staticmethod + def get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + import copy + + # For some reason, the _ version of this modifies the model_config + # object, so using deepcopy to avoid this problem. + return VllmConfig._get_quantization_config(copy.deepcopy(model_config), + load_config) + + def with_hf_config( + self, + hf_config: PretrainedConfig, + architectures: Optional[list[str]] = None, + ) -> "VllmConfig": + if architectures is not None: + hf_config = copy.deepcopy(hf_config) + hf_config.architectures = architectures + + model_config = copy.deepcopy(self.model_config) + model_config.hf_config = hf_config + + return replace(self, model_config=model_config) + + def __post_init__(self): + """Verify configs are valid & consistent with each other. + """ + if self.model_config is not None: + self.model_config.verify_async_output_proc(self.parallel_config, + self.speculative_config, + self.device_config) + self.model_config.verify_with_parallel_config(self.parallel_config) + self.model_config.verify_dual_chunk_attention_config( + self.load_config) + + self.cache_config.verify_with_parallel_config(self.parallel_config) + + if self.lora_config is not None: + self.lora_config.verify_with_cache_config(self.cache_config) + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_lora_support() + if self.prompt_adapter_config is not None: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) + + if self.quant_config is None and self.model_config is not None: + self.quant_config = VllmConfig._get_quantization_config( + self.model_config, self.load_config) + + from vllm.platforms import current_platform + if self.model_config is not None and \ + self.scheduler_config.chunked_prefill_enabled and \ + self.model_config.dtype == torch.float32 and \ + current_platform.get_device_capability() == (7, 5): + logger.warning_once( + "Turing devices tensor cores do not support float32 matmul. " + "To workaround this limitation, vLLM will set 'ieee' input " + "precision for chunked prefill triton kernels.") + + # async tp is built on top of sequence parallelism + # and requires it to be enabled. + if self.compilation_config.pass_config.enable_async_tp: + self.compilation_config.pass_config.enable_sequence_parallelism = \ + True + if self.compilation_config.pass_config.enable_sequence_parallelism: + self.compilation_config.custom_ops.append("+rms_norm") + if envs.VLLM_USE_V1 and self.model_config is not None and \ + not self.model_config.enforce_eager: + # FIXME(rob): Add function to set all of these. + if not self.compilation_config.custom_ops: + self.compilation_config.custom_ops = ["none"] + self.compilation_config.cudagraph_num_of_warmups = 1 + self.compilation_config.pass_config.enable_fusion = False + self.compilation_config.pass_config.enable_noop = False + self.compilation_config.level = CompilationLevel.PIECEWISE + self.compilation_config.set_splitting_ops_for_v1() + + self._set_cudagraph_sizes() + + if self.cache_config.cpu_offload_gb > 0 and \ + self.compilation_config.level != CompilationLevel.NO_COMPILATION \ + and not envs.VLLM_USE_V1: + logger.warning( + "CPU offload is not supported with `torch.compile` in v0 yet." + " Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + if ((not envs.VLLM_USE_V1) and self.lora_config is not None + and self.compilation_config.level + != CompilationLevel.NO_COMPILATION): + logger.warning( + "LoRA for V0 is not supported with `torch.compile` yet. " + "Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + if self.compilation_config.full_cuda_graph and \ + not self.model_config.disable_cascade_attn: + logger.warning_once( + "full_cuda_graph is not supported with " + "cascade attention. Disabling cascade attention.") + self.model_config.disable_cascade_attn = True + self.cache_config.enable_prefix_caching = False + + if (self.kv_events_config is not None + and self.kv_events_config.enable_kv_cache_events + and not self.cache_config.enable_prefix_caching): + logger.warning( + "KV cache events are on, but prefix caching is not enabled." + "Use --enable-prefix-caching to enable.") + if (self.kv_events_config is not None + and self.kv_events_config.publisher != "null" + and not self.kv_events_config.enable_kv_cache_events): + logger.warning("KV cache events are disabled," + "but the scheduler is configured to publish them." + "Modify KVEventsConfig.enable_kv_cache_events" + "to True to enable.") + current_platform.check_and_update_config(self) + + if not self.instance_id: + self.instance_id = random_uuid()[:5] + + if (envs.VLLM_USE_V1 + and not self.scheduler_config.disable_hybrid_kv_cache_manager): + # logger should only print warning message for hybrid models. As we + # can't know whether the model is hybrid or not now, so we don't log + # warning message here and will log it later. + if not (current_platform.is_cuda() or current_platform.is_rocm()): + # Hybrid KV cache manager is not supported on non-GPU platforms. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.kv_transfer_config is not None: + # Hybrid KV cache manager is not compatible with KV transfer. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.kv_events_config is not None: + # Hybrid KV cache manager is not compatible with KV events. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + + def update_sizes_for_sequence_parallelism(self, + possible_sizes: list) -> list: + # remove the sizes that not multiple of tp_size when + # enable sequence parallelism + removed_sizes = [ + size for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size != 0 + ] + if removed_sizes: + logger.warning( + "Batch sizes %s are removed because they are not " + "multiple of tp_size %d when " + "sequence parallelism is enabled", removed_sizes, + self.parallel_config.tensor_parallel_size) + + return [ + size for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size == 0 + ] + + def _set_cudagraph_sizes(self): + """ + cudagraph batchsize padding logic: + + `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible + batch sizes that cudagraph will capture. + + Depending on the engine's configuration of `max_num_seqs`, the + candidate batch sizes to capture cudagraph will shrink to the subset + which just cover the range of `[1, max_num_seqs]`. In the common case, + `max_num_seqs` is 256, and the cudagraph batch sizes will be + `[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`. + + However, if users specify the cudagraph capture sizes through + compilation config, we will use the specified sizes instead. + + In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` + will be the final sizes to capture cudagraph (in descending order). + + During runtime, if batchsize is larger than + `vllm_config.compilation_config.cudagraph_capture_sizes`, + no cudagraph will be used. + If the batch size is no larger than + `vllm_config.compilation_config.cudagraph_capture_sizes`, + we can quickly find the padded graph size for a given batch size by + looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. + """ + + # calculate the default `batch_size_capture_list` + if not envs.VLLM_USE_V1: + batch_size_capture_list = [] + max_batchsize_to_capture = 0 + if self.scheduler_config is not None and \ + self.model_config is not None and \ + not self.model_config.enforce_eager: + + possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)] + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + possible_sizes = self.update_sizes_for_sequence_parallelism( + possible_sizes) + + # find the minimum size that is larger than max_num_seqs, + # which then becomes the max_batchsize_to_capture + larger_sizes = [ + x for x in possible_sizes + if x >= self.scheduler_config.max_num_seqs + ] + if larger_sizes: + max_batchsize_to_capture = larger_sizes[0] + else: + max_batchsize_to_capture = possible_sizes[-1] + + # filter out the sizes that are + # larger than max_batchsize_to_capture + batch_size_capture_list = [ + size for size in possible_sizes + if size <= max_batchsize_to_capture + ] + else: + batch_size_capture_list = [] + if self.model_config is not None and \ + not self.model_config.enforce_eager: + cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes + if len(cuda_graph_sizes) == 1: + batch_size_capture_list = [1, 2, 4] + [ + i for i in range(8, cuda_graph_sizes[0] + 1, 8) + ] + elif len(cuda_graph_sizes) > 1: + batch_size_capture_list = sorted(cuda_graph_sizes) + else: + raise TypeError(f"Invalid value for {cuda_graph_sizes=}.") + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + batch_size_capture_list = \ + self.update_sizes_for_sequence_parallelism(batch_size_capture_list) + max_num_tokens = self.scheduler_config.max_num_batched_tokens + batch_size_capture_list = [ + size for size in batch_size_capture_list + if size <= max_num_tokens + ] + + self.compilation_config.init_with_cudagraph_sizes( + batch_size_capture_list) + + def recalculate_max_model_len(self, max_model_len: int): + model_config = self.model_config + max_model_len = model_config.get_and_verify_max_len(max_model_len) + self.model_config.max_model_len = max_model_len + self.scheduler_config.max_model_len = max_model_len + self.compute_hash() + + def __str__(self): + return ( + f"model={self.model_config.model!r}," + f" speculative_config={self.speculative_config!r}," + f" tokenizer={self.model_config.tokenizer!r}, " + f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}," + f" tokenizer_mode={self.model_config.tokenizer_mode}, " + f"revision={self.model_config.revision}, " + f"override_neuron_config={self.model_config.override_neuron_config}," + f" tokenizer_revision={self.model_config.tokenizer_revision}, " + f"trust_remote_code={self.model_config.trust_remote_code}, " + f"dtype={self.model_config.dtype}, " + f"max_seq_len={self.model_config.max_model_len}," + f" download_dir={self.load_config.download_dir!r}, " + f"load_format={self.load_config.load_format}, " + f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}," + f" pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa + f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa + f"quantization={self.model_config.quantization}, " + f"enforce_eager={self.model_config.enforce_eager}, " + f"kv_cache_dtype={self.cache_config.cache_dtype}, " + f" device_config={self.device_config.device}, " + f"decoding_config={self.decoding_config!r}, " + f"observability_config={self.observability_config!r}, " + f"seed={self.model_config.seed}, " + f"served_model_name={self.model_config.served_model_name}, " + f"num_scheduler_steps={self.scheduler_config.num_scheduler_steps}, " + f"multi_step_stream_outputs={self.scheduler_config.multi_step_stream_outputs}, " # noqa + f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " + f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa + f"use_async_output_proc={self.model_config.use_async_output_proc}, " + f"pooler_config={self.model_config.pooler_config!r}, " + f"compilation_config={self.compilation_config!r}") + + +_current_vllm_config: Optional[VllmConfig] = None + + +@contextmanager +def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): + """ + Temporarily set the current vLLM config. + Used during model initialization. + We save the current vLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the vLLM config to determine how to dispatch. + """ + global _current_vllm_config + old_vllm_config = _current_vllm_config + from vllm.compilation.counter import compilation_counter + num_models_seen = compilation_counter.num_models_seen + try: + _current_vllm_config = vllm_config + yield + except Exception: + raise + else: + logger.debug("enabled custom ops: %s", + vllm_config.compilation_config.enabled_custom_ops) + logger.debug("disabled custom ops: %s", + vllm_config.compilation_config.disabled_custom_ops) + if check_compile and \ + vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ + and compilation_counter.num_models_seen == num_models_seen: + # If the model supports compilation, + # compilation_counter.num_models_seen should be increased + # by at least 1. + # If it is not increased, it means the model does not support + # compilation (does not have @support_torch_compile decorator). + logger.warning( + "`torch.compile` is turned on, but the model %s" + " does not support it. Please open an issue on GitHub" + " if you want it to be supported.", + vllm_config.model_config.model) + finally: + _current_vllm_config = old_vllm_config + + +def get_current_vllm_config() -> VllmConfig: + if _current_vllm_config is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the vllm config. In that case, we set a default + # config. + logger.warning("Current vLLM config is not set.") + from vllm.config import VllmConfig + return VllmConfig() + return _current_vllm_config + + +def contains_object_print(text): + """ + Check if the text looks like a printed Python object, e.g. + contains any substring matching the pattern: "at 0xFFFFFFF>" + We match against 0x followed by 2-16 hex chars (there's + a max of 16 on a 64 bit system). + + Args: + text (str): The text to check + + Returns: + result (bool): `True` if a match is found, `False` otherwise. + """ + pattern = r'at 0x[a-fA-F0-9]{2,16}>' + match = re.search(pattern, text) + return match is not None + + +def assert_hashable(text): + if not contains_object_print(text): + return True + raise AssertionError( + f"vLLM tried to hash some configs that may have Python objects ids " + f"in them. This is a bug, please file an issue. " + f"Text being hashed: {text}") + + +T = TypeVar("T") + + +def get_layers_from_vllm_config(vllm_config: VllmConfig, + layer_type: type[T]) -> dict[str, T]: + return { + layer_name: layer + for layer_name, layer in + vllm_config.compilation_config.static_forward_context.items() + if isinstance(layer, layer_type) + } diff --git a/connections.py b/connections.py new file mode 100644 index 0000000..103505e --- /dev/null +++ b/connections.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Mapping, MutableMapping +from pathlib import Path +from typing import Optional +from urllib.parse import urlparse + +import aiohttp +import requests + +from vllm.version import __version__ as VLLM_VERSION + + +class HTTPConnection: + """Helper class to send HTTP requests.""" + + def __init__(self, *, reuse_client: bool = True) -> None: + super().__init__() + + self.reuse_client = reuse_client + + self._sync_client: Optional[requests.Session] = None + self._async_client: Optional[aiohttp.ClientSession] = None + + def get_sync_client(self) -> requests.Session: + if self._sync_client is None or not self.reuse_client: + self._sync_client = requests.Session() + + return self._sync_client + + # NOTE: We intentionally use an async function even though it is not + # required, so that the client is only accessible inside async event loop + async def get_async_client(self) -> aiohttp.ClientSession: + if self._async_client is None or not self.reuse_client: + self._async_client = aiohttp.ClientSession(trust_env=True) + + return self._async_client + + def _validate_http_url(self, url: str): + parsed_url = urlparse(url) + + if parsed_url.scheme not in ("http", "https"): + raise ValueError("Invalid HTTP URL: A valid HTTP URL " + "must have scheme 'http' or 'https'.") + + def _headers(self, **extras: str) -> MutableMapping[str, str]: + return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras} + + def get_response( + self, + url: str, + *, + stream: bool = False, + timeout: Optional[float] = None, + extra_headers: Optional[Mapping[str, str]] = None, + ): + self._validate_http_url(url) + + client = self.get_sync_client() + extra_headers = extra_headers or {} + + return client.get(url, + headers=self._headers(**extra_headers), + stream=stream, + timeout=timeout) + + async def get_async_response( + self, + url: str, + *, + timeout: Optional[float] = None, + extra_headers: Optional[Mapping[str, str]] = None, + ): + self._validate_http_url(url) + + client = await self.get_async_client() + extra_headers = extra_headers or {} + + return client.get(url, + headers=self._headers(**extra_headers), + timeout=timeout) + + def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.content + + async def async_get_bytes( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> bytes: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.read() + + def get_text(self, url: str, *, timeout: Optional[float] = None) -> str: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.text + + async def async_get_text( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> str: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.text() + + def get_json(self, url: str, *, timeout: Optional[float] = None) -> str: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.json() + + async def async_get_json( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> str: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.json() + + def download_file( + self, + url: str, + save_path: Path, + *, + timeout: Optional[float] = None, + chunk_size: int = 128, + ) -> Path: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + with save_path.open("wb") as f: + for chunk in r.iter_content(chunk_size): + f.write(chunk) + + return save_path + + async def async_download_file( + self, + url: str, + save_path: Path, + *, + timeout: Optional[float] = None, + chunk_size: int = 128, + ) -> Path: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + with save_path.open("wb") as f: + async for chunk in r.content.iter_chunked(chunk_size): + f.write(chunk) + + return save_path + + +global_http_connection = HTTPConnection() +""" +The global [`HTTPConnection`][vllm.connections.HTTPConnection] instance used +by vLLM. +""" diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/block/__init__.py b/core/block/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/block/block_table.py b/core/block/block_table.py new file mode 100644 index 0000000..444bb25 --- /dev/null +++ b/core/block/block_table.py @@ -0,0 +1,399 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import List, Optional + +from vllm.core.block.common import BlockList +from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator +from vllm.utils import Device, cdiv, chunk_list + + +class BlockTable: + """A class to manage blocks for a specific sequence. + + The BlockTable maps a sequence of tokens to a list of blocks, where each + block represents a contiguous memory allocation for a portion of the + sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is + responsible for allocating and freeing memory for the blocks. + + Args: + block_size (int): The maximum number of tokens that can be stored in a + single block. + block_allocator (DeviceAwareBlockAllocator): The block allocator used to + manage memory for the blocks. + _blocks (Optional[List[Block]], optional): An optional list of existing + blocks to initialize the BlockTable with. If not provided, an empty + BlockTable is created. + max_block_sliding_window (Optional[int], optional): The number of + blocks to keep around for each sequence. If None, all blocks + are kept (eg., when sliding window is not used). + It should at least fit the sliding window size of the model. + + Attributes: + _block_size (int): The maximum number of tokens that can be stored in a + single block. + _allocator (DeviceAwareBlockAllocator): The block allocator used to + manage memory for the blocks. + _blocks (Optional[List[Block]]): The list of blocks managed by this + BlockTable. + _num_full_slots (int): The number of tokens currently stored in the + blocks. + """ + + def __init__( + self, + block_size: int, + block_allocator: DeviceAwareBlockAllocator, + _blocks: Optional[List[Block]] = None, + max_block_sliding_window: Optional[int] = None, + ): + self._block_size = block_size + self._allocator = block_allocator + if _blocks is None: + _blocks = [] + self._blocks: BlockList = BlockList(_blocks) + + self._max_block_sliding_window = max_block_sliding_window + self._num_full_slots = self._get_num_token_ids() + + @staticmethod + def get_num_required_blocks(token_ids: List[int], + block_size: int, + num_lookahead_slots: int = 0) -> int: + """Calculates the minimum number of blocks required to store a given + sequence of token IDs along with any look-ahead slots that may be + required (like in multi-step + chunked-prefill). + + This assumes worst-case scenario, where every block requires a new + allocation (e.g. ignoring prefix caching). + + Args: + token_ids (List[int]): The sequence of token IDs to be stored. + block_size (int): The maximum number of tokens that can be stored in + a single block. + num_lookahead_slots (int): look-ahead slots that the sequence may + require. + + Returns: + int: The minimum number of blocks required to store the given + sequence of token IDs along with any required look-ahead slots. + """ + return cdiv(len(token_ids) + num_lookahead_slots, block_size) + + def allocate(self, + token_ids: List[int], + device: Device = Device.GPU, + extra_hash: Optional[int] = None) -> None: + """Allocates memory blocks for storing the given sequence of token IDs. + + This method allocates the required number of blocks to store the given + sequence of token IDs. + + Args: + token_ids (List[int]): The sequence of token IDs to be stored. + device (Device, optional): The device on which the blocks should be + allocated. Defaults to Device.GPU. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefixcaching block. + """ + assert not self._is_allocated + assert token_ids + blocks = self._allocate_blocks_for_token_ids(prev_block=None, + token_ids=token_ids, + device=device, + extra_hash=extra_hash) + self.update(blocks) + self._num_full_slots = len(token_ids) + + def update(self, blocks: List[Block]) -> None: + """Resets the table to the newly provided blocks + (with their corresponding block ids) + """ + self._blocks.update(blocks) + + def append_token_ids(self, + token_ids: List[int], + num_lookahead_slots: int = 0, + num_computed_slots: Optional[int] = None, + extra_hash: Optional[int] = None) -> None: + """Appends a sequence of token IDs to the existing blocks in the + BlockTable. + + This method appends the given sequence of token IDs to the existing + blocks in the BlockTable. If there is not enough space in the existing + blocks, new blocks are allocated using the `ensure_num_empty_slots` + method to accommodate the additional tokens. + + The token IDs are divided into chunks of size `block_size` (except for + the first chunk, which may be smaller), and each chunk is appended to a + separate block. + + Args: + token_ids (List[int]): The sequence of token IDs to be appended. + num_computed_slots (Optional[int]): The number of KV cache slots + that are already filled (computed). + When sliding window is enabled, this is used to compute how many + blocks to drop at the front of the sequence. + Without sliding window, None can be passed. + Without chunked prefill, it should be the same as + _num_full_slots. + extra_hash (Optional[int]): The hash value of additional + factors such as adapters that influence the block, apart + from the token_ids. + """ + assert self._is_allocated, "no blocks have been allocated" + assert len(self._blocks) > 0 + + # Drop blocks that are no longer needed due to sliding window + if self._max_block_sliding_window is not None: + null_block = self._allocator.allocate_or_get_null_block() + assert num_computed_slots is not None + end_block_idx = (num_computed_slots // + self._block_size) - self._max_block_sliding_window + for idx in range(0, end_block_idx): + b = self._blocks[idx] + if b is not null_block: + self._allocator.free(b) + self._blocks[idx] = null_block + + # Ensure there are enough empty slots for the new tokens plus + # lookahead slots + self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + + num_lookahead_slots, + extra_hash=extra_hash) + + # Update the blocks with the new tokens + first_block_idx = self._num_full_slots // self._block_size + token_blocks = self._chunk_token_blocks_for_append(token_ids) + + for i, token_block in enumerate(token_blocks): + self._blocks.append_token_ids(first_block_idx + i, token_block) + + self._num_full_slots += len(token_ids) + + def ensure_num_empty_slots(self, + num_empty_slots: int, + extra_hash: Optional[int] = None) -> None: + """Ensures that the BlockTable has at least the specified number of + empty slots available. + + This method checks if the BlockTable has enough empty slots (i.e., + available space) to accommodate the requested number of tokens. If not, + it allocates additional blocks on the GPU to ensure that the required + number of empty slots is available. + + Args: + num_empty_slots (int): The minimum number of empty slots required. + extra_hash (Optional[int]): The hash value of additional + factors such as adapters that influence the block, apart + from the token_ids. + """ + # Currently the block table only supports + # appending tokens to GPU blocks. + device = Device.GPU + assert self._is_allocated + + if self._num_empty_slots >= num_empty_slots: + return + + slots_to_allocate = num_empty_slots - self._num_empty_slots + blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) + + for _ in range(blocks_to_allocate): + assert len(self._blocks) > 0 + self._blocks.append( + self._allocator.allocate_mutable_block( + prev_block=self._blocks[-1], + device=device, + extra_hash=extra_hash)) + + def fork(self) -> "BlockTable": + """Creates a new BlockTable instance with a copy of the blocks from the + current instance. + + This method creates a new BlockTable instance with the same block size, + block allocator, and a copy of the blocks from the current instance. The + new BlockTable has its own independent set of blocks, but shares the + same underlying memory allocation with the original BlockTable. + + Returns: + BlockTable: A new BlockTable instance with a copy of the blocks from + the current instance. + """ + assert self._is_allocated + assert len(self._blocks) > 0 + forked_blocks = self._allocator.fork(self._blocks[-1]) + return BlockTable( + block_size=self._block_size, + block_allocator=self._allocator, + _blocks=forked_blocks, + max_block_sliding_window=self._max_block_sliding_window, + ) + + def free(self) -> None: + """Frees the memory occupied by the blocks in the BlockTable. + + This method iterates over all the blocks in the `_blocks` list and calls + the `free` method of the `_allocator` object to release the memory + occupied by each block. After freeing all the blocks, the `_blocks` list + is set to `None`. + """ + for block in self.blocks: + self._allocator.free(block) + self._blocks.reset() + + @property + def physical_block_ids(self) -> List[int]: + """Returns a list of physical block indices for the blocks in the + BlockTable. + + This property returns a list of integers, where each integer represents + the physical block index of a corresponding block in the `_blocks` list. + The physical block index is a unique identifier for the memory location + occupied by the block. + + Returns: + List[int]: A list of physical block indices for the blocks in the + BlockTable. + """ + return self._blocks.ids() + + def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: + """Get the number of "unseen" tokens in the sequence. + + Unseen tokens are tokens in the sequence corresponding to this block + table, but are not yet appended to this block table. + + Args: + sequence_token_ids (List[int]): The list of token ids in the + sequence. + + Returns: + List[int]: The postfix of sequence_token_ids that has not yet been + appended to the block table. + """ + + # Since the block table is append-only, the unseen token ids are the + # ones after the appended ones. + return sequence_token_ids[self.num_full_slots:] + + def _allocate_blocks_for_token_ids( + self, + prev_block: Optional[Block], + token_ids: List[int], + device: Device, + extra_hash: Optional[int] = None) -> List[Block]: + blocks: List[Block] = [] + + block_token_ids = [] + tail_token_ids = [] + for cur_token_ids in chunk_list(token_ids, self._block_size): + if len(cur_token_ids) == self._block_size: + block_token_ids.append(cur_token_ids) + else: + tail_token_ids.append(cur_token_ids) + + if block_token_ids: + blocks.extend( + self._allocator.allocate_immutable_blocks( + prev_block, + block_token_ids=block_token_ids, + device=device, + extra_hash=extra_hash)) + prev_block = blocks[-1] + + if tail_token_ids: + assert len(tail_token_ids) == 1 + cur_token_ids = tail_token_ids[0] + + block = self._allocator.allocate_mutable_block( + prev_block=prev_block, device=device, extra_hash=extra_hash) + block.append_token_ids(cur_token_ids) + + blocks.append(block) + + return blocks + + def _get_all_token_ids(self) -> List[int]: + # NOTE: This function is O(seq_len); use sparingly. + token_ids: List[int] = [] + + if not self._is_allocated: + return token_ids + + for block in self.blocks: + token_ids.extend(block.token_ids) + + return token_ids + + def _get_num_token_ids(self) -> int: + res = 0 + for block in self.blocks: + res += len(block.token_ids) + + return res + + @property + def _is_allocated(self) -> bool: + return len(self._blocks) > 0 + + @property + def blocks(self) -> List[Block]: + return self._blocks.list() + + @property + def _num_empty_slots(self) -> int: + assert self._is_allocated + return len(self._blocks) * self._block_size - self._num_full_slots + + @property + def num_full_slots(self) -> int: + """Returns the total number of tokens currently stored in the + BlockTable. + + Returns: + int: The total number of tokens currently stored in the BlockTable. + """ + return self._num_full_slots + + def get_num_blocks_touched_by_append_slots( + self, token_ids: List[int], num_lookahead_slots: int) -> int: + """Determine how many blocks will be "touched" by appending the token + ids. + + This is required for the scheduler to determine whether a sequence can + continue generation, or if it must be preempted. + """ + # Math below is equivalent to: + # all_token_ids = token_ids + [-1] * num_lookahead_slots + # token_blocks = self._chunk_token_blocks_for_append(all_token_ids) + # return len(token_blocks) + + num_token_ids = len(token_ids) + num_lookahead_slots + first_chunk_size = self._block_size - (self._num_full_slots % + self._block_size) + num_token_blocks = (1 + math.ceil( + (num_token_ids - first_chunk_size) / self._block_size)) + return num_token_blocks + + def _chunk_token_blocks_for_append( + self, token_ids: List[int]) -> List[List[int]]: + """Split the token ids into block-sized chunks so they can be easily + appended to blocks. The first such "token block" may have less token ids + than the block size, since the last allocated block may be partially + full. + + If no token ids are provided, then no chunks are returned. + """ + + if not token_ids: + return [] + + first_chunk_size = self._block_size - (self._num_full_slots % + self._block_size) + token_blocks = [token_ids[:first_chunk_size]] + token_blocks.extend( + chunk_list(token_ids[first_chunk_size:], self._block_size)) + return token_blocks diff --git a/core/block/common.py b/core/block/common.py new file mode 100644 index 0000000..a337007 --- /dev/null +++ b/core/block/common.py @@ -0,0 +1,371 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import deque +from dataclasses import dataclass +from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple + +from vllm.core.block.interfaces import Block, BlockAllocator + +BlockId = int +RefCount = int + + +class RefCounterProtocol(Protocol): + + def incr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def decr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def get(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + +class RefCounter(RefCounterProtocol): + """A class for managing reference counts for a set of block indices. + + The RefCounter class maintains a dictionary that maps block indices to their + corresponding reference counts. It provides methods to increment, decrement, + and retrieve the reference count for a given block index. + + Args: + all_block_indices (Iterable[BlockId]): An iterable of block indices + to initialize the reference counter with. + """ + + def __init__(self, all_block_indices: Iterable[BlockId]): + deduped = set(all_block_indices) + self._refcounts: Dict[BlockId, RefCount] = { + index: 0 + for index in deduped + } + + def incr(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + pre_incr_refcount = self._refcounts[block_id] + + assert pre_incr_refcount >= 0 + + post_incr_refcount = pre_incr_refcount + 1 + self._refcounts[block_id] = post_incr_refcount + return post_incr_refcount + + def decr(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + refcount = self._refcounts[block_id] + + assert refcount > 0 + refcount -= 1 + + self._refcounts[block_id] = refcount + + return refcount + + def get(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + return self._refcounts[block_id] + + def as_readonly(self) -> "ReadOnlyRefCounter": + return ReadOnlyRefCounter(self) + + +class ReadOnlyRefCounter(RefCounterProtocol): + """A read-only view of the RefCounter class. + + The ReadOnlyRefCounter class provides a read-only interface to access the + reference counts maintained by a RefCounter instance. It does not allow + modifications to the reference counts. + + Args: + refcounter (RefCounter): The RefCounter instance to create a read-only + view for. + """ + + def __init__(self, refcounter: RefCounter): + self._refcounter = refcounter + + def incr(self, block_id: BlockId) -> RefCount: + raise ValueError("Incr not allowed") + + def decr(self, block_id: BlockId) -> RefCount: + raise ValueError("Decr not allowed") + + def get(self, block_id: BlockId) -> RefCount: + return self._refcounter.get(block_id) + + +class CopyOnWriteTracker: + """A class for tracking and managing copy-on-write operations for blocks. + + The CopyOnWriteTracker class maintains a mapping of source block indices to + their corresponding copy-on-write destination block indices. It works in + conjunction with a RefCounter. + + Args: + refcounter (RefCounter): The reference counter used to track block + reference counts. + """ + + def __init__(self, refcounter: RefCounterProtocol): + self._copy_on_writes: List[Tuple[BlockId, BlockId]] = [] + self._refcounter = refcounter + + def is_appendable(self, block: Block) -> bool: + """Checks if the block is shared or not. If shared, then it cannot + be appended and needs to be duplicated via copy-on-write + """ + block_id = block.block_id + if block_id is None: + return True + + refcount = self._refcounter.get(block_id) + return refcount <= 1 + + def record_cow(self, src_block_id: Optional[BlockId], + trg_block_id: Optional[BlockId]) -> None: + """Records a copy-on-write operation from source to target block id + Args: + src_block_id (BlockId): The source block id from which to copy + the data + trg_block_id (BlockId): The target block id to which the data + is copied + """ + assert src_block_id is not None + assert trg_block_id is not None + self._copy_on_writes.append((src_block_id, trg_block_id)) + + def clear_cows(self) -> List[Tuple[BlockId, BlockId]]: + """Clears the copy-on-write tracking information and returns the current + state. + + This method returns a list mapping source block indices to + destination block indices for the current copy-on-write operations. + It then clears the internal tracking information. + + Returns: + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices for the + current copy-on-write operations. + """ + cows = self._copy_on_writes + self._copy_on_writes = [] + return cows + + +class BlockPool: + """Used to pre-allocate block objects, in order to avoid excessive python + object allocations/deallocations. + The pool starts from "pool_size" objects and will increase to more objects + if necessary + + Note that multiple block objects may point to the same physical block id, + which is why this pool is needed, so that it will be easier to support + prefix caching and more complicated sharing of physical blocks. + """ + + def __init__(self, block_size: int, create_block: Block.Factory, + allocator: BlockAllocator, pool_size: int): + self._block_size = block_size + self._create_block = create_block + self._allocator = allocator + self._pool_size = pool_size + assert self._pool_size >= 0 + + self._free_ids: Deque[int] = deque(range(self._pool_size)) + self._pool = [] + for i in range(self._pool_size): + self._pool.append( + self._create_block(prev_block=None, + token_ids=[], + block_size=self._block_size, + allocator=self._allocator, + block_id=None, + extra_hash=None)) + + def increase_pool(self): + """Doubles the internal pool size + """ + cur_pool_size = self._pool_size + new_pool_size = cur_pool_size * 2 + self._pool_size = new_pool_size + + self._free_ids += deque(range(cur_pool_size, new_pool_size)) + + for i in range(cur_pool_size, new_pool_size): + self._pool.append( + self._create_block(prev_block=None, + token_ids=[], + block_size=self._block_size, + allocator=self._allocator, + block_id=None, + extra_hash=None)) + + def init_block(self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + physical_block_id: Optional[int], + extra_hash: Optional[int] = None) -> Block: + if len(self._free_ids) == 0: + self.increase_pool() + assert len(self._free_ids) > 0 + + pool_id = self._free_ids.popleft() + + block = self._pool[pool_id] + block.__init__( # type: ignore[misc] + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + allocator=block._allocator, # type: ignore[attr-defined] + block_id=physical_block_id, + extra_hash=extra_hash) + block.pool_id = pool_id # type: ignore[attr-defined] + return block + + def free_block(self, block: Block) -> None: + self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined] + + +class BlockList: + """This class is an optimization to allow fast-access to physical + block ids. It maintains a block id list that is updated with the + block list and this avoids the need to reconstruct the block id + list on every iteration of the block manager + """ + + def __init__(self, blocks: List[Block]): + self._blocks: List[Block] = [] + self._block_ids: List[int] = [] + + self.update(blocks) + + def _add_block_id(self, block_id: Optional[BlockId]) -> None: + assert block_id is not None + self._block_ids.append(block_id) + + def _update_block_id(self, block_index: int, + new_block_id: Optional[BlockId]) -> None: + assert new_block_id is not None + self._block_ids[block_index] = new_block_id + + def update(self, blocks: List[Block]): + self._blocks = blocks + + # Cache block ids for fast query + self._block_ids = [] + for block in self._blocks: + self._add_block_id(block.block_id) + + def append_token_ids(self, block_index: int, token_ids: List[int]) -> None: + block = self._blocks[block_index] + prev_block_id = block.block_id + + block.append_token_ids(token_ids) + + # CoW or promotion may update the internal block_id + if prev_block_id != block.block_id: + self._update_block_id(block_index, block.block_id) + + def append(self, new_block: Block): + self._blocks.append(new_block) + self._add_block_id(new_block.block_id) + + def __len__(self) -> int: + return len(self._blocks) + + def __getitem__(self, block_index: int) -> Block: + return self._blocks[block_index] + + def __setitem__(self, block_index: int, new_block: Block) -> None: + self._blocks[block_index] = new_block + self._update_block_id(block_index, new_block.block_id) + + def reset(self): + self._blocks = [] + self._block_ids = [] + + def list(self) -> List[Block]: + return self._blocks + + def ids(self) -> List[int]: + return self._block_ids + + +@dataclass +class CacheMetricData: + """A utility dataclass to maintain cache metric. + To avoid overflow, we maintain the hit rate in block granularity, so that + we can maintain a single hit rate for n_completed_block x block_size, + and calculate the real time hit rate by the following: + BS = The number of queries per block. + nB = The number of completed blocks. + HR = hit rate of (nB x BS) queries. + Q = current number of queries (< BS). + H = current number of hits (< BS). + hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS) + """ + num_completed_blocks: int = 0 + completed_block_cache_hit_rate: float = 0.0 + num_incompleted_block_queries: int = 0 + num_incompleted_block_hit: int = 0 + block_size: int = 1000 + + def query(self, hit: bool): + self.num_incompleted_block_queries += 1 + self.num_incompleted_block_hit += 1 if hit else 0 + + # When a block is completed, update the cache hit rate + # and reset the incomplete numbers. + if self.num_incompleted_block_queries == self.block_size: + hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + self.completed_block_cache_hit_rate = ( + self.completed_block_cache_hit_rate * self.num_completed_blocks + + hit_rate) / (self.num_completed_blocks + 1) + self.num_incompleted_block_queries = 0 + self.num_incompleted_block_hit = 0 + self.num_completed_blocks += 1 + + def get_hit_rate(self): + incomplete_ratio = self.num_incompleted_block_queries / self.block_size + total_blocks = self.num_completed_blocks + incomplete_ratio + if total_blocks == 0: + return 0.0 + + completed_block_hit, incompleted_block_hit = 0.0, 0.0 + if self.num_completed_blocks > 0: + completed_block_hit = (self.completed_block_cache_hit_rate * + self.num_completed_blocks) + if self.num_incompleted_block_queries > 0: + incompleted_hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio) + return (completed_block_hit + incompleted_block_hit) / total_blocks + + +def get_all_blocks_recursively(last_block: Block) -> List[Block]: + """Retrieves all the blocks in a sequence starting from the last block. + + This function recursively traverses the sequence of blocks in reverse order, + starting from the given last block, and returns a list of all the blocks in + the sequence. + + Args: + last_block (Block): The last block in the sequence. + + Returns: + List[Block]: A list of all the blocks in the sequence, in the order they + appear. + """ + + def recurse(block: Block, lst: List[Block]) -> None: + if block.prev_block is not None: + recurse(block.prev_block, lst) + lst.append(block) + + all_blocks: List[Block] = [] + recurse(last_block, all_blocks) + return all_blocks diff --git a/core/block/cpu_gpu_block_allocator.py b/core/block/cpu_gpu_block_allocator.py new file mode 100644 index 0000000..ea490c3 --- /dev/null +++ b/core/block/cpu_gpu_block_allocator.py @@ -0,0 +1,441 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Dict, FrozenSet, List, Optional, Tuple + +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, + DeviceAwareBlockAllocator) +from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator +from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.platforms import current_platform +from vllm.utils import Device + + +class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): + """A block allocator that can allocate blocks on both CPU and GPU memory. + + This class implements the `DeviceAwareBlockAllocator` interface and provides + functionality for allocating and managing blocks of memory on both CPU and + GPU devices. + + The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU + blocks, and allows for allocation, deallocation, forking, and swapping of + blocks across these memory pools. + """ + + @staticmethod + def create( + allocator_type: str, + num_gpu_blocks: int, + num_cpu_blocks: int, + block_size: int, + ) -> DeviceAwareBlockAllocator: + """Creates a CpuGpuBlockAllocator instance with the specified + configuration. + + This static method creates and returns a CpuGpuBlockAllocator instance + based on the provided parameters. It initializes the CPU and GPU block + allocators with the specified number of blocks, block size, and + allocator type. + + Args: + allocator_type (str): The type of block allocator to use for CPU + and GPU blocks. Currently supported values are "naive" and + "prefix_caching". + num_gpu_blocks (int): The number of blocks to allocate for GPU + memory. + num_cpu_blocks (int): The number of blocks to allocate for CPU + memory. + block_size (int): The size of each block in number of tokens. + + Returns: + DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the + specified configuration. + + Notes: + - The block IDs are assigned contiguously, with GPU block IDs coming + before CPU block IDs. + """ + # For HPU, block id 0 is used only for padding + reserved_blocks = 1 if current_platform.is_hpu() else 0 + block_ids = list( + range(reserved_blocks, num_gpu_blocks + num_cpu_blocks)) + num_gpu_blocks -= reserved_blocks + gpu_block_ids = block_ids[:num_gpu_blocks] + cpu_block_ids = block_ids[num_gpu_blocks:] + + if allocator_type == "naive": + gpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + elif allocator_type == "prefix_caching": + gpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + else: + raise ValueError(f"Unknown allocator type {allocator_type=}") + + return CpuGpuBlockAllocator( + cpu_block_allocator=cpu_allocator, + gpu_block_allocator=gpu_allocator, + ) + + def __init__(self, cpu_block_allocator: BlockAllocator, + gpu_block_allocator: BlockAllocator): + assert not ( + cpu_block_allocator.all_block_ids + & gpu_block_allocator.all_block_ids + ), "cpu and gpu block allocators can't have intersection of block ids" + + self._allocators = { + Device.CPU: cpu_block_allocator, + Device.GPU: gpu_block_allocator, + } + + self._swap_mapping: Dict[int, int] = {} + self._null_block: Optional[Block] = None + + self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} + for _, allocator in self._allocators.items(): + for block_id in allocator.all_block_ids: + self._block_ids_to_allocator[block_id] = allocator + + def allocate_or_get_null_block(self) -> Block: + if self._null_block is None: + self._null_block = NullBlock( + self.allocate_mutable_block(None, Device.GPU)) + return self._null_block + + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Device, + extra_hash: Optional[int] = None) -> Block: + """Allocates a new mutable block on the specified device. + + Args: + prev_block (Optional[Block]): The previous block to in the sequence. + Used for prefix hashing. + device (Device): The device on which to allocate the new block. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix caching block. + + Returns: + Block: The newly allocated mutable block. + """ + return self._allocators[device].allocate_mutable_block( + prev_block, extra_hash=extra_hash) + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + extra_hash: Optional[int] = None) -> List[Block]: + """Allocates a new group of immutable blocks with the provided block + token IDs on the specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + block_token_ids (List[int]): The list of block token IDs to be + stored in the new blocks. + device (Device): The device on which to allocate the new block. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix caching block. + + Returns: + List[Block]: The newly allocated list of immutable blocks + containing the provided block token IDs. + """ + return self._allocators[device].allocate_immutable_blocks( + prev_block, block_token_ids, extra_hash=extra_hash) + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Device, + extra_hash: Optional[int] = None) -> Block: + """Allocates a new immutable block with the provided token IDs on the + specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + token_ids (List[int]): The list of token IDs to be stored in the new + block. + device (Device): The device on which to allocate the new block. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix caching block. + + Returns: + Block: The newly allocated immutable block containing the provided + token IDs. + """ + return self._allocators[device].allocate_immutable_block( + prev_block, token_ids, extra_hash=extra_hash) + + def free(self, block: Block) -> None: + """Frees the memory occupied by the given block. + + Args: + block (Block): The block to be freed. + """ + # Null block should never be freed + if isinstance(block, NullBlock): + return + block_id = block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] + allocator.free(block) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: A new list of blocks that shares the same memory as the + original sequence. + """ + # do not attempt to fork the null block + assert not isinstance(last_block, NullBlock) + block_id = last_block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] + return allocator.fork(last_block) + + def get_num_free_blocks(self, device: Device) -> int: + """Returns the number of free blocks available on the specified device. + + Args: + device (Device): The device for which to query the number of free + blocks. AssertionError is raised if None is passed. + + Returns: + int: The number of free blocks available on the specified device. + """ + return self._allocators[device].get_num_free_blocks() + + def get_num_total_blocks(self, device: Device) -> int: + return self._allocators[device].get_num_total_blocks() + + def get_physical_block_id(self, device: Device, absolute_id: int) -> int: + """Returns the zero-offset block id on certain device given the + absolute block id. + + Args: + device (Device): The device for which to query relative block id. + absolute_id (int): The absolute block id for the block in + whole allocator. + + Returns: + int: The zero-offset block id on certain device. + """ + return self._allocators[device].get_physical_block_id(absolute_id) + + def swap(self, blocks: List[Block], src_device: Device, + dst_device: Device) -> Dict[int, int]: + """Execute the swap for the given blocks from source_device + on to dest_device, save the current swap mapping and append + them to the accumulated `self._swap_mapping` for each + scheduling move. + + Args: + blocks: List of blocks to be swapped. + src_device (Device): Device to swap the 'blocks' from. + dst_device (Device): Device to swap the 'blocks' to. + + Returns: + Dict[int, int]: Swap mapping from source_device + on to dest_device. + """ + src_block_ids = [block.block_id for block in blocks] + self._allocators[src_device].swap_out(blocks) + self._allocators[dst_device].swap_in(blocks) + dst_block_ids = [block.block_id for block in blocks] + + current_swap_mapping: Dict[int, int] = {} + for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids): + if src_block_id is not None and dst_block_id is not None: + self._swap_mapping[src_block_id] = dst_block_id + current_swap_mapping[src_block_id] = dst_block_id + return current_swap_mapping + + def get_num_full_blocks_touched(self, blocks: List[Block], + device: Device) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out the given blocks on to the 'device'. + + Args: + blocks: List of blocks to be swapped. + device (Device): Device to swap the 'blocks' on. + + Returns: + int: the number of full blocks that will be touched by + swapping in/out the given blocks on to the 'device'. + Non full blocks are ignored when deciding the number + of blocks to touch. + """ + return self._allocators[device].get_num_full_blocks_touched(blocks) + + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + """Clears the copy-on-write (CoW) state and returns the mapping of + source to destination block IDs. + + Returns: + List[Tuple[int, int]]: A list mapping source block IDs to + destination block IDs. + """ + # CoW only supported on GPU + device = Device.GPU + return self._allocators[device].clear_copy_on_writes() + + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, only use for prefix caching.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_accessed(block_ids, now) + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + """Mark blocks as accessed, only use for prefix caching.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_computed(block_ids) + + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].get_common_computed_block_ids( + computed_seq_block_ids) + + @property + def all_block_ids(self) -> FrozenSet[int]: + return frozenset(self._block_ids_to_allocator.keys()) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + assert device in self._allocators + return self._allocators[device].get_prefix_cache_hit_rate() + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache for specified or all devices.""" + if device: + return self._allocators[device].reset_prefix_cache() + success = True + for allocator in self._allocators.values(): + success = success and allocator.reset_prefix_cache() + return success + + def get_and_reset_swaps(self) -> List[Tuple[int, int]]: + """Returns and clears the mapping of source to destination block IDs. + Will be called after every swapping operations for now, and after every + schedule when BlockManagerV2 become default. Currently not useful. + + Returns: + List[Tuple[int, int]]: A mapping of source to destination block IDs. + """ + mapping = self._swap_mapping.copy() + self._swap_mapping.clear() + return list(mapping.items()) + + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + device: Device = Device.GPU, + ) -> List[int]: + return self._allocators[device].find_cached_blocks_prefix(block_hashes) + + +class NullBlock(Block): + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + This implementation just wraps an ordinary block and prevents it from + being modified. It also allows for testing if a block is NullBlock + via isinstance(). + """ + + def __init__(self, proxy: Block): + super().__init__() + self._proxy = proxy + + def append_token_ids(self, token_ids: List[BlockId]): + raise ValueError("null block should not be modified") + + @property + def block_id(self): + return self._proxy.block_id + + @block_id.setter + def block_id(self, value: Optional[BlockId]): + raise ValueError("null block should not be modified") + + @property + def token_ids(self) -> List[BlockId]: + return self._proxy.token_ids + + @property + def num_tokens_total(self) -> int: + raise NotImplementedError( + "num_tokens_total is not used for null block") + + @property + def num_empty_slots(self) -> BlockId: + return self._proxy.num_empty_slots + + @property + def is_full(self): + return self._proxy.is_full + + @property + def prev_block(self): + return self._proxy.prev_block + + @property + def extra_hash(self): + return None + + @property + def computed(self): + return self._proxy.computed + + @computed.setter + def computed(self, value): + self._proxy.computed = value + + @property + def last_accessed(self) -> float: + return self._proxy.last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._proxy.last_accessed = last_accessed_ts + + @property + def content_hash(self): + return self._proxy.content_hash diff --git a/core/block/interfaces.py b/core/block/interfaces.py new file mode 100644 index 0000000..1a05881 --- /dev/null +++ b/core/block/interfaces.py @@ -0,0 +1,319 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple + +from vllm.utils import Device + +BlockId = int + + +class Block(ABC): + + @abstractmethod + def append_token_ids(self, token_ids: List[int]) -> None: + pass + + @property + @abstractmethod + def block_id(self) -> Optional[int]: + pass + + @block_id.setter + @abstractmethod + def block_id(self, value: Optional[int]) -> None: + """NOTE: Do not use this API outside Block.""" + self._block_id = value + + @property + @abstractmethod + def token_ids(self) -> List[int]: + pass + + @property + @abstractmethod + def num_tokens_total(self) -> int: + """The number of tokens till the current block (inclusive) + """ + pass + + @property + @abstractmethod + def num_empty_slots(self) -> int: + pass + + @property + @abstractmethod + def is_full(self) -> bool: + pass + + @property + @abstractmethod + def prev_block(self) -> Optional["Block"]: + pass + + @property + @abstractmethod + def extra_hash(self) -> Optional[int]: + return None + + @property + @abstractmethod + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + @abstractmethod + def computed(self, value) -> bool: + """Should be only used by PrefixCacingAllocator""" + raise NotImplementedError + + @property + @abstractmethod + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + @abstractmethod + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + + class Factory(Protocol): + + @abstractmethod + def __call__( + self, + prev_block: Optional["Block"], + token_ids: List[int], + block_size: int, + allocator: "BlockAllocator", + block_id: Optional[int] = None, + computed: bool = False, + extra_hash: Optional[int] = None, + ) -> "Block": + pass + + @property + @abstractmethod + def content_hash(self) -> Optional[int]: + """Return the content-based hash of the current block, or None if it is + not yet defined or not supported. + + For the content-based hash to be defined, the current block must be + full. + """ + return None + + +class BlockAllocator(ABC): + + @abstractmethod + def allocate_mutable_block(self, prev_block: Optional[Block], + extra_hash: Optional[int]) -> Block: + pass + + @abstractmethod + def allocate_immutable_block(self, prev_block: Optional[Block], + token_ids: List[int], + extra_hash: Optional[int]) -> Block: + pass + + @abstractmethod + def allocate_immutable_blocks(self, prev_block: Optional[Block], + block_token_ids: List[List[int]], + extra_hash: Optional[int]) -> List[Block]: + pass + + @abstractmethod + def free(self, block: Block) -> None: + pass + + @abstractmethod + def fork(self, last_block: Block) -> List[Block]: + pass + + @abstractmethod + def get_num_total_blocks(self) -> int: + pass + + @abstractmethod + def get_num_free_blocks(self) -> int: + pass + + @abstractmethod + def get_physical_block_id(self, absolute_id: int) -> int: + pass + + @abstractmethod + def swap_out(self, blocks: List[Block]) -> None: + pass + + @abstractmethod + def swap_in(self, blocks: List[Block]) -> None: + pass + + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + pass + + @abstractmethod + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + pass + + @abstractmethod + def cow_block_if_not_appendable(self, block: Block) -> BlockId: + """NOTE: This should not be used besides Block""" + pass + + @abstractmethod + def promote_to_immutable_block(self, block: Block) -> BlockId: + """NOTE: This should not be used besides Block""" + pass + + @abstractmethod + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + + @abstractmethod + def reset_prefix_cache(self) -> bool: + """Reset prefix cache.""" + pass + + class NoFreeBlocksError(ValueError): + pass + + @abstractmethod + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + ) -> List[int]: + pass + + +class DeviceAwareBlockAllocator(ABC): + + @abstractmethod + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Device, + extra_hash: Optional[int] = None) -> Block: + pass + + @abstractmethod + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Device, + extra_hash: Optional[int] = None) -> Block: + pass + + @abstractmethod + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + extra_hash: Optional[int] = None, + ) -> List[Block]: + pass + + @abstractmethod + def get_num_free_blocks(self, device: Device) -> int: + pass + + @abstractmethod + def get_num_total_blocks(self, device: Device) -> int: + pass + + @abstractmethod + def free(self, block: Block) -> None: + pass + + @abstractmethod + def fork(self, last_block: Block) -> List[Block]: + pass + + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + pass + + @abstractmethod + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + pass + + @abstractmethod + def get_num_full_blocks_touched(self, blocks: List[Block], + device: Device) -> int: + pass + + @abstractmethod + def swap(self, blocks: List[Block], src_device: Device, + dst_device: Device) -> Dict[int, int]: + pass + + @abstractmethod + def get_physical_block_id(self, device: Device, absolute_id: int) -> int: + pass + + @abstractmethod + def allocate_or_get_null_block(self) -> Block: + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + There is at most one null block per allocator. + """ + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + + @abstractmethod + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache.""" + pass + + @abstractmethod + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + device: Device = Device.GPU, + ) -> List[int]: + pass diff --git a/core/block/naive_block.py b/core/block/naive_block.py new file mode 100644 index 0000000..dae6ead --- /dev/null +++ b/core/block/naive_block.py @@ -0,0 +1,466 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import deque +from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union + +from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, + get_all_blocks_recursively) +from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device + +Refcount = int + + +class NaiveBlockAllocator(BlockAllocator): + """A simple block allocator that manages blocks of memory without prefix + caching. + + Args: + create_block (Block.Factory): A factory function for creating new + blocks. This is used when a NaiveBlockAllocator is composed within + a prefix caching allocator -- the naive block allocator must + construct prefix caching blocks (but shouldn't know anything else + about them). + num_blocks (int): The total number of blocks to manage. + block_size (int): The size of each block in tokens. + block_ids (Optional[Iterable[int]], optional): An optional iterable of + block IDs. If not provided, block IDs will be assigned sequentially + from 0 to num_blocks - 1. + """ + + def __init__( + self, + create_block: Block.Factory, + num_blocks: int, + block_size: int, + block_ids: Optional[Iterable[int]] = None, + block_pool: Optional[BlockPool] = None, + ): + if block_ids is None: + block_ids = range(num_blocks) + + self._free_block_indices: Deque[BlockId] = deque(block_ids) + self._all_block_indices = frozenset(block_ids) + assert len(self._all_block_indices) == num_blocks + + self._refcounter = RefCounter( + all_block_indices=self._free_block_indices) + self._block_size = block_size + + self._cow_tracker = CopyOnWriteTracker( + refcounter=self._refcounter.as_readonly()) + + if block_pool is None: + extra_factor = 4 + # Pre-allocate "num_blocks * extra_factor" block objects. + # The "* extra_factor" is a buffer to allow more block objects + # than physical blocks + self._block_pool = BlockPool(self._block_size, create_block, self, + num_blocks * extra_factor) + else: + # In this case, the block pool is provided by the caller, + # which means that there is most likely a need to share + # a block pool between allocators + self._block_pool = block_pool + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> Block: + """Allocates a new immutable block with the given token IDs, linked to + the previous block. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. If + None, then the block to be allocated is the first block in the + sequence. + token_ids (List[int]): The token IDs to be stored in the new block. + + Returns: + Block: The newly allocated immutable block. + """ + assert device is None + block = self.allocate_mutable_block(prev_block=prev_block) + block.append_token_ids(token_ids) + return block + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> List[Block]: + assert device is None + num_blocks = len(block_token_ids) + + block_ids = [] + for i in range(num_blocks): + block_ids.append(self._allocate_block_id()) + + blocks = [] + for i in range(num_blocks): + prev_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block_token_ids[i], + block_size=self._block_size, + physical_block_id=block_ids[i]) + blocks.append(prev_block) + + return blocks + + def allocate_mutable_block(self, + prev_block: Optional[Block], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> Block: + """Allocates a new mutable block, linked to the previous block. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. If + None, then the block to be allocated is the first block in the + sequence. + + Returns: + Block: The newly allocated mutable block. + """ + assert device is None + block_id = self._allocate_block_id() + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=[], + block_size=self._block_size, + physical_block_id=block_id) + return block + + def _allocate_block_id(self) -> BlockId: + if not self._free_block_indices: + raise BlockAllocator.NoFreeBlocksError() + + block_id = self._free_block_indices.popleft() + self._refcounter.incr(block_id) + return block_id + + def _free_block_id(self, block: Union[Block, BlockId]) -> None: + if isinstance(block, Block): + block_id = block.block_id + block.block_id = None + else: + block_id = block + assert block_id is not None + + refcount = self._refcounter.decr(block_id) + if refcount == 0: + self._free_block_indices.appendleft(block_id) + + def free(self, block: Block, keep_block_object: bool = False) -> None: + # Release the physical block id + self._free_block_id(block) + + # Release the block object + if not keep_block_object: + self._block_pool.free_block(block) + + def free_block_id(self, block_id: BlockId) -> None: + self._free_block_id(block_id) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: The new sequence of blocks that shares the same memory + as the original sequence. + """ + source_blocks = get_all_blocks_recursively(last_block) + + forked_blocks: List[Block] = [] + prev_block = None + for block in source_blocks: + + # Increment refcount for each block. + assert block.block_id is not None + refcount = self._refcounter.incr(block.block_id) + assert refcount != 1, "can't fork free'd block" + + forked_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block.token_ids, + block_size=self._block_size, + physical_block_id=block.block_id) + + forked_blocks.append(forked_block) + prev_block = forked_blocks[-1] + + return forked_blocks + + def get_num_free_blocks(self) -> int: + return len(self._free_block_indices) + + def get_num_total_blocks(self) -> int: + return len(self._all_block_indices) + + def get_physical_block_id(self, absolute_id: int) -> int: + """Returns the zero-offset block id on certain block allocator + given the absolute block id. + + Args: + absolute_id (int): The absolute block id for the block + in whole allocator. + + Returns: + int: The zero-offset block id on certain device. + """ + return sorted(self._all_block_indices).index(absolute_id) + + @property + def refcounter(self): + return self._refcounter + + @property + def all_block_ids(self) -> FrozenSet[int]: + return self._all_block_indices + + def cow_block_if_not_appendable(self, block: Block) -> BlockId: + """Performs a copy-on-write operation on the given block if it is not + appendable. + + Args: + block (Block): The block to check for copy-on-write. + + Returns: + BlockId: The block index of the new block if a copy-on-write + operation was performed, or the original block index if + no copy-on-write was necessary. + """ + src_block_id = block.block_id + assert src_block_id is not None + + if self._cow_tracker.is_appendable(block): + return src_block_id + + self._free_block_id(block) + trg_block_id = self._allocate_block_id() + + self._cow_tracker.record_cow(src_block_id, trg_block_id) + + return trg_block_id + + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: + """Returns the copy-on-write source->destination mapping and clears it. + + Returns: + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. + """ + return self._cow_tracker.clear_cows() + + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + """Mark blocks as computed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + """Determine blocks that can be skipped in prefill. + + Since the naive allocator does not support prefix caching, always return + an empty list. + """ + return [] + + def promote_to_immutable_block(self, block: Block) -> BlockId: + raise NotImplementedError("There is no promotion for naive blocks") + + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out. + + Args: + blocks: List of blocks to be swapped. + Returns: + int: the number of full blocks that will be touched by + swapping in/out the given blocks. Non full blocks are ignored + when deciding the number of blocks to touch. + """ + # NOTE: for naive block, we use set to eliminate common blocks among + # seqs, also we compare the empty slots in the mutable blocks with + # lookahead slots to get the number of unique new block that are + # needed. + old_block_set = set() + for block in blocks: + if block.is_full: + old_block_set.add(block) + return len(old_block_set) + + def swap_out(self, blocks: List[Block]) -> None: + for block in blocks: + self._free_block_id(block) + + def swap_in(self, blocks: List[Block]) -> None: + for block in blocks: + # Here we allocate either immutable or mutable block and then + # extract its block_id. Note that the block object is released + # and the block_id is assigned to "block" to allow reusing the + # existing "block" object + if block.is_full: + tmp_block = self.allocate_immutable_block( + prev_block=block.prev_block, token_ids=block.token_ids) + else: + tmp_block = self.allocate_mutable_block( + prev_block=block.prev_block) + tmp_block.append_token_ids(block.token_ids) + + block_id = tmp_block.block_id + tmp_block.block_id = None + self._block_pool.free_block(tmp_block) + + block.block_id = block_id # Assign block_id + + def get_prefix_cache_hit_rate(self) -> float: + return -1 + + def reset_prefix_cache(self) -> bool: + """No prefix cache for naive block allocator.""" + return True + + def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: + # Not applicable for naive block allocator. + return [] + + +class NaiveBlock(Block): + """An implementation of the Block class that does not support prefix + caching. + + The NaiveBlock class represents a block of token IDs with a fixed size. It + provides methods for appending token IDs to the block and manages copy-on + -write operations when necessary. + + Args: + prev_block (Block): The previous block in the sequence. + token_ids (List[int]): The initial token IDs to be stored in the block. + block_size (int): The maximum number of token IDs that can be stored in + the block. + allocator (BlockAllocator): The block allocator associated with this + block. + block_id (Optional[int], optional): The physical block index + of this block. Defaults to None, which means no allocation has been + made. + _cow_target (Optional[Block], optional): The copy-on-write target block. + If not provided, it defaults to self. + """ + + def __init__(self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + _cow_target: Optional[Block] = None, + extra_hash: Optional[int] = None): + self._token_ids: List[int] = [] + self._block_size = block_size + self._prev_block = prev_block + self._block_id = block_id + self._allocator = allocator + self._cow_target = _cow_target if _cow_target is not None else self + + self._append_token_ids_no_cow(token_ids) + + def append_token_ids(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block and performs a + copy-on-write if necessary. + + Args: + token_ids (Optional[List[int]]): The token IDs to be appended + to the block. + """ + self._append_token_ids_no_cow(token_ids) + + if self._block_id is not None: + self._block_id = (self._allocator.cow_block_if_not_appendable( + self._cow_target)) + + def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block + + Args: + token_ids (List[int]): The token IDs to be appended to the block. + """ + if len(token_ids) == 0: + return + + assert len(token_ids) <= self.num_empty_slots + + self._token_ids.extend(token_ids) + + @property + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + def computed(self, value) -> None: + raise NotImplementedError + + @property + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + + @property + def block_id(self) -> Optional[int]: + return self._block_id + + @block_id.setter + def block_id(self, value: Optional[int]) -> None: + self._block_id = value + + @property + def is_full(self) -> bool: + return self.num_empty_slots == 0 + + @property + def num_empty_slots(self) -> int: + return self._block_size - len(self.token_ids) + + @property + def token_ids(self) -> List[int]: + return self._token_ids + + @property + def num_tokens_total(self) -> int: + raise NotImplementedError( + "num_tokens_total is not used for naive block") + + @property + def block_size(self) -> int: + return self._block_size + + @property + def prev_block(self) -> Optional["Block"]: + return self._prev_block + + @property + def extra_hash(self): + return None + + @property + def content_hash(self) -> Optional[int]: + return None diff --git a/core/block/prefix_caching_block.py b/core/block/prefix_caching_block.py new file mode 100644 index 0000000..2913a01 --- /dev/null +++ b/core/block/prefix_caching_block.py @@ -0,0 +1,1135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Token blocks.""" +import sys +from bisect import bisect_left +from os.path import commonprefix +from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, + Tuple) + +from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, + get_all_blocks_recursively) +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device, + DeviceAwareBlockAllocator) +from vllm.core.block.naive_block import (BlockPool, NaiveBlock, + NaiveBlockAllocator) +from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor +from vllm.logger import init_logger +from vllm.sequence import Sequence + +PrefixHash = int + +# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME +# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, +# then we know this block hasn't been accessed yet. +_DEFAULT_LAST_ACCESSED_TIME = -1 + +logger = init_logger(__name__) + + +class BlockTracker: + """Used to track the status of a block inside the prefix caching allocator + """ + __slots__ = ("active", "last_accessed", "computed") + + def reset(self): + self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME + self.computed: bool = False + + def __init__(self): + self.active: bool = False + self.reset() + + def enable(self): + assert not self.active + self.active = True + self.reset() + + def disable(self): + assert self.active + self.active = False + self.reset() + + +class PrefixCachingBlockAllocator(BlockAllocator): + """A block allocator that implements prefix caching. + + The PrefixCachingBlockAllocator maintains a cache of blocks based on their + content hash. It reuses blocks with the same content hash to avoid redundant + memory allocation. The allocator also supports copy-on-write operations. + + Args: + num_blocks (int): The total number of blocks to manage. + block_size (int): The size of each block in tokens. + block_ids(Optional[Iterable[int]], optional): An optional iterable of + block IDs. If not provided, block IDs will be assigned sequentially + from 0 to num_blocks - 1. + """ + + # Note that we use 'None' as a string here instead of None because + # as of Python 3.12, hash(None) returns a constant predictable value. + # This could possibly make it easier to find and exploit hash + # collisions. 'None' as a string will be hashed differently per process, + # but consistently within the same process. This is the same as the + # behavior of None prior to Python 3.12. + _none_hash: int = hash('None') + + # Implements Block.Factory. + def __init__( + self, + num_blocks: int, + block_size: int, + block_ids: Optional[Iterable[int]] = None, + eviction_policy: EvictionPolicy = EvictionPolicy.LRU, + ): + if block_ids is None: + block_ids = range(num_blocks) + + self._block_size = block_size + + # A mapping of prefix hash to block index. All blocks which have a + # prefix hash will be in this dict, even if they have refcount 0. + self._cached_blocks: Dict[PrefixHash, BlockId] = {} + + # A list of immutable block IDs that have been touched by scheduler + # and should be marked as computed after an entire batch of sequences + # are scheduled. + self._touched_blocks: Set[BlockId] = set() + + # Used to track status of each physical block id + self._block_tracker: Dict[BlockId, BlockTracker] = {} + for block_id in block_ids: + self._block_tracker[block_id] = BlockTracker() + + # Pre-allocate "num_blocks * extra_factor" block objects. + # The "* extra_factor" is a buffer to allow more block objects + # than physical blocks + extra_factor = 4 + self._block_pool = BlockPool(self._block_size, self._create_block, + self, num_blocks * extra_factor) + + # An allocator for blocks that do not have prefix hashes. + self._hashless_allocator = NaiveBlockAllocator( + create_block=self._create_block, # type: ignore + num_blocks=num_blocks, + block_size=block_size, + block_ids=block_ids, + block_pool=self._block_pool, # Share block pool here + ) + + # Evitor used to maintain how we want to handle those computed blocks + # if we find memory pressure is high. + self.eviction_policy = eviction_policy + self.evictor: Evictor = make_evictor(self.eviction_policy) + + # We share the refcounter between allocators. This allows us to promote + # blocks originally allocated in the hashless allocator to immutable + # blocks. + self._refcounter = self._hashless_allocator.refcounter + + self._cow_tracker = CopyOnWriteTracker( + refcounter=self._refcounter.as_readonly()) + + self.metric_data = CacheMetricData() + + def _create_block( + self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + computed: bool = False, + extra_hash: Optional[int] = None, + ) -> Block: + # Bind block to self. + allocator = self + + return PrefixCachingBlock( + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=allocator, + computed=computed, + extra_hash=extra_hash, + ) + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> Block: + """Allocates an immutable block with the given token IDs, reusing cached + blocks if possible. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + token_ids (List[int]): The token IDs to be stored in the block. + + Returns: + Block: The allocated immutable block. + """ + assert device is None + assert_prefix_caching_block_or_none(prev_block) + + # First, try to create a block that points to cached data + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=token_ids, + block_size=self._block_size, + physical_block_id=None, + extra_hash=extra_hash) + assert block.content_hash is not None + + cached_block_id = self._cached_blocks.get(block.content_hash, None) + if cached_block_id is not None: + self.metric_data.query(hit=True) + block.block_id = cached_block_id + self._incr_refcount_cached_block(block) + return block + self.metric_data.query(hit=False) + self._block_pool.free_block(block) + + # No cached block => Allocate a new block + block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash) + block.append_token_ids(token_ids) + return block + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> List[Block]: + blocks = [] + for token_ids in block_token_ids: + prev_block = self.allocate_immutable_block(prev_block=prev_block, + token_ids=token_ids, + device=device, + extra_hash=extra_hash) + blocks.append(prev_block) + return blocks + + def allocate_mutable_block(self, + prev_block: Optional[Block], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> Block: + """Allocates a mutable block. If there are no free blocks, this will + evict unused cached blocks. + + Args: + prev_block (Block): The previous block in the sequence. + None is not allowed unlike it is super class. + + Returns: + Block: The allocated mutable block. + """ + assert device is None + assert_prefix_caching_block_or_none(prev_block) + + block_id = self._allocate_block_id() + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=[], + block_size=self._block_size, + physical_block_id=block_id, + extra_hash=extra_hash) + assert not block.computed + assert block.content_hash is None + return block + + def _incr_refcount_cached_block(self, block: Block) -> None: + # Set this block to be "computed" since it is pointing to a + # cached block id (which was already computed) + block.computed = True + + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.incr(block_id) + if refcount == 1: + # In case a cached block was evicted, restore its tracking + if block_id in self.evictor: + self.evictor.remove(block_id) + + self._track_block_id(block_id, computed=True) + + def _decr_refcount_cached_block(self, block: Block) -> None: + # Ensure this is immutable/cached block + assert block.content_hash is not None + + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.decr(block_id) + if refcount > 0: + block.block_id = None + return + else: + assert refcount == 0 + + # No longer used + assert block.content_hash in self._cached_blocks + + # Add the cached block to the evictor + # (This keeps the cached block around so it can be reused) + self.evictor.add(block_id, block.content_hash, block.num_tokens_total, + self._block_tracker[block_id].last_accessed) + + # Stop tracking the block + self._untrack_block_id(block_id) + + block.block_id = None + + def _decr_refcount_hashless_block(self, block: Block) -> None: + block_id = block.block_id + assert block_id is not None + + # We may have a fork case where block is shared, + # in which case, we cannot remove it from tracking + refcount = self._refcounter.get(block_id) + if refcount == 1: + self._untrack_block_id(block_id) + + # Decrement refcount of the block_id, but do not free the block object + # itself (will be handled by the caller) + self._hashless_allocator.free(block, keep_block_object=True) + + def _allocate_block_id(self) -> BlockId: + """First tries to allocate a block id from the hashless allocator, + and if there are no blocks, then tries to evict an unused cached block. + """ + hashless_block_id = self._maybe_allocate_hashless_block_id() + if hashless_block_id is not None: + return hashless_block_id + + evicted_block_id = self._maybe_allocate_evicted_block_id() + if evicted_block_id is not None: + return evicted_block_id + + # No block available in hashless allocator, nor in unused cache blocks. + raise BlockAllocator.NoFreeBlocksError() + + def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]: + try: + # Allocate mutable block and extract its block_id + block = self._hashless_allocator.allocate_mutable_block( + prev_block=None) + block_id = block.block_id + self._block_pool.free_block(block) + + self._track_block_id(block_id, computed=False) + return block_id + except BlockAllocator.NoFreeBlocksError: + return None + + def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]: + if self.evictor.num_blocks == 0: + return None + + # Here we get an evicted block, which is only added + # into evictor if its ref counter is 0 + # and since its content would be changed, we need + # to remove it from _cached_blocks's tracking list + block_id, content_hash_to_evict = self.evictor.evict() + + # Sanity checks + assert content_hash_to_evict in self._cached_blocks + _block_id = self._cached_blocks[content_hash_to_evict] + assert self._refcounter.get(_block_id) == 0 + assert _block_id == block_id + + self._cached_blocks.pop(content_hash_to_evict) + + self._refcounter.incr(block_id) + self._track_block_id(block_id, computed=False) + + return block_id + + def _free_block_id(self, block: Block) -> None: + """Decrements the refcount of the block. The block may be in two + possible states: (1) immutable/cached or (2) mutable/hashless. + In the first case, the refcount is decremented directly and the block + may be possibly added to the evictor. In other case, hashless + allocator free(..) with keep_block_object=True is called to only free + the block id (since the block object may be reused by the caller) + """ + block_id = block.block_id + assert block_id is not None, "Freeing unallocated block is undefined" + + if block.content_hash is not None: + # Immutable: This type of block is always cached, and we want to + # keep it in the evictor for future reuse + self._decr_refcount_cached_block(block) + else: + # Mutable: This type of block is not cached, so we release it + # directly to the hashless allocator + self._decr_refcount_hashless_block(block) + + assert block.block_id is None + + def free(self, block: Block, keep_block_object: bool = False) -> None: + """Release the block (look at free_block_id(..) docs) + """ + # Release the physical block index + self._free_block_id(block) + + # Release the block object to the pool + if not keep_block_object: + self._block_pool.free_block(block) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: The new sequence of blocks that shares the same memory + as the original sequence. + """ + source_blocks = get_all_blocks_recursively(last_block) + + forked_blocks: List[Block] = [] + prev_block = None + for block in source_blocks: + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.incr(block_id) + assert refcount != 1, "can't fork free'd block_id = {}".format( + block_id) + + forked_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block.token_ids, + block_size=self._block_size, + physical_block_id=block_id, + extra_hash=block.extra_hash) + + forked_blocks.append(forked_block) + prev_block = forked_blocks[-1] + + return forked_blocks + + def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + assert device is None + # The number of free blocks is the number of hashless free blocks + # plus the number of blocks evictor could free from its list. + return self._hashless_allocator.get_num_free_blocks( + ) + self.evictor.num_blocks + + def get_num_total_blocks(self) -> int: + return self._hashless_allocator.get_num_total_blocks() + + def get_physical_block_id(self, absolute_id: int) -> int: + """Returns the zero-offset block id on certain block allocator + given the absolute block id. + + Args: + absolute_id (int): The absolute block id for the block + in whole allocator. + + Returns: + int: The rzero-offset block id on certain device. + """ + return sorted(self.all_block_ids).index(absolute_id) + + @property + def all_block_ids(self) -> FrozenSet[int]: + return self._hashless_allocator.all_block_ids + + def get_prefix_cache_hit_rate(self) -> float: + return self.metric_data.get_hit_rate() + + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ + num_used_blocks = (self.get_num_total_blocks() - + self.get_num_free_blocks()) + if num_used_blocks > 0: + logger.warning( + "Failed to reset prefix cache because some " + "blocks (%d) are not freed yet", num_used_blocks) + return False + + # Free all blocks in the evictor. + while (block_id := + self._maybe_allocate_evicted_block_id()) is not None: + self._hashless_allocator.free_block_id(block_id) + + # Should not have any cached blocks because all blocks are evicted. + assert not self._cached_blocks + + # Reset the evictor. + self.evictor = make_evictor(self.eviction_policy) + + # Reset the block tracker. + for block_id in self._block_tracker: + self._block_tracker[block_id] = BlockTracker() + + # Reset the metrics. + self.metric_data = CacheMetricData() + + logger.info("Successfully reset prefix cache") + return True + + def is_block_cached(self, block: Block) -> bool: + assert block.content_hash is not None + return block.content_hash in self._cached_blocks + + def promote_to_immutable_block(self, block: Block) -> BlockId: + """Once a mutable block is full, it can be promoted to an immutable + block. This means that its content can be referenced by future blocks + having the same prefix. + + Note that if we already have a cached block with the same content, we + will replace the newly-promoted block's mapping with the existing cached + block id. + + Args: + block: The mutable block to be promoted. + + Returns: + BlockId: Either the original block index, or the block index of + the previously cached block matching the same content. + """ + # Ensure block can be promoted + assert block.content_hash is not None + assert block.block_id is not None + assert self._refcounter.get(block.block_id) > 0 + + if block.content_hash not in self._cached_blocks: + # No cached content hash => Set this block as cached. + # Note that this block cannot be marked as computed yet + # because other sequences in the same batch cannot reuse + # this block. + self._cached_blocks[block.content_hash] = block.block_id + # Mark this block as touched so that it can be marked as + # computed after the entire batch of sequences are scheduled. + self._touched_blocks.add(block.block_id) + return block.block_id + + # Reuse the cached content hash + self._decr_refcount_hashless_block(block) + block.block_id = self._cached_blocks[block.content_hash] + + # Increment refcount of the cached block and (possibly) restore + # it from the evictor. + # Note that in this case, the block is marked as computed + self._incr_refcount_cached_block(block) + + return block.block_id + + def cow_block_if_not_appendable(self, block: Block) -> BlockId: + """Performs a copy-on-write operation on the given block if it is not + appendable. + + Args: + block (Block): The block to check for copy-on-write. + + Returns: + BlockId: The block index of the new block if a copy-on-write + operation was performed, or the original block index if + no copy-on-write was necessary. + """ + src_block_id = block.block_id + assert src_block_id is not None + + if self._cow_tracker.is_appendable(block): + return src_block_id + + self._free_block_id(block) + trg_block_id = self._allocate_block_id() + + self._cow_tracker.record_cow(src_block_id, trg_block_id) + + return trg_block_id + + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: + """Returns the copy-on-write source->destination mapping and clears it. + + Returns: + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. + """ + return self._cow_tracker.clear_cows() + + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + If the block is added into evictor, we need to update corresponding + info in evictor's metadata. + """ + + for block_id in block_ids: + if self._block_tracker[block_id].active: + self._block_tracker[block_id].last_accessed = now + elif block_id in self.evictor: + self.evictor.update(block_id, now) + else: + raise ValueError( + "Mark block as accessed which is not belonged to GPU") + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + # Mark all touched blocks as computed. + for block_id in self._touched_blocks: + self._block_tracker[block_id].computed = True + self._touched_blocks.clear() + + def _track_block_id(self, block_id: Optional[BlockId], + computed: bool) -> None: + assert block_id is not None + self._block_tracker[block_id].enable() + self._block_tracker[block_id].computed = computed + + def _untrack_block_id(self, block_id: Optional[BlockId]) -> None: + assert block_id is not None + self._block_tracker[block_id].disable() + + def block_is_computed(self, block_id: int) -> bool: + if self._block_tracker[block_id].active: + return self._block_tracker[block_id].computed + else: + return block_id in self.evictor + + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + """Return the block ids that are common for a given sequence group. + + Only those blocks that are immutable and already be marked + compyted would be taken consideration. + """ + + # NOTE We exclude the last block to avoid the case where the entire + # prompt is cached. This would cause erroneous behavior in model + # runner. + + # It returns a list of int although type annotation says list of string. + if len(computed_seq_block_ids) == 1: + return computed_seq_block_ids[0] + + return commonprefix([ + ids for ids in computed_seq_block_ids # type: ignore + if ids + ]) + + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out. + + Args: + blocks: List of blocks to be swapped. + Returns: + int: the number of full blocks that will be touched by + swapping in/out the given blocks. Non full blocks are ignored + when deciding the number of blocks to touch. + """ + num_touched_blocks: int = 0 + for block in blocks: + # If the block has a match in the cache and the cached + # block is not referenced, then we still count it as a + # touched block + if block.is_full and (not self.is_block_cached(block) or \ + (block.content_hash is not None and \ + self._cached_blocks[block.content_hash] in \ + self.evictor)): + num_touched_blocks += 1 + return num_touched_blocks + + def swap_out(self, blocks: List[Block]) -> None: + """Execute the swap out actions. Basically just free the + given blocks. + + Args: + blocks: List of blocks to be swapped out. + """ + for block in blocks: + self._free_block_id(block) + + def swap_in(self, blocks: List[Block]) -> None: + """Execute the swap in actions. Change the block id from + old allocator to current allocator for each block to finish + the block table update. + + Args: + blocks: List of blocks to be swapped in. + """ + for block in blocks: + # Here we allocate either immutable or mutable block and then + # extract its block_id. Note that the block object is released + # and the block_id is assigned to "block" to allow reusing the + # existing "block" object + if block.is_full: + tmp_block = self.allocate_immutable_block( + prev_block=block.prev_block, + token_ids=block.token_ids, + extra_hash=block.extra_hash) + else: + tmp_block = self.allocate_mutable_block( + prev_block=block.prev_block, extra_hash=block.extra_hash) + tmp_block.append_token_ids(block.token_ids) + + block_id = tmp_block.block_id + self._block_pool.free_block(tmp_block) + + block.block_id = block_id # Assign block_id + + def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: + """ + Given a list of block hashes, return the prefix of the block hashes that + are all cached. + + Since a block's block hash includes the hashes of all previous blocks, + and we only allocate/deallocate blocks in the entire sequence, so if a + block is cached, then all previous blocks are also cached. With this + property, we can use binary search to find the prefix of cached blocks. + + Args: + block_hashes (List[int]): The list of block hashes. + + Returns: + List[int]: The prefix of the `block_hashes` that are cached. + """ + + def _block_is_cached(block_hash: PrefixHash) -> bool: + if block_hash not in self._cached_blocks: + return False + + cached_block_id = self._cached_blocks[block_hash] + # We only consider the blocks that are marked as computed. + return self.block_is_computed(cached_block_id) + + def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int: + + # python <= 3.10 don't have the key argument + if sys.version_info < (3, 10): + a = [key(e) for e in a] + return bisect_left(a, x) + else: + return bisect_left(a, x, key=key) + + # Look for the first block that's not cached, and returns the prefix + # i.e. blocks that are cached. + idx = _bisect_left(block_hashes, + True, + key=lambda x: not _block_is_cached(x)) + return block_hashes[:idx] + + +class PrefixCachingBlock(Block): + """A block implementation that supports prefix caching. + + The PrefixCachingBlock class represents a block of token IDs with prefix + caching capabilities. It wraps a NaiveBlock internally and provides + additional functionality for content hashing and promoting immutable blocks + with the prefix caching allocator. + + Args: + prev_block (Optional[PrefixCachingBlock]): The previous block in the + sequence. + token_ids (List[int]): The initial token IDs to be stored in the block. + block_size (int): The maximum number of token IDs that can be stored in + the block. + allocator (BlockAllocator): The prefix + caching block allocator associated with this block. + block_id (Optional[int], optional): The physical block index + of this block. Defaults to None. + extra_hash (Optional[int]): The hash value of additional factors + such as adapters that influence the block, apart from the token_ids. + """ + + # Note that we use 'None' as a string here instead of None because + # as of Python 3.12, hash(None) returns a constant predictable value. + # This could possibly make it easier to find and exploit hash + # collisions. 'None' as a string will be hashed differently per process, + # but consistently within the same process. This is the same as the + # behavior of None prior to Python 3.12. + _none_hash: int = hash('None') + + def __init__( + self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + computed: bool = False, + extra_hash: Optional[int] = None, + ): + assert isinstance(allocator, PrefixCachingBlockAllocator), ( + "Currently this class is only tested with " + "PrefixCachingBlockAllocator. Got instead allocator = {}".format( + allocator)) + assert_prefix_caching_block_or_none(prev_block) + + self._prev_block = prev_block + self._cached_content_hash: Optional[int] = None + self._cached_num_tokens_total: int = 0 + self._allocator = allocator + self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME + self._computed = computed + self._extra_hash = extra_hash + + # On the first time, we create the block object, and next we only + # reinitialize it + if hasattr(self, "_block"): + self._block.__init__( # type: ignore[has-type] + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=self._allocator) + else: + self._block = NaiveBlock(prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=self._allocator) + + self._update_num_tokens_total() + + def _update_num_tokens_total(self): + """Incrementally computes the number of tokens that there is + till the current block (included) + """ + res = 0 + + # Add all previous blocks + if self._prev_block is not None: + res += self._prev_block.num_tokens_total + + # Add current block + res += len(self.token_ids) + + self._cached_num_tokens_total = res + + @property + def computed(self) -> bool: + return self._computed + + @computed.setter + def computed(self, value) -> None: + self._computed = value + + @property + def last_accessed(self) -> float: + return self._last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._last_accessed = last_accessed_ts + + def append_token_ids(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block and registers the block as + immutable if the block becomes full. + + Args: + token_ids (List[int]): The token IDs to be appended to the block. + """ + # Ensure this is mutable block (not promoted) + assert self.content_hash is None + assert not self.computed + + if len(token_ids) == 0: + return + + # Ensure there are input tokens + assert token_ids, "Got token_ids = {}".format(token_ids) + + # Naive block handles CoW. + self._block.append_token_ids(token_ids) + self._update_num_tokens_total() + + # If the content hash is present, then the block can be made immutable. + # Register ourselves with the allocator, potentially replacing the + # physical block index. + if self.content_hash is not None: + self.block_id = self._allocator.promote_to_immutable_block(self) + + @property + def block_id(self) -> Optional[int]: + return self._block.block_id + + @block_id.setter + def block_id(self, value) -> None: + self._block.block_id = value + + @property + def is_full(self) -> bool: + return self._block.is_full + + @property + def num_empty_slots(self) -> int: + return self._block.num_empty_slots + + @property + def num_tokens_total(self) -> int: + return self._cached_num_tokens_total + + @property + def block_size(self) -> int: + return self._block.block_size + + @property + def token_ids(self) -> List[int]: + return self._block.token_ids + + @property + def prev_block(self) -> Optional[Block]: + return self._prev_block + + @property + def extra_hash(self) -> Optional[int]: + return self._extra_hash + + @property + def content_hash(self) -> Optional[int]: + """Return the content-based hash of the current block, or None if it is + not yet defined. + + For the content-based hash to be defined, the current block must be + full. + """ + # If the hash is already computed, return it. + if self._cached_content_hash is not None: + return self._cached_content_hash + + # We cannot compute a hash for the current block because it is not full. + if not self.is_full: + return None + + is_first_block = self._prev_block is None + prev_block_hash = ( + self._none_hash if is_first_block else + self._prev_block.content_hash # type: ignore + ) + + # Previous block exists but does not yet have a hash. + # Return no hash in this case. + if prev_block_hash == self._none_hash and not is_first_block: + return None + + self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( + is_first_block, + prev_block_hash, + cur_block_token_ids=self.token_ids, + extra_hash=self._extra_hash) + return self._cached_content_hash + + @classmethod + def hash_block_tokens(cls, + is_first_block: bool, + prev_block_hash: Optional[int], + cur_block_token_ids: List[int], + extra_hash: Optional[int] = None) -> int: + """Computes a hash value corresponding to the contents of a block and + the contents of the preceding block(s). The hash value is used for + prefix caching. + + Parameters: + - is_first_block (bool): A flag indicating if the block is the first in + the sequence. + - prev_block_hash (Optional[int]): The hash of the previous block. None + if this is the first block. + - cur_block_token_ids (List[int]): A list of token ids in the current + block. The current block is assumed to be full. + - extra_hash (Optional[int]): The hash value of additional factors + such as adapters that influence the block, apart from the token_ids. + + Returns: + - int: The computed hash value for the block. + """ + if is_first_block and prev_block_hash is None: + prev_block_hash = cls._none_hash + return hash((is_first_block, prev_block_hash, *cur_block_token_ids, + extra_hash)) + + +class ComputedBlocksTracker: + """ + Tracks the computed blocks for each sequence. + + Internally, it maintains a map from sequence id to the list of block hashes + for the sequence. We cache the hashes of the full blocks for each sequence, + and make sure the hash is calculated in the same way as the allocator. + When a sequence is being decoded, we also update the sequence's hash + accordingly and incrementally. + + From the sequence hash, with prefix caching enabled, we could also calculate + the number of cached tokens for the sequence by looking up the number of + cached block hashes in the allocator. + """ + + # Note that we use 'None' as a string here instead of None because + # as of Python 3.12, hash(None) returns a constant predictable value. + # This could possibly make it easier to find and exploit hash + # collisions. 'None' as a string will be hashed differently per process, + # but consistently within the same process. This is the same as the + # behavior of None prior to Python 3.12. + _none_hash: int = hash('None') + + def __init__( + self, + allocator: DeviceAwareBlockAllocator, + block_size: int, + enable_caching: bool, + ): + self._allocator = allocator + self._block_size = block_size + self._enable_caching = enable_caching + + # A map from seq_id to the list of block hashes for the + # sequence. This is so that we don't have to recompute the block hashes + # for the sequence when we need to check if the sequence is cached. + # Note a block that's not full will not have its hash calculated and + # recorded. + self._seq_id_to_blocks_hashes: Dict[int, List[int]] = {} + + # A map from seq_id to the number of tokens that are cached for the + # sequence. + # We need this so that a sequence in continuous prefill doesn't + # accidentally see its cached token count change. See comments in + # `get_num_cached_tokens` for more details. + self._seq_id_to_num_tokens_computed: Dict[int, int] = {} + + def _update_seq_hashes(self, seq: Sequence) -> None: + """Incrementally update the sequence's block hashes and record them.""" + assert self._enable_caching + + block_hashes_recorded = self._seq_id_to_blocks_hashes.get( + seq.seq_id, []) + cur_num_blocks_recorded = len(block_hashes_recorded) + token_ids = seq.get_token_ids() + assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, ( + f"The sequence has {len(token_ids)} tokens, but" + f" already recorded {cur_num_blocks_recorded} blocks. " + "This should not happen since we assume blocks are " + "only appended other than recomputation. When the sequence is " + "recomputed, we should have removed the info of the old blocks.") + # Update the computed block hashes for the sequence. Since only full + # blocks are considered as "computed", we take floor here. + num_computed_blocks = len(token_ids) // self._block_size + + # We need to know the hash of the previous block to compute the hash of + # the current block so that blocks could be uniquely identified across + # sequences of prefixes. + prev_block_hash = (self._none_hash if cur_num_blocks_recorded == 0 else + block_hashes_recorded[-1]) + # Only update the computed block hashes for the new blocks + for i in range(cur_num_blocks_recorded, num_computed_blocks): + assert len(token_ids) >= (i + 1) * self._block_size + block_token_ids = token_ids[i * self._block_size:(i + 1) * + self._block_size] + + # NOTE: If there are any factors affecting the block besides + # token_ids, they should be added as input to extra_hash. + extra_hash = seq.extra_hash() + + # This has to be kept in sync with the allocator's hash + # calculation. + block_hash = PrefixCachingBlock.hash_block_tokens( + is_first_block=prev_block_hash == self._none_hash, + prev_block_hash=prev_block_hash, + cur_block_token_ids=block_token_ids, + extra_hash=extra_hash, + ) + block_hashes_recorded.append(block_hash) + prev_block_hash = block_hash + + self._seq_id_to_blocks_hashes[seq.seq_id] = block_hashes_recorded + + def get_num_cached_tokens(self, seq: Sequence) -> int: + if not self._enable_caching: + return 0 + + # We always try to update the sequence hashes on the fly. + # This is to ensure that we don't miss any cached tokens for the + # sequence during decode. + # This routine should only update hash for any new blocks too. + self._update_seq_hashes(seq) + + num_computed_tokens_prev = self._seq_id_to_num_tokens_computed.get( + seq.seq_id, None) + + # TODO(rickyx): This hack could be removed once we mark blocks as + # computed correctly with chunked prefills. + if num_computed_tokens_prev is not None and seq.is_prefill(): + # For a sequence that is still in prefill, we don't + # recompute the number of cached tokens. + # This also handles correctly chunked prefill since currently + # we mark blocks as computed even if the sequence is still partially + # prefilled. So a continuously prefilled sequence should not + # see its cached token count change while running. + return num_computed_tokens_prev + + block_hashes = self._seq_id_to_blocks_hashes[seq.seq_id] + + # This is O(logN), where N is the number of blocks. + num_cached_blocks = len( + self._allocator.find_cached_blocks_prefix(block_hashes)) + num_cached_tokens = num_cached_blocks * self._block_size + self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens + return num_cached_tokens + + def remove_seq(self, seq_id: int) -> None: + """Stop tracking the sequence.""" + if not self._enable_caching: + return + assert seq_id in self._seq_id_to_blocks_hashes + del self._seq_id_to_blocks_hashes[seq_id] + + assert seq_id in self._seq_id_to_num_tokens_computed + del self._seq_id_to_num_tokens_computed[seq_id] + + +class LastAccessBlocksTracker: + """Manages the last access time of the tracked sequences, in order to allow + an efficient update of allocator's block last access times + """ + + def __init__(self, allocator): + self._allocator = allocator + self._seq_last_access: Dict[int, Optional[float]] = {} + + def add_seq(self, seq_id: int) -> None: + """Start tracking seq_id + """ + assert seq_id not in self._seq_last_access + self._seq_last_access[seq_id] = None + + def remove_seq(self, seq_id: int) -> None: + """Stop tracking seq_id + """ + assert seq_id in self._seq_last_access + del self._seq_last_access[seq_id] + + def update_last_access(self, seq_id: int, time: float) -> None: + assert seq_id in self._seq_last_access + self._seq_last_access[seq_id] = time + + def update_seq_blocks_last_access(self, seq_id: int, + block_ids: List[int]) -> None: + assert seq_id in self._seq_last_access + + ts = self._seq_last_access[seq_id] + + if ts is None: + # No last access was recorded, no need to update. + return + + self._allocator.mark_blocks_as_accessed(block_ids, ts) + + +def assert_prefix_caching_block_or_none(block: Optional[Block]): + if block is None: + return + assert isinstance(block, + PrefixCachingBlock), "Got block = {}".format(block) diff --git a/core/block/utils.py b/core/block/utils.py new file mode 100644 index 0000000..e933c6e --- /dev/null +++ b/core/block/utils.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Block manager utils.""" +from vllm.sequence import SequenceGroup +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) + + +def check_no_caching_or_swa_for_blockmgr_encdec( + block_mgr, seq_group: SequenceGroup) -> None: + ''' + Enforce that prefix caching & sliding-window attention (SWA) + are currently unsupported *specifically* for encoder/decoder models. + + Raises NotImplementedError if unsupported scenario is detected. + + Arguments: + + * block_mgr: BlockSpaceManager instance + * seq_group: SequenceGroup passed to block_mgr + ''' + + if seq_group.is_encoder_decoder(): + if block_mgr.max_block_sliding_window is not None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + + if block_mgr.enable_caching: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) diff --git a/core/block_manager.py b/core/block_manager.py new file mode 100644 index 0000000..a333992 --- /dev/null +++ b/core/block_manager.py @@ -0,0 +1,521 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""A block manager that manages token blocks.""" +from typing import Dict, List, Optional +from typing import Sequence as GenericSequence +from typing import Tuple + +from vllm.core.block.block_table import BlockTable +from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.interfaces import Block +from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, + LastAccessBlocksTracker) +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.utils import Device + +SeqId = int +EncoderSeqId = str + + +class SelfAttnBlockSpaceManager(BlockSpaceManager): + """BlockSpaceManager which manages the allocation of KV cache. + + It owns responsibility for allocation, swapping, allocating memory for + autoregressively-generated tokens, and other advanced features such as + prefix caching, forking/copy-on-write, and sliding-window memory allocation. + + This class implements the design described in + https://github.com/vllm-project/vllm/pull/3492. + + Lookahead slots + The block manager has the notion of a "lookahead slot". These are slots + in the KV cache that are allocated for a sequence. Unlike the other + allocated slots, the content of these slots is undefined -- the worker + may use the memory allocations in any way. + + In practice, a worker could use these lookahead slots to run multiple + forward passes for a single scheduler invocation. Each successive + forward pass would write KV activations to the corresponding lookahead + slot. This allows low inter-token latency use-cases, where the overhead + of continuous batching scheduling is amortized over >1 generated tokens. + + Speculative decoding uses lookahead slots to store KV activations of + proposal tokens. + + See https://github.com/vllm-project/vllm/pull/3250 for more information + on lookahead scheduling. + + Args: + block_size (int): The size of each memory block. + num_gpu_blocks (int): The number of memory blocks allocated on GPU. + num_cpu_blocks (int): The number of memory blocks allocated on CPU. + watermark (float, optional): The threshold used for memory swapping. + Defaults to 0.01. + sliding_window (Optional[int], optional): The size of the sliding + window. Defaults to None. + enable_caching (bool, optional): Flag indicating whether caching is + enabled. Defaults to False. + """ + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + watermark: float = 0.01, + sliding_window: Optional[int] = None, + enable_caching: bool = False, + ) -> None: + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.num_total_cpu_blocks = num_cpu_blocks + + self.sliding_window = sliding_window + # max_block_sliding_window is the max number of blocks that need to be + # allocated + self.max_block_sliding_window = None + if sliding_window is not None: + # +1 here because // rounds down + num_blocks = sliding_window // block_size + 1 + # +1 here because the last block may not be full, + # and so the sequence stretches one more block at the beginning + # For example, if sliding_window is 3 and block_size is 4, + # we may need 2 blocks when the second block only holds 1 token. + self.max_block_sliding_window = num_blocks + 1 + + self.watermark = watermark + assert watermark >= 0.0 + + self.enable_caching = enable_caching + + self.watermark_blocks = int(watermark * num_gpu_blocks) + + self.block_allocator = CpuGpuBlockAllocator.create( + allocator_type="prefix_caching" if enable_caching else "naive", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=block_size, + ) + + self.block_tables: Dict[SeqId, BlockTable] = {} + self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} + + self._computed_blocks_tracker = ComputedBlocksTracker( + self.block_allocator, self.block_size, self.enable_caching) + self._last_access_blocks_tracker = LastAccessBlocksTracker( + self.block_allocator) + + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + num_required_blocks = BlockTable.get_num_required_blocks( + seq.get_token_ids(), + block_size=self.block_size, + num_lookahead_slots=num_lookahead_slots, + ) + + if seq_group.is_encoder_decoder(): + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + num_required_blocks += BlockTable.get_num_required_blocks( + encoder_seq.get_token_ids(), + block_size=self.block_size, + ) + + if self.max_block_sliding_window is not None: + num_required_blocks = min(num_required_blocks, + self.max_block_sliding_window) + + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( + device=Device.GPU) + + # Use watermark to avoid frequent cache eviction. + if (self.num_total_gpu_blocks - num_required_blocks + < self.watermark_blocks): + return AllocStatus.NEVER + if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def _allocate_sequence(self, seq: Sequence) -> BlockTable: + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + max_block_sliding_window=self.max_block_sliding_window, + ) + if seq.get_token_ids(): + # NOTE: If there are any factors affecting the block besides + # token_ids, they should be added as input to extra_hash. + extra_hash = seq.extra_hash() + + # Add blocks to the block table only if the sequence is non empty. + block_table.allocate(token_ids=seq.get_token_ids(), + extra_hash=extra_hash) + + return block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + + # Allocate self-attention block tables for decoder sequences + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert not (set(seq.seq_id for seq in waiting_seqs) + & self.block_tables.keys()), "block table already exists" + + # NOTE: Here we assume that all sequences in the group have the same + # prompt. + seq = waiting_seqs[0] + block_table: BlockTable = self._allocate_sequence(seq) + self.block_tables[seq.seq_id] = block_table + + # Track seq + self._last_access_blocks_tracker.add_seq(seq.seq_id) + + # Assign the block table for each sequence. + for seq in waiting_seqs[1:]: + self.block_tables[seq.seq_id] = block_table.fork() + + # Track seq + self._last_access_blocks_tracker.add_seq(seq.seq_id) + + # Allocate cross-attention block table for encoder sequence + # + # NOTE: Here we assume that all sequences in the group have the same + # encoder prompt. + request_id = seq_group.request_id + + assert (request_id + not in self.cross_block_tables), \ + "block table already exists" + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + if seq_group.is_encoder_decoder(): + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + block_table = self._allocate_sequence(encoder_seq) + self.cross_block_tables[request_id] = block_table + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + """Determine if there is enough space in the GPU KV cache to continue + generation of the specified sequence group. + + We use a worst-case heuristic: assume each touched block will require a + new allocation (either via CoW or new block). We can append slots if the + number of touched blocks is less than the number of free blocks. + + "Lookahead slots" are slots that are allocated in addition to the slots + for known tokens. The contents of the lookahead slots are not defined. + This is used by speculative decoding when speculating future tokens. + """ + + num_touched_blocks = 0 + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + block_table = self.block_tables[seq.seq_id] + + num_touched_blocks += ( + block_table.get_num_blocks_touched_by_append_slots( + token_ids=block_table.get_unseen_token_ids( + seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + )) + + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( + Device.GPU) + return num_touched_blocks <= num_free_gpu_blocks + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + + block_table = self.block_tables[seq.seq_id] + + block_table.append_token_ids( + token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + num_computed_slots=seq.data.get_num_computed_tokens(), + extra_hash=seq.extra_hash(), + ) + # Return any new copy-on-writes. + new_cows = self.block_allocator.clear_copy_on_writes() + return new_cows + + def free(self, seq: Sequence) -> None: + seq_id = seq.seq_id + + if seq_id not in self.block_tables: + # Already freed or haven't been scheduled yet. + return + + # Update seq block ids with the latest access time + self._last_access_blocks_tracker.update_seq_blocks_last_access( + seq_id, self.block_tables[seq.seq_id].physical_block_ids) + + # Untrack seq + self._last_access_blocks_tracker.remove_seq(seq_id) + self._computed_blocks_tracker.remove_seq(seq_id) + + # Free table/blocks + self.block_tables[seq_id].free() + del self.block_tables[seq_id] + + def free_cross(self, seq_group: SequenceGroup) -> None: + request_id = seq_group.request_id + if request_id not in self.cross_block_tables: + # Already freed or hasn't been scheduled yet. + return + self.cross_block_tables[request_id].free() + del self.cross_block_tables[request_id] + + def get_block_table(self, seq: Sequence) -> List[int]: + block_ids = self.block_tables[seq.seq_id].physical_block_ids + return block_ids # type: ignore + + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + request_id = seq_group.request_id + assert request_id in self.cross_block_tables + block_ids = self.cross_block_tables[request_id].physical_block_ids + assert all(b is not None for b in block_ids) + return block_ids # type: ignore + + def access_all_blocks_in_seq(self, seq: Sequence, now: float): + if self.enable_caching: + # Record the latest access time for the sequence. The actual update + # of the block ids is deferred to the sequence free(..) call, since + # only during freeing of block ids, the blocks are actually added to + # the evictor (which is when the most updated time is required) + # (This avoids expensive calls to mark_blocks_as_accessed(..)) + self._last_access_blocks_tracker.update_last_access( + seq.seq_id, now) + + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + # If prefix caching is enabled, mark immutable blocks as computed + # right after they have been scheduled (for prefill). This assumes + # the scheduler is synchronous so blocks are actually computed when + # scheduling the next batch. + self.block_allocator.mark_blocks_as_computed([]) + + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: + """Determine which blocks for which we skip prefill. + + With prefix caching we can skip prefill for previously-generated blocks. + Currently, the attention implementation only supports skipping cached + blocks if they are a contiguous prefix of cached blocks. + + This method determines which blocks can be safely skipped for all + sequences in the sequence group. + """ + computed_seq_block_ids = [] + for seq in seqs: + all_blocks = self.block_tables[seq.seq_id].physical_block_ids + num_cached_tokens = ( + self._computed_blocks_tracker.get_num_cached_tokens(seq)) + assert num_cached_tokens % self.block_size == 0 + num_cached_blocks = num_cached_tokens // self.block_size + computed_block_ids = all_blocks[:num_cached_blocks] + computed_seq_block_ids.append(computed_block_ids) + + # NOTE(sang): This assumes seq_block_ids doesn't contain any None. + return self.block_allocator.get_common_computed_block_ids( + computed_seq_block_ids) # type: ignore + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + if parent_seq.seq_id not in self.block_tables: + # Parent sequence has either been freed or never existed. + return + src_block_table = self.block_tables[parent_seq.seq_id] + self.block_tables[child_seq.seq_id] = src_block_table.fork() + + # Track child seq + self._last_access_blocks_tracker.add_seq(child_seq.seq_id) + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + """Returns the AllocStatus for the given sequence_group + with num_lookahead_slots. + + Args: + sequence_group (SequenceGroup): The sequence group to swap in. + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + AllocStatus: The AllocStatus for the given sequence group. + """ + return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED, + num_lookahead_slots) + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + """Returns the block id mapping (from CPU to GPU) generated by + swapping in the given seq_group with num_lookahead_slots. + + Args: + seq_group (SequenceGroup): The sequence group to swap in. + + Returns: + List[Tuple[int, int]]: The mapping of swapping block from CPU + to GPU. + """ + physical_block_id_mapping = [] + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + blocks = self.block_tables[seq.seq_id].blocks + if len(blocks) == 0: + continue + + seq_swap_mapping = self.block_allocator.swap(blocks=blocks, + src_device=Device.CPU, + dst_device=Device.GPU) + + # Refresh the block ids of the table (post-swap) + self.block_tables[seq.seq_id].update(blocks) + + seq_physical_block_id_mapping = { + self.block_allocator.get_physical_block_id( + Device.CPU, cpu_block_id): + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block_id) + for cpu_block_id, gpu_block_id in seq_swap_mapping.items() + } + + physical_block_id_mapping.extend( + list(seq_physical_block_id_mapping.items())) + + return physical_block_id_mapping + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + """Returns whether we can swap out the given sequence_group + with num_lookahead_slots. + + Args: + seq_group (SequenceGroup): The sequence group to swap out. + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + bool: Whether it's possible to swap out current sequence group. + """ + alloc_status = self._can_swap(seq_group, Device.CPU, + SequenceStatus.RUNNING) + return alloc_status == AllocStatus.OK + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + """Returns the block id mapping (from GPU to CPU) generated by + swapping out the given sequence_group with num_lookahead_slots. + + Args: + sequence_group (SequenceGroup): The sequence group to swap out. + + Returns: + List[Tuple[int, int]]: The mapping of swapping block from + GPU to CPU. + """ + physical_block_id_mapping = [] + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + blocks = self.block_tables[seq.seq_id].blocks + if len(blocks) == 0: + continue + + seq_swap_mapping = self.block_allocator.swap(blocks=blocks, + src_device=Device.GPU, + dst_device=Device.CPU) + + # Refresh the block ids of the table (post-swap) + self.block_tables[seq.seq_id].update(blocks) + + seq_physical_block_id_mapping = { + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block_id): + self.block_allocator.get_physical_block_id( + Device.CPU, cpu_block_id) + for gpu_block_id, cpu_block_id in seq_swap_mapping.items() + } + + physical_block_id_mapping.extend( + list(seq_physical_block_id_mapping.items())) + + return physical_block_id_mapping + + def get_num_free_gpu_blocks(self) -> int: + return self.block_allocator.get_num_free_blocks(Device.GPU) + + def get_num_free_cpu_blocks(self) -> int: + return self.block_allocator.get_num_free_blocks(Device.CPU) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_allocator.get_prefix_cache_hit_rate(device) + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return self.block_allocator.reset_prefix_cache(device) + + def _can_swap(self, + seq_group: SequenceGroup, + device: Device, + status: SequenceStatus, + num_lookahead_slots: int = 0) -> AllocStatus: + """Returns the AllocStatus for swapping in/out the given sequence_group + on to the 'device'. + + Args: + sequence_group (SequenceGroup): The sequence group to swap in/out. + device (Device): device to swap the 'seq_group' on. + status (SequenceStatus): The status of sequence which is needed + for action. RUNNING for swap out and SWAPPED for swap in + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + AllocStatus: The AllocStatus for swapping in/out the given + sequence_group on to the 'device'. + """ + # First determine the number of blocks that will be touched by this + # swap. Then verify if there are available blocks in the device + # to perform the swap. + num_blocks_touched = 0 + blocks: List[Block] = [] + for seq in seq_group.get_seqs(status=status): + block_table = self.block_tables[seq.seq_id] + if block_table.blocks is not None: + # Compute the number blocks to touch for the tokens to be + # appended. This does NOT include the full blocks that need + # to be touched for the swap. + num_blocks_touched += \ + block_table.get_num_blocks_touched_by_append_slots( + block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots) + blocks.extend(block_table.blocks) + # Compute the number of full blocks to touch and add it to the + # existing count of blocks to touch. + num_blocks_touched += self.block_allocator.get_num_full_blocks_touched( + blocks, device=device) + + watermark_blocks = 0 + if device == Device.GPU: + watermark_blocks = self.watermark_blocks + + if self.block_allocator.get_num_total_blocks( + device) < num_blocks_touched: + return AllocStatus.NEVER + elif self.block_allocator.get_num_free_blocks( + device) - num_blocks_touched >= watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def get_num_cached_tokens(self, seq: Sequence) -> int: + """Get the number of tokens in blocks that are already computed and + cached in the block manager for the sequence. + """ + return self._computed_blocks_tracker.get_num_cached_tokens(seq) diff --git a/core/evictor.py b/core/evictor.py new file mode 100644 index 0000000..7ec4768 --- /dev/null +++ b/core/evictor.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +import heapq +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple + + +class EvictionPolicy(enum.Enum): + """Enum for eviction policy used by make_evictor to instantiate the correct + Evictor subclass. + """ + LRU = enum.auto() + + +class Evictor(ABC): + """The Evictor subclasses should be used by the BlockAllocator class to + handle eviction of freed Blocks. + """ + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def __contains__(self, block_id: int) -> bool: + pass + + @abstractmethod + def evict(self) -> Tuple[int, int]: + """Runs the eviction algorithm and returns the evicted block's + content hash along with physical block id along with physical block id + """ + pass + + @abstractmethod + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + """Adds block to the evictor, making it a candidate for eviction""" + pass + + @abstractmethod + def update(self, block_id: int, last_accessed: float): + """Update corresponding block's access time in metadata""" + pass + + @abstractmethod + def remove(self, block_id: int): + """Remove a given block id from the cache.""" + pass + + @property + @abstractmethod + def num_blocks(self) -> int: + pass + + +class BlockMetaData: + """Data structure for storing key data describe cached block, so that + evitor could use to make its decision which one to choose for eviction + + Here we use physical block id as the dict key, as there maybe several + blocks with the same content hash, but their physical id is unique. + """ + + def __init__(self, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + self.content_hash = content_hash + self.num_hashed_tokens = num_hashed_tokens + self.last_accessed = last_accessed + + +class LRUEvictor(Evictor): + """Evicts in a least-recently-used order using the last_accessed timestamp + that's recorded in the Block. If there are multiple blocks with + the same last_accessed time, then the one with the largest num_hashed_tokens + will be evicted. If two blocks each have the lowest last_accessed time and + highest num_hashed_tokens value, then one will be chose arbitrarily + """ + + # CLEANUP_THRESHOLD determines the maximum allowable size of the priority + # queue relative to the free table size. When this threshold is exceeded, + # a cleanup operation is triggered to reduce memory usage. + CLEANUP_THRESHOLD = 50 + + def __init__(self): + self.free_table: Dict[int, BlockMetaData] = {} + self.priority_queue = [] + + def __contains__(self, block_id: int) -> bool: + return block_id in self.free_table + + def evict(self) -> Tuple[int, int]: + if len(self.free_table) == 0: + raise ValueError("No usable cache memory left") + + while self.priority_queue: + # We do not remove outdated entries from the priority queue at the + # time of updating the last_accessed timestamp. Instead, outdated + # entries are filtered out here during eviction. Outdated entries + # would either not in the free table, or have older last accessed + # time. + last_accessed, _, block_id, content_hash = heapq.heappop( + self.priority_queue) + if (block_id in self.free_table and + self.free_table[block_id].last_accessed == last_accessed): + self.free_table.pop(block_id) + return block_id, content_hash + + raise ValueError("No usable cache memory left") + + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + self.free_table[block_id] = BlockMetaData(content_hash, + num_hashed_tokens, + last_accessed) + heapq.heappush( + self.priority_queue, + (last_accessed, -num_hashed_tokens, block_id, content_hash)) + self._cleanup_if_necessary() + + def update(self, block_id: int, last_accessed: float): + self.free_table[block_id].last_accessed = last_accessed + + def _cleanup_if_necessary(self): + if len(self.priority_queue) > LRUEvictor.CLEANUP_THRESHOLD * len( + self.free_table): + self._cleanup() + + def _cleanup(self): + new_priority_queue: List[Tuple[float, int, int, int]] = [] + + for block_id, block in self.free_table.items(): + new_priority_queue.append( + (block.last_accessed, -block.num_hashed_tokens, block_id, + block.content_hash)) + heapq.heapify(new_priority_queue) + + self.priority_queue = new_priority_queue + + def remove(self, block_id: int): + if block_id not in self.free_table: + raise ValueError( + "Attempting to remove block that's not in the evictor") + self.free_table.pop(block_id) + + @property + def num_blocks(self) -> int: + return len(self.free_table) + + +def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: + if eviction_policy == EvictionPolicy.LRU: + return LRUEvictor() + else: + raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") diff --git a/core/interfaces.py b/core/interfaces.py new file mode 100644 index 0000000..ba290ee --- /dev/null +++ b/core/interfaces.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +from abc import ABC, abstractmethod +from typing import List, Optional +from typing import Sequence as GenericSequence +from typing import Tuple + +from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device + + +class AllocStatus(enum.Enum): + """Result for BlockSpaceManager.can_allocate + + 1. Ok: seq_group can be allocated now. + 2. Later: seq_group cannot be allocated. + The capacity of allocator is larger than seq_group required. + 3. Never: seq_group can never be allocated. + The seq_group is too large to allocated in GPU. + """ + OK = enum.auto() + LATER = enum.auto() + NEVER = enum.auto() + + +class BlockSpaceManager(ABC): + + @staticmethod + def get_block_space_manager_class(version: str): + version = version.lower() + + if version == "selfattn": + from vllm.core.block_manager import SelfAttnBlockSpaceManager + return SelfAttnBlockSpaceManager + + if version == "placeholder": + from vllm.core.placeholder_block_space_manager import ( + PlaceholderBlockSpaceManager) + return PlaceholderBlockSpaceManager + + raise ValueError(f"Unknown version {version=}") + + @abstractmethod + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + pass + + @abstractmethod + def allocate(self, seq_group: SequenceGroup) -> None: + pass + + @abstractmethod + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + pass + + @abstractmethod + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + @abstractmethod + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + pass + + @abstractmethod + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + pass + + @abstractmethod + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def free(self, seq: Sequence) -> None: + pass + + @abstractmethod + def get_block_table(self, seq: Sequence) -> List[int]: + pass + + @abstractmethod + def get_num_free_gpu_blocks(self) -> int: + pass + + @abstractmethod + def get_num_free_cpu_blocks(self) -> int: + pass + + @abstractmethod + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: + pass + + @abstractmethod + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + + @abstractmethod + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache for specified or all devices.""" + pass + + @abstractmethod + def get_num_cached_tokens(self, seq: Sequence) -> int: + pass diff --git a/core/placeholder_block_space_manager.py b/core/placeholder_block_space_manager.py new file mode 100644 index 0000000..71b2294 --- /dev/null +++ b/core/placeholder_block_space_manager.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import List, Optional, Tuple + +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device + + +class PlaceholderBlockSpaceManager(BlockSpaceManager): + """A version of BlockSpaceManager for use in environments + where block management is not required. + For example: pooling models or attention-free models like Mamba. + + This class provides the same interface as BlockSpaceManager, but its + methods perform no actions or return simple values like True in specific + actions. It's designed to be used in scenarios where the overhead of + block management is unnecessary, such as in an embedding environment. + """ + + def __init__( + self, + **kwargs, + ) -> None: + pass + + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + # Always return OK for dummy purposes + return AllocStatus.OK + + def allocate(self, seq_group: SequenceGroup) -> None: + # No actual allocation logic needed + pass + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + return True + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + return [] + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.OK + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + return True + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def free(self, seq: Sequence) -> None: + # No operation on free + return + + def get_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + + def get_num_free_gpu_blocks(self) -> int: + return 1 + + def get_num_free_cpu_blocks(self) -> int: + return 1 + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + def get_common_computed_block_ids(self, + seq_group: List[Sequence]) -> List[int]: + return [] + + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + pass + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return -1 + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return True + + def get_num_cached_tokens(self, seq: Sequence) -> int: + return 0 diff --git a/core/scheduler.py b/core/scheduler.py new file mode 100644 index 0000000..44be855 --- /dev/null +++ b/core/scheduler.py @@ -0,0 +1,2093 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +import os +import random +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Callable, Deque, Dict, Iterable, List, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple, Union + +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import (Sequence, SequenceData, SequenceGroup, + SequenceGroupBase, SequenceGroupMetadata, + SequenceGroupMetadataDelta, SequenceStage, + SequenceStatus) +from vllm.utils import Device, PyObjectCache + +logger = init_logger(__name__) + +# Test-only. If configured, decode is preempted with +# ARTIFICIAL_PREEMPTION_PROB% probability. +ENABLE_ARTIFICIAL_PREEMPT = bool( + os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa +ARTIFICIAL_PREEMPTION_PROB = 0.5 +ARTIFICIAL_PREEMPTION_MAX_CNT = 500 + + +class PreemptionMode(enum.Enum): + """Preemption modes. + + 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory + and swap them back in when the sequences are resumed. + 2. Recomputation: Discard the blocks of the preempted sequences and + recompute them when the sequences are resumed, treating the sequences as + new prompts. + """ + + SWAP = enum.auto() + RECOMPUTE = enum.auto() + + +@dataclass +class SchedulingBudget: + """The available slots for scheduling. + + TODO(sang): Right now, the budget is request_id-aware meaning it can ignore + budget update from the same request_id. It is because in normal scheduling + path, we update RUNNING num_seqs ahead of time, meaning it could be + updated more than once when scheduling RUNNING requests. Since this won't + happen if we only have chunked prefill scheduling, we can remove this + feature from the API when chunked prefill is enabled by default. + """ + + token_budget: int + max_num_seqs: int + _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) + _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) + # Number of cached tokens in the batch. + _num_cached_tokens: int = 0 + # Number of actual non-cached tokens in the batch. + _num_batched_tokens: int = 0 + _num_curr_seqs: int = 0 + + def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): + # We allow num_new_tokens to be 0 when the entire sequence has + # been cached. + assert num_new_tokens >= 0 + assert num_new_seqs != 0 + return (self.num_batched_tokens + num_new_tokens <= self.token_budget + and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) + + def remaining_token_budget(self): + return self.token_budget - self.num_batched_tokens + + def add_num_batched_tokens(self, + req_id: str, + num_batched_tokens: int, + num_cached_tokens: int = 0): + if req_id in self._request_ids_num_batched_tokens: + return + assert num_cached_tokens >= 0 + assert num_batched_tokens >= 0 + + self._request_ids_num_batched_tokens.add(req_id) + self._num_batched_tokens += num_batched_tokens + self._num_cached_tokens += num_cached_tokens + + def subtract_num_batched_tokens(self, req_id: str, + num_batched_tokens: int): + if req_id in self._request_ids_num_batched_tokens: + self._request_ids_num_batched_tokens.remove(req_id) + self._num_batched_tokens -= num_batched_tokens + + def add_num_seqs(self, req_id: str, num_curr_seqs: int): + if req_id in self._request_ids_num_curr_seqs: + return + + self._request_ids_num_curr_seqs.add(req_id) + self._num_curr_seqs += num_curr_seqs + + def subtract_num_seqs(self, req_id: str, num_curr_seqs: int): + if req_id in self._request_ids_num_curr_seqs: + self._request_ids_num_curr_seqs.remove(req_id) + self._num_curr_seqs -= num_curr_seqs + + @property + def num_batched_tokens(self): + return self._num_batched_tokens + + @property + def num_curr_seqs(self): + return self._num_curr_seqs + + @property + def num_cached_tokens(self): + return self._num_cached_tokens + + +@dataclass +class ScheduledSequenceGroup: + # A sequence group that's scheduled. + seq_group: SequenceGroup + # The total chunk size (number of tokens) to process for next iteration. + # 1 for decoding. Same as prompt tokens for prefill, but if prefill is + # chunked, it can be smaller than that. + token_chunk_size: int + + +@dataclass +class SchedulerOutputs: + """The scheduling decision made from a scheduler.""" + + # Scheduled sequence groups. + scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup] + # Number of prefill groups scheduled. + num_prefill_groups: int + # Total number of batched tokens. + num_batched_tokens: int + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] + # Sequence groups that are going to be ignored. + ignored_seq_groups: List[SequenceGroup] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + # The number of requests in the running queue + running_queue_size: int + preempted: int + + def __post_init__(self): + # Swap in and swap out should never happen at the same time. + assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) + + self.num_loras: int = len(self.lora_requests) + if self.num_loras > 0: + self._sort_by_lora_ids() + + self.num_prompt_adapters: int = len(self.prompt_adapter_requests) + + def is_empty(self) -> bool: + # NOTE: We do not consider the ignored sequence groups. + return (not self.scheduled_seq_groups and not self.blocks_to_swap_in + and not self.blocks_to_swap_out and not self.blocks_to_copy) + + def _sort_by_lora_ids(self): + assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups) + + def key_fn(group: ScheduledSequenceGroup): + key = (group.seq_group.lora_int_id, group.seq_group.request_id) + if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups): + # Sort sequence groups so that all prefills come before all + # decodes as required by chunked prefill. + return (not group.seq_group.is_prefill(), *key) + return key + + self.scheduled_seq_groups = sorted(self.scheduled_seq_groups, + key=key_fn) + + @property + def lora_requests(self) -> Set[LoRARequest]: + return { + g.seq_group.lora_request + for g in self.scheduled_seq_groups + if g.seq_group.lora_request is not None + } + + @property + def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]: + return { + g.seq_group.prompt_adapter_request + for g in self.scheduled_seq_groups + if g.seq_group.prompt_adapter_request is not None + } + + +@dataclass +class SchedulerRunningOutputs: + """The requests that are scheduled from a running queue. + + Could contain prefill (prefill that's chunked) or decodes. If there's not + enough memory, it can be preempted (for recompute) or swapped out. + """ + + # Selected sequences that are running and in a decoding phase. + decode_seq_groups: List[ScheduledSequenceGroup] + # Selected sequences that are running and in a prefill phase. + # I.e., it means the prefill has been chunked. + prefill_seq_groups: List[ScheduledSequenceGroup] + # The preempted sequences. + preempted: List[SequenceGroup] + # Sequences that are swapped out. + swapped_out: List[SequenceGroup] + # The blocks to swap out. + blocks_to_swap_out: List[Tuple[int, int]] + # The blocks to copy. + blocks_to_copy: List[Tuple[int, int]] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + + # Optimization for fast-access to seq_group lists + decode_seq_groups_list: List[SequenceGroup] + prefill_seq_groups_list: List[SequenceGroup] + + @classmethod + def create_empty(cls) -> "SchedulerRunningOutputs": + return SchedulerRunningOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out=[], + blocks_to_copy=[], + num_lookahead_slots=0, + decode_seq_groups_list=[], + prefill_seq_groups_list=[], + ) + + +@dataclass +class SchedulerSwappedInOutputs: + """The requests that are scheduled from a swap queue. + + Could contain prefill (prefill that's chunked) or decodes. + """ + + # Selected sequences that are going to be swapped in and is in a + # decoding phase. + decode_seq_groups: List[ScheduledSequenceGroup] + # Selected sequences that are going to be swapped in and in a prefill + # phase. I.e., it means the prefill has been chunked. + prefill_seq_groups: List[ScheduledSequenceGroup] + # The blocks to swap in. + blocks_to_swap_in: List[Tuple[int, int]] + # The blocks to copy. + blocks_to_copy: List[Tuple[int, int]] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + # Infeasible sequence groups. + infeasible_seq_groups: List[SequenceGroup] + + @classmethod + def create_empty(cls) -> "SchedulerSwappedInOutputs": + return SchedulerSwappedInOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + blocks_to_swap_in=[], + blocks_to_copy=[], + num_lookahead_slots=0, + infeasible_seq_groups=[], + ) + + +@dataclass +class SchedulerPrefillOutputs: + """The requests that are scheduled from a waiting queue. + + Could contain a fresh prefill requests or preempted requests that need + to be recomputed from scratch. + """ + + # Selected sequences for prefill. + seq_groups: List[ScheduledSequenceGroup] + # Ignored sequence groups. + ignored_seq_groups: List[SequenceGroup] + num_lookahead_slots: int + + @classmethod + def create_empty(cls) -> "SchedulerPrefillOutputs": + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=0, + ) + + +def seq_group_metadata_builder(): + return SequenceGroupMetadata(request_id="", + is_prompt=False, + seq_data={}, + sampling_params=None, + block_tables={}) + + +def scheduler_running_outputs_builder(): + return SchedulerRunningOutputs(decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out=[], + blocks_to_copy=[], + num_lookahead_slots=0, + prefill_seq_groups_list=[], + decode_seq_groups_list=[]) + + +def scheduled_seq_group_builder(): + return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup), + token_chunk_size=0) + # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) + + +@dataclass +class PartialPrefillMetadata: + """Holds information about the partial prefills that are currently running + during a single iteration of the Scheduler. + When chunked prefill is enabled, we allow a certain number of seqs to be + partially prefilled during each iteration. Having multiple partial prefills + in flight allows us to minimize TTFT and avoid decode starvation in cases + where a single sequence group with a very large prompt blocks the queue for + too many iterations. + The number of long prefill requests is limited so that smaller + requests may jump the queue in front of them and get to the decode + phase faster. + """ + + # A minimum bound on the total number of prefills to be scheduled during + # this iteration + schedulable_prefills: int + + # The number of long prefill requests currently running + long_prefills: int + + scheduler_config: SchedulerConfig + + def can_schedule(self, seq_group: SequenceGroup) -> bool: + """When concurrent partial prefills are enabled, + we limit the number of long requests and only accept + shorter requests from the queue while running them + concurrently""" + return not (seq_group.first_seq.get_num_new_tokens() + > self.scheduler_config.long_prefill_token_threshold + and self.long_prefills + >= self.scheduler_config.max_long_partial_prefills + and self.scheduler_config.max_num_partial_prefills > 1) + + def maybe_increment_partial_prefills(self, + seq_group: SequenceGroup) -> None: + # When a new prefill is scheduled, we need to know if it is a + # long request + if (seq_group.first_seq.get_num_new_tokens() + > self.scheduler_config.long_prefill_token_threshold): + self.long_prefills += 1 + + @classmethod + def from_queues( + cls, + running: Deque[SequenceGroup], + waiting: Deque[SequenceGroup], + scheduler_config: SchedulerConfig, + ) -> "PartialPrefillMetadata": + """Create a PartialPrefillMetadata object from the current state of + the scheduler's queues. + This accounts for the currently running prefill requests, and peeks into + the waiting queue to see if there are more prefills to potentially be + scheduled during this iteration.""" + prefills = 0 + long_prefills = 0 + + waiting_long_prefills = 0 + + for sg in running: + if sg.first_seq.data.stage == SequenceStage.PREFILL: + prefills += 1 + if (sg.first_seq.get_num_new_tokens() + > scheduler_config.long_prefill_token_threshold): + long_prefills += 1 + + for sg in waiting: + # Don't bother looping through the rest of the queue if we know + # there are already at + # least max_partial_prefills requests to fill + if prefills >= scheduler_config.max_num_partial_prefills: + break + + # Don't count long requests from the waiting queue if we aren't + # going to schedule them anyway + if (sg.first_seq.get_num_new_tokens() + > scheduler_config.long_prefill_token_threshold): + if (long_prefills + waiting_long_prefills + >= scheduler_config.max_long_partial_prefills): + continue + waiting_long_prefills += 1 + prefills += 1 + + # NB: long_prefills and waiting_long_prefills are tracked separately. + # We don't account for the waiting requests here because we need to use + # this metadata to track how many have actually been scheduled. + return PartialPrefillMetadata( + schedulable_prefills=min( + prefills, scheduler_config.max_num_partial_prefills), + long_prefills=long_prefills, + scheduler_config=scheduler_config, + ) + + +class Scheduler: + + def __init__( + self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], + pipeline_parallel_size: int = 1, + output_proc_callback: Optional[Callable] = None, + ) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + # Note for LoRA scheduling: the current policy is extremely + # simple and NOT fair. It can lead to starvation of some + # LoRAs. This should be improved in the future. + self.lora_config = lora_config + + version = "selfattn" + if (self.scheduler_config.runner_type == "pooling" + or self.cache_config.is_attention_free): + version = "placeholder" + + BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( + version) + + num_gpu_blocks = cache_config.num_gpu_blocks + if num_gpu_blocks: + num_gpu_blocks //= pipeline_parallel_size + + num_cpu_blocks = cache_config.num_cpu_blocks + if num_cpu_blocks: + num_cpu_blocks //= pipeline_parallel_size + + # Create the block space manager. + self.block_manager = BlockSpaceManagerImpl( + block_size=self.cache_config.block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + sliding_window=self.cache_config.sliding_window, + enable_caching=self.cache_config.enable_prefix_caching, + ) + + # Sequence groups in the WAITING state. + # Contain new prefill or preempted requests. + self.waiting: Deque[SequenceGroup] = deque() + # Sequence groups in the RUNNING state. + # Contain decode requests. + self.running: Deque[SequenceGroup] = deque() + # Sequence groups in the SWAPPED state. + # Contain decode requests that are swapped out. + self.swapped: Deque[SequenceGroup] = deque() + # Sequence groups finished requests ids since last step iteration. + # It lets the model know that any state associated with these requests + # can and must be released after the current step. + # This is used to evict the finished requests from the Mamba cache. + self._finished_requests_ids: List[str] = list() + # Time at previous scheduling step + self.prev_time = 0.0 + # Did we schedule a prompt at previous step? + self.prev_prompt = False + # Latency of the last prompt step + self.last_prompt_latency = 0.0 + # preemption mode, RECOMPUTE or SWAP + self.user_specified_preemption_mode = scheduler_config.preemption_mode + + # The following field is test-only. It is used to inject artificial + # preemption. + self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT + self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT + if self.enable_artificial_preemption + else 0) + self.num_cumulative_preemption: int = 0 + + # Used to cache python objects + self._seq_group_metadata_cache: List[PyObjectCache] = [] + self._scheduler_running_outputs_cache: List[PyObjectCache] = [] + self._scheduled_seq_group_cache: List[PyObjectCache] = [] + + # For async output processing, we need to swap cache buffers between + # iterations. I.e. since the output processing is lagged one step, + # we cannot reuse the cached objects immediately when the schedule() + # is called again, but only when schedule() is called the second time. + self.output_proc_callback = output_proc_callback + self.use_async_output_proc = self.output_proc_callback is not None + self.num_cache_iters = 2 if self.use_async_output_proc else 1 + + self.cache_id = 0 + for i in range(self.num_cache_iters): + self._seq_group_metadata_cache.append( + PyObjectCache(seq_group_metadata_builder)) + self._scheduler_running_outputs_cache.append( + PyObjectCache(scheduler_running_outputs_builder)) + self._scheduled_seq_group_cache.append( + PyObjectCache(scheduled_seq_group_builder)) + + # For async postprocessor, the extra decode run cannot be done + # when the request reaches max_model_len. In this case, the request + # will be stopped during schedule() call and added to this stop list + # for processing and deallocation by the free_finished_seq_groups() + self._async_stopped: List[SequenceGroup] = [] + + # List with the chunk sizes to hand out to each sequence depending + # on how many partial prefills are running. This is slightly faster than + # running an integer division every time a prefill is scheduled. + # This splits the budget evenly among all prefills. + self.partial_prefill_budget_lookup_list = [0] * ( + self.scheduler_config.max_num_partial_prefills + 1) + self.partial_prefill_budget_lookup_list[0] = ( + scheduler_config.max_num_batched_tokens) + for i in range(1, self.scheduler_config.max_num_partial_prefills + 1): + self.partial_prefill_budget_lookup_list[i] = ( + scheduler_config.max_num_batched_tokens // i) + + @property + def next_cache_id(self): + return (self.cache_id + 1) % self.num_cache_iters + + @property + def lora_enabled(self) -> bool: + return bool(self.lora_config) + + @property + def num_decoding_tokens_per_seq(self) -> int: + """The number of new tokens.""" + return 1 + + def add_seq_group(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the waiting queue. + self.waiting.append(seq_group) + + def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the running queue. + # Only for testing purposes. + self.running.append(seq_group) + + def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the swapped queue. + # Only for testing purposes. + self.swapped.append(seq_group) + + def abort_seq_group( + self, + request_id: Union[str, Iterable[str]], + seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None, + ) -> None: + """Aborts a sequence group with the given ID. + + Check if the sequence group with the given ID + is present in any of the state queue. + If present, remove the sequence group from the state queue. + Also, if any of the sequences in the sequence group is not finished, + free the sequence with status `FINISHED_ABORTED`. + Otherwise, do nothing. + + Args: + request_id: The ID(s) of the sequence group to abort. + seq_id_to_seq_group: helper for groups with n>1 + """ + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) + seq_id_to_seq_group = seq_id_to_seq_group or {} + for state_queue in [self.waiting, self.running, self.swapped]: + aborted_groups: List[SequenceGroup] = [] + for seq_group in state_queue: + # When n>1, seq_group.request_id looks like + # foo_parallel_sample_0, while request_ids is just foo, and we + # should resolve it as real_request_id to match. + if seq_group.request_id in seq_id_to_seq_group: + real_request_id = seq_id_to_seq_group[ + seq_group.request_id].group_id + else: + real_request_id = seq_group.request_id + if real_request_id in request_ids: + # Appending aborted group into pending list. + aborted_groups.append(seq_group) + # We can't remove real_request_id in request_ids here, + # because there may be other seq groups sharing the same + # real_request_id + for aborted_group in aborted_groups: + # Remove the sequence group from the state queue. + state_queue.remove(aborted_group) + # Remove the aborted request from the Mamba cache. + self._finished_requests_ids.append(aborted_group.request_id) + for seq in aborted_group.get_seqs(): + if seq.is_finished(): + continue + seq.status = SequenceStatus.FINISHED_ABORTED + self.free_seq(seq) + if aborted_group.request_id in seq_id_to_seq_group: + del seq_id_to_seq_group[aborted_group.request_id] + + self._free_seq_group_cross_attn_blocks(aborted_group) + + def _free_seq_group_cross_attn_blocks( + self, + seq_group: SequenceGroup, + ) -> None: + """ + Free a sequence group from a cross-attention block table. + Has no effect on decoder-only models. + """ + if seq_group.is_encoder_decoder(): + self.block_manager.free_cross(seq_group) + + def has_unfinished_seqs(self) -> bool: + return (len(self.waiting) != 0 or len(self.running) != 0 + or len(self.swapped) != 0) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_manager.get_prefix_cache_hit_rate(device) + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return self.block_manager.reset_prefix_cache(device) + + def get_num_unfinished_seq_groups(self) -> int: + return len(self.waiting) + len(self.running) + len(self.swapped) + + def get_and_reset_finished_requests_ids(self) -> List[str]: + """Flushes the list of request ids of previously finished seq_groups.""" + finished_requests_ids = self._finished_requests_ids + self._finished_requests_ids = list() + return finished_requests_ids + + def _schedule_running( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> SchedulerRunningOutputs: + """Schedule sequence groups that are running. + + Running queue should include decode and chunked prefill requests. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any decodes are preempted. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any decodes are preempted. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + partial_prefill_metadata: information about the partial prefills + that are currently running + + Returns: + SchedulerRunningOutputs. + """ + ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[ + self.cache_id].get_object() + ret.blocks_to_swap_out.clear() + ret.blocks_to_copy.clear() + ret.decode_seq_groups.clear() + ret.prefill_seq_groups.clear() + ret.preempted.clear() + ret.swapped_out.clear() + + ret.num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill=False, enable_chunking=enable_chunking) + + ret.decode_seq_groups_list.clear() + ret.prefill_seq_groups_list.clear() + + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out + blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy + + decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups + prefill_seq_groups: List[ + ScheduledSequenceGroup] = ret.prefill_seq_groups + preempted: List[SequenceGroup] = ret.preempted + swapped_out: List[SequenceGroup] = ret.swapped_out + + running_queue = self.running + assert len(self._async_stopped) == 0 + while running_queue: + seq_group = running_queue[0] + # We discard the cached tokens info here because we don't need it + # for running sequence: + # 1. If a sequence is running with chunked prefill, the cached + # tokens info was already used for the first prefill. + # 2. If a sequence is running with non-chunked prefill, then + # there it's a decoding sequence, and the cached tokens info is + # irrelevant. + num_uncached_new_tokens, _ = \ + self._get_num_new_uncached_and_cached_tokens( + seq_group, + SequenceStatus.RUNNING, + enable_chunking, + budget, + partial_prefill_metadata, + ) + + num_running_tokens = num_uncached_new_tokens + if num_running_tokens == 0: + # No budget => Stop + break + + running_queue.popleft() + + # With async postprocessor, an extra decode run is done + # to process the final tokens. The check below avoids this extra + # decode run when the model max len is reached, in order to avoid + # a memory overflow. + if (self.use_async_output_proc and seq_group.seqs[0].get_len() + > self.scheduler_config.max_model_len): + self._async_stopped.append(seq_group) + continue + + # NOTE(woosuk): Preemption happens only when there is no available + # slot to keep all the sequence groups in the RUNNING state. + while not self._can_append_slots(seq_group, enable_chunking): + budget.subtract_num_batched_tokens(seq_group.request_id, + num_running_tokens) + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(seq_group.request_id, + num_running_seqs) + + if (curr_loras is not None and seq_group.lora_int_id > 0 + and seq_group.lora_int_id in curr_loras): + curr_loras.remove(seq_group.lora_int_id) + + # Determine victim sequence + cont_loop = True + if running_queue: + # Preempt the lowest-priority sequence group. + victim_seq_group = running_queue.pop() + else: + # No other sequence group can be preempted. + # Preempt the current sequence group. + # Note: This is also where we stop this loop + # (since there is nothing else to preempt) + victim_seq_group = seq_group + cont_loop = False + + # With async postprocessor, before preempting a sequence + # we need to ensure it has no pending async postprocessor + do_preempt = True + if self.use_async_output_proc: + assert self.output_proc_callback is not None + self.output_proc_callback( + request_id=victim_seq_group.request_id) + + # It may be that the async pending "victim_seq_group" + # becomes finished, in which case we simply free it. + if victim_seq_group.is_finished(): + self._free_finished_seq_group(victim_seq_group) + do_preempt = False + + # Do preemption + if do_preempt: + preempted_mode = self._preempt(victim_seq_group, + blocks_to_swap_out) + if preempted_mode == PreemptionMode.RECOMPUTE: + preempted.append(victim_seq_group) + else: + swapped_out.append(victim_seq_group) + + if not cont_loop: + break + else: + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + is_prefill = seq_group.is_prefill() + + scheduled_seq_group: ScheduledSequenceGroup = ( + self._scheduled_seq_group_cache[ + self.cache_id].get_object()) + scheduled_seq_group.seq_group = seq_group + if is_prefill: + scheduled_seq_group.token_chunk_size = num_running_tokens + prefill_seq_groups.append(scheduled_seq_group) + ret.prefill_seq_groups_list.append(seq_group) + else: + scheduled_seq_group.token_chunk_size = 1 + decode_seq_groups.append(scheduled_seq_group) + ret.decode_seq_groups_list.append(seq_group) + + budget.add_num_batched_tokens(seq_group.request_id, + num_running_tokens) + # OPTIMIZATION: Note that get_max_num_running_seqs is + # expensive. For the default scheduling chase where + # enable_chunking is False, num_seqs are updated before running + # this method, so we don't have to update it again here. + if enable_chunking: + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.add_num_seqs(seq_group.request_id, num_running_seqs) + if curr_loras is not None and seq_group.lora_int_id > 0: + curr_loras.add(seq_group.lora_int_id) + + self._scheduler_running_outputs_cache[self.next_cache_id].reset() + self._scheduled_seq_group_cache[self.next_cache_id].reset() + + return ret + + def _schedule_swapped( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + ) -> SchedulerSwappedInOutputs: + """Schedule sequence groups that are swapped out. + + It schedules swapped requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are swapped in. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are swapped in. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + + Returns: + SchedulerSwappedInOutputs. + """ + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_in: List[Tuple[int, int]] = [] + blocks_to_copy: List[Tuple[int, int]] = [] + decode_seq_groups: List[ScheduledSequenceGroup] = [] + prefill_seq_groups: List[ScheduledSequenceGroup] = [] + infeasible_seq_groups: List[SequenceGroup] = [] + + swapped_queue = self.swapped + + leftover_swapped: Deque[SequenceGroup] = deque() + while swapped_queue: + seq_group = swapped_queue[0] + + # If the sequence group cannot be swapped in, stop. + is_prefill = seq_group.is_prefill() + alloc_status = self.block_manager.can_swap_in( + seq_group, + self._get_num_lookahead_slots(is_prefill, enable_chunking)) + if alloc_status == AllocStatus.LATER: + break + elif alloc_status == AllocStatus.NEVER: + logger.warning( + "Failing the request %s because there's not enough kv " + "cache blocks to run the entire sequence.", + seq_group.request_id, + ) + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + infeasible_seq_groups.append(seq_group) + swapped_queue.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None + if (lora_int_id > 0 and (lora_int_id not in curr_loras) + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_swapped.appendleft(seq_group) + swapped_queue.popleft() + continue + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.SWAPPED, enable_chunking, + budget)) + + if num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): + break + + if lora_int_id > 0 and curr_loras is not None: + curr_loras.add(lora_int_id) + swapped_queue.popleft() + self._swap_in(seq_group, blocks_to_swap_in) + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + if is_prefill: + prefill_seq_groups.append( + ScheduledSequenceGroup( + seq_group, + token_chunk_size=num_new_tokens_uncached + + num_new_tokens_cached, + )) + else: + decode_seq_groups.append( + ScheduledSequenceGroup(seq_group, token_chunk_size=1)) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + + swapped_queue.extendleft(leftover_swapped) + + return SchedulerSwappedInOutputs( + decode_seq_groups=decode_seq_groups, + prefill_seq_groups=prefill_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_copy=blocks_to_copy, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=False, enable_chunking=enable_chunking), + infeasible_seq_groups=infeasible_seq_groups, + ) + + def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: + if (self.scheduler_config.chunked_prefill_enabled + and not self.scheduler_config.is_multi_step): + prompt_limit = self.scheduler_config.max_model_len + else: + prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) + + # Model is fine tuned with long context. Return the fine tuned max_len. + if seq_group.lora_request and seq_group.lora_request.long_lora_max_len: + assert prompt_limit <= seq_group.lora_request.long_lora_max_len + return seq_group.lora_request.long_lora_max_len + else: + return prompt_limit + + def _get_priority(self, + seq_group: SequenceGroup) -> Tuple[Optional[int], float]: + """Get the priority of the sequence group. + Highest preference to user-defined priority, followed by arrival time. + Args: + seq_group: The sequence group input. + Returns: + The priority of the sequence group. + """ + return seq_group.priority, seq_group.arrival_time + + def _schedule_priority_preemption( + self, + budget: SchedulingBudget, + ) -> int: + """Sorts waiting and running queue. Also, force preempt requests + from the running queue if their priority is lower. + Priority-based preemption is used with the priority policy. + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + Returns: + A count of priority-based preemptions. + """ + + waiting_queue = self.waiting + + running_queue = deque(sorted(self.running, key=self._get_priority)) + + blocks_to_swap_out: List[Tuple[int, int]] = [] + force_preemption_count = 0 + + if waiting_queue: + seq_group = waiting_queue.popleft() + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens_uncached, _ = \ + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.WAITING, False, budget) + + # Only preempt if priority inversion exists + while running_queue and self._get_priority( + running_queue[-1]) > self._get_priority(seq_group): + # Only preempt if waiting sequence cannot be allocated + can_allocate = self.block_manager.can_allocate(seq_group) + if (num_new_tokens_uncached > 0 + and can_allocate == AllocStatus.OK + and budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + )): + break + + # Adjust budget to remove the victim sequence group + vseq_group = running_queue.pop() + num_running_tokens_uncached, _ = ( + self._get_num_new_uncached_and_cached_tokens( + vseq_group, SequenceStatus.RUNNING, False, budget)) + budget.subtract_num_batched_tokens( + vseq_group.request_id, num_running_tokens_uncached) + num_running_seqs = vseq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(vseq_group.request_id, + num_running_seqs) + + # Preempt out the victim sequence group + self._preempt(vseq_group, blocks_to_swap_out) + waiting_queue.appendleft(vseq_group) + force_preemption_count += 1 + # Put the sequence back into the waiting queue + waiting_queue.appendleft(seq_group) + + waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) + + self.waiting = waiting_queue + self.running = running_queue + return force_preemption_count + + def _schedule_prefills( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> SchedulerPrefillOutputs: + """Schedule sequence groups that are in prefill stage. + + Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE + as a new prefill (that starts from beginning -> most recently generated + tokens). + + It schedules waiting requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are scheduled. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + partial_prefill_metadata: information about the partial prefills + that are currently running + + Returns: + SchedulerPrefillOutputs. + """ + if budget.remaining_token_budget() == 0: + # Do nothing: Can't add any more prefill anyway + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking), + ) + ignored_seq_groups: List[SequenceGroup] = [] + seq_groups: List[ScheduledSequenceGroup] = [] + using_prompt_embeds: bool = False + + waiting_queue = self.waiting + + leftover_waiting_sequences: Deque[SequenceGroup] = deque() + while self._passed_delay(time.time()) and waiting_queue: + seq_group = waiting_queue[0] + + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( + "Waiting sequence group should have only one prompt " + "sequence.") + if (partial_prefill_metadata is not None + and not partial_prefill_metadata.can_schedule(seq_group)): + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, + SequenceStatus.WAITING, + enable_chunking, + budget, + partial_prefill_metadata=partial_prefill_metadata, + )) + num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached + + if not enable_chunking: + num_prompt_tokens = waiting_seqs[0].get_len() + assert num_new_tokens == num_prompt_tokens + + prompt_limit = self._get_prompt_limit(seq_group) + if num_new_tokens > prompt_limit: + logger.warning( + "Input prompt (%d tokens) is too long" + " and exceeds limit of %d", + num_new_tokens, + prompt_limit, + ) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + num_lookahead_slots: int = 0 + if self.scheduler_config.is_multi_step and enable_chunking: + num_lookahead_slots = self._get_num_lookahead_slots( + True, enable_chunking) + + # If the sequence group cannot be allocated, stop. + can_allocate = self.block_manager.can_allocate( + seq_group, num_lookahead_slots=num_lookahead_slots) + if can_allocate == AllocStatus.LATER: + break + elif can_allocate == AllocStatus.NEVER: + logger.warning( + "Input prompt (%d tokens) + lookahead slots (%d) is " + "too long and exceeds the capacity of block_manager", + num_new_tokens, + num_lookahead_slots, + ) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + # We cannot mix sequence groups that use prompt embeds and + # those that do not. + if len(seq_groups) == 0: + using_prompt_embeds = seq_group.uses_prompt_embeds() + if using_prompt_embeds != seq_group.uses_prompt_embeds(): + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None + if (self.lora_enabled and lora_int_id > 0 + and lora_int_id not in curr_loras + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + + if (budget.num_batched_tokens + >= self.scheduler_config.max_num_batched_tokens): + # We've reached the budget limit - since there might be + # continuous prefills in the running queue, we should break + # to avoid scheduling any new prefills. + break + + num_new_seqs = seq_group.get_max_num_running_seqs() + if num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): + break + + # Can schedule this request. + if curr_loras is not None and lora_int_id > 0: + curr_loras.add(lora_int_id) + waiting_queue.popleft() + self._allocate_and_set_running(seq_group) + + if partial_prefill_metadata is not None: + partial_prefill_metadata.maybe_increment_partial_prefills( + seq_group) + + if enable_chunking and self.scheduler_config.is_multi_step: + blocks_to_copy: List[Tuple[int, int]] = [] + # init_multi_step_from_lookahead_slots happens in append_slots + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + # This assert will trip when a copy-on-write happens. This is + # not a concern as the very first sequence-group block + # allocation happens above. Still, we have the assert to + # catch any edge-cases. + assert not blocks_to_copy + else: + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config. + num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking, + ) + + seq_groups.append( + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=num_new_tokens)) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + + # Queue requests that couldn't be scheduled. + waiting_queue.extendleft(leftover_waiting_sequences) + if len(seq_groups) > 0: + self.prev_prompt = True + + return SchedulerPrefillOutputs( + seq_groups=seq_groups, + ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking), + ) + + def _schedule_default(self) -> SchedulerOutputs: + """Schedule queued requests. + + The current policy is designed to optimize the throughput. First, + it batches as many prefill requests as possible. And it schedules + decodes. If there's a pressure on GPU memory, decode requests can + be swapped or preempted. + """ + # Include running requests to the budget. + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + # Make sure we include num running seqs before scheduling prefill, + # so that we don't schedule beyond max_num_seqs for prefill. + for seq_group in self.running: + budget.add_num_seqs(seq_group.request_id, + seq_group.get_max_num_running_seqs()) + curr_loras = (set( + seq_group.lora_int_id for seq_group in self.running + if seq_group.lora_int_id > 0) if self.lora_enabled else None) + + prefills = SchedulerPrefillOutputs.create_empty() + running_scheduled = SchedulerRunningOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() + + # If any requests are swapped, prioritized swapped requests. + if not self.swapped: + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=False) + + if len(prefills.seq_groups + ) == 0 and self.scheduler_config.policy == "priority": + self._schedule_priority_preemption(budget) + + # Don't schedule decodes if prefills are scheduled. + # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running + # only contains decode requests, not chunked prefills. + if len(prefills.seq_groups) == 0: + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=False) + + # If any sequence group is preempted, do not swap in any sequence + # group. because it means there's no slot for new running requests. + if (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) == 0): + swapped_in = \ + self._schedule_swapped(budget, curr_loras) + + assert (budget.num_batched_tokens + <= self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting.extendleft(running_scheduled.preempted) + # Update new running requests. + if len(prefills.seq_groups) > 0: + self.running.extend([s.seq_group for s in prefills.seq_groups]) + + self.running.extend(running_scheduled.decode_seq_groups_list) + + if len(swapped_in.decode_seq_groups) > 0: + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + + # Update swapped requests. + self.swapped.extend(running_scheduled.swapped_out) + preempted = len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) + + # There should be no prefill from running queue because this policy + # doesn't allow chunked prefills. + assert len(running_scheduled.prefill_seq_groups) == 0 + assert len(swapped_in.prefill_seq_groups) == 0 + + # Merge lists + num_prefill_groups = len(prefills.seq_groups) + ignored_seq_groups_for_embeds = list[SequenceGroup]() + if num_prefill_groups > 0: + scheduled_seq_groups = prefills.seq_groups + scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) + ignored_seq_groups_for_embeds.clear() + else: + scheduled_seq_groups = running_scheduled.decode_seq_groups + if len(scheduled_seq_groups) > 0: + using_prompt_embeds = scheduled_seq_groups[ + 0].seq_group.uses_prompt_embeds() + ignored_seq_groups_for_embeds.clear() + indices_ignored = list[int]() + for i, schedule_seq_group in enumerate(scheduled_seq_groups): + if using_prompt_embeds !=\ + schedule_seq_group.seq_group.uses_prompt_embeds(): + ignored_seq_groups_for_embeds.append( + schedule_seq_group.seq_group) + indices_ignored.append(i) + if len(ignored_seq_groups_for_embeds) > 0: + scheduled_seq_groups = [ + group for i, group in enumerate(scheduled_seq_groups) + if i not in indices_ignored + ] + else: + ignored_seq_groups_for_embeds.clear() + + scheduled_seq_groups.extend(swapped_in.decode_seq_groups) + + blocks_to_copy = running_scheduled.blocks_to_copy + blocks_to_copy.extend(swapped_in.blocks_to_copy) + + ignored_seq_groups = prefills.ignored_seq_groups + ignored_seq_groups.extend(ignored_seq_groups_for_embeds) + ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) + + return SchedulerOutputs( + scheduled_seq_groups=scheduled_seq_groups, + num_prefill_groups=num_prefill_groups, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), + preempted=preempted, + ) + + def _schedule_chunked_prefill(self) -> SchedulerOutputs: + """Schedule queued requests. + + Chunked prefill allows to chunk prefill requests, batch them together + with decode requests. This policy 1. schedule as many decoding requests + as possible. 2. schedule chunked prefill requests that are not + finished. 3. schedule swapped request. 4. schedule new prefill + requests. + + The policy can sustain the high GPU utilization because it can put + prefill and decodes requests to the same batch, while it improves + inter token latency because decodes requests don't need to be blocked + by prefill requests. + """ + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + curr_loras: Set[int] = set() + + prefills = SchedulerPrefillOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() + + # Create partial prefill metadata + partial_prefill_metadata = PartialPrefillMetadata.from_queues( + running=self.running, + waiting=self.waiting, + scheduler_config=self.scheduler_config, + ) + + # Decoding should be always scheduled first by fcfs. + running_scheduled = self._schedule_running( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) + + # Schedule swapped out requests. + # If preemption happens, it means we don't have space for swap-in. + if len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) == 0: + swapped_in = self._schedule_swapped(budget, curr_loras) + + prefills = self._schedule_prefills( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) + + assert (budget.num_batched_tokens + <= self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting.extendleft(running_scheduled.preempted) + + # Update new running requests. + # By default, vLLM scheduler prioritizes prefills. + # Once chunked prefill is enabled, + # the policy is changed to prioritize decode requests. + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in swapped_in.prefill_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.decode_seq_groups]) + # Because multiple prefills may be running concurrently, we need to + # make sure that prefills which are scheduled to finish are listed + # before those that won't. This is so that on the next scheduling + # iteration when they have transitioned to the decode stage, they are + # properly prioritized over sequences that are still in the prefill + # stage. + self.running.extend( + self._order_finishing_prefills_first( + running_scheduled.prefill_seq_groups)) + self.running.extend([s.seq_group for s in prefills.seq_groups]) + + # Update swapped requests. + self.swapped.extend(running_scheduled.swapped_out) + # Put prefills first due to Attention backend ordering assumption. + scheduled_seq_groups = (prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups) + num_prefill_groups = (len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups)) + # If all prompts, then we set num_lookahead_slots to 0 + # this allows us to go through the `no_spec` path in + # `spec_decode_worker.py` + all_prefills = len(scheduled_seq_groups) == num_prefill_groups + num_lookahead_slots = (0 if + (all_prefills + and not self.scheduler_config.is_multi_step) + else running_scheduled.num_lookahead_slots) + return SchedulerOutputs( + scheduled_seq_groups=scheduled_seq_groups, + num_prefill_groups=num_prefill_groups, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, + num_lookahead_slots=num_lookahead_slots, + running_queue_size=len(self.running), + preempted=(len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)), + ) + + def _order_finishing_prefills_first( + self, scheduled_prefill_seqs: List[ScheduledSequenceGroup] + ) -> List[SequenceGroup]: + """Returns a list of prefilling SequenceGroups where sequences that are + scheduled to finish prefilling are listed first""" + finishing = [ + s.seq_group for s in scheduled_prefill_seqs + if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size + ] + not_finishing = [ + s.seq_group for s in scheduled_prefill_seqs + if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size + ] + return finishing + not_finishing + + def _schedule(self) -> SchedulerOutputs: + """Schedule queued requests.""" + if self.scheduler_config.chunked_prefill_enabled: + return self._schedule_chunked_prefill() + else: + return self._schedule_default() + + def _can_append_slots(self, seq_group: SequenceGroup, + enable_chunking: bool) -> bool: + """Determine whether or not we have enough space in the KV cache to + continue generation of the sequence group. + """ + # It is True only for testing case to trigger artificial preemption. + if (self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0): + self.artificial_preempt_cnt -= 1 + return False + + is_prefill = seq_group.is_prefill() + num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + if is_prefill and num_lookahead_slots > 0: + # Appending prefill slots only happens multi-step and + # chunked-prefill are enabled together. + assert self.scheduler_config.is_multi_step and enable_chunking + + return self.block_manager.can_append_slots( + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) + + def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: + # async_output_proc is allowed only when we have a single sequence + # in the sequence group + no_single_seq = seq_group.sampling_params is None or ( + seq_group.sampling_params.n == 1) + return no_single_seq + + def schedule( + self + ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: + # Schedule sequence groups. + # This function call changes the internal states of the scheduler + # such as self.running, self.swapped, and self.waiting. + scheduler_start_time = time.perf_counter() + + scheduler_outputs: SchedulerOutputs = self._schedule() + now = time.time() + + if not self.cache_config.enable_prefix_caching: + common_computed_block_nums = [] + + allow_async_output_proc: bool = self.use_async_output_proc + + # Create input data structures. + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + for i, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + token_chunk_size = scheduled_seq_group.token_chunk_size + seq_group.maybe_set_first_scheduled_time(now) + + seq_group_metadata = self._seq_group_metadata_cache[ + self.cache_id].get_object() + seq_group_metadata.seq_data.clear() + seq_group_metadata.block_tables.clear() + + # seq_id -> SequenceData + seq_data: Dict[int, SequenceData] = {} + # seq_id -> physical block numbers + block_tables: Dict[int, List[int]] = {} + + if seq_group.is_encoder_decoder(): + # Encoder associated with SequenceGroup + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + encoder_seq_data = encoder_seq.data + # Block table for cross-attention + # Also managed at SequenceGroup level + cross_block_table = self.block_manager.get_cross_block_table( + seq_group) + else: + encoder_seq_data = None + cross_block_table = None + + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq_id = seq.seq_id + seq_data[seq_id] = seq.data + block_tables[seq_id] = self.block_manager.get_block_table(seq) + self.block_manager.access_all_blocks_in_seq(seq, now) + + if self.cache_config.enable_prefix_caching: + common_computed_block_nums = ( + self.block_manager.get_common_computed_block_ids( + seq_group.get_seqs(status=SequenceStatus.RUNNING))) + + do_sample = True + is_prompt = seq_group.is_prefill() + # We should send the metadata to workers when the first prefill + # is sent. Subsequent requests could be chunked prefill or decode. + is_first_prefill = False + if is_prompt: + seqs = seq_group.get_seqs() + # Prefill has only 1 sequence. + assert len(seqs) == 1 + num_computed_tokens = seqs[0].data.get_num_computed_tokens() + is_first_prefill = num_computed_tokens == 0 + # In the next iteration, all prompt tokens are not computed. + # It means the prefill is chunked, and we don't need sampling. + # NOTE: We use get_len instead of get_prompt_len because when + # a sequence is preempted, prefill includes previous generated + # output tokens. + if (token_chunk_size + num_computed_tokens + < seqs[0].data.get_len()): + do_sample = False + + # It assumes the scheduled_seq_groups is ordered by + # prefill < decoding. + if is_first_prefill or not self.scheduler_config.send_delta_data: + seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group.request_id, + is_prompt=is_prompt, + seq_data=seq_data, + sampling_params=seq_group.sampling_params, + block_tables=block_tables, + do_sample=do_sample, + pooling_params=seq_group.pooling_params, + token_chunk_size=token_chunk_size, + lora_request=seq_group.lora_request, + computed_block_nums=common_computed_block_nums, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + state=seq_group.state, + token_type_ids=seq_group.token_type_ids, + # `multi_modal_data` will only be present for the 1st comm + # between engine and worker. + # the subsequent comms can still use delta, but + # `multi_modal_data` will be None. + multi_modal_data=(seq_group.multi_modal_data + if scheduler_outputs.num_prefill_groups + > 0 else None), + multi_modal_placeholders=( + seq_group.multi_modal_placeholders + if scheduler_outputs.num_prefill_groups > 0 else None), + prompt_adapter_request=seq_group.prompt_adapter_request, + ) + else: + # When SPMD mode is enabled, we only send delta data except for + # the first request to reduce serialization cost. + seq_data_delta = {} + for id, data in seq_data.items(): + seq_data_delta[id] = data.get_delta_and_reset() + seq_group_metadata = SequenceGroupMetadataDelta( + seq_data_delta, + seq_group.request_id, + block_tables, + is_prompt, + do_sample=do_sample, + token_chunk_size=token_chunk_size, + computed_block_nums=common_computed_block_nums, + ) + seq_group_metadata_list.append(seq_group_metadata) + + if allow_async_output_proc: + allow_async_output_proc = self._allow_async_output_proc( + seq_group) + + # Now that the batch has been created, we can assume all blocks in the + # batch will have been computed before the next scheduling invocation. + # This is because the engine assumes that a failure in model execution + # will crash the vLLM instance / will not retry. + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + self.block_manager.mark_blocks_as_computed( + scheduled_seq_group.seq_group, + scheduled_seq_group.token_chunk_size) + + self._seq_group_metadata_cache[self.next_cache_id].reset() + + scheduler_time = time.perf_counter() - scheduler_start_time + # Add this to scheduler time to all the sequences that are currently + # running. This will help estimate if the scheduler is a significant + # component in the e2e latency. + for seq_group in self.running: + if seq_group is not None and seq_group.metrics is not None: + if seq_group.metrics.scheduler_time is not None: + seq_group.metrics.scheduler_time += scheduler_time + else: + seq_group.metrics.scheduler_time = scheduler_time + + # Move to next cache (if exists) + self.cache_id = self.next_cache_id + + # Return results + return (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + + def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: + self.block_manager.fork(parent_seq, child_seq) + + def free_seq(self, seq: Sequence) -> None: + """Free a sequence from a block table.""" + self.block_manager.free(seq) + + def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: + """Free finished seqs in a sequence group.""" + for seq in seq_group.get_seqs(): + if seq.is_finished(): + self.free_seq(seq) + + def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None: + if seq_group.is_finished(): + # Free cross-attention block table, if it exists + self._free_seq_group_cross_attn_blocks(seq_group) + + # Add the finished requests to the finished requests list. + # This list will be used to update the Mamba cache in the + # next step. + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + + def free_finished_seq_groups(self) -> None: + remaining: Deque[SequenceGroup] = deque() + for seq_group in self.running: + self._free_finished_seq_group(seq_group) + if not seq_group.is_finished(): + remaining.append(seq_group) + + self.running = remaining + + # Handle async stopped sequence groups + # (ones that reached max model len) + if self._async_stopped: + for seq_group in self._async_stopped: + self._free_seq_group_cross_attn_blocks(seq_group) + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + + self._async_stopped.clear() + + def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: + self.block_manager.allocate(seq_group) + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + seq.status = SequenceStatus.RUNNING + + def _append_slots( + self, + seq_group: SequenceGroup, + blocks_to_copy: List[Tuple[int, int]], + enable_chunking: bool = False, + ) -> None: + """Appends new slots to the sequences in the given sequence group. + + Args: + seq_group (SequenceGroup): The sequence group containing the + sequences to append slots to. + blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two + ints, the first int is the source block index, and the second + int is the destination block index. This list is updated with + the new source and destination block indices for the appended + slots. + enable_chunking (bool): True if chunked prefill is enabled. + """ + is_prefill: bool = seq_group.is_prefill() + num_lookahead_slots: int = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config.num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking, + ) + + seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING + if self.scheduler_config.is_multi_step and enable_chunking: + # In multi-step chunked-prefill any sequence type can have + # slots appended. + seq_status = None + + for seq in seq_group.get_seqs(status=seq_status): + cows = self.block_manager.append_slots(seq, num_lookahead_slots) + if len(cows) > 0: + blocks_to_copy.extend(cows) + + def _preempt(self, seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]]) -> PreemptionMode: + # If preemption mode is not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + # FIXME(woosuk): This makes our scheduling policy a bit bizarre. + # As swapped sequences are prioritized over waiting sequences, + # sequence groups with multiple sequences are implicitly prioritized + # over sequence groups with a single sequence. + # TODO(woosuk): Support recomputation for sequence groups with multiple + # sequences. This may require a more sophisticated CUDA kernel. + if self.user_specified_preemption_mode is None: + if seq_group.get_max_num_running_seqs() == 1: + preemption_mode = PreemptionMode.RECOMPUTE + else: + preemption_mode = PreemptionMode.SWAP + + elif self.user_specified_preemption_mode == "swap": + preemption_mode = PreemptionMode.SWAP + else: + preemption_mode = PreemptionMode.RECOMPUTE + + if self.num_cumulative_preemption % 50 == 0: + logger.warning( + "Sequence group %s is preempted by %s mode because there is " + "not enough KV cache space. This can affect the end-to-end " + "performance. Increase gpu_memory_utilization or " + "tensor_parallel_size to provide more KV cache memory. " + "total_num_cumulative_preemption=%d", + seq_group.request_id, + preemption_mode, + self.num_cumulative_preemption + 1, + ) + self.num_cumulative_preemption += 1 + + if preemption_mode == PreemptionMode.RECOMPUTE: + self._preempt_by_recompute(seq_group) + elif preemption_mode == PreemptionMode.SWAP: + self._preempt_by_swap(seq_group, blocks_to_swap_out) + else: + raise AssertionError("Invalid preemption mode.") + return preemption_mode + + def _preempt_by_recompute( + self, + seq_group: SequenceGroup, + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + assert len(seqs) == 1 + for seq in seqs: + seq.status = SequenceStatus.WAITING + self.free_seq(seq) + seq.reset_state_for_recompute() + self._free_seq_group_cross_attn_blocks(seq_group) + + def _preempt_by_swap( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]], + ) -> None: + self._swap_out(seq_group, blocks_to_swap_out) + + def _swap_in( + self, + seq_group: SequenceGroup, + blocks_to_swap_in: List[Tuple[int, int]], + ) -> None: + mapping = self.block_manager.swap_in(seq_group) + blocks_to_swap_in.extend(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + seq.status = SequenceStatus.RUNNING + + def _swap_out( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]], + ) -> None: + if not self.block_manager.can_swap_out(seq_group): + # FIXME(woosuk): Abort the sequence group instead of aborting the + # entire engine. + raise RuntimeError( + "Aborted due to the lack of CPU swap space. Please increase " + "the swap space to avoid this error.") + mapping = self.block_manager.swap_out(seq_group) + blocks_to_swap_out.extend(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.SWAPPED + + def _passed_delay(self, now: float) -> bool: + if self.prev_prompt: + self.last_prompt_latency = now - self.prev_time + self.prev_time, self.prev_prompt = now, False + # Delay scheduling prompts to let waiting queue fill up + if self.scheduler_config.delay_factor > 0 and self.waiting: + earliest_arrival_time = min( + [e.metrics.arrival_time for e in self.waiting]) + passed_delay = ((now - earliest_arrival_time) + > (self.scheduler_config.delay_factor * + self.last_prompt_latency) or not self.running) + else: + passed_delay = True + return passed_delay + + def _get_num_lookahead_slots(self, is_prefill: bool, + enable_chunking: bool) -> int: + """The number of slots to allocate per sequence per step, beyond known + token ids. Speculative decoding uses these slots to store KV activations + of tokens which may or may not be accepted. + + Speculative decoding does not yet support prefill, so we do not perform + lookahead allocation for prefill. + + When chunking is enabled with multi-step, we allocate lookahead slots + for the prefills for when the prefills turn into decodes in the first + step. + """ + if is_prefill: + if self.scheduler_config.is_multi_step and enable_chunking: + # num_lookahead_slots was introduced in the context of decodes, + # in Speculative Decoding. + # When the num_scheduler_steps is 8, say, then the + # num_lookahead_slots is 7. Meaning, we are doing a 1-step of + # decode anyways and we wish to do 7 more. + # + # "lookaheads" for prefills, is introduced in support for + # Chunked-Prefill in Multi-Step. + return self.scheduler_config.num_lookahead_slots + 1 + else: + return 0 + + return self.scheduler_config.num_lookahead_slots + + def _get_num_new_uncached_and_cached_tokens( + self, + seq_group: SequenceGroup, + status: SequenceStatus, + enable_chunking: bool, + budget: SchedulingBudget, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> Tuple[int, int]: + """ + Returns the number of new uncached and cached tokens to schedule for a + given sequence group that's in a given `status`. + + The API could chunk the number of tokens to compute based on `budget` + if `enable_chunking` is True. If a sequence group has multiple + sequences (e.g., running beam search), it means it is in decoding + phase, so chunking doesn't happen. + + Returns (0, 0) if the new token cannot be computed due to token budget. + + The cached tokens's blocks are already computed, and the attention + backend will reuse the cached blocks rather than recomputing them. So + the scheduler could schedule these cached tokens "for free". + + Args: + seq_group: The sequence group to get the number of new tokens to + schedule. + status: The status of the sequences to get the number of new tokens + to schedule. + enable_chunking: Whether to chunk the number of tokens to compute. + budget: The budget to chunk the number of tokens to compute. + partial_prefill_metadata: information about the partial prefills + that are currently running + + + Returns: + A tuple of two ints. The first int is the number of new uncached + tokens to schedule. The second int is the number of cached tokens. + If no more new tokens can be scheduled, returns (0, 0). + """ + num_cached_new_tokens = 0 + num_uncached_new_tokens = 0 + + seqs = seq_group.get_seqs(status=status) + # Compute the number of new uncached and cached tokens for + # each sequence. + for seq in seqs: + if not seq.is_prefill(): + # Decode sequences should always just have 1 uncached token + # TODO(rickyx): Actually is this still correct for multi-step? + num_uncached_new_tokens += 1 + continue + + num_computed_tokens_seq = seq.get_num_computed_tokens() + all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq + if not self.cache_config.enable_prefix_caching: + # If prefix caching is not enabled, all new tokens are uncached. + num_uncached_new_tokens += all_num_new_tokens_seq + continue + + # NOTE: the cache token might be currently in a block that's in an + # evictor meaning that it's not yet allocated. However, we don't + # exclude such tokens in the cache count because it will be + # guaranteed to be allocated later if the sequence can be allocated. + num_cached_tokens_seq = self.block_manager.get_num_cached_tokens( + seq) + + # Sanity check. + if num_cached_tokens_seq < num_computed_tokens_seq: + # This should only happen with chunked prefill, and + # the seq is still in prefill. The `num_cached_tokens_seq` + # is the value we calculated on scheduling the first prefill. + # For subsequent continuous prefill steps, we cached the + # number of cache tokens for the sequence so the cached token + # count could be less than the number of computed tokens. + # See comments on `ComputedBlocksTracker` for more details. + assert ( + seq.is_prefill() and seq.status == SequenceStatus.RUNNING + and self.scheduler_config.chunked_prefill_enabled + ), ("Number of cached tokens should not be less than the " + "number of computed tokens for a sequence that's still " + f"in prefill. But there are {num_cached_tokens_seq} cached " + f"tokens and {num_computed_tokens_seq} computed tokens " + f"for sequence {seq.seq_id}.") + + num_cached_new_tokens_seq = max( + 0, num_cached_tokens_seq - num_computed_tokens_seq) + num_uncached_new_tokens_seq = (all_num_new_tokens_seq - + num_cached_new_tokens_seq) + + num_uncached_new_tokens += num_uncached_new_tokens_seq + num_cached_new_tokens += num_cached_new_tokens_seq + + if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0: + # For a fully cached hit sequence, we actually need to recompute the + # last token. So we need at least 1 uncached token to schedule. + # See ModelRunner._compute_for_prefix_cache_hit for more details. + num_uncached_new_tokens = 1 + num_cached_new_tokens -= 1 + + if enable_chunking and len(seqs) == 1: + # Chunk if a running request cannot fit in the given budget. + # If number of seq > 1, it means it is doing beam search + # in a decode phase. Do not chunk. + num_uncached_new_tokens = self._chunk_new_tokens_to_schedule( + self.scheduler_config, + self.cache_config, + budget, + self._get_prompt_limit(seq_group), + num_uncached_new_tokens, + self.partial_prefill_budget_lookup_list, + partial_prefill_metadata, + ) + + return num_uncached_new_tokens, num_cached_new_tokens + + @staticmethod + def _chunk_new_tokens_to_schedule( + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + budget: SchedulingBudget, + prompt_limit: int, + num_new_tokens: int, + partial_prefill_budget_lookup_list: List[int], + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> int: + """ + Chunks the number of new tokens to schedule based on the budget when + chunked prefill is enabled. + + Args: + scheduler_config: The scheduler config. + cache_config: The cache config. + budget: The budget to chunk the number of tokens to compute. + prompt_limit: The maximum number of tokens allowed in a prompt. + num_new_tokens: The number of new tokens to schedule. + + Returns: + The number of new tokens to schedule after chunking. + """ + remaining_token_budget = budget.remaining_token_budget() + if scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > prompt_limit: + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + return num_new_tokens + + return 0 if num_new_tokens > \ + remaining_token_budget else num_new_tokens + + # Get the number of tokens to allocate to this prefill slot + prefill_slot_budget = ( + remaining_token_budget if partial_prefill_metadata is None else + partial_prefill_budget_lookup_list[ + partial_prefill_metadata.schedulable_prefills]) + + if cache_config.enable_prefix_caching: + # When prefix caching is enabled and we're partially prefilling + # a sequence, we always allocate a number of new tokens that is + # divisible by the block size to avoid partial block matching. + block_size = cache_config.block_size + # Don't exceed either the total budget or slot budget. + # Take min of those and get the next lowest multiple of the + # block size: + remaining_token_budget = ( + min(remaining_token_budget, prefill_slot_budget) // + block_size) * block_size + # NB: In the case where num_new_tokens < budget, we are + # finishing prefill for this sequence, so we do not need to + # allocate a full block. + + num_new_tokens = min(num_new_tokens, remaining_token_budget, + prefill_slot_budget) + + return num_new_tokens diff --git a/cumem_allocator.abi3.so b/cumem_allocator.abi3.so new file mode 100755 index 0000000..589985d Binary files /dev/null and b/cumem_allocator.abi3.so differ diff --git a/device_allocator/__init__.py b/device_allocator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/device_allocator/cumem.py b/device_allocator/cumem.py new file mode 100644 index 0000000..942e866 --- /dev/null +++ b/device_allocator/cumem.py @@ -0,0 +1,281 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# cumem-based pytorch pluggable allocator to implement sleep mode. +# other approaches tried but failed: +# - cuda-python package binding +# - custom libcuda driver ctypes wrapper +# both of them failed because of cuda context mismatch. +# not sure why, they are created from a different context. +# the only successful approach is to call cuda driver API in C. +import dataclasses +import gc +import os +from contextlib import contextmanager +from typing import Any, Callable, Optional, Union + +import torch + +from vllm.utils import is_pin_memory_available + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found_line = None + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found_line = line + break + if found_line is None: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = found_line.index("/") + path = found_line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +cumem_available = False +try: + from vllm.cumem_allocator import (init_module, python_create_and_map, + python_unmap_and_release) + from vllm.distributed.device_communicators.cuda_wrapper import ( + CudaRTLibrary) + lib_name = find_loaded_library("cumem_allocator") + libcudart = CudaRTLibrary() + cumem_available = True +except ModuleNotFoundError: + # rocm platform does not support cumem allocator + init_module = None + python_create_and_map = None + python_unmap_and_release = None + CudaRTLibrary = None + lib_name = None + libcudart = None + +# py_device, py_alignedSize, py_d_mem, py_p_memHandle +HandleType = tuple[int, int, int, int] + + +@dataclasses.dataclass +class AllocationData: + handle: HandleType + tag: str + cpu_backup_tensor: Optional[torch.Tensor] = None + + +def create_and_map(allocation_handle: HandleType) -> None: + python_create_and_map(*allocation_handle) + + +def unmap_and_release(allocation_handle: HandleType) -> None: + python_unmap_and_release(*allocation_handle) + + +def get_pluggable_allocator( + python_malloc_fn: Callable[[int], + int], python_free_func: Callable[[int, int], + None] +) -> torch.cuda.memory.CUDAPluggableAllocator: + init_module(python_malloc_fn, python_free_func) + new_alloc = torch.cuda.memory.CUDAPluggableAllocator( + lib_name, 'my_malloc', 'my_free') + return new_alloc + + +@contextmanager +def use_memory_pool_with_allocator( + python_malloc_fn: Callable[[int], int], + python_free_func: Callable[[int, int], None]) -> None: + new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) + mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) + with torch.cuda.memory.use_mem_pool(mem_pool): + yield mem_pool, new_alloc + + +class CuMemAllocator: + """ + A singleton class that manages a memory pool for CUDA tensors. + The memory in this pool can be offloaded or discarded when the + allocator sleeps. + + Inside the `use_memory_pool(tag)` context, all tensors created will + be allocated in the memory pool, and has the same tag as the + tag passed to the context. + + When we call `sleep`, all tensors with the specified tag will be + offloaded to CPU memory, and the rest of the tensors will be discarded. + When we call `wake_up`, all tensors that are previously offloaded + will be loaded back to GPU memory, and the rest of the tensors will + have empty memory. + + Why it needs to be a singleton? + When allocated tensors are garbage collected, PyTorch will call + the free callback, which will call the `python_free_callback` method. + The C-extension uses a global variable to store the function of an + instance of this class. If we create multiple instances of this class, + the global variable will be overwritten and the free callback will + not work as expected. + """ + instance: "CuMemAllocator" = None + default_tag: str = "default" + + @staticmethod + def get_instance() -> "CuMemAllocator": + """ + CuMemAllocator is a singleton class. + We cannot call the constructor directly. + Call this method to get the instance. + """ + assert cumem_available, "cumem allocator is not available" + if CuMemAllocator.instance is None: + CuMemAllocator.instance = CuMemAllocator() + return CuMemAllocator.instance + + def __init__(self): + conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") + assert "expandable_segments:True" not in conf, \ + ("Expandable segments are not compatible with memory pool. " + "Please track https://github.com/pytorch/pytorch/issues/147851 " + "for the latest updates.") + + self.pointer_to_data: dict[int, AllocationData] = {} + self.current_tag: str = CuMemAllocator.default_tag + self.allocator_and_pools: dict[str, Any] = {} + + def python_malloc_callback(self, allocation_handle: HandleType) -> None: + """ + Internal method to store the allocation data + when memory is allocated in the memory pool.""" + py_d_mem = allocation_handle[2] + self.pointer_to_data[py_d_mem] = AllocationData( + allocation_handle, self.current_tag) + return + + def python_free_callback(self, ptr: int) -> HandleType: + """ + Internal method to look up the allocation data + when memory is freed in the memory pool.""" + data = self.pointer_to_data.pop(ptr) + if data.cpu_backup_tensor is not None: + data.cpu_backup_tensor = None + return data.handle + + def sleep( + self, + offload_tags: Optional[Union[tuple[str, ...], + str]] = None) -> None: + """ + Put the allocator in sleep mode. + All data in the memory allocation with the specified tag will be + offloaded to CPU memory, and others will be discarded. + + :param offload_tags: The tags of the memory allocation that will be + offloaded. The rest of the memory allocation will be discarded. + """ + if offload_tags is None: + # by default, allocated tensors are offloaded + # when the allocator sleeps + offload_tags = (CuMemAllocator.default_tag, ) + elif isinstance(offload_tags, str): + offload_tags = (offload_tags, ) + + assert isinstance(offload_tags, tuple) + + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + if data.tag in offload_tags: + size_in_bytes = handle[1] + cpu_backup_tensor = torch.empty( + size_in_bytes, + dtype=torch.uint8, + device='cpu', + pin_memory=is_pin_memory_available()) + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) + data.cpu_backup_tensor = cpu_backup_tensor + unmap_and_release(handle) + + gc.collect() + torch.cuda.empty_cache() + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + """ + Wake up the allocator from sleep mode. + All data that is previously offloaded will be loaded back to GPU + memory, and the rest of the data will have empty memory. + + :param tags: The tags of the memory allocation that will be loaded + back to GPU memory. If None, all memory allocation will be loaded + back to GPU memory. + """ + for ptr, data in self.pointer_to_data.items(): + if tags is None or data.tag in tags: + handle = data.handle + create_and_map(handle) + if data.cpu_backup_tensor is not None: + cpu_backup_tensor = data.cpu_backup_tensor + if cpu_backup_tensor is not None: + size_in_bytes = cpu_backup_tensor.numel( + ) * cpu_backup_tensor.element_size() + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) + data.cpu_backup_tensor = None + + @contextmanager + def use_memory_pool(self, tag: Optional[str] = None): + """ + A context manager to use the memory pool. + All memory allocation created inside the context will be allocated + in the memory pool, and has the specified tag. + + :param tag: The tag of the memory allocation. If None, the default tag + will be used. + """ + if tag is None: + tag = CuMemAllocator.default_tag + + assert isinstance(tag, str) + + old_tag = self.current_tag + self.current_tag = tag + with use_memory_pool_with_allocator(self.python_malloc_callback, + self.python_free_callback) as data: + # start to hit another PyTorch bug in PyTorch 2.6, + # possibly because of gc-related issue w.r.t. the allocator and + # the memory pool. + # to avoid the issue, we keep a reference of the data. + # see https://github.com/pytorch/pytorch/issues/146431 . + self.allocator_and_pools[tag] = data + yield + # PyTorch's bug, calling torch.cuda.empty_cache() will error + # when using pluggable allocator, see + # https://github.com/pytorch/pytorch/issues/145168 . + # if we have some memory allocated and then freed, + # the memory will not be released. + # right now it is fine, because we only use this allocator + # during weight loading and kv cache creation, where we only + # allocate memory. + # TODO: we need to find a way to release the memory, + # i.e. calling torch.cuda.empty_cache() + self.current_tag = old_tag + + def get_current_usage(self) -> int: + """ + Get the total number of bytes allocated in the memory pool. + """ + sum_bytes: int = 0 + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + sum_bytes += handle[1] + return sum_bytes diff --git a/distributed/__init__.py b/distributed/__init__.py new file mode 100644 index 0000000..e911b2a --- /dev/null +++ b/distributed/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .communication_op import * +from .parallel_state import * +from .utils import * diff --git a/distributed/communication_op.py b/distributed/communication_op.py new file mode 100644 index 0000000..0a5a951 --- /dev/null +++ b/distributed/communication_op.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional, Union + +import torch +import torch.distributed + +from .parallel_state import get_tp_group + + +def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + return get_tp_group().all_reduce(input_) + + +def tensor_model_parallel_all_gather(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + return get_tp_group().all_gather(input_, dim) + + +def tensor_model_parallel_reduce_scatter(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """Reduce-Scatter the input tensor across model parallel group.""" + return get_tp_group().reduce_scatter(input_, dim) + + +def tensor_model_parallel_gather(input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """Gather the input tensor across model parallel group.""" + return get_tp_group().gather(input_, dst, dim) + + +def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor, + Any]]] = None, + src: int = 0): + if not torch.distributed.is_initialized(): + return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, src) diff --git a/distributed/device_communicators/__init__.py b/distributed/device_communicators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/distributed/device_communicators/all2all.py b/distributed/device_communicators/all2all.py new file mode 100644 index 0000000..35f2fd0 --- /dev/null +++ b/distributed/device_communicators/all2all.py @@ -0,0 +1,264 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib.util +from typing import TYPE_CHECKING, Any + +import torch +import torch.distributed as dist + +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger + +from .base_device_communicator import All2AllManagerBase, Cache + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE +else: + FusedMoE = None + + +class NaiveAll2AllManager(All2AllManagerBase): + """ + A naive implementation of all2all communication. + It uses all-reduce under the hood, which is not + efficient at all. The main purpose is for testing and + debugging. + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def naive_multicast(self, x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor): + assert (len(x.shape) == 2) + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + buffer[start:end, :].copy_(x) + for idx in range(self.dp_world_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + self.dp_group.broadcast(buffer[start:end, :], idx) + + return buffer + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_cpu) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + + all_hidden_states = self.dp_group.all_reduce(hidden_states) + hidden_states = all_hidden_states[start:end, :] + return hidden_states + + def destroy(self): + pass + + +class PPLXAll2AllManager(All2AllManagerBase): + """ + All2All communication based on PPLX kernels. + """ + + def __init__(self, cpu_group): + has_pplx = importlib.util.find_spec("pplx_kernels") is not None + assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa + super().__init__(cpu_group) + + if self.internode: + # inter-node communication needs nvshmem, + # intra-node communication uses p2p mapping directly + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, + nvshmem_init) + logger.debug( + "Initialize NVSHMEM for pplx_kernels: " + "rank=%d, world size=%d", self.rank, self.world_size) + uid = nvshmem_get_unique_id( + ) if self.rank == 0 else nvshmem_alloc_empty_unique_id() + dist.broadcast(uid, + src=dist.get_process_group_ranks(self.cpu_group)[0], + group=self.cpu_group) + logger.debug("PPLX NVSHMEM UID = %s", uid) + nvshmem_init(uid, self.rank, self.world_size) + + self.handle_cache = Cache() + + def get_handle(self, kwargs): + import pplx_kernels as pplx + return self.handle_cache.get_or_create( + kwargs, pplx.AllToAll.internode + if self.internode else pplx.AllToAll.intranode) + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + with self.handle_cache._lock: + for _, handle in self.handle_cache._cache.items(): + handle.destroy() + + if self.internode: + from pplx_kernels.nvshmem import nvshmem_finalize + logger.debug("PPLX NVSHMEM finalize") + nvshmem_finalize() + + +class DeepEPAll2AllManagerBase(All2AllManagerBase): + """ + All2All communication based on DeepEP High-Throughput kernels. + """ + + def __init__(self, cpu_group): + has_deepep = importlib.util.find_spec("deep_ep") is not None + assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa + super().__init__(cpu_group) + self.handle_cache = Cache() + + # This is the DeepEP default. Stick to it till we can establish + # reasonable defaults based on profiling. + self.num_sms = 20 + + def get_handle(self, kwargs): + raise NotImplementedError + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + pass + + +class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): + """ + All2All communication based on DeepEP High-Throughput kernels. + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def _make_all2all_kwargs(self) -> dict[Any, Any]: + # Defaults for internode and intranode are taken from DeepEP tests. + num_nvl_bytes = 1024 * 1024 * 1024 + num_rdma_bytes = None + num_qps_per_rank = None + + if self.internode: + num_rdma_bytes = 1024 * 1024 * 1024 + num_qps_per_rank = self.num_sms // 2 + else: + num_rdma_bytes = 0 + num_qps_per_rank = 1 + + assert num_rdma_bytes is not None + assert num_qps_per_rank is not None + return dict(group=self.cpu_group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=False, + num_qps_per_rank=num_qps_per_rank) + + def get_handle(self, kwargs): + + assert len(kwargs) == 0, ( + "DeepEPHTAll2AllManager expects no arguments. All the required " + "args are computed in the Manager itself.") + + import deep_ep + buffer_kwargs = self._make_all2all_kwargs() + logger.debug("DeepEP all2all args %s", buffer_kwargs) + handle: deep_ep.Buffer = self.handle_cache.get_or_create( + buffer_kwargs, deep_ep.Buffer) + # It is dangerous to set num sms outside this function. num_sms is not + # a part of the hash-key that identifies this object. If we are in a + # situation where we make objects with different num_sms, the hash key + # in get_or_create must be updated. + handle.set_num_sms(self.num_sms) + return handle + + +class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): + """ + All2All communication based on DeepEP Low-Latency kernels. + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def _make_all2all_kwargs( + self, + max_num_tokens_per_dp_rank: int, + token_hidden_size: int, + num_ep_ranks: int, + num_global_experts: int, + num_local_experts: int, + ) -> dict[Any, Any]: + """ + max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank + can dispatch all the ranks must hold the same value. + token_hidden_size: the hidden dimension of each token. + num_ep_ranks: the number of EP group ranks. + num_global_experts: Number of experts in the model. + num_local_experts: Number of experts in an EP rank. + """ + import deep_ep + + # Defaults for internode and intranode are taken from DeepEP tests. + num_nvl_bytes = 1024 * 1024 * 1024 + num_qps_per_rank = num_local_experts + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, + hidden=token_hidden_size, + num_ranks=num_ep_ranks, + num_experts=num_global_experts) + + assert num_rdma_bytes is not None + return dict(group=self.cpu_group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_qps_per_rank) + + def get_handle(self, kwargs): + """ + The kwargs for DeepEPLLAll2AllManager is dictated by + _make_all2all_kwargs. + """ + import deep_ep + buffer_kwargs = self._make_all2all_kwargs(**kwargs) + logger.debug("DeepEP all2all args %s", buffer_kwargs) + handle: deep_ep.Buffer = self.handle_cache.get_or_create( + buffer_kwargs, deep_ep.Buffer) + # It is dangerous to set num sms outside this function. num_sms is not + # a part of the hash-key that identifies this object. If we are in a + # situation where we make objects with different num_sms, the hash key + # in get_or_create must be updated. + handle.set_num_sms(self.num_sms) + return handle diff --git a/distributed/device_communicators/base_device_communicator.py b/distributed/device_communicators/base_device_communicator.py new file mode 100644 index 0000000..1bc2d8e --- /dev/null +++ b/distributed/device_communicators/base_device_communicator.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +from typing import Optional +from weakref import WeakValueDictionary + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +class Cache: + + def __init__(self): + self._cache: WeakValueDictionary = WeakValueDictionary() + self._lock = threading.RLock() # Reentrant lock for thread safety + + def get_or_create(self, kwargs, func): + # Create a hashable key from the kwargs + key = tuple(sorted((k, v) for k, v in kwargs.items())) + + with self._lock: + instance = self._cache.get(key) + if instance is None: + instance = func(**kwargs) + self._cache[key] = instance + return instance + + +class All2AllManagerBase: + + def __init__(self, cpu_group): + self.cpu_group = cpu_group + + # compute some common properties + from vllm.distributed.parallel_state import (get_dp_group, + get_tp_group, + in_the_same_node_as) + + # all2all lives in ep group, which is merged from dp and tp group + self.dp_group = get_dp_group() + self.tp_group = get_tp_group() + # no self.ep_group since self.ep_group is still in construction + # when we create this object + self.dp_rank = self.dp_group.rank_in_group + self.dp_world_size = self.dp_group.world_size + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + + # all2all communication often has separate implementations for + # intra-node and inter-node communication + self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) + + def get_handle(self, kwargs): + # get a handle for the all2all communication, + # based on the kwargs. + # different layers can have different configs, + # e.g. one layer has hidden size 1024, another has 2048. + # usually the underlying implementation caches the handle + # and reuse it for the same config. + raise NotImplementedError + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + pass + + +class DeviceCommunicatorBase: + """ + Base class for device-specific communicator. + It can use the `cpu_group` to initialize the communicator. + If the device has PyTorch integration (PyTorch can recognize its + communication backend), the `device_group` will also be given. + """ + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + self.device = device or torch.device("cpu") + self.cpu_group = cpu_group + self.device_group = device_group + self.unique_name = unique_name + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + self.ranks = dist.get_process_group_ranks(cpu_group) + self.global_rank = dist.get_rank() + self.global_world_size = dist.get_world_size() + self.rank_in_group = dist.get_group_rank(self.cpu_group, + self.global_rank) + + use_ep = False + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None: + # as long as we use data parallel (coupled data parallel + # where all data parallel ranks execute forward together), + # we initialize the all2all manager used in expert parallel. + use_ep = config.parallel_config.data_parallel_size > 1 + + self.use_all2all = "ep" in unique_name and use_ep + self.all2all_manager: Optional[All2AllManagerBase] = None + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + dist.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * self.world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + # Reshape + output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output_tensor = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + # Perform reduce-scatter operation + torch.distributed.reduce_scatter_tensor(output_tensor, + input_tensor, + group=self.device_group) + + # Reshape before returning + return output_tensor.movedim(0, dim).contiguous() + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + pass + + def prepare_communication_buffer_for_model(self, + model: torch.nn.Module) -> None: + """ + Prepare the communication buffer for the model. + """ + if not self.use_all2all: + return + + moe_modules = [ + module for module in model.modules() + if module.__class__.__name__ == "FusedMoE" + ] + for module in moe_modules: + module.quant_method.init_prepare_finalize(module.moe_config, + module.quant_config) + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Dispatch the hidden states and router logits to the appropriate device. + This is a no-op in the base class. + """ + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Combine the hidden states and router logits from the appropriate device. + This is a no-op in the base class. + """ + return hidden_states diff --git a/distributed/device_communicators/cpu_communicator.py b/distributed/device_communicators/cpu_communicator.py new file mode 100644 index 0000000..94effa0 --- /dev/null +++ b/distributed/device_communicators/cpu_communicator.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Optional + +import torch +from torch.distributed import ProcessGroup + +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum + +from .base_device_communicator import DeviceCommunicatorBase + + +class CpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + self.dist_module = torch.distributed + + if (current_platform.get_cpu_architecture() + == CpuArchEnum.X86) and hasattr( + torch.ops._C, + "init_shm_manager") and unique_name.startswith("tp"): + self.dist_module = _CPUSHMDistributed(self) + + def all_reduce(self, input_): + self.dist_module.all_reduce(input_, group=self.device_group) + return input_ + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + + # Gather. + self.dist_module.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group) + + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * self.world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + self.dist_module.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + + # Reshape + output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + +class _CPUSHMDistributed: + + def __init__(self, communicator: CpuCommunicator): + instance_identifier = os.environ["VLLM_DIST_IDENT"] + unique_name = communicator.unique_name + instance_identifier = f"{instance_identifier}-{unique_name}" + self.communicator = communicator + + group_ranks = [str(rank) for rank in self.communicator.ranks] + shm_group_identifier = f"[{'-'.join(group_ranks)}]" + self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm" + + self.handle = self._init_cpu_shm() + + def _init_cpu_shm(self) -> int: + handle = torch.ops._C.init_shm_manager( + self.group_name, + self.communicator.world_size, + self.communicator.rank, + ) + torch.distributed.barrier(self.communicator.device_group) + torch.ops._C.join_shm_manager( + handle, + self.group_name, + ) + torch.distributed.barrier(self.communicator.device_group) + + return handle + + def all_reduce(self, + input: torch.Tensor, + group: Optional[ProcessGroup] = None) -> None: + torch.ops._C.shm_allreduce(self.handle, input) + + def gather(self, + input: torch.Tensor, + gather_list: Optional[list[torch.Tensor]], + dst: int = -1, + group: Optional[ProcessGroup] = None) -> None: + # Note: different from the torch gather, here we use local dst rank. + torch.ops._C.shm_gather(self.handle, input, gather_list, + torch.distributed.get_group_rank(group, dst)) + + def all_gather_into_tensor(self, + output: torch.Tensor, + input: torch.Tensor, + group: Optional[ProcessGroup] = None) -> None: + torch.ops._C.shm_all_gather(self.handle, input, output) diff --git a/distributed/device_communicators/cuda_communicator.py b/distributed/device_communicators/cuda_communicator.py new file mode 100644 index 0000000..055d916 --- /dev/null +++ b/distributed/device_communicators/cuda_communicator.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm.logger import init_logger + +from .base_device_communicator import DeviceCommunicatorBase + +logger = init_logger(__name__) + + +class CudaCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + if "tp" not in unique_name: + # only tp uses custom allreduce + use_custom_allreduce = False + else: + from vllm.distributed.parallel_state import ( + _ENABLE_CUSTOM_ALL_REDUCE) + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + + # ep does not use pynccl + use_pynccl = "ep" not in unique_name + + self.use_pynccl = use_pynccl + self.use_custom_allreduce = use_custom_allreduce + + # lazy import to avoid documentation build error + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator) + + self.pynccl_comm: Optional[PyNcclCommunicator] = None + if use_pynccl and self.world_size > 1: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + + self.ca_comm: Optional[CustomAllreduce] = None + if use_custom_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + + if self.use_all2all: + all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if all2all_backend == "naive": + from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + logger.info("Using naive all2all manager.") + elif all2all_backend == "pplx": + from .all2all import PPLXAll2AllManager + self.all2all_manager = PPLXAll2AllManager(self.cpu_group) + logger.info("Using PPLX all2all manager.") + elif all2all_backend == "deepep_high_throughput": + from .all2all import DeepEPHTAll2AllManager + self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group) + logger.info("Using DeepEP High-Throughput all2all manager.") + elif all2all_backend == "deepep_low_latency": + from .all2all import DeepEPLLAll2AllManager + self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) + logger.info("Using DeepEP Low-Latency all2all manager.") + else: + raise ValueError(f"Unknown all2all backend: {all2all_backend}") + + def all_reduce(self, input_): + # always try custom allreduce first, + # and then pynccl. + ca_comm = self.ca_comm + if ca_comm is not None and not ca_comm.disabled and \ + ca_comm.should_custom_ar(input_): + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + out = pynccl_comm.all_reduce(input_) + if out is None: + # fall back to the default all-reduce using PyTorch. + # this usually happens during testing. + # when we run the model, allreduce only happens for the TP + # group, where we always have either custom allreduce or pynccl. + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group) + return out + + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + pynccl_comm.reduce_scatter(output, input_) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + if self.pynccl_comm is not None: + self.pynccl_comm = None + if self.ca_comm is not None: + self.ca_comm = None + if self.all2all_manager is not None: + self.all2all_manager.destroy() + self.all2all_manager = None + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert self.all2all_manager is not None + hidden_states, router_logits = self.all2all_manager.dispatch( + hidden_states, router_logits) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert self.all2all_manager is not None + hidden_states = self.all2all_manager.combine(hidden_states) + return hidden_states diff --git a/distributed/device_communicators/cuda_wrapper.py b/distributed/device_communicators/cuda_wrapper.py new file mode 100644 index 0000000..0052ba0 --- /dev/null +++ b/distributed/device_communicators/cuda_wrapper.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""This file is a pure Python wrapper for the cudart library. +It avoids the need to compile a separate shared library, and is +convenient for use when we just need to call a few functions. +""" + +import ctypes +from dataclasses import dataclass +from typing import Any, Optional + +# this line makes it possible to directly load `libcudart.so` using `ctypes` +import torch # noqa + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# === export types and functions from cudart to Python === +# for the original cudart definition, please check +# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html + +cudaError_t = ctypes.c_int +cudaMemcpyKind = ctypes.c_int + + +class cudaIpcMemHandle_t(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +@dataclass +class Function: + name: str + restype: Any + argtypes: list[Any] + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = line.index("/") + path = line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +class CudaRTLibrary: + exported_functions = [ + # ​cudaError_t cudaSetDevice ( int device ) + Function("mcSetDevice", cudaError_t, [ctypes.c_int]), + # cudaError_t cudaDeviceSynchronize ( void ) + Function("mcDeviceSynchronize", cudaError_t, []), + # ​cudaError_t cudaDeviceReset ( void ) + Function("mcDeviceReset", cudaError_t, []), + + # const char* cudaGetErrorString ( cudaError_t error ) + Function("mcGetErrorString", ctypes.c_char_p, [cudaError_t]), + + # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) + Function("mcMalloc", cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), + # ​cudaError_t cudaFree ( void* devPtr ) + Function("mcFree", cudaError_t, [ctypes.c_void_p]), + # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) + Function("mcMemset", cudaError_t, + [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), + # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa + Function("mcMemcpy", cudaError_t, [ + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind + ]), + + # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa + Function("mcIpcGetMemHandle", cudaError_t, + [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), + # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa + Function("mcIpcOpenMemHandle", cudaError_t, [ + ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint + ]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: dict[str, dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + if so_file is None: + so_file = find_loaded_library("libmcruntime") + if so_file is None: + so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var + assert so_file is not None, \ + ( + "libcudart is not loaded in the current process, " + "try setting VLLM_CUDART_SO_PATH" + ) + if so_file not in CudaRTLibrary.path_to_library_cache: + lib = ctypes.CDLL(so_file) + CudaRTLibrary.path_to_library_cache[so_file] = lib + self.lib = CudaRTLibrary.path_to_library_cache[so_file] + + if so_file not in CudaRTLibrary.path_to_dict_mapping: + _funcs = {} + for func in CudaRTLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs + self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] + + def CUDART_CHECK(self, result: cudaError_t) -> None: + if result != 0: + error_str = self.cudaGetErrorString(result) + raise RuntimeError(f"CUDART error: {error_str}") + + def cudaGetErrorString(self, error: cudaError_t) -> str: + return self.funcs["mcGetErrorString"](error).decode("utf-8") + + def cudaSetDevice(self, device: int) -> None: + self.CUDART_CHECK(self.funcs["mcSetDevice"](device)) + + def cudaDeviceSynchronize(self) -> None: + self.CUDART_CHECK(self.funcs["mcDeviceSynchronize"]()) + + def cudaDeviceReset(self) -> None: + self.CUDART_CHECK(self.funcs["mcDeviceReset"]()) + + def cudaMalloc(self, size: int) -> ctypes.c_void_p: + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["mcMalloc"](ctypes.byref(devPtr), size)) + return devPtr + + def cudaFree(self, devPtr: ctypes.c_void_p) -> None: + self.CUDART_CHECK(self.funcs["mcFree"](devPtr)) + + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, + count: int) -> None: + self.CUDART_CHECK(self.funcs["mcMemset"](devPtr, value, count)) + + def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, + count: int) -> None: + cudaMemcpyDefault = 4 + kind = cudaMemcpyDefault + self.CUDART_CHECK(self.funcs["mcMemcpy"](dst, src, count, kind)) + + def cudaIpcGetMemHandle(self, + devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: + handle = cudaIpcMemHandle_t() + self.CUDART_CHECK(self.funcs["mcIpcGetMemHandle"]( + ctypes.byref(handle), devPtr)) + return handle + + def cudaIpcOpenMemHandle(self, + handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: + cudaIpcMemLazyEnablePeerAccess = 1 + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["mcIpcOpenMemHandle"]( + ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) + return devPtr diff --git a/distributed/device_communicators/custom_all_reduce.py b/distributed/device_communicators/custom_all_reduce.py new file mode 100644 index 0000000..7dd104a --- /dev/null +++ b/distributed/device_communicators/custom_all_reduce.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.distributed.device_communicators.custom_all_reduce_utils import ( + gpu_p2p_access_check) +from vllm.distributed.parallel_state import in_the_same_node_as +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import cuda_device_count_stateless + +try: + ops.meta_size() + custom_ar = True +except Exception: + # For CPUs + custom_ar = False + +logger = init_logger(__name__) + + +def _can_p2p(rank: int, world_size: int) -> bool: + for i in range(world_size): + if i == rank: + continue + if envs.VLLM_SKIP_P2P_CHECK: + logger.info( + "Skipping P2P check and trusting the driver's P2P report.") + return torch.cuda.can_device_access_peer(rank, i) + if not gpu_p2p_access_check(rank, i): + return False + return True + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + +class CustomAllreduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + + # max_size: max supported allreduce size + def __init__(self, + group: ProcessGroup, + device: Union[int, str, torch.device], + max_size=8192 * 1024) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if not custom_ar: + # disable because of missing custom allreduce library + # e.g. in a non-GPU environment + logger.info("Custom allreduce is disabled because " + "of missing custom allreduce library") + return + + self.group = group + + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "CustomAllreduce should be attached to a non-NCCL group.") + + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize custom allreduce for multi-node case. + logger.warning( + "Custom allreduce is disabled because this process group" + " spans across nodes.") + return + + rank = dist.get_rank(group=self.group) + self.rank = rank + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.", + world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(cuda_device_count_stateless())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported + # this checks hardware and driver support for NVLink + assert current_platform.is_cuda_alike() + fully_connected = current_platform.is_fully_connected( + physical_device_ids) + if world_size > 2 and not fully_connected: + logger.warning( + "Custom allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_reduce=True explicitly.") + return + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + # On AMD GPU, p2p is always enabled between XGMI connected GPUs + if not current_platform.is_rocm() and not _can_p2p(rank, world_size): + logger.warning( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.") + return + + self.disabled = False + # Buffers memory are owned by this Python class and passed to C++. + # Meta data composes of two parts: meta data for synchronization and a + # temporary buffer for storing intermediate allreduce results. + self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, + group=group, + uncached=True) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty(8 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + self.max_size = max_size + self.rank = rank + self.world_size = world_size + self.fully_connected = fully_connected + self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, + self.fully_connected) + ops.register_buffer(self._ptr, self.buffer_ptrs) + + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + + def register_graph_buffers(self): + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + logger.info("Registering %d cuda graph addresses", len(offset)) + # We cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + all_data = [[None, None] + for _ in range(dist.get_world_size(group=self.group))] + all_data[self.rank] = [handle, offset] + ranks = sorted(dist.get_process_group_ranks(group=self.group)) + for i, rank in enumerate(ranks): + dist.broadcast_object_list(all_data[i], + src=rank, + group=self.group, + device="cpu") + # Unpack list of tuples to tuple of lists. + handles = [d[0] for d in all_data] # type: ignore + offsets = [d[1] for d in all_data] # type: ignore + ops.register_graph_buffers(self._ptr, handles, offsets) + + def should_custom_ar(self, inp: torch.Tensor): + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + if self.world_size == 2 or self.fully_connected: + return inp_size < self.max_size + return False + + def all_reduce(self, + inp: torch.Tensor, + *, + out: torch.Tensor = None, + registered: bool = False): + """Performs an out-of-place all reduce. + + If registered is True, this assumes inp's pointer is already + IPC-registered. Otherwise, inp is first copied into a pre-registered + buffer. + """ + if out is None: + out = torch.empty_like(inp) + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) + else: + ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], + self.max_size) + return out + + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + """The main allreduce API that provides support for cuda graph.""" + # When custom allreduce is disabled, this will be None. + if self.disabled or not self.should_custom_ar(input): + return None + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + return self.all_reduce(input, registered=True) + else: + # If warm up, mimic the allocation pattern since custom + # allreduce is out-of-place. + return torch.empty_like(input) + else: + # Note: outside of cuda graph context, custom allreduce incurs a + # cost of cudaMemcpy, which should be small (<=1% of overall + # latency) compared to the performance gain of using custom kernels + return self.all_reduce(input, registered=False) + + def close(self): + if not self.disabled and self._ptr: + if ops is not None: + ops.dispose(self._ptr) + self._ptr = 0 + self.free_shared_buffer(self.meta_ptrs, rank=self.rank) + self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) + + def __del__(self): + self.close() + + @staticmethod + def create_shared_buffer(size_in_bytes: int, + group: Optional[ProcessGroup] = None, + uncached: Optional[bool] = False) -> list[int]: + pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes) + + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=group) + + pointers: list[int] = [] + for i, h in enumerate(handles): + if i == rank: + pointers.append(pointer) # type: ignore + else: + pointers.append(ops.open_mem_handle(h)) + return pointers + + @staticmethod + def free_shared_buffer(pointers: list[int], + group: Optional[ProcessGroup] = None, + rank: Optional[int] = 0) -> None: + if rank is None: + rank = dist.get_rank(group=group) + if ops is not None: + ops.free_shared_buffer(pointers[rank]) diff --git a/distributed/device_communicators/custom_all_reduce_utils.py b/distributed/device_communicators/custom_all_reduce_utils.py new file mode 100644 index 0000000..7c6001e --- /dev/null +++ b/distributed/device_communicators/custom_all_reduce_utils.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ctypes +import json +import os +import pickle +import subprocess +import sys +import tempfile +from collections.abc import Sequence +from itertools import product +from typing import Optional + +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary +from vllm.logger import init_logger +from vllm.utils import (cuda_device_count_stateless, + update_environment_variables) + +logger = init_logger(__name__) + + +def producer(batch_src: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for i in batch_src: + lib.cudaSetDevice(i) + pointer = lib.cudaMalloc(1024) + lib.cudaMemset(pointer, 1, 1024) + lib.cudaDeviceSynchronize() + handle = lib.cudaIpcGetMemHandle(pointer) + producer_queue.put(handle) + open_success = consumer_queue.get() + if open_success: + # use two queues to simulate barrier + producer_queue.put(0) + consumer_queue.get() + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def consumer(batch_tgt: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for j in batch_tgt: + lib.cudaSetDevice(j) + handle = producer_queue.get() + open_success = False + try: + pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore + open_success = True + except RuntimeError: + # cannot error out here, because the producer process + # is still waiting for the response. + pass + consumer_queue.put(open_success) + if open_success: + # modify the memory + lib.cudaMemset(pointer, 2, 1024) + lib.cudaDeviceSynchronize() + # use two queues to simulate barrier + producer_queue.get() + consumer_queue.put(0) + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def can_actually_p2p( + batch_src: Sequence[int], + batch_tgt: Sequence[int], +) -> Sequence[bool]: + """ + Usually, checking if P2P access is enabled can be done by + `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes + the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)` + returns `True` even if P2P access is not actually possible. + See https://github.com/vllm-project/vllm/issues/2728 and + https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 + Therefore, we have to perform a real P2P access to check if it is actually + possible. + + Note on p2p and cuda IPC: + Usually, one process uses one GPU: + GPU src --> cuda context src --> tensor src --> process src + + We need to combine p2p and cuda IPC, so that: + GPU src --> cuda context src --> tensor src --> process src + |shared| + GPU tgt --> cuda context tgt --> tensor tgt --> process tgt + That is to say, process src creates a tensor in GPU src, passes IPC handle to + process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the + tensor in process tgt will be reflected in the tensor in process src, because + they are the same memory segment. + It is important to note that process tgt accesses the tensor in GPU tgt, not + GPU src. That's why we need p2p access. + + The most time-consuming part is the process creation. To avoid creating + processes for every pair of GPUs, we use batched testing. We create two + processes for testing all pairs of GPUs in batch. The trick is to reset + the device after each test (which is not available in PyTorch). + """ # noqa + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + # pass the CUDA_VISIBLE_DEVICES to the child process + # to make sure they see the same set of GPUs + + # make sure the processes are spawned + smp = mp.get_context("spawn") + producer_queue = smp.Queue() + consumer_queue = smp.Queue() + result_queue = smp.Queue() + p_src = smp.Process(target=producer, + args=(batch_src, producer_queue, consumer_queue, + result_queue, cuda_visible_devices)) + p_tgt = smp.Process(target=consumer, + args=(batch_tgt, producer_queue, consumer_queue, + result_queue, cuda_visible_devices)) + p_src.start() + p_tgt.start() + p_src.join() + p_tgt.join() + assert p_src.exitcode == 0 and p_tgt.exitcode == 0 + result: list[bool] = [] + for src, tgt in zip(batch_src, batch_tgt): + a = result_queue.get() + b = result_queue.get() + if a != b: + logger.warning( + "Two processes do not agree on the P2P access" + " status on %d -> %d, treat as disabled.", src, tgt) + result.append(False) + else: + result.append(a) + return result + + +# why do we need this cache? +# we are testing peer-to-peer (p2p) access between GPUs,across processes. +# if we test it every time, it will be very slow, because we need to create +# N * N * 2 processes, where N is the world size. This is very slow. +# to reduce the time, we use a cache file to store the p2p access status. +# the cache file is generated by the master process if it does not exist. +# then all the processes can read the cache file to check the p2p access status. +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, +# e.g. used by different vllm engines. The device id in the cache file is a +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number +# of visible devices in the vllm engine. +_gpu_p2p_access_cache: Optional[dict[str, bool]] = None + + +def gpu_p2p_access_check(src: int, tgt: int) -> bool: + """Check if GPU src can access GPU tgt.""" + + # if the cache variable is already calculated, + # read from the cache instead of checking it again + global _gpu_p2p_access_cache + if _gpu_p2p_access_cache is not None: + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + is_distributed = dist.is_initialized() + + num_dev = cuda_device_count_stateless() + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices is None: + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + + path = os.path.join( + envs.VLLM_CACHE_ROOT, + f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + from vllm.distributed.parallel_state import get_world_group + if ((not is_distributed or get_world_group().local_rank == 0) + and (not os.path.exists(path))): + # only the local master process (with local_rank == 0) can + # enter this block to calculate the cache + logger.info("generating GPU P2P access cache in %s", path) + cache: dict[str, bool] = {} + ids = list(range(num_dev)) + # batch of all pairs of GPUs + batch_src, batch_tgt = zip(*list(product(ids, ids))) + # NOTE: we use `subprocess` rather than `multiprocessing` here + # because the caller might not have `if __name__ == "__main__":`, + # in that case we cannot use spawn method in multiprocessing. + # However, `can_actually_p2p` requires spawn method. + # The fix is, we use `subprocess` to call the function, + # where we have `if __name__ == "__main__":` in this file. + + # use a temporary file to store the result + # we don't use the output of the subprocess directly, + # because the subprocess might produce logging output + with tempfile.NamedTemporaryFile() as output_file: + input_bytes = pickle.dumps( + (batch_src, batch_tgt, output_file.name)) + returned = subprocess.run([sys.executable, __file__], + input=input_bytes, + capture_output=True) + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError( + f"Error happened when batch testing " + f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" + f"{returned.stderr.decode()}") from e + with open(output_file.name, "rb") as f: + result = pickle.load(f) + for _i, _j, r in zip(batch_src, batch_tgt, result): + cache[f"{_i}->{_j}"] = r + with open(path, "w") as f: + json.dump(cache, f, indent=4) + if is_distributed: + get_world_group().barrier() + logger.info("reading GPU P2P access cache from %s", path) + with open(path) as f: + cache = json.load(f) + _gpu_p2p_access_cache = cache + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + +__all__ = ["gpu_p2p_access_check"] + +if __name__ == "__main__": + batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read()) + result = can_actually_p2p(batch_src, batch_tgt) + with open(output_file, "wb") as f: + f.write(pickle.dumps(result)) diff --git a/distributed/device_communicators/hpu_communicator.py b/distributed/device_communicators/hpu_communicator.py new file mode 100644 index 0000000..f00f6b6 --- /dev/null +++ b/distributed/device_communicators/hpu_communicator.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.distributed as dist + +from vllm.platforms import current_platform + +from .base_device_communicator import DeviceCommunicatorBase + +if current_platform.is_hpu(): + import habana_frameworks.torch as htorch # noqa: F401 + + +class HpuCommunicator(DeviceCommunicatorBase): + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge + # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used + # (which is required for tensor parallel HPUGraph inference) + htorch.core.mark_step() + dist.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + htorch.core.mark_step() + dist.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor diff --git a/distributed/device_communicators/neuron_communicator.py b/distributed/device_communicators/neuron_communicator.py new file mode 100644 index 0000000..5b61a16 --- /dev/null +++ b/distributed/device_communicators/neuron_communicator.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase) +from vllm.platforms import current_platform + +if current_platform.is_neuron(): + import torch_xla.core.xla_model as xm + + +class NeuronCommunicator(DeviceCommunicatorBase): + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return xm.all_reduce(xm.REDUCE_SUM, x) + + def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + assert dim == -1, "Neuron only supports dim=-1 for all-gather." + return xm.all_gather(x, dim=dim) diff --git a/distributed/device_communicators/pynccl.py b/distributed/device_communicators/pynccl.py new file mode 100644 index 0000000..2948629 --- /dev/null +++ b/distributed/device_communicators/pynccl.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, + ncclRedOpTypeEnum, ncclUniqueId) +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger +from vllm.utils import current_stream + +logger = init_logger(__name__) + + +class PyNcclCommunicator: + + def __init__( + self, + group: Union[ProcessGroup, StatelessProcessGroup], + device: Union[int, str, torch.device], + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + if not isinstance(group, StatelessProcessGroup): + assert dist.is_initialized() + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "PyNcclCommunicator should be attached to a non-NCCL group.") + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + return + + self.available = True + self.disabled = False + + logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) + + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + + if not isinstance(group, StatelessProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + self.unique_id = group.broadcast_obj(self.unique_id, src=0) + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank) + + stream = current_stream() + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + stream.synchronize() + del data + + def all_reduce(self, + in_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None) -> torch.Tensor: + if self.disabled: + return None + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert in_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}") + + out_tensor = torch.empty_like(in_tensor) + + if stream is None: + stream = current_stream() + self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + return out_tensor + + def all_gather(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, + cudaStream_t(stream.cuda_stream)) + + def reduce_scatter(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + self.comm, cudaStream_t(stream.cuda_stream)) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) diff --git a/distributed/device_communicators/pynccl_wrapper.py b/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000..718d7ea --- /dev/null +++ b/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,341 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger +from vllm.utils import find_nccl_library + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: list[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("mcclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("mcclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("mcclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("mcclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("mcclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("mcclAllGather", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("mcclReduceScatter", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("mcclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("mcclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function("mcclBroadcast", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ctypes.c_int, ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("mcclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: dict[str, dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s. " + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s. " + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["mcclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"MCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["mcclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["mcclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["mcclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["mcclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["mcclReduceScatter"](sendbuff, recvbuff, + count, datatype, op, + comm, stream)) + + def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["mcclAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) + + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["mcclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["mcclRecv"](recvbuff, count, datatype, src, + comm, stream)) + + def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, root: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["mcclBroadcast"](sendbuff, recvbuff, count, + datatype, root, comm, + stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["mcclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] diff --git a/distributed/device_communicators/shm_broadcast.py b/distributed/device_communicators/shm_broadcast.py new file mode 100644 index 0000000..c781004 --- /dev/null +++ b/distributed/device_communicators/shm_broadcast.py @@ -0,0 +1,585 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pickle +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from multiprocessing import shared_memory +from threading import Event +from typing import Any, Optional, Union +from unittest.mock import patch + +import torch +import torch.distributed as dist +import zmq +from torch.distributed import ProcessGroup +from zmq import IPV6 # type: ignore +from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore + +import vllm.envs as envs +from vllm.distributed.utils import StatelessProcessGroup, sched_yield +from vllm.logger import init_logger +from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path, + is_valid_ipv6_address) + +VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL + +logger = init_logger(__name__) + + +class SpinTimer: + + def record_activity(self): + pass + + def spin(self): + sched_yield() + + +class SpinSleepTimer(SpinTimer): + """ + In setups which have long inactivity periods it is desirable to reduce + system power consumption when vllm does nothing. This would lead to more + CPU thermal headroom when a request eventually comes, especially when + multiple GPUs are connected as each GPU would otherwise pin one thread at + 100% CPU usage. + + The simplest solution is to reduce polling frequency when there is no + activity for a certain period of time. + """ + + def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1): + self.last_activity = time.monotonic() + self.busy_loop_s = busy_loop_s + self.wait_sleep_s = wait_sleep_s + + def record_activity(self): + self.last_activity = time.monotonic() + + def spin(self): + curr_time = time.monotonic() + if curr_time >= self.last_activity + self.busy_loop_s: + time.sleep(self.wait_sleep_s) + else: + sched_yield() + + +class ShmRingBuffer: + + def __init__(self, + n_reader: int, + max_chunk_bytes: int, + max_chunks: int, + name: Optional[str] = None): + """ + A shared memory ring buffer implementation for broadcast communication. + Essentially, it is a queue where only one will `enqueue` and multiple + will `dequeue`. The max size of each item, together with the max number + of items that can be stored in the buffer are known in advance. + In this case, we don't need to synchronize the access to + the buffer. + + Buffer memory layout: + data metadata + | | + | (current_idx) | (current_idx) + v v + +-------------------------------+----------------------------------------+ + | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata | + +-------------------------------+----------------------------------------+ + | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes | + + metadata memory layout: each byte is a flag, the first byte is the written + flag, and the rest are reader flags. The flags are set to 0 by default. + +--------------+--------------+--------------+-----+--------------+ + | written_flag | reader0_flag | reader1_flag | ... | readerN_flag | + +--------------+--------------+--------------+-----+--------------+ + + The state of metadata is as follows: + + (case 1) 0???...???: the block is not written yet, cannot read, can write + (case 2) 1000...000: the block is just written, can read, cannot write + (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write + (case 4) 1111...111: the block is written and read by all readers, cannot read, can write + + State transition for readers: + + When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read. + Only after the caller finishes reading the block, the reader can mark the block as read. + Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0). + + State transition for writer: + + When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case + to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer + can reset the reader flags to 0, and mark the block as written (from 0 to 1). + NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct. + + During creation, `name` is None and the buffer is created. We can pass the + created object to other processes by pickling it. The other processes will + get the name of the shared memory and open it, so that they can access the + same shared memory buffer. + """# noqa + self.n_reader = n_reader + self.metadata_size = 1 + n_reader + self.max_chunk_bytes = max_chunk_bytes + self.max_chunks = max_chunks + self.total_bytes_of_buffer = (self.max_chunk_bytes + + self.metadata_size) * self.max_chunks + self.data_offset = 0 + self.metadata_offset = self.max_chunk_bytes * self.max_chunks + + if name is None: + # we are creating a buffer + self.is_creator = True + self.shared_memory = shared_memory.SharedMemory( + create=True, size=self.total_bytes_of_buffer) + # initialize the metadata section to 0 + with memoryview(self.shared_memory.buf[self.metadata_offset:] + ) as metadata_buffer: + torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0) + else: + # we are opening an existing buffer + self.is_creator = False + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + try: + self.shared_memory = shared_memory.SharedMemory(name=name) + # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa + # Some platforms allocate memory based on page size, + # so the shared memory block size may be larger or equal + # to the requested size. The size parameter is ignored + # when attaching to an existing block. + assert (self.shared_memory.size + >= self.total_bytes_of_buffer) + except FileNotFoundError: + # we might deserialize the object in a different node + # in this case, this object is not used, + # and we should suppress the error + pass + + def handle(self): + return (self.n_reader, self.max_chunk_bytes, self.max_chunks, + self.shared_memory.name) + + def __reduce__(self): + return ( + self.__class__, + self.handle(), + ) + + def __del__(self): + if hasattr(self, "shared_memory"): + self.shared_memory.close() + if self.is_creator: + self.shared_memory.unlink() + + @contextmanager + def get_data(self, current_idx: int): + start = self.data_offset + current_idx * self.max_chunk_bytes + end = start + self.max_chunk_bytes + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + @contextmanager + def get_metadata(self, current_idx: int): + start = self.metadata_offset + current_idx * self.metadata_size + end = start + self.metadata_size + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + +@dataclass +class Handle: + local_reader_ranks: list[int] = field(default_factory=list) + + buffer_handle: Optional[tuple[int, int, int, str]] = None + local_subscribe_addr: Optional[str] = None + remote_subscribe_addr: Optional[str] = None + remote_addr_ipv6: bool = False + + +class MessageQueue: + + def __init__( + self, + n_reader, # number of all readers + n_local_reader, # number of local readers through shared memory + local_reader_ranks: Optional[list[int]] = None, + max_chunk_bytes: int = 1024 * 1024 * 10, + max_chunks: int = 10, + connect_ip: Optional[str] = None, + ): + if local_reader_ranks is None: + local_reader_ranks = list(range(n_local_reader)) + else: + assert len(local_reader_ranks) == n_local_reader + self.n_local_reader = n_local_reader + n_remote_reader = n_reader - n_local_reader + self.n_remote_reader = n_remote_reader + + context = Context() + + if n_local_reader > 0: + # for local readers, we will: + # 1. create a shared memory ring buffer to communicate small data + # 2. create a publish-subscribe socket to communicate large data + self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, + max_chunks) + + # XPUB is very similar to PUB, + # except that it can receive subscription messages + # to confirm the number of subscribers + self.local_socket = context.socket(XPUB) + # set the verbose option so that we can receive every subscription + # message. otherwise, we will only receive the first subscription + # see http://api.zeromq.org/3-3:zmq-setsockopt for more details + self.local_socket.setsockopt(XPUB_VERBOSE, True) + local_subscribe_addr = get_open_zmq_ipc_path() + logger.debug("Binding to %s", local_subscribe_addr) + self.local_socket.bind(local_subscribe_addr) + + self.current_idx = 0 + else: + self.buffer = None # type: ignore + local_subscribe_addr = None + self.local_socket = None + self.current_idx = -1 + + remote_addr_ipv6 = False + if n_remote_reader > 0: + # for remote readers, we will: + # create a publish-subscribe socket to communicate large data + if not connect_ip: + connect_ip = get_ip() + self.remote_socket = context.socket(XPUB) + self.remote_socket.setsockopt(XPUB_VERBOSE, True) + remote_subscribe_port = get_open_port() + if is_valid_ipv6_address(connect_ip): + self.remote_socket.setsockopt(IPV6, 1) + remote_addr_ipv6 = True + connect_ip = f"[{connect_ip}]" + socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" + self.remote_socket.bind(socket_addr) + remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" + else: + remote_subscribe_addr = None + self.remote_socket = None + + self._is_writer = True + self._is_local_reader = False + self.local_reader_rank = -1 + # rank does not matter for remote readers + self._is_remote_reader = False + self._read_spin_timer = SpinTimer() + + self.handle = Handle( + local_reader_ranks=local_reader_ranks, + buffer_handle=self.buffer.handle() + if self.buffer is not None else None, + local_subscribe_addr=local_subscribe_addr, + remote_subscribe_addr=remote_subscribe_addr, + remote_addr_ipv6=remote_addr_ipv6, + ) + + logger.info("vLLM message queue communication handle: %s", self.handle) + + def export_handle(self) -> Handle: + return self.handle + + @staticmethod + def create_from_handle(handle: Handle, rank) -> "MessageQueue": + self = MessageQueue.__new__(MessageQueue) + self.handle = handle + self._is_writer = False + + context = Context() + + if rank in handle.local_reader_ranks: + assert handle.buffer_handle is not None + self.buffer = ShmRingBuffer(*handle.buffer_handle) + self.current_idx = 0 + self.local_reader_rank = handle.local_reader_ranks.index(rank) + self._is_local_reader = True + self._is_remote_reader = False + + self.local_socket = context.socket(SUB) + self.local_socket.setsockopt_string(SUBSCRIBE, "") + socket_addr = handle.local_subscribe_addr + logger.debug("Connecting to %s", socket_addr) + self.local_socket.connect(socket_addr) + + self.remote_socket = None + + self._read_spin_timer = SpinSleepTimer( + ) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer() + else: + self.buffer = None # type: ignore + self.current_idx = -1 + self.local_reader_rank = -1 + self._is_local_reader = False + self._is_remote_reader = True + + self.local_socket = None + + self.remote_socket = context.socket(SUB) + self.remote_socket.setsockopt_string(SUBSCRIBE, "") + if handle.remote_addr_ipv6: + self.remote_socket.setsockopt(IPV6, 1) + socket_addr = handle.remote_subscribe_addr + logger.debug("Connecting to %s", socket_addr) + self.remote_socket.connect(socket_addr) + + return self + + def wait_until_ready(self): + """This is a collective operation. All processes (including the + readers and the writer) should call this function. + """ + if self._is_writer: + # wait for all readers to connect + + # local readers + for i in range(self.n_local_reader): + # wait for subscription messages from all local readers + self.local_socket.recv() + if self.n_local_reader > 0: + # send a message to all local readers + # to make sure the publish channel is working + self.local_socket.send(b"READY") + + # remote readers + for i in range(self.n_remote_reader): + # wait for subscription messages from all remote readers + self.remote_socket.recv() + if self.n_remote_reader > 0: + # send a message to all remote readers + # to make sure the publish channel is working + self.remote_socket.send(b"READY") + elif self._is_local_reader: + # wait for the writer to send a message + recv = self.local_socket.recv() + assert recv == b"READY" + elif self._is_remote_reader: + # wait for the writer to send a message + recv = self.remote_socket.recv() + assert recv == b"READY" + + @contextmanager + def acquire_write(self, timeout: Optional[float] = None): + assert self._is_writer, "Only writers can acquire write" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_count = sum(metadata_buffer[1:]) + written_flag = metadata_buffer[0] + if written_flag and read_count != self.buffer.n_reader: + # this block is written and not read by all readers + # for writers, `self.current_idx` is the next block to write + # if this block is not ready to write, + # we need to wait until it is read by all readers + + # Release the processor to other threads + sched_yield() + + # if we wait for a long time, log a message + if (time.monotonic() - start_time + > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): + logger.debug( + ("No available shared memory broadcast block found" + " in %s second."), + VLLM_RINGBUFFER_WARNING_INTERVAL, + ) + n_warning += 1 + + # if we time out, raise an exception + if (timeout is not None + and time.monotonic() - start_time > timeout): + raise TimeoutError + + continue + # found a block that is either + # (1) not written + # (2) read by all readers + + # mark the block as not written + metadata_buffer[0] = 0 + # let caller write to the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has written to the buffer + # NOTE: order is important here + # first set the read flags to 0 + # then set the written flag to 1 + # otherwise, the readers may think they already read the block + for i in range(1, self.buffer.n_reader + 1): + # set read flag to 0, meaning it is not read yet + metadata_buffer[i] = 0 + # mark the block as written + metadata_buffer[0] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks + break + + @contextmanager + def acquire_read(self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None): + assert self._is_local_reader, "Only readers can acquire read" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_flag = metadata_buffer[self.local_reader_rank + 1] + written_flag = metadata_buffer[0] + if not written_flag or read_flag: + # this block is either + # (1) not written + # (2) already read by this reader + + # for readers, `self.current_idx` is the next block to read + # if this block is not ready, + # we need to wait until it is written + + # Release the processor to other threads + self._read_spin_timer.spin() + + # if we wait for a long time, log a message + if (time.monotonic() - start_time + > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): + logger.debug( + ("No available shared memory broadcast block found" + " in %s second."), + VLLM_RINGBUFFER_WARNING_INTERVAL, + ) + n_warning += 1 + + if cancel is not None and cancel.is_set(): + raise RuntimeError("cancelled") + + # if we time out, raise an exception + if (timeout is not None + and time.monotonic() - start_time > timeout): + raise TimeoutError + + continue + # found a block that is not read by this reader + # let caller read from the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has read from the buffer + # set the read flag + metadata_buffer[self.local_reader_rank + 1] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks + + self._read_spin_timer.record_activity() + break + + def enqueue(self, obj, timeout: Optional[float] = None): + """ Write to message queue with optional timeout (in seconds) """ + assert self._is_writer, "Only writers can enqueue" + serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + if self.n_local_reader > 0: + if len(serialized_obj) >= self.buffer.max_chunk_bytes: + with self.acquire_write(timeout) as buf: + buf[0] = 1 # overflow + self.local_socket.send(serialized_obj) + else: + with self.acquire_write(timeout) as buf: + buf[0] = 0 # not overflow + buf[1:len(serialized_obj) + 1] = serialized_obj + if self.n_remote_reader > 0: + self.remote_socket.send(serialized_obj) + + def dequeue(self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None): + """ Read from message queue with optional timeout (in seconds) """ + if self._is_local_reader: + with self.acquire_read(timeout, cancel) as buf: + overflow = buf[0] == 1 + if not overflow: + # no need to know the size of serialized object + # pickle format contains the size information internally + # see https://docs.python.org/3/library/pickle.html + obj = pickle.loads(buf[1:]) + if overflow: + obj = MessageQueue.recv(self.local_socket, timeout) + elif self._is_remote_reader: + obj = MessageQueue.recv(self.remote_socket, timeout) + else: + raise RuntimeError("Only readers can dequeue") + return obj + + @staticmethod + def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any: + timeout_ms = None if timeout is None else int(timeout * 1000) + if not socket.poll(timeout=timeout_ms): + raise TimeoutError + recv = socket.recv(copy=False) + return pickle.loads(recv.buffer) + + def broadcast_object(self, obj=None): + if self._is_writer: + self.enqueue(obj) + return obj + else: + return self.dequeue() + + @staticmethod + def create_from_process_group(pg: Union[ProcessGroup, + StatelessProcessGroup], + max_chunk_bytes, + max_chunks, + writer_rank=0) -> "MessageQueue": + if isinstance(pg, ProcessGroup): + group_rank = dist.get_rank(pg) + group_world_size = dist.get_world_size(pg) + global_ranks = dist.get_process_group_ranks(pg) + else: + group_rank = pg.rank + group_world_size = pg.world_size + global_ranks = list(range(pg.world_size)) + + from vllm.distributed.parallel_state import in_the_same_node_as + status = in_the_same_node_as(pg, source_rank=writer_rank) + same_node_ranks = [i for i, s in enumerate(status) if s] + n_reader = group_world_size - 1 + n_local_reader = len(same_node_ranks) - 1 + local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] + buffer_io: MessageQueue + if group_rank == writer_rank: + buffer_io = MessageQueue( + n_reader=n_reader, + n_local_reader=n_local_reader, + local_reader_ranks=local_reader_ranks, + max_chunk_bytes=max_chunk_bytes, + max_chunks=max_chunks, + ) + handle = buffer_io.export_handle() + if isinstance(pg, ProcessGroup): + dist.broadcast_object_list([handle], + src=global_ranks[writer_rank], + group=pg) + else: + pg.broadcast_obj(handle, writer_rank) + else: + if isinstance(pg, ProcessGroup): + recv = [None] + dist.broadcast_object_list(recv, + src=global_ranks[writer_rank], + group=pg) + handle = recv[0] # type: ignore + else: + handle = pg.broadcast_obj(None, writer_rank) + buffer_io = MessageQueue.create_from_handle(handle, group_rank) + buffer_io.wait_until_ready() + return buffer_io diff --git a/distributed/device_communicators/tpu_communicator.py b/distributed/device_communicators/tpu_communicator.py new file mode 100644 index 0000000..c60a7a7 --- /dev/null +++ b/distributed/device_communicators/tpu_communicator.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Optional + +import torch +from torch.distributed import ProcessGroup + +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .base_device_communicator import DeviceCommunicatorBase + +USE_RAY = parallel_config = get_current_vllm_config( +).parallel_config.distributed_executor_backend == "ray" + +logger = init_logger(__name__) + +if current_platform.is_tpu(): + import torch_xla + import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr + from torch_xla._internal import pjrt + from torch_xla.distributed.xla_multiprocessing import ( + create_optimized_replica_groups) + + if USE_RAY: + from vllm.executor import ray_utils + + +class TpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + + # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node + # must be used together. Therefore, the local rank and world size can + # be simply calculated as follows. + global_rank = self.global_rank + global_world_size = self.global_world_size + + if USE_RAY: + logger.info("TpuCommunicator initialized with RAY") + # Calculate how many TPU nodes are in the current deployment. This + # is the Ray placement group if it is deployed with Ray. Default + # to the number of TPU nodes in the Ray cluster. The number of TPU + # nodes is computed by the total number of TPUs divided by the + # number of TPU accelerators per node, to account for clusters + # with both CPUs and TPUs. + num_nodes = ray_utils.get_num_tpu_nodes() + num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() + if num_nodes_in_pg > 0: + num_nodes = num_nodes_in_pg + + local_world_size = global_world_size // num_nodes + local_rank = global_rank % local_world_size + else: + logger.info("TpuCommunicator initialized with MP") + # Sanity: Verify we run on a single host + num_hosts = torch_xla.tpu.num_tpu_workers() + assert num_hosts == 1 + + # Get the current number of TPUs (we have locally) + local_world_size = torch_xla.tpu.num_available_chips() + + # Get current rank + local_rank = global_rank % local_world_size + + # Ensure environment variables are set for multihost deployments. + # On GKE, this is needed for libtpu and TPU driver to know which TPU + # chip is actually visible. Otherwise the TPU driver will fail to + # initialize because the number of devices would be different from + # the number of visible worker addresses. + os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) + os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) + + pjrt.initialize_multiprocess(local_rank, local_world_size) + xr._init_world_size_ordinal() + self.groups = create_optimized_replica_groups() + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + # TODO: Remove the groups specification after XLA compiler can support + # auto-reordering the ring order for all-reduce. + return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups) + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + assert dim == -1, "TPUs only support dim=-1 for all-gather." + return xm.all_gather(input_, dim=dim) + + +try: + from tpu_commons.distributed.device_communicators import ( + TpuCommunicator as TpuCommonsCommunicator) + TpuCommunicator = TpuCommonsCommunicator # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TpuCommunicator") + pass diff --git a/distributed/device_communicators/xpu_communicator.py b/distributed/device_communicators/xpu_communicator.py new file mode 100644 index 0000000..216ff85 --- /dev/null +++ b/distributed/device_communicators/xpu_communicator.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from .base_device_communicator import DeviceCommunicatorBase + + +class XpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + + def all_reduce(self, input_) -> torch.Tensor: + dist.all_reduce(input_, group=self.device_group) + return input_ + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # For xpu path, gather doesn't work properly together with ray + # cluster so we use all_gather instead for now. + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((self.world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + if self.rank_in_group == dst: + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + else: + output_tensor = None + return output_tensor diff --git a/distributed/kv_events.py b/distributed/kv_events.py new file mode 100644 index 0000000..2d79357 --- /dev/null +++ b/distributed/kv_events.py @@ -0,0 +1,356 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import queue +import threading +import time +from abc import ABC, abstractmethod +from collections import deque +from dataclasses import asdict +from itertools import count +from queue import Queue +from typing import Any, Callable, Optional, Union + +import msgspec +import zmq + +from vllm.config import KVEventsConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class EventBatch( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, # type: ignore[call-arg] +): + ts: float + events: list[Any] + data_parallel_rank: Optional[int] = None + + +class KVCacheEvent( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, # type: ignore[call-arg] + tag=True): + """Base class for all KV cache-related events""" + + +class BlockStored(KVCacheEvent): + block_hashes: list[int] + parent_block_hash: Optional[int] + token_ids: list[int] + block_size: int + lora_id: Optional[int] + + +class BlockRemoved(KVCacheEvent): + block_hashes: list[int] + + +class AllBlocksCleared(KVCacheEvent): + pass + + +class KVEventBatch(EventBatch): + events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]] + + +class EventPublisher(ABC): + """Lightweight publisher for EventBatch batches with data parallelism + support. + + In data parallel setups, each DP rank runs its own EventPublisher instance + to avoid duplicate events and ensure proper event attribution: + + - Each DP rank creates a separate publisher + - Publishers automatically annotate events with their data_parallel_rank + - This allows consumers to distinguish events from different DP ranks + + The publisher is responsible for adding DP metadata since the scheduler + operates independently of DP topology and shouldn't need DP awareness. + """ + + def __init__(self, data_parallel_rank: int = 0) -> None: + self._data_parallel_rank = data_parallel_rank + + @abstractmethod + def publish(self, events: EventBatch) -> None: + """Emit events in order. + + Implementations should guarantee at-least-once delivery and + monotonic ordering (e.g., via sequence numbers). + """ + + @abstractmethod + def shutdown(self) -> None: + """Shutdown the publisher.""" + + +class NullEventPublisher(EventPublisher): + """No-op implementation (default when disabled).""" + + def publish(self, events) -> None: + return + + def shutdown(self) -> None: + return + + +class ZmqEventPublisher(EventPublisher): + """Reliable PUB/ROUTER publisher with an in-memory replay buffer. + + Spawns a separate thread to handle publishing from a queue. + + Parameters + ---------- + endpoint: + PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to + connect. + replay_endpoint: + Optional ROUTER address for replay requests. When given, subscribers can + request missed batches by sending the starting sequence number as an + 8-byte big-endian integer. + buffer_steps: + Number of past batches to keep for replay. + hwm: + ZeroMQ high-water-mark for PUB socket. + max_queue_size: + Maximum number of events to buffer in memory. + topic: + Topic to publish events to. + """ + SHUTDOWN_TIMEOUT: float = 1.0 + END_SEQ = (-1).to_bytes(8, "big", signed=True) + + def __init__( + self, + data_parallel_rank: int, + endpoint: str = "tcp://*:5557", + replay_endpoint: Optional[str] = None, + buffer_steps: int = 10_000, + hwm: int = 100_000, + max_queue_size: int = 100_000, + topic: str = "", + ) -> None: + # Storage + super().__init__(data_parallel_rank) + self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size) + self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps) + + # ZMQ sockets + self._ctx = zmq.Context.instance() + self._pub: Optional[zmq.Socket] = None + self._replay: Optional[zmq.Socket] = None + self._dp_rank = data_parallel_rank + + self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank) + self._replay_endpoint = self.offset_endpoint_port( + replay_endpoint, self._dp_rank) + self._hwm = hwm + self._socket_setup() + + # Payload + self._seq_gen = count() + self._topic_bytes = topic.encode('utf-8') + + # Thread + self._running = True + logger.info("Starting ZMQ publisher thread") + + self._thread = threading.Thread(target=self._publisher_thread, + daemon=True, + name="zmq-publisher") + self._thread.start() + + def publish(self, events: EventBatch) -> None: + if not self._running: + raise RuntimeError("Publisher is closed") + if events.data_parallel_rank is None: + events.data_parallel_rank = self._data_parallel_rank + self._event_queue.put(events) + + def shutdown(self) -> None: + """Stop the publisher thread and clean up resources.""" + self._running = False + self._event_queue.put_nowait(None) + + start = time.time() + pending_items = True + while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT): + pending_items = not self._event_queue.empty() + if pending_items: + time.sleep(0.1) + + if pending_items: + logger.warning( + "Warning: Queue still has %s items after %s seconds timeout", + self._event_queue.qsize(), + self.SHUTDOWN_TIMEOUT, + ) + + if self._thread.is_alive(): + self._thread.join(timeout=self.SHUTDOWN_TIMEOUT) + + # Clean up ZMQ resources + try: + if self._pub is not None: + self._pub.close(linger=0) + if self._replay is not None: + self._replay.close(linger=0) + finally: + pass # Do not terminate context; other sockets may use it + + def _socket_setup(self) -> None: + """Initialize sockets + https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety + """ + if self._pub is None: + self._pub = self._ctx.socket(zmq.PUB) + self._pub.set_hwm(self._hwm) + # Heuristic: bind if wildcard / * present, else connect. + # bind stable, connect volatile convention + if (self._endpoint is not None + and ("*" in self._endpoint or "::" in self._endpoint + or self._endpoint.startswith("ipc://") + or self._endpoint.startswith("inproc://"))): + self._pub.bind(self._endpoint) + elif self._endpoint is not None: + self._pub.connect(self._endpoint) + + # Set up replay socket: use ROUTER + # 1) handles multiple REQ clients (identities) + # 2) lets us send back one request → many replies (streamed events) + # 3) works in our non‑blocking poll loop alongside PUB + if self._replay_endpoint is not None: + self._replay = self._ctx.socket(zmq.ROUTER) + self._replay.bind(self._replay_endpoint) + + def _publisher_thread(self) -> None: + """Background thread that processes the event queue.""" + self._pack = msgspec.msgpack.Encoder() + + assert self._pub is not None # narrows type for mypy + + while self._running or self._event_queue.qsize() > 0: + # --- replay (non-critical) --------------------------------- + if self._replay is not None and self._replay.poll(0): + try: + self._service_replay() + except Exception as e: + logger.exception("Error in replay: %s", e) + + # --- main queue (critical) --------------------------------- + try: + event = self._event_queue.get(timeout=0.1) + if event is None: + break # Sentinel received, exit thread + except queue.Empty: + continue + + try: + seq = next(self._seq_gen) + + payload = self._pack.encode(event) + seq_bytes = seq.to_bytes(8, "big") + self._pub.send_multipart( + (self._topic_bytes, seq_bytes, payload)) + + self._buffer.append((seq, payload)) + self._event_queue.task_done() + + except Exception as e: + # Publishing failed; back-off a bit to avoid a tight error loop + logger.exception("Error in publisher thread: %s", e) + time.sleep(0.1) + + def _service_replay(self) -> None: + """If a replay request is waiting, send buffered batches.""" + assert self._replay is not None # narrows type for mypy + + frame = self._replay.recv_multipart() + if len(frame) != 3: + logger.warning("Invalid replay request: %s", frame) + return + client_id, _, start_seq_bytes = frame + start_seq = int.from_bytes(start_seq_bytes, "big") + + for seq, buf in self._buffer: + if seq >= start_seq: + # [identity, empty_delim, seq_bytes, payload] + # (identity, empty_delim) are stripped off by the router + # receiving payload is (seq_bytes, payload) + self._replay.send_multipart( + (client_id, b"", seq.to_bytes(8, "big"), buf)) + # Send end of sequence marker + # receiving payload is (-1, b""") + self._replay.send_multipart((client_id, b"", self.END_SEQ, b"")) + + @staticmethod + def offset_endpoint_port(endpoint: Optional[str], + data_parallel_rank: int) -> Optional[str]: + """Helper function to offset the port in an endpoint by + the data parallel rank. + + Args: + endpoint: The endpoint string + (e.g., "tcp://*:5557" or "inproc://cache") + data_parallel_rank: The data parallel rank to offset by + + Returns: + The endpoint with the port offset by data_parallel_rank + or suffix appended + """ + # Do nothing if input is None or data_parallel_rank is 0 + if not endpoint or data_parallel_rank == 0: + return endpoint + + if "inproc" in endpoint: + return f"{endpoint}_dp{data_parallel_rank}" + if "tcp" in endpoint: + if endpoint and ":" in endpoint: + # Get everything after the last colon (the port) + last_colon_idx = endpoint.rfind(":") + base_addr = endpoint[:last_colon_idx] + base_port = int(endpoint[last_colon_idx + 1:]) + new_port = base_port + data_parallel_rank + return f"{base_addr}:{new_port}" + return endpoint + raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'") + + +class EventPublisherFactory: + _registry: dict[str, Callable[..., EventPublisher]] = { + "null": NullEventPublisher, + "zmq": ZmqEventPublisher, + } + + @classmethod + def register_publisher(cls, name: str, + ctor: Callable[..., EventPublisher]) -> None: + if name in cls._registry: + raise KeyError(f"publisher '{name}' already registered") + cls._registry[name] = ctor + + @classmethod + def create(cls, + config: Optional[KVEventsConfig], + data_parallel_rank: int = 0) -> EventPublisher: + """Create publisher from a config mapping.""" + if not config: + return NullEventPublisher() + + config_dict = asdict(config) + + kind = config_dict.pop("publisher", "null") + config_dict.pop("enable_kv_cache_events") + try: + constructor = cls._registry[kind] + except KeyError as exc: + raise ValueError(f"Unknown event publisher '{kind}'") from exc + return constructor(data_parallel_rank=data_parallel_rank, + **config_dict) diff --git a/distributed/kv_transfer/README.md b/distributed/kv_transfer/README.md new file mode 100644 index 0000000..349d3df --- /dev/null +++ b/distributed/kv_transfer/README.md @@ -0,0 +1,29 @@ + +# Distributed KV cache transfer + +This folder implements distributed KV cache transfer across vLLM instances. +Currently the main usecase is for disaggregated prefilling. + +## Abstractions + +The KV cache transfer contains three layer of abstractions: + +- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`. +- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics). +- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`. + +Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. + +NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed +communication service already supports key-value-based lookup (like redis or +RDMA database). + +NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates. + +## Disaggregated prefilling + +The example usage is in [this file](../../../examples/online_serving/disaggregated_prefill.sh). + +Here is the diagram of how we run disaggregated prefilling. + +![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg) diff --git a/distributed/kv_transfer/__init__.py b/distributed/kv_transfer/__init__.py new file mode 100644 index 0000000..fa9b7e4 --- /dev/null +++ b/distributed/kv_transfer/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.distributed.kv_transfer.kv_transfer_state import ( + KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group, + has_kv_transfer_group, is_v1_kv_transfer_group) + +__all__ = [ + "get_kv_transfer_group", "has_kv_transfer_group", + "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", + "KVConnectorBaseType" +] diff --git a/distributed/kv_transfer/disagg_prefill_workflow.jpg b/distributed/kv_transfer/disagg_prefill_workflow.jpg new file mode 100644 index 0000000..a25ec5e Binary files /dev/null and b/distributed/kv_transfer/disagg_prefill_workflow.jpg differ diff --git a/distributed/kv_transfer/kv_connector/__init__.py b/distributed/kv_transfer/kv_connector/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/distributed/kv_transfer/kv_connector/base.py b/distributed/kv_transfer/kv_connector/base.py new file mode 100644 index 0000000..181c339 --- /dev/null +++ b/distributed/kv_transfer/kv_connector/base.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +KVConnectorBase Class for Distributed KV Cache & Hidden State communication + +The class provides two primary abstract methods: +1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states +2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Union + +import torch + +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class KVConnectorBase(ABC): + """ + Abstract base class for a KV connector. + + The class provides two primary abstract methods: + 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states + 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states + """ + + @abstractmethod + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + connector when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + """ + Send KV caches and hidden states to the connector. + + This method processes the input tokens, KV caches, and + hidden/intermediate states for a given model and sends the data to the + decode instance. + + Args: + model_executable (torch.nn.Module): The model executable containing + start and end layer information. + model_input (ModelInputForGPUWithSamplingMetadata): The input + metadata from vLLM. + kv_caches (list[torch.Tensor]): List of KV caches (keys and values) + for each layer. + hidden_or_intermediate_states (Union[torch.Tensor, + IntermediateTensors]): + The hidden or intermediate states associated with the tokens. + + Returns: + None + + """ + + raise NotImplementedError + + @abstractmethod + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + """ + Receive KV caches and hidden states from the connector. + + This method attempts to retrieve KV caches and hidden states for input + tokens. If all required KV caches and hidden states are received, it + will bypass model input, else it will fall back to normal vLLM model + forwarding. + + Args: + model_executable (torch.nn.Module): + The model executable from vLLM modelrunner. + model_input (ModelInputForGPUWithSamplingMetadata): + The model input from vLLM modelrunner. + kv_caches (list[torch.Tensor]): + List of KV caches for each layer. + + Returns: + - hidden_or_intermediate_states (torch.Tensor or + IntermediateTensors): + Concatenated hidden states if all required data is retrieved, + otherwise `None`. + - bypass_model_exec (bool): + Indicates whether the model execution can be skipped (True) or + needs to be redone (False). + - model_input (ModelInputForGPUWithSamplingMetadata): + Optionally adjusted input metadata for re-execution when + `bypass_model_exec=False`. + + """ + + raise NotImplementedError + + +KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1] diff --git a/distributed/kv_transfer/kv_connector/factory.py b/distributed/kv_transfer/kv_connector/factory.py new file mode 100644 index 0000000..58dfa25 --- /dev/null +++ b/distributed/kv_transfer/kv_connector/factory.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import importlib +from typing import TYPE_CHECKING, Callable + +import vllm.envs as envs +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.logger import init_logger + +from .base import KVConnectorBase + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class KVConnectorFactory: + _registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, + class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> type[KVConnectorBaseType]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector_v0(cls, rank: int, local_rank: int, + config: "VllmConfig") -> KVConnectorBase: + if envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V0 Connector, " + f"but found {envs.VLLM_USE_V1=}") + + connector_name = config.kv_transfer_config.kv_connector + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + connector_cls = cls._registry[connector_name]() + assert issubclass(connector_cls, KVConnectorBase) + return connector_cls(rank, local_rank, config) + + @classmethod + def create_connector_v1( + cls, + config: "VllmConfig", + role: KVConnectorRole, + ) -> KVConnectorBase_V1: + if not envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V1 Connector, " + f"but found {envs.VLLM_USE_V1=}") + + kv_transfer_config = config.kv_transfer_config + connector_name = kv_transfer_config.kv_connector + if connector_name in cls._registry: + connector_cls = cls._registry[connector_name]() + else: + connector_module_path = kv_transfer_config.kv_connector_module_path + if connector_module_path is None: + raise ValueError( + f"Unsupported connector type: {connector_name}") + connector_module = importlib.import_module(connector_module_path) + connector_cls = getattr(connector_module, connector_name) + assert issubclass(connector_cls, KVConnectorBase_V1) + logger.info("Creating v1 connector with name: %s and engine_id: %s", + connector_name, kv_transfer_config.engine_id) + # NOTE(Kuntai): v1 connector is explicitly separated into two roles. + # Scheduler connector: + # - Co-locate with scheduler process + # - Should only be used inside the Scheduler class + # Worker connector: + # - Co-locate with worker process + # - Should only be used inside the forward context & attention layer + # We build separately to enforce strict separation + return connector_cls(config, role) + + +# Register various connectors here. +# The registration should not be done in each individual file, as we want to +# only load the files corresponding to the current connector. +KVConnectorFactory.register_connector( + "PyNcclConnector", + "vllm.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector") + +KVConnectorFactory.register_connector( + "MooncakeConnector", + "vllm.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnector", + "vllm.distributed.kv_transfer.kv_connector.lmcache_connector", + "LMCacheConnector") + +KVConnectorFactory.register_connector( + "MooncakeStoreConnector", + "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", + "MooncakeStoreConnector") + +KVConnectorFactory.register_connector( + "SharedStorageConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", + "SharedStorageConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnectorV1", + "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", + "LMCacheConnectorV1") + +KVConnectorFactory.register_connector( + "NixlConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", + "NixlConnector") + +KVConnectorFactory.register_connector( + "MultiConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", + "MultiConnector") diff --git a/distributed/kv_transfer/kv_connector/lmcache_connector.py b/distributed/kv_transfer/kv_connector/lmcache_connector.py new file mode 100644 index 0000000..78bf309 --- /dev/null +++ b/distributed/kv_transfer/kv_connector/lmcache_connector.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +LMCache KV Cache Connector for Distributed Machine Learning Inference + +The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker +(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache; +(2) offload and share KV caches. +""" + +from typing import TYPE_CHECKING, Union + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class LMCacheConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + + self.transfer_config = config.kv_transfer_config + self.vllm_config = config + + from lmcache.experimental.cache_engine import LMCacheEngineBuilder + from lmcache.integration.vllm.utils import ENGINE_NAME + from lmcache.integration.vllm.vllm_adapter import ( + RetrieveStatus, StoreStatus, init_lmcache_engine, + lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store, + lmcache_store_kv) + logger.info("Initializing LMCacheConfig under kv_transfer_config %s", + self.transfer_config) + + # TODO (Jiayi): Find model_config, parallel_config, and cache_config + self.engine = init_lmcache_engine(config.model_config, + config.parallel_config, + config.cache_config) + self.lmcache_engine_name = ENGINE_NAME + self.lmcache_engine_builder = LMCacheEngineBuilder + + self.model_config = config.model_config + self.parallel_config = config.parallel_config + self.cache_config = config.cache_config + self.lmcache_retrieve_kv = lmcache_retrieve_kv + self.lmcache_store_kv = lmcache_store_kv + self.lmcache_should_retrieve = lmcache_should_retrieve + self.lmcache_should_store = lmcache_should_store + self.store_status = StoreStatus + self.retrieve_status = RetrieveStatus + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + retrieve_status = self.lmcache_should_retrieve(model_input) + model_input, bypass_model_exec, hidden_or_intermediate_states =\ + self.lmcache_retrieve_kv( + model_executable, model_input, self.cache_config, kv_caches, + retrieve_status) + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + store_status = self.lmcache_should_store(model_input) + self.lmcache_store_kv( + self.model_config, + self.parallel_config, + self.cache_config, + model_executable, + model_input, + kv_caches, + store_status, + ) + + def close(self): + self.lmcache_engine_builder.destroy(self.lmcache_engine_name) diff --git a/distributed/kv_transfer/kv_connector/mooncake_store_connector.py b/distributed/kv_transfer/kv_connector/mooncake_store_connector.py new file mode 100644 index 0000000..94a7ce9 --- /dev/null +++ b/distributed/kv_transfer/kv_connector/mooncake_store_connector.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +MooncakeStore Connector for Distributed Machine Learning Inference +The MooncakeStoreConnector transfers KV caches between prefill vLLM workers +(KV cache producer) and decode vLLM workers (KV cache consumer) using a +database-style KVStore. +""" +import hashlib +from typing import TYPE_CHECKING, Union + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.utils import ( + model_aware_kv_ops_helper as kv_helper) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class MooncakeStoreConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + self.kv_transfer_config = config.kv_transfer_config + self.kv_helper = kv_helper(config) + self.local_tp_rank = local_rank + + # Init kv_store + if self.kv_transfer_config.kv_connector == "MooncakeStoreConnector": + # Check if MOONCAKE_CONFIG_PATH is set + import os + use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None + + if not use_mooncake_store: + raise ValueError( + "To use MooncakeStoreConnector, you need to pass the ENV: " + "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") + else: + from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501 + MooncakeStore) + logger.info( + "Initializing KVStoreConnector under kv_transfer_config %s", + self.kv_transfer_config) + self.kv_store = MooncakeStore(config) + else: + logger.error("Can not find %s", + self.kv_transfer_config.kv_connector) + + assert self.kv_store is not None + + def close(self) -> None: + """Close the buffer and release resources. + This method is responsible for cleaning up resources related to the + connector when it is no longer needed. + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + self.kv_store.close() + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + num_heads, head_size = self.kv_helper.get_model_args(model_executable) + + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + current_tokens = input_tokens_tensor[start_pos:end_pos] + store_key_prefix = self.tensor_hash(current_tokens) + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + key_cache, value_cache = self.kv_helper.get_kv_from_cache( + kv_cache, num_heads, head_size) + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + kvcache_to_sent = torch.stack((keys, values), dim=0) + store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}" + self.kv_store.put(store_kvcache_key, kvcache_to_sent) + + hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}" + self.kv_store.put(hidden_key, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + bypass_model_exec = True + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + hidden_or_intermediate_states_for_one_req = [] + + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + + # get roi for current seq + load_key_prefix = self.tensor_hash(current_tokens) + load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}" + remote_kv = self.kv_store.get(load_kvcache_key) + hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}" + hidden = self.kv_store.get(hidden_key) + + if remote_kv is None or hidden is None: + # didn't find any match. + bypass_model_exec = False + continue + + num_computed_tokens = current_tokens.shape[0] + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # call self.kv_store to get kv layer by layer + for layer_id in range(start_layer, end_layer): + layer = model_executable.model.layers[layer_id] + # get kvcache object + kv_cache = kv_caches[layer_id - start_layer] + + # get remote kvcache + remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][ + layer_id] + + self.kv_helper.put_kv_to_cache(model_executable, remote_k, + remote_v, layer, kv_cache, + slot_mapping, start_pos, + end_pos) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + logger.warning( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + @staticmethod + def tensor_hash(tensor: torch.Tensor) -> int: + """Calculate the hash value of the tensor.""" + tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes() + hash_object = hashlib.blake2b(tensor_bytes) + hash_hex = hash_object.hexdigest() + return int(hash_hex[:16], 16) diff --git a/distributed/kv_transfer/kv_connector/simple_connector.py b/distributed/kv_transfer/kv_connector/simple_connector.py new file mode 100644 index 0000000..e7c079e --- /dev/null +++ b/distributed/kv_transfer/kv_connector/simple_connector.py @@ -0,0 +1,329 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Simple KV Cache Connector for Distributed Machine Learning Inference + +The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache +producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or +MooncakePipe. + +But the logic can be extended to support other pipe and lookup buffer. +""" +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.utils import ( + model_aware_kv_ops_helper as kv_helper) +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class SimpleConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + + self.config = config.kv_transfer_config + self.kv_helper = kv_helper(config) + + if self.config.kv_connector == "PyNcclConnector": + from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( + PyNcclPipe) + logger.info( + "Initializing PyNcclConfig under kv_transfer_config %s", + self.config) + elif self.config.kv_connector == "MooncakeConnector": + # Check if MOONCAKE_CONFIG_PATH is set + import os + use_mooncake_distributed_pipe = os.getenv( + 'MOONCAKE_CONFIG_PATH') is not None + + if not use_mooncake_distributed_pipe: + raise ValueError( + "To use MooncakeConnector, you need to pass the ENV: " + "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") + else: + from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501 + MooncakePipe) + logger.info( + "Initializing MooncakeConfig under kv_transfer_config %s", + self.config) + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.producer_buffer: Optional[SimpleBuffer] = None + self.consumer_buffer: Optional[SimpleBuffer] = None + + self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe] + self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe] + self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] + self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] + + # 2 pipes for every rank in the world + port_offset_base = 2 * rank + + # In disaggregated prefill, the prefill vLLM only uses send pipe + # and the decode vLLM only uses recv pipe + if self.config.is_kv_producer: + + if self.config.kv_connector == "PyNcclConnector": + self.producer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + self.producer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base + 1, + device="cpu", + ) + elif self.config.kv_connector == "MooncakeConnector": + self.producer_data_pipe = MooncakePipe( + local_rank=local_rank, + config=self.config, + ) + # We only need to initialize MooncakePipe once + self.producer_signal_pipe = self.producer_data_pipe + + self.producer_buffer = SimpleBuffer(self.producer_signal_pipe, + self.producer_data_pipe, + self.config.kv_buffer_size) + + else: + + # the current vLLM instance is KV consumer, so it needs to connect + # its recv pipe to the send pipe of KV producer + if self.config.kv_connector == "PyNcclConnector": + self.consumer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + self.consumer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base + 1, + device="cpu", + ) + elif self.config.kv_connector == "MooncakeConnector": + self.consumer_data_pipe = MooncakePipe( + local_rank=local_rank, + config=self.config, + ) + self.consumer_signal_pipe = self.consumer_data_pipe + + self.consumer_buffer = SimpleBuffer( + self.consumer_signal_pipe, + self.consumer_data_pipe, + self.config.kv_buffer_size, + ) + + def select(self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: + + assert self.consumer_buffer is not None, "Please initialize the "\ + "consumer buffer before calling select." + return self.consumer_buffer.drop_select(input_tokens, roi) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + assert self.producer_buffer is not None, "Please initialize the "\ + "producer buffer before calling insert." + + self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + num_heads, head_size = self.kv_helper.get_model_args(model_executable) + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You have some decode requests while using " + "SimpleConnector. Their KVCache won't be sent.") + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + key_cache, value_cache = self.kv_helper.get_kv_from_cache( + kv_cache, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + self.insert(current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to --max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + ret = self.select(current_tokens, + torch.ones_like(current_tokens, dtype=bool)) + if ret[0] is None: + # didn't find any match. + bypass_model_exec = False + num_computed_tokens_list.append(0) + continue + + roi: torch.Tensor = ret[1] + keys: torch.Tensor = ret[2] + values: torch.Tensor = ret[3] + hidden: torch.Tensor = ret[4] + + num_computed_tokens = roi.shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), hidden is not None + ]): + bypass_model_exec = False + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for cur_layer in range(start_layer, end_layer): + + layer_id = cur_layer - start_layer + kv_cache = kv_caches[layer_id] + layer = model_executable.model.layers[cur_layer] + + # get remote kvcache + remote_k, remote_v = keys[layer_id], values[layer_id] + + self.kv_helper.put_kv_to_cache(model_executable, remote_k, + remote_v, layer, kv_cache, + slot_mapping, start_pos, + end_pos) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.warning( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def close(self): + self.producer_data_pipe.close() + self.consumer_data_pipe.close() + if self.config.kv_connector == "PyNcclConnector": + self.producer_signal_pipe.close() + self.consumer_signal_pipe.close() + elif self.config.kv_connector == "MooncakeConnector": + # MooncakePipe reuses data_pipe for signal_pipe, so we only have to + # close the data_pipe. + pass diff --git a/distributed/kv_transfer/kv_connector/utils.py b/distributed/kv_transfer/kv_connector/utils.py new file mode 100644 index 0000000..b9bed06 --- /dev/null +++ b/distributed/kv_transfer/kv_connector/utils.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +KV cache helper for store. +""" + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class model_aware_kv_ops_helper: + + def __init__(self, config: VllmConfig): + self.is_deepseek_mla = config.model_config.is_deepseek_mla + self.use_mla_opt = not envs.VLLM_MLA_DISABLE + self.tp_size = config.parallel_config.tensor_parallel_size + + def get_model_args(self, model_executable: torch.nn.Module): + + model_config = model_executable.model.config + self.model_executable = model_executable + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + + # Deepseek's MLA (Multi-head Latent Attention) uses two different + # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. + # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, + # resulting in a kv_cache shape of [num_blks, blk_size, 1, + # kv_lora_rank + qk_rope_head_dim]. + # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading + # to a kv_cache shape of [2, num_blks, blk_size, + # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. + # For more details, see vllm/attention/backends/mla/common.py. + if self.is_deepseek_mla and self.use_mla_opt: + head_size = model_config.kv_lora_rank + \ + model_config.qk_rope_head_dim + num_heads = 1 + elif self.is_deepseek_mla and not self.use_mla_opt: + head_size = model_config.qk_nope_head_dim + \ + model_config.qk_rope_head_dim + else: + head_size = getattr(model_config, "head_dim", None) + if head_size is None: + head_size = int(hidden_size // num_attention_heads) + + return num_heads, head_size + + def get_kv_from_cache(self, kv_cache, num_heads, head_size): + if self.is_deepseek_mla and self.use_mla_opt: + key_cache = kv_cache.reshape(-1, num_heads, head_size) + value_cache = kv_cache.reshape(-1, num_heads, head_size) + else: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + return key_cache, value_cache + + def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values, + layer, kv_cache, slot_mapping, start_pos, end_pos): + + model_config = model_executable.model.config + + if self.is_deepseek_mla and self.use_mla_opt: + layer.self_attn.attn = layer.self_attn.mla_attn + k_c_normed_k_pe = keys.squeeze(1) + k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] + k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] + ops.concat_and_cache_mla( + k_c_normed.to(kv_cache.device), + k_pe.to(kv_cache.device), + kv_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + ) + else: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys.to(key_cache.device), + values.to(value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + +def get_kv_connector_cache_layout(): + vllm_config = get_current_vllm_config() + kv_config = vllm_config.kv_transfer_config + if vllm_config.model_config is None: + logger.warning("Unable to detect current VLLM config. " \ + "Defaulting to NHD kv cache layout.") + else: + use_mla = vllm_config.model_config.use_mla + if not use_mla and kv_config.kv_connector == "NixlConnector": + logger.info("NixlConnector detected. Setting KV cache " \ + "layout to HND for better xfer performance.") + return "HND" + return "NHD" diff --git a/distributed/kv_transfer/kv_connector/v1/__init__.py b/distributed/kv_transfer/kv_connector/v1/__init__.py new file mode 100644 index 0000000..f00f31d --- /dev/null +++ b/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorRole) + +__all__ = ["KVConnectorRole", "KVConnectorBase_V1"] diff --git a/distributed/kv_transfer/kv_connector/v1/base.py b/distributed/kv_transfer/kv_connector/v1/base.py new file mode 100644 index 0000000..f80b5eb --- /dev/null +++ b/distributed/kv_transfer/kv_connector/v1/base.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State +communication in vLLM v1 + +The class provides the following primitives: + Scheduler-side: runs in the scheduler, binds metadata, which + is used by the worker-side to load/save KV cache. + get_num_new_matched_tokens() - get number of new tokens + that exist in the remote KV cache. Might be called multiple + times for a given request and should be side-effect free. + update_state_after_alloc() - update KVConnector state after + temporary buffer alloc by the CacheManager. + request_finished() - called when a request is finished, with + the computed kv cache blocks for the request. + Returns whether KV cache should be freed now or will be + freed asynchronously and optionally returns KV transfer + params. + + Worker-side: runs in each worker, loads/saves KV cache to/from + the Connector based on the metadata. + start_load_kv() - starts loading all KVs (maybe async) + wait_for_layer_load() - blocks until layer i load is done + + save_kv_layer() - starts saving KV for layer i (maybe async) + wait_for_save() - blocks until all saves are done + + get_finished() - called with ids of finished requests, returns + ids of requests that have completed async sending/recving. +""" + +import enum +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class KVConnectorRole(enum.Enum): + # Connector running in the scheduler process + SCHEDULER = 0 + + # Connector running in the worker process + WORKER = 1 + + +class KVConnectorMetadata: + """ + Abstract Metadata used to communicate between the + Scheduler KVConnector and Worker KVConnector. + """ + pass + + +class KVConnectorBase_V1(ABC): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + logger.warning( + "Initializing KVConnectorBase_V1. This API is experimental and " + "subject to change in the future as we iterate the design.") + self._connector_metadata = KVConnectorMetadata() + self._vllm_config = vllm_config + self._role = role + + @property + def role(self) -> KVConnectorRole: + return self._role + + # ============================== + # Worker-side methods + # ============================== + + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + """Set the connector metadata from the scheduler. + + This function should be called by the model runner every time + before the model execution. The metadata will be used for runtime + KV cache loading and saving. + + Args: + connector_metadata (dict): the connector metadata. + """ + self._connector_metadata = connector_metadata + + def clear_connector_metadata(self) -> None: + """Clear the connector metadata. + + This function should be called by the model runner every time + after the model execution. + """ + self._connector_metadata = KVConnectorMetadata() + + def _get_connector_metadata(self) -> KVConnectorMetadata: + """Get the connector metadata. + + This function should only be called inside the connector. + + Returns: + ConnectorMetadata: the connector metadata. + """ + return self._connector_metadata + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: kv_caches: + dictionary of layer names, kv cache + """ + return + + @abstractmethod + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + pass + + @abstractmethod + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + pass + + @abstractmethod + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + pass + + @abstractmethod + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer + (requests that previously returned True from request_finished()), + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + return None, None + + # ============================== + # Scheduler-side methods + # ============================== + + @abstractmethod + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + A tuple with the following elements: + - The number of tokens that can be loaded from the + external KV cache beyond what is already computed. + - `True` if external KV cache tokens will be loaded + asynchronously (between scheduler steps). Must be + 'False' if the first element is 0. + """ + pass + + @abstractmethod + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + + If get_num_new_matched_tokens previously returned True for a + request, this function may be called twice for that same request - + first when blocks are allocated for the connector tokens to be + asynchronously loaded into, and second when any additional blocks + are allocated, after the load/transfer is complete. + + Args: + request (Request): the request object. + blocks (KVCacheBlocks): the blocks allocated for the request. + num_external_tokens (int): the number of tokens that will be + loaded from the external KV cache. + """ + pass + + @abstractmethod + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + pass + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + return False, None diff --git a/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py new file mode 100644 index 0000000..cc1f4ba --- /dev/null +++ b/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + +import torch +from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class LMCacheConnectorV1(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + self._lmcache_engine.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + self._lmcache_engine.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, + **kwargs) + + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + self._lmcache_engine.wait_for_save() + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self._lmcache_engine.get_num_new_matched_tokens( + request, num_computed_tokens), False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + self._lmcache_engine.update_state_after_alloc(request, + num_external_tokens) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + return self._lmcache_engine.build_connector_meta(scheduler_output) diff --git a/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/distributed/kv_transfer/kv_connector/v1/multi_connector.py new file mode 100644 index 0000000..be3c233 --- /dev/null +++ b/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class MultiKVConnectorMetadata(KVConnectorMetadata): + metadata: tuple[KVConnectorMetadata, ...] + extra_async_saves: Optional[dict[str, int]] = None + + +class MultiConnector(KVConnectorBase_V1): + """ + A wrapper for using multiple KVConnectors at the same time. + + The current logic is: + - Load KV from the first connector that advertises available tokens from + get_num_new_matched_tokens(), based on the order in the config. + - Save to all connectors. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._connectors: list[KVConnectorBase_V1] = [] + ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "connectors") + assert ktcs is not None + for ktc in ktcs: + temp_config = copy.copy(vllm_config) + temp_config.kv_transfer_config = KVTransferConfig(**ktc) + self._connectors.append( + KVConnectorFactory.create_connector_v1(temp_config, role)) + + # A mapping from request id to the index of the connector chosen to + # load the request from (if any). + self._requests_to_connector: dict[str, int] = {} + + # Keeps track of *additional* remaining async saves (beyond 1) to be + # finished per request. Not needed for async loads since we only allow + # a single connector to load. + # Propagated from scheduler to worker side via the connector metadata. + self._extra_async_saves: dict[str, int] = {} + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + for c in self._connectors: + c.register_kv_caches(kv_caches) + + # We must override the base class method here because we need to bind + # the metadata to each connector in the order of the connectors in the + # MultiKVConnectorMetadata. + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + assert isinstance(connector_metadata, MultiKVConnectorMetadata) + if connector_metadata.extra_async_saves: + self._extra_async_saves.update( + connector_metadata.extra_async_saves) + for c, cm in zip(self._connectors, connector_metadata.metadata): + c.bind_connector_metadata(cm) + + def clear_connector_metadata(self) -> None: + for c in self._connectors: + c.clear_connector_metadata() + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + for c in self._connectors: + c.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + for c in self._connectors: + c.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + for c in self._connectors: + c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) + + def wait_for_save(self): + for c in self._connectors: + c.wait_for_save() + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + finished_sending: set[str] = set() + finished_recving: set[str] = set() + for c in self._connectors: + sending, recving = c.get_finished(finished_req_ids) + if not recving and not sending: + continue + # Aggregate finished recving request ids. + finished_recving.update(recving or ()) + # Aggregate finished sending request ids - only include + # once we've drained the "extra" count (for cases where + # more than one connector is async-saving the same request). + for req_id in sending or (): + extra_pending = self._extra_async_saves.get(req_id) + if extra_pending is None: + finished_sending.add(req_id) + continue + assert extra_pending > 0 + if extra_pending == 1: + del self._extra_async_saves[req_id] + else: + self._extra_async_saves[req_id] = extra_pending - 1 + + return finished_sending or None, finished_recving or None + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + to_return = (0, False) + for i, c in enumerate(self._connectors): + toks, load_async = c.get_num_new_matched_tokens( + request, num_computed_tokens) + # The first connector that has new matched tokens will be assigned + # to this request. + if to_return[0] == 0 and toks > 0: + self._requests_to_connector[request.request_id] = i + to_return = (toks, load_async) + return to_return + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + chosen_connector = self._requests_to_connector.get( + request.request_id, -1) + empty_blocks = blocks.new_empty() + for i, c in enumerate(self._connectors): + if i == chosen_connector: + # Forward call to the chosen connector (if any). + c.update_state_after_alloc(request, blocks, + num_external_tokens) + else: + # Call with empty blocks for other connectors. + c.update_state_after_alloc(request, empty_blocks, 0) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: + metadata = MultiKVConnectorMetadata(metadata=tuple( + c.build_connector_meta(scheduler_output) + for c in self._connectors)) + if self._extra_async_saves: + metadata.extra_async_saves = self._extra_async_saves + self._extra_async_saves = {} + return metadata + + def request_finished( + self, + request: "Request", + blocks: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + async_saves = 0 + kv_txfer_params = None + for c in self._connectors: + async_save, txfer_params = c.request_finished(request, blocks) + if async_save: + async_saves += 1 + if txfer_params is not None: + if kv_txfer_params is not None: + #TODO we can probably change this to merge the dicts here, + # checking for key clashes. + raise RuntimeError( + "Only one connector can produce KV transfer params") + kv_txfer_params = txfer_params + if async_saves > 1: + self._extra_async_saves[request.request_id] = async_saves - 1 + + # Clean up other state for this request. + self._requests_to_connector.pop(request.request_id, None) + + return async_saves > 0, kv_txfer_params diff --git a/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/distributed/kv_transfer/kv_connector/v1/nixl_connector.py new file mode 100644 index 0000000..7552fc8 --- /dev/null +++ b/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -0,0 +1,1030 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import math +import threading +import time +import uuid +from collections import defaultdict +from collections.abc import Iterator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import msgspec +import torch +import zmq + +from vllm import envs +from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.logger import init_logger +from vllm.platforms import _Backend +from vllm.utils import make_zmq_path, make_zmq_socket, round_down +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +Transfer = tuple[int, float] # (xfer_handle, start_time) +GET_META_MSG = b"get_meta_msg" + +logger = init_logger(__name__) + +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used +try: + from nixl._api import nixl_agent as NixlWrapper + logger.info("NIXL is available") +except ImportError: + logger.warning("NIXL is not available") + NixlWrapper = None + + +class NixlAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + num_blocks: int + tp_size: int + block_len: int + attn_backend_name: str + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_engine_id: str + + +class NixlConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req( + self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + ): + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + ) + + +class NixlConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + self.engine_id = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler : Optional[NixlConnectorScheduler] = \ + NixlConnectorScheduler(vllm_config, str(self.engine_id)) + self.connector_worker: Optional[NixlConnectorWorker] = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = NixlConnectorWorker( + vllm_config, str(self.engine_id)) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, NixlConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """NixlConnector does not do layerwise saving.""" + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """NixlConnector does not save explicitly.""" + pass + + def wait_for_save(self): + """NixlConnector does not save explicitly.""" + pass + + +class NixlConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + self.side_channel_port = ( + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.tensor_parallel_size) + logger.info("Initializing NIXL Scheduler %s", engine_id) + + # Requests that need to start recv. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + "NIXLConnector get_num_new_matched_tokens: " + "num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + # Remote prefill: get all prompt blocks from remote. + assert num_computed_tokens % self.block_size == 0 + rounded_num_prompt_tokens = round_down( + len(request.prompt_token_ids), self.block_size) + count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) + if count > 0: + return count, True + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + + params = request.kv_transfer_params + logger.debug( + "NIXLConnector update_state_after_alloc: " + "num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + if params.get("remote_block_ids"): + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + # If remote_blocks and num_external_tokens = 0, we have + # a full prefix cache hit on the D worker. We need to call + # send_notif in _read_blocks to free the memory on the P. + local_block_ids = (blocks.get_unhashed_block_ids() + if num_external_tokens > 0 else []) + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, local_block_ids) + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", params) + else: + assert num_external_tokens == 0 + # Only trigger 1 KV transfer per request. + params["do_remote_prefill"] = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = NixlConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + params = request.kv_transfer_params + logger.debug( + "NIXLConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", request.status, params) + + if (params is None or not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + return False, None + + # Get computed blocks. + all_full = request.num_computed_tokens % self.block_size == 0 + computed_block_ids = block_ids if all_full else block_ids[:-1] + + # If prompt < block_size, no xfer so free blocks immediately. + delay_free_blocks = len(computed_block_ids) > 0 + + return delay_free_blocks, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, + ) + + +class NixlConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + if NixlWrapper is None: + logger.error("NIXL is not available") + raise RuntimeError("NIXL is not available") + logger.info("Initializing NIXL wrapper") + logger.info("Initializing NIXL worker %s", engine_id) + + # Config. + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + + # Agent. + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. + self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict) + + # NIXL handshake port. + # NOTE(rob): Within a DP group, each DP rank gets its own + # base port (which is sent in the KVTransferParams). + # Each TP rank listens/queries on the base_port + tp_rank. + self.side_channel_port = ( + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.tensor_parallel_size) + + # Metadata. + self.engine_id = engine_id + self.tp_rank = get_tensor_model_parallel_rank() + self.world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() + + # KV Caches and nixl tracking data. + self.kv_caches: dict[str, torch.Tensor] = {} + + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. + self.kv_caches_base_addr: dict[str, list[int]] = {} + + # Number of NIXL regions. Currently one region per cache + # (so 1 per layer for MLA, otherwise 2 per layer) + self.num_regions = 0 + self.num_layers = 0 + + # nixl_prepped_dlist_handle. + self.src_xfer_side_handle: int = 0 + # Map of engine_id -> nixl_prepped_dlist_handle (int)]. + self.dst_xfer_side_handles: dict[str, int] = {} + + # Map of engine_id -> num_blocks. All ranks in the same deployment will + # have the same number of blocks. + self.dst_num_blocks: dict[str, int] = {} + self._registered_descs: list[Any] = [] + + # In progress transfers. + # [req_id -> list[handle]] + self._recving_transfers = defaultdict[str, list[Transfer]](list) + + # Complete transfer tracker. Used by the rank 0 to track finished + # transactions on ranks 1 to N-1. + # [req_id -> count] + self._done_recving_count: defaultdict[str, + int] = defaultdict(lambda: 0) + self._done_sending_count: defaultdict[str, + int] = defaultdict(lambda: 0) + + # Background thread for establishing new connections. + self._nixl_handshake_listener_t: Optional[threading.Thread] = None + + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + # List of block window sizes for each layer for local attention + self.block_window_per_layer: list[Optional[int]] = [] + self.use_mla = self.model_config.use_mla + + backend = get_attn_backend(self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.use_mla) + self.backend_name = backend.get_name() + attn_backend = backend_name_to_enum(self.backend_name) + self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 + logger.debug("Detected attention backend %s", self.backend_name) + + self._tp_size: dict[str, int] = {self.engine_id: self.world_size} + # With heterogeneous TP, P must wait for all assigned D TP workers to + # finish reading before safely freeing the blocks. + self.consumer_notification_counts_by_req = defaultdict[str, int](int) + + @staticmethod + def _nixl_handshake_listener(metadata: NixlAgentMetadata, + ready_event: threading.Event, base_port: int, + tp_rank: int): + """Background thread for getting new NIXL handshakes.""" + # NOTE(rob): this is a simple implementation. We will move + # to a better approach via HTTP endpoint soon. + + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(metadata) + size_in_bytes = len(encoded_data) + logger.debug("Size of encoded NixlAgentMetadata: %s bytes", + str(size_in_bytes)) + + # Listen for new requests for metadata. + host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + path = make_zmq_path("tcp", host, base_port + tp_rank) + logger.debug("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: + ready_event.set() + while True: + identity, _, msg = sock.recv_multipart() + if msg != GET_META_MSG: + logger.warning( + "Connection listener got unexpected message %s", msg) + sock.send_multipart((identity, b"", encoded_data)) + + def _nixl_handshake(self, host: str, port: int): + """Do a NIXL handshake with a remote instance.""" + + start_time = time.perf_counter() + + # NOTE(rob): we need each rank to have a unique port. This is + # a hack to keep us moving. We will switch when moving to etcd + # or where we have a single ZMQ socket in the scheduler. + + def handshake(path: str, rank: int) -> NixlAgentMetadata: + # Send query for the request. + with zmq_ctx(zmq.REQ, path) as sock: + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + + # Register Remote agent. + self.add_remote_agent(metadata, rank) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + return metadata + + # Handshake with remote agent-rank0 first to get the tp_size of remote + path = make_zmq_path("tcp", host, port) + logger.debug("Querying master rank metadata on path: %s", path) + metadata = handshake(path, 0) + + # Handshake only with the other TP remote the current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size + p_remote_rank = self.tp_rank // tp_ratio + if p_remote_rank > 0: + path = make_zmq_path("tcp", host, port + p_remote_rank) + logger.debug("Querying metadata on path: %s at remote rank %s", + path, p_remote_rank) + _ = handshake(path, p_remote_rank) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data in nixl.""" + + _, first_kv_cache = next(iter(kv_caches.items())) + kv_elem_size = first_kv_cache.element_size() + + # TODO(tms): Find a more robust way to detect and handle MLA + # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected + # KV memory layout is HND, as opposed to the default NHD. Note that it + # will only affects the strides. For MLA instead, we make require no + # such thing and resort to the standard layout. + use_mla = len(first_kv_cache.shape) == 3 + assert use_mla == self.use_mla + + # TODO (NickLucche) not compatible with hybrid allocator. Enforce check + # once it goes live, as a single kv layout is expected for xfers. + if use_mla: + # MLA case. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 2 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, kv_latent_dim = block_shape + self.slot_size_bytes = kv_elem_size * kv_latent_dim + else: + # [2 (k and v), num_blocks, ...] + if self._use_flashinfer: + # FlashInfer swaps 2<->num_blocks dimensions. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 4 # [2, block_size, kv_heads, head_dim] + else: + self.num_blocks = first_kv_cache.shape[1] + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, n_kv_heads, head_dim = block_shape[-3:] + # head size in bytes. + self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim + assert block_size == self.block_size + # TODO(tms): self.block_len needs to be per-layer for sliding window, + # hybrid attn, etc + # block size in bytes + self.block_len = kv_elem_size * math.prod(block_shape) + logger.info( + "Registering KV_Caches: use_mla: %s, num_blocks: %s, " + "block_shape: %s, per_layer_kv_cache_shape: %s", use_mla, + self.num_blocks, block_shape, first_kv_cache.shape) + self.dst_num_blocks[self.engine_id] = self.num_blocks + self.kv_caches = kv_caches + kv_caches_base_addr = [] + caches_data = [] + + # Note(tms): I modified this from the original region setup code. + # K and V are now in different regions. Advantage is that we can + # elegantly support MLA and any cases where the K and V tensors + # are non-contiguous (it's not locally guaranteed that they will be) + # Disadvantage is that the encoded NixlAgentMetadata is now larger + # (roughly 8KB vs 5KB). + # Conversely for FlashInfer, K and V are transferred in the same tensor + # to better exploit the memory layout (ie num_blocks is the first dim). + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \ + else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len + caches_data.append( + (base_addr, region_len, cache.device.index, "")) + kv_caches_base_addr.append(base_addr) + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.num_regions = len(caches_data) + self.num_layers = len(self.kv_caches.keys()) + + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + if self.vllm_config.model_config.hf_config.model_type == "llama4": + from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, + Llama4TextConfig) + llama4_config = self.vllm_config.model_config.hf_text_config + no_rope_layers = llama4_config.no_rope_layers + chunk_size = llama4_config.attention_chunk_size + chunk_block_size = math.ceil(chunk_size / self.block_size) + for layer_idx in range(self.num_layers): + # no_rope_layers[layer_idx] == 0 means NoPE (global) + # Any other value means RoPE (local chunked) + is_local_attention = no_rope_layers[layer_idx] != 0 + block_window = chunk_block_size if is_local_attention else None + self.block_window_per_layer.append(block_window) + logger.debug("Llama 4 block window per layer mapping: %s", + self.block_window_per_layer) + assert len(self.block_window_per_layer) == self.num_layers + + descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + logger.debug("Done registering descs") + self._registered_descs.append(descs) + + # Register local/src descr for NIXL xfer. + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + # NOTE With heter-TP, more blocks are prepared than what are + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We + # could create fewer, but then _get_block_descs_ids needs to + # select agent_meta.num_blocks instead of self.num_blocks for + # local descr, and that makes handling regular flow less clean. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + addr = base_addr + block_offset + # (addr, len, device id) + blocks_data.append((addr, self.block_len, self.tp_rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.tp_rank) + + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + # NIXL_INIT_AGENT to be used for preparations of local descs. + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + + # After KV Caches registered, listen for new connections. + metadata = NixlAgentMetadata( + engine_id=self.engine_id, + agent_metadata=self.nixl_wrapper.get_agent_metadata(), + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + num_blocks=self.num_blocks, + tp_size=self.world_size, + block_len=self.block_len, + attn_backend_name=self.backend_name) + ready_event = threading.Event() + self._nixl_handshake_listener_t = threading.Thread( + target=self._nixl_handshake_listener, + args=(metadata, ready_event, self.side_channel_port, self.tp_rank), + daemon=True, + name="nixl_handshake_listener") + self._nixl_handshake_listener_t.start() + ready_event.wait() + + def add_remote_agent(self, + nixl_agent_meta: NixlAgentMetadata, + remote_tp_rank: int = 0): + """ + Add the remote NIXL agent and prepare the descriptors for reading cache + blocks from remote. + + In particular, handle both homogeneous and heterogeneous TP. The former + requires local rank_i to read from remote rank_i. + The latter, assuming D.world_size > P.world_size, requires that two or + more local TP worker share the xfer from a single TP worker. + + Here's an example: + + rank_offset p_remote_tp_rank + (kv split no) + -------------------------------- + 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ] + / + 1 0 Worker1 ---- 2nd half of KV -----/ + + 0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ] + / + 1 1 Worker3 ---- 2nd half of KV -----/ + + + Decoder TP workers Prefix TP workers + (world_size=4) (world_size=2) + tp_ratio = 4 // 2 = 2 + + Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] + then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format. + Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio + first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split + along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. + + Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1. + + Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0 + so that the whole cache is shared by "tp_ratio" D TP workers. + """ # noqa: E501 + engine_id = nixl_agent_meta.engine_id + # TODO re-evaluate refreshing for scaling/recovery + if remote_tp_rank in self._remote_agents.get(engine_id, ()): + return + + if engine_id in self._tp_size: + assert self._tp_size[engine_id] == nixl_agent_meta.tp_size + else: + self._tp_size[engine_id] = nixl_agent_meta.tp_size + # We may eventually enable this after asserting equality in cache + # layout and close outputs. + assert nixl_agent_meta.attn_backend_name == self.backend_name + + self._remote_agents[engine_id][ + remote_tp_rank] = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) + + # Number of D TP workers reading from a single P TP worker. This is + # 1 when P and D `--tensor-parallel-size` match. + assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, ( + "Local TP size must be divisible by remote TP size.") + tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id] + assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" + if self.use_mla: + # With MLA the only difference is in the number of blocks. + remote_block_size = nixl_agent_meta.block_len // ( + self.slot_size_bytes) + assert self.block_len == nixl_agent_meta.block_len + else: + remote_block_size = nixl_agent_meta.block_len // ( + self.slot_size_bytes * tp_ratio) + if self._use_flashinfer: + # Account for joint KV in FlashInfer. + remote_block_size //= 2 + + assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( + "Remote P worker KV layer cache must be of shape [2, N, " + "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." + ) + + assert self.block_size == remote_block_size, "Remote P worker with " \ + "different block size is not supported" + + assert self.num_blocks >= nixl_agent_meta.num_blocks + + # Create dst descs and xfer side handles. TP workers have same #blocks. + if engine_id in self.dst_num_blocks: + assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks + else: + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + + blocks_data = [] + # With homogeneous TP, D pulls the whole kv cache from corresponding + # rank. With heterogeneous TP, prepare the descriptors by splitting the + # P KV cache along kv_head dim, of D worker's kv_head size (D>P). + # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. + p_remote_tp_rank = self.tp_rank // tp_ratio + # Only register the remote's descriptors if current rank pulls from it. + if p_remote_tp_rank == remote_tp_rank: + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + rank_offset = self.tp_rank % tp_ratio * self.block_len \ + if not self.use_mla else 0 + # Register all remote blocks, but only the corresponding kv heads. + for base_addr in nixl_agent_meta.kv_caches_base_addr: + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_len + # For each block, grab the heads chunk belonging to rank_i + # of size remote_nheads // tp_ratio, which correspond to + # self.block_len == remote_block_len//tp_ratio bytes. + addr = base_addr + block_offset + rank_offset + # (addr, len, device id) + blocks_data.append((addr, self.block_len, remote_tp_rank)) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and " + "local rank %s", len(blocks_data), engine_id, remote_tp_rank, + self.tp_rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id][remote_tp_rank], descs) + + def get_finished(self) -> tuple[set[str], set[str]]: + """ + Get requests that are done sending or recving. + + In TP>1 setup, each rank exchanges KVs with its counterpart + ranks independently. get_finished() runs in a worker creates + the done_sending and done_recving sets that are sent to the + scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs + are done before adding to finished, Ranks 1 to N-1 communicate + to Rank 0 once their transaction is done + Rank 0 returns + finished sets to Scheduler only once all ranks are done. + """ + done_sending = self._get_new_notifs() + done_recving = self._pop_done_transfers(self._recving_transfers) + if len(done_sending) > 0 or len(done_recving) > 0: + logger.debug( + "Rank %s, get_finished: %s requests done sending " + "and %s requests done recving", self.tp_rank, + len(done_sending), len(done_recving)) + + if self.world_size == 1: + return done_sending, done_recving + + # Rank 0: get finished from all other ranks. + if self.tp_rank == 0: + for req_id in done_sending: + self._done_sending_count[req_id] += 1 + for req_id in done_recving: + self._done_recving_count[req_id] += 1 + + # Keep track of how many other ranks have finished. + other_ranks_finished_ids: list[str] = [] + for i in range(1, self.world_size): + other_ranks_finished_ids.extend( + self.tp_group.recv_object(src=i)) + for req_id in other_ranks_finished_ids: + if (req_id in self._done_recving_count + or req_id in self._recving_transfers): + self._done_recving_count[req_id] += 1 + else: + self._done_sending_count[req_id] += 1 + + # Return ids that finished on all ranks to the scheduler. + all_done_recving: set[str] = set() + for req_id in list(self._done_recving_count.keys()): + if self._done_recving_count[req_id] == self.world_size: + del self._done_recving_count[req_id] + all_done_recving.add(req_id) + + all_done_sending: set[str] = set() + for req_id in list(self._done_sending_count.keys()): + if self._done_sending_count[req_id] == self.world_size: + del self._done_sending_count[req_id] + all_done_sending.add(req_id) + + return all_done_sending, all_done_recving + + # Ranks 1 to N-1: send finished ids to Rank 0. + else: + finished_req_ids = list(done_recving.union(done_sending)) + self.tp_group.send_object(finished_req_ids, dst=0) + + # Unused as only Rank 0 results are sent to scheduler. + return done_sending, done_recving + + def _get_new_notifs(self) -> set[str]: + """ + Get req_ids which got a remote xfer message. When multiple consumers + are reading from the same producer (heterogeneous TP scenario), wait + for all consumers to be done pulling. + """ + notified_req_ids: set[str] = set() + for notifs in self.nixl_wrapper.get_new_notifs().values(): + for notif in notifs: + req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) + self.consumer_notification_counts_by_req[req_id] += 1 + # Wait all consumers (D) to be done reading before freeing. + if self.consumer_notification_counts_by_req[req_id] == int( + tp_ratio): + notified_req_ids.add(req_id) + del self.consumer_notification_counts_by_req[req_id] + return notified_req_ids + + def _pop_done_transfers( + self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]: + """ + Pop completed xfers by checking for DONE state. + Args: + transfers: dict of req_id -> list[running_xfer] + Returns: + set of req_ids that have all done xfers + """ + done_req_ids: set[str] = set() + for req_id, handles in list(transfers.items()): + for handle, xfer_stime in handles: + xfer_state = self.nixl_wrapper.check_xfer_state(handle) + if xfer_state == "DONE": + self.nixl_wrapper.release_xfer_handle(handle) + done_req_ids.add(req_id) + del transfers[req_id] + elif xfer_state == "PROC": + continue + else: + raise RuntimeError("Transfer failed with state %s", + xfer_state) + return done_req_ids + + def start_load_kv(self, metadata: NixlConnectorMetadata): + """ + Start loading by triggering non-blocking nixl_xfer. + We check for these trnxs to complete in each step(). + """ + for req_id, meta in metadata.requests.items(): + logger.debug( + "start_load_kv for request %s from remote engine %s. " + "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, + meta.remote_engine_id, len(meta.local_block_ids), + len(meta.remote_block_ids)) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_host=meta.remote_host, + remote_port=meta.remote_port, + ) + + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + remote_host: str, + remote_port: int, + dst_engine_id: str, + request_id: str, + ): + # NOTE(rob): this takes ~2s. We need to get this off the hotpath. + if dst_engine_id not in self._remote_agents: + self._nixl_handshake(remote_host, remote_port) + + # NOTE(rob): having the staging blocks be on the READER side is + # not going to work well (since we will have to call rearrange tensors). + # after we detect the txn is complete (which means we cannot make the + # read trxn async easily). If we want to make "READ" happen cleanly, + # then we will need to have the staging blocks on the remote side. + + # NOTE(rob): according to nvidia the staging blocks are used to + # saturate IB with heterogeneous TP sizes. We should remove the staging + # blocks until we are ready. + + # Number of D TP workers that will read from dst P. Propagate tp_ratio + # on notification so that dst worker can wait before freeing blocks. + tp_ratio = self._tp_size[ + self.engine_id] // self._tp_size[dst_engine_id] + notif_id = f"{request_id}:{tp_ratio}".encode() + + # Full prefix cache hit: do not need to read remote blocks, + # just notify P worker that we have the blocks we need. + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: + remote_rank = self.tp_rank // tp_ratio + agent_name = self._remote_agents[dst_engine_id][remote_rank] + self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) + return + + # Partial prefix cache hit: just read uncomputed blocks. + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + + # Get side handles. + local_xfer_side_handle = self.src_xfer_side_handle + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + + # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from + # corresponding rank. With heterogeneous TP, fixing D>P, the D tp + # workers will issue xfers to parts of the P worker remote kv caches. + + # Get descs ids. + local_block_descs_ids: list[int] = [] + remote_block_descs_ids: list[int] = [] + if not self.block_window_per_layer: + # Default case: assume global attention + remote_block_descs_ids = self._get_block_descs_ids( + dst_engine_id, remote_block_ids) + local_block_descs_ids = self._get_block_descs_ids( + self.engine_id, local_block_ids) + else: + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + for layer_idx, block_window in enumerate( + self.block_window_per_layer): + # For each layer: + if block_window is None: + # If not chunked, we just use the + # full block lists (global attention) + layer_local_block_ids = local_block_ids + layer_remote_block_ids = remote_block_ids + else: + # If chunked, get the last block_window blocks + layer_local_block_ids = local_block_ids[-block_window:] + layer_remote_block_ids = remote_block_ids[-block_window:] + + # Get descs ids for the layer. + layer_local_desc_ids = self._get_block_descs_ids( + self.engine_id, layer_local_block_ids, layer_idx) + layer_remote_desc_ids = self._get_block_descs_ids( + dst_engine_id, layer_remote_block_ids, layer_idx) + + local_block_descs_ids.extend(layer_local_desc_ids) + remote_block_descs_ids.extend(layer_remote_desc_ids) + + assert len(local_block_descs_ids) == len(remote_block_descs_ids) + + # Prepare transfer with Nixl. + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + local_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg=notif_id, + ) + + # Begin async xfer. + self.nixl_wrapper.transfer(handle) + + # Use handle to check completion in future step(). + # TODO (NickLucche) surface xfer elapsed time + self._recving_transfers[request_id].append( + (handle, time.perf_counter())) + + def _get_block_descs_ids(self, + engine_id: str, + block_ids: list[int], + layer_idx: Optional[int] = None) -> list[int]: + """ + Get the descs ids for a set of block ids. + If layer_idx is provided, we use the region_ids for the given layer. + Otherwise, we use all regions. + """ + if layer_idx is None: + region_ids = range(self.num_regions) + else: + assert layer_idx < self.num_layers + if self.num_layers < self.num_regions: + # If we have more regions than layers, we assume that + # the regions are organized as [K0, V0, K1, V1, ...] + # and we select K_i and V_i + assert 2 * self.num_layers == self.num_regions + region_ids = range(2 * layer_idx, 2 * layer_idx + 2) + else: + # Otherwise, we assume we have MLA and select i-th layer + assert self.num_layers == self.num_regions + region_ids = range(layer_idx, layer_idx + 1) + + num_blocks = self.dst_num_blocks[engine_id] + + # Compute the desc ids for each block. + descs_ids: list[int] = [] + for reg_id in region_ids: + for block_id in block_ids: + descs_ids.append(reg_id * num_blocks + block_id) + return descs_ids + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + if socket_type not in (zmq.ROUTER, zmq.REQ): + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: Optional[zmq.Context] = None + try: + ctx = zmq.Context() # type: ignore[attr-defined] + yield make_zmq_socket(ctx=ctx, + path=addr, + socket_type=socket_type, + bind=socket_type == zmq.ROUTER) + finally: + if ctx is not None: + ctx.destroy(linger=0) diff --git a/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py new file mode 100644 index 0000000..f86b926 --- /dev/null +++ b/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import hashlib +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import safetensors +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import MLACommonMetadata +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class ReqMeta: + # Request tokens + token_ids: torch.Tensor + # Slot mappings, should have the same length as token_ids + slot_mapping: torch.Tensor + # Is store or load + is_store: bool + + @staticmethod + def make_meta(token_ids: list[int], block_ids: list[int], block_size: int, + is_store: bool) -> "ReqMeta": + valid_num_tokens = align_to_block_size(len(token_ids), block_size) + token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens] + block_ids_tensor = torch.tensor(block_ids) + num_blocks = block_ids_tensor.shape[0] + block_offsets = torch.arange(0, block_size) + slot_mapping = block_offsets.reshape((1, block_size)) + \ + block_ids_tensor.reshape((num_blocks, 1)) * block_size + slot_mapping = slot_mapping.flatten()[:valid_num_tokens] + return ReqMeta( + token_ids=token_ids_tensor, + slot_mapping=slot_mapping, + is_store=is_store, + ) + + +@dataclass +class SharedStorageConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] + + def __init__(self): + self.requests = [] + + def add_request( + self, + token_ids: list[int], + block_ids: list[int], + block_size: int, + is_store: bool, + ) -> None: + self.requests.append( + ReqMeta.make_meta(token_ids, block_ids, block_size, is_store)) + + +class SharedStorageConnector(KVConnectorBase_V1): + # NOTE: This is Simple debug implementation of the KV connector. + # It save / load the KV cache to / from the disk. + # It does extra work which will overwrite the existing prefix-cache in GPU + # - to remove the overhead, need to add some "mask" in the ReqMeta class + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._block_size = vllm_config.cache_config.block_size + self._requests_need_load: dict[str, Request] = {} + transfer_config = vllm_config.kv_transfer_config + self._storage_path = transfer_config.get_from_extra_config( + "shared_storage_path", "/tmp") + logger.info(vllm_config.kv_transfer_config) + logger.info("Shared storage path is %s", self._storage_path) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """Start loading the KV cache from the connector buffer to vLLM's + paged KV buffer. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + """ + attn_metadata = forward_context.attn_metadata + + def inject_kv_into_layer( + dst_kv_cache_layer: torch.Tensor, + src_kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + """Inject the KV cache into the layer. + + Args: + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_pages, page_size, xxx] if not + using MLA, [num_pages, page_size, xxx] otherwise. + src_kv_cache (torch.Tensor): the source KV cache. In shape + [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] + otherwise. + slot_mapping (torch.Tensor): the slot mapping. In shape + [num_tokens]. + """ + dst_kv_cache_layer_shape = dst_kv_cache_layer.shape + if isinstance(attn_metadata, MLACommonMetadata): + num_pages = dst_kv_cache_layer_shape[0] + page_size = dst_kv_cache_layer_shape[1] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + num_pages * page_size, -1) + dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + else: + num_pages = dst_kv_cache_layer_shape[1] + page_size = dst_kv_cache_layer_shape[2] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + 2, num_pages * page_size, -1) + dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + + # Get the metadata + metadata: KVConnectorMetadata = self._get_connector_metadata() + assert isinstance(metadata, SharedStorageConnectorMetadata) + + if metadata is None: + logger.warning( + "In connector.start_load_kv, but the connector metadata is None" + ) + return + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + logger.warning( + "In connector.start_load_kv, but the attn_metadata is None") + return + + # Load the KV for each request each layer + for request in metadata.requests: + if request.is_store: + continue + logger.info("Inject KV cache of %d tokens to the paged memory", + len(request.slot_mapping)) + for layer_name in forward_context.no_compile_layers: + attn_layer = forward_context.no_compile_layers[layer_name] + kv_cache_layer = attn_layer.kv_cache[\ + forward_context.virtual_engine] + + filename = self._generate_filename_debug( + layer_name, request.token_ids) + kv_cache = safetensors.torch.load_file( + filename)["kv_cache"].cuda() + inject_kv_into_layer(kv_cache_layer, kv_cache, + request.slot_mapping) + + def wait_for_layer_load(self, layer_name: str) -> None: + """Blocking until the KV for a specific layer is loaded into vLLM's + paged buffer. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + return + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + + def extract_kv_from_layer( + layer: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> torch.Tensor: + """Extract the KV cache from the layer. + + Assume the shape of the layer is (2, num_pages, page_size, xxx) + if MLA is not used, and (num_pages, page_size, xxx) otherwise. + """ + if isinstance(attn_metadata, MLACommonMetadata): + num_pages, page_size = layer.shape[0], layer.shape[1] + return layer.reshape(num_pages * page_size, -1)[slot_mapping, + ...] + num_pages, page_size = layer.shape[1], layer.shape[2] + return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, + ...] + + connector_metadata = self._get_connector_metadata() + assert isinstance(connector_metadata, SharedStorageConnectorMetadata) + for request in connector_metadata.requests: + if request.is_store: + filename = self._generate_filename_debug( + layer_name, request.token_ids) + kv_cache = extract_kv_from_layer(kv_layer, + request.slot_mapping) + tensors = {"kv_cache": kv_cache.detach().cpu()} + safetensors.torch.save_file(tensors, filename) + + def wait_for_save(self): + return + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + # NOTE: in this debug implementation, we assume that the prompt is + # cached_prompt + newly_generated_single_token + # Therefore, we use prompt_token_ids[:-1] to determine the folder name + + # NOTE: in current v1 scheduler, the num_computed_tokens is aligned + # with the block granularity. And it expects the returned blocks and + # num_computed_tokens to also be aligned with the block granularity. + if not self._found_match_for_request(request): + return 0, False + + logger.info("External Cache Hit!") + + # Now, first num_tokens_to_check tokens are hit, we need to prepare + # the metadata for the worker connector to correctly load the KV + num_tokens_to_check = align_to_block_size( + len(request.prompt_token_ids) - 1, self._block_size) + + return num_tokens_to_check - num_computed_tokens, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + + If blocks were allocated, add to _requests_need_load, + such that we load the KVs in the next forward pass. + """ + if num_external_tokens > 0: + self._requests_need_load[request.request_id] = request + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = SharedStorageConnectorMetadata() + + total_need_load = 0 + for new_req in scheduler_output.scheduled_new_reqs: + if new_req.req_id in self._requests_need_load: + meta.add_request(token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=False) + total_need_load += 1 + else: + # NOTE: here, we set the store and load being exclusive, + # but a single request can have both store and load. + # NOTE(rob): for this debug implementation, we only cache + # the original prompt tokens. + if not self._found_match_for_request(new_req): + meta.add_request(token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=True) + + for cached_req in scheduler_output.scheduled_cached_reqs: + # NOTE(rob): here we rely on the resumed requests being + # the first N requests in the list scheduled_cache_reqs. + if not cached_req.resumed_from_preemption: + break + if cached_req.req_id in self._requests_need_load: + # NOTE(rob): cached_req_data does not have the full + # list of token ids (only new tokens). So we look it + # up in the actual request object. + request = self._requests_need_load[cached_req.req_id] + total_tokens = (len(cached_req.new_token_ids) + + cached_req.num_computed_tokens) + token_ids = request.all_token_ids[:total_tokens] + + # NOTE(rob): For resumed req, new_block_ids is all + # of the block_ids for the request. + block_ids = cached_req.new_block_ids[0] + + meta.add_request(token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + is_store=False) + total_need_load += 1 + + assert total_need_load == len(self._requests_need_load) + self._requests_need_load.clear() + return meta + + # ============================== + # Helper functions + # ============================== + + def _found_match_for_request( + self, + request: "Request", + ) -> bool: + """Check if the cache is hit for the request. + """ + num_tokens_to_check = align_to_block_size( + len(request.prompt_token_ids) - 1, self._block_size) + foldername = self._generate_foldername_debug(torch.tensor( + request.prompt_token_ids)[:num_tokens_to_check], + create_folder=False) + return os.path.exists(foldername) + + def _generate_foldername_debug( + self, + input_ids: torch.Tensor, + create_folder=False, + ) -> str: + """Generate a folder name based on the hash of the bytes of the input + ids. + """ + input_ids_bytes = input_ids.numpy().tobytes() + input_ids_hash = hashlib.md5(input_ids_bytes, + usedforsecurity=False).hexdigest() + foldername = os.path.join(self._storage_path, input_ids_hash) + if create_folder: + os.makedirs(foldername, exist_ok=True) + return foldername + + def _generate_filename_debug( + self, + layer_name: str, + input_ids: torch.Tensor, + ) -> str: + """Generate a file name based on the layer name and the hash + of the bytes of the input ids. + """ + foldername = self._generate_foldername_debug(input_ids, + create_folder=True) + return os.path.join(foldername, f"{layer_name}.safetensors") + + +def align_to_block_size(num_tokens: int, block_size) -> int: + """Align the number of tokens to the block size. + """ + return (num_tokens - 1) // block_size * block_size diff --git a/distributed/kv_transfer/kv_connector_agent.py b/distributed/kv_transfer/kv_connector_agent.py new file mode 100644 index 0000000..8633fda --- /dev/null +++ b/distributed/kv_transfer/kv_connector_agent.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""A centralized entrypoint to perform distributed KV cache transfer. + +This implementation is a shim wrapper on two APIs exposed by `kv_connector`: +1. `send_kv_caches_and_hidden_states` +2. `recv_kv_caches_and_hidden_states +""" +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + from vllm.config import VllmConfig + +import torch + +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + + +class KVTransferAgent: + """ + A class designated for distributed KV transfer + + Target use cases: + 1. Disaggregated prefill + 2. Remote KV cache storage + """ + + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + + self.config = config + + if config.kv_transfer_config is None: + raise ValueError("KVTransferConfig is not set in the VllmConfig," + " cannot initialize KVConnector.") + + assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ + "TransferAgent should only be used when kv_connector is set." + + self.connector = KVConnectorFactory.create_connector_v0( + rank, local_rank, config) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + self.connector.send_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches, + hidden_or_intermediate_states) + + def close(self) -> None: + self.connector.close() + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + return self.connector.recv_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches) diff --git a/distributed/kv_transfer/kv_lookup_buffer/__init__.py b/distributed/kv_transfer/kv_lookup_buffer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/distributed/kv_transfer/kv_lookup_buffer/base.py b/distributed/kv_transfer/kv_lookup_buffer/base.py new file mode 100644 index 0000000..eef1426 --- /dev/null +++ b/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file contains a new class `KVLookupBufferBase` that allows developers to +think of KV cache operations as inserting new KV cache entries (`insert`) +into the lookup buffer and querying existing KV caches (`drop_select`) +from the lookup buffer. + +This file also contains a new class `KVStoreBufferBase` that allows developers +to manage the KVCache buffer as a simple key-value storage buffer with basic +put/get operations. + +These classes above are abstracted behind class `KVCacheBufferBase`. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class KVCacheBufferBase(ABC): + """ + Abstract base class for a KVCache buffer. + """ + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + KVCache buffer when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + +class KVLookupBufferBase(KVCacheBufferBase): + """ + Abstract base class for a KVCache lookup buffer. + + This class provides an abstraction for a key-value (KV) cache lookup buffer. + + The key of the lookup buffer: + - input_tokens: token IDs of the request + - roi: a binary mask on top of input_tokens. + - Purpose of roi: Since KV cache may only be available for a subset of + tokens in the input (for example, when vLLM is connected to an external + KV cache service), roi specifies the subset of tokens that the KV cache + is associated with. + - NOTE: roi can be further extended to describe which part of KV the + current process is holding (each process may only hold a part of KV + due to TP and PP). This is not implemented for now. + + The value of the lookup buffer: + - key: the key tensor in the KV cache + - value: the value tensor in the KV cache + - hidden: the final hidden state generated by model forwarding. This allows + vLLM to bypass further model forwarding by transmitting the hidden state. + """ + + @abstractmethod + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + """Insert into the lookup buffer. + + The functionality is similar to the following python statement + ``` + buffer[input_tokens, roi] = [key, value, hidden] + ``` + + FIXME: in the future, we should only have two arguments, key and value, + where key is a tensor dict and value is a tensor dict. + + FIXME: we should transmit both sampler outputs and the hidden states. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + key (torch.Tensor): The key tensor in the KV cache. + value (torch.Tensor): The value tensor in the KV cache. + hidden (torch.Tensor): The final hidden state tensor generated + during model forwarding to bypass model + forwarding. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: + """Select and *drop* KV cache entries from the lookup buffer. + + The functionality is similar to the following python statements + ``` + ret = buffer.pop(input_tokens, roi) + return ret + ``` + + If `input_tokens` and `roi` is `None`, it means selecting any of the + KV caches in the buffer, return, and remove it from the buffer, useful + when offloading KV cache to KV cache storage service. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + + Returns: + list[Optional[torch.Tensor]]: A list of tensors. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + +class KVStoreBufferBase(KVCacheBufferBase): + """ + Abstract base class for a KVCache storage buffer with key-value semantics. + This class provides a simple key-value storage buffer abstract with basic + put/get operations, which enables flexible KVCache transfer granular + control. + + The functionality is similar to a distributed key-value store, where: + - Key: A unique string identifier for the cached entry + - Value: + - Tensor to be stored and retrieved + - None (indicating deletion or empty value) + """ + + @abstractmethod + def put( + self, + key: str, + value: Optional[torch.Tensor], + ) -> None: + """Store a key-value pair in the buffer. + + Args: + key (str): Unique identifier for a tensor, this tensor could be the + key cache tensor, value cache tensor, or hidden state tensor + generated during model forwarding. + + value (Optional[torch.Tensor]): Tensor to be stored. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def get( + self, + key: str, + ) -> Optional[torch.Tensor]: + """Retrieve a value from the buffer by key. + + Args: + key (str): Unique identifier for a tensor, this tensor could be the + key cache tensor, value cache tensor, or hidden state tensor + generated during model forwarding. + + Returns: + Optional[torch.Tensor]: Stored tensor if exists, None otherwise. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py b/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py new file mode 100644 index 0000000..4381aad --- /dev/null +++ b/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file contains a new class `MooncakeStore` that allows developers to +think of KV cache transfer operations as putting new KV cache entries +into a remote KVStore-based lookup buffer and getting existing KV caches +from this remote lookup buffer. +""" +import json +import os +from dataclasses import dataclass +from typing import Optional + +import torch +from safetensors.torch import load as safetensors_load +from safetensors.torch import save as safetensors_save + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVStoreBufferBase) +from vllm.logger import init_logger + +DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB + +logger = init_logger(__name__) + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: int + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + + @staticmethod + def from_file(file_path: str) -> 'MooncakeStoreConfig': + """Load the config from a JSON file.""" + with open(file_path) as fin: + config = json.load(fin) + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname"), + metadata_server=config.get("metadata_server"), + global_segment_size=config.get("global_segment_size", + DEFAULT_GLOBAL_SEGMENT_SIZE), + local_buffer_size=config.get("local_buffer_size", + DEFAULT_LOCAL_BUFFER_SIZE), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + master_server_address=config.get("master_server_address"), + ) + + @staticmethod + def load_from_env() -> 'MooncakeStoreConfig': + """Load config from a file specified in the environment variable.""" + config_file_path = os.getenv('MOONCAKE_CONFIG_PATH') + if config_file_path is None: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeStoreConfig.from_file(config_file_path) + + +class MooncakeStore(KVStoreBufferBase): + + def __init__( + self, + config: VllmConfig, + ): + + try: + from mooncake.store import MooncakeDistributedStore + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + + try: + self.store = MooncakeDistributedStore() + self.config = MooncakeStoreConfig.load_from_env() + logger.info("Mooncake Configuration loaded successfully.") + + self.store.setup(self.config.local_hostname, + self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, self.config.device_name, + self.config.master_server_address) + + except ValueError as e: + logger.error("Configuration loading failed: %s", e) + raise + except Exception as exc: + logger.error( + "An error occurred while loading the configuration: %s", exc) + raise + + def close(self): + # MooncakeDistributedStore will automatically call the destructor, so + # it is unnecessary to close it manually. + pass + + def put( + self, + key: str, + value: Optional[torch.Tensor], + ) -> None: + # A message queue needs to be introduced before making it asynchronous. + if value is not None: + self._put_impl(key, value) + + def get( + self, + key: str, + ) -> Optional[torch.Tensor]: + # A message queue needs to be introduced before making it asynchronous. + value = self._get_impl(key) + return value + + def _put_impl( + self, + key: str, + value: torch.Tensor, + ) -> None: + """Put KVCache to Mooncake Store""" + device_id = value.device.index if value.device.type == 'cuda' else -1 + device_tensor = torch.tensor(device_id, dtype=torch.int32) + value_bytes = safetensors_save({ + "tensor": value, + "device_id": device_tensor + }) + try: + self.store.put(key, value_bytes) + except TypeError as err: + logger.error("Failed to put value into Mooncake Store: %s", err) + raise TypeError("Mooncake Store Put Type Error.") from err + + def _get_impl( + self, + key: str, + ) -> Optional[torch.Tensor]: + """Get KVCache from Mooncake Store""" + try: + data = self.store.get(key) + except TypeError as err: + logger.error("Failed to get value from Mooncake Store: %s", err) + raise TypeError("Mooncake Store Get Type Error.") from err + + if data: + loaded_tensors = safetensors_load(data) + tensor = loaded_tensors["tensor"] + device_id_tensor = loaded_tensors["device_id"] + device_id = int(device_id_tensor.item()) + device = torch.device( + 'cuda', device_id) if device_id >= 0 else torch.device('cpu') + return tensor.to(device) + + return None diff --git a/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py new file mode 100644 index 0000000..a0ff7c3 --- /dev/null +++ b/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" + Implements a distributed key-value (KV) cache transfer mechanism. + + Key Features: + - Distributed KV cache transmission using PyNccl pipes. + - Non-blocking `insert`, blocking `drop_select`. + - Use CPU signal pipe to avoid racing condition + - Handles buffer size constraints and provide backpressure mechanism to + stop the prefill instance when the decode instance is slow. +""" +import threading +from collections import deque +from typing import Optional, Union + +import torch + +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SimpleBuffer(KVLookupBufferBase): + + def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, + buffer_size_thresh: float): + """ + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use + CPU recv to listen to new request. + + data_pipe: on device (e.g. GPU) + """ + + self.buffer: deque[list[torch.Tensor]] = deque() + + self.buffer_size = 0 + self.buffer_size_threshold = buffer_size_thresh + self.buffer_cv = threading.Condition() + self.signal_pipe = signal_pipe + self.data_pipe = data_pipe + self.request_handling_thread: Optional[threading.Thread] = None + + self.normal_signal = torch.tensor([0], device="cpu") + self.end_signal = None + + def _matches(self, tokens_roi_sender: list[torch.Tensor], + tokens_roi_recver: list[torch.Tensor]): + + # tokens_roi_sender: tokens and roi of the producer (in the buffer) + # tokens_roi_recver: tokens and roi of the consumer (query) + + tokens_sender = tokens_roi_sender[0] + tokens_recver = tokens_roi_recver[0] + roi_sender = tokens_roi_sender[1] + roi_recver = tokens_roi_recver[1] + + if tokens_recver is None: + # consumer sends an empty request + # semantics: DROP SELECT * LIMIT 1 + # so any of the data in the buffer can be drop-selected + return True + + # Assuming that roi is a binary mask on tokens + tokens_sender = tokens_sender[roi_sender] + tokens_recver = tokens_recver[roi_recver] + + # simple common prefix matching + min_length = min(len(tokens_sender), len(tokens_recver)) + if torch.allclose(tokens_sender[:min_length], + tokens_recver[:min_length]): + return min_length + + return 0 + + def _send_tensor_and_dec_size(self, + tensor: Optional[torch.Tensor]) -> None: + + assert tensor is not None, "Use self.data_pipe.send(None) instead" + self.buffer_size -= tensor.element_size() * tensor.numel() + if tensor.dtype == torch.bool: + tensor = tensor.float() + self.data_pipe.send_tensor(tensor) + + def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]): + + if isinstance(data, torch.Tensor): + return data.element_size() * data.numel() + if not data: + # cannot perform `not data` on a tensor + # so this check needs to go after the check above + return 0 + + raise AssertionError(f"Unknown data type {type(data)}") + + def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor): + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + if isinstance(key, torch.Tensor): + key = key.clone() + if isinstance(value, torch.Tensor): + value = value.clone() + if isinstance(hidden, torch.Tensor): + hidden = hidden.clone() + + buffer_item = [input_tokens, roi, key, value, hidden] + data_size = sum([self._get_element_size(data) for data in buffer_item]) + + with self.buffer_cv: + if self.buffer_size + data_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size + data_size > self.buffer_size_threshold: + self.buffer_cv.wait() + + self.buffer_size += data_size + self.buffer.append(buffer_item) + self.buffer_cv.notify() + + def _is_end_signal(self, signal): + return signal is None + + def drop_select_handler(self): + + try: + + while True: + signal = self.signal_pipe.recv_tensor() + if self._is_end_signal(signal): + logger.info("Received end signal!") + break + + input_tokens = self.data_pipe.recv_tensor() + + roi = self.data_pipe.recv_tensor() + assert roi is not None, "Please provide the roi when sending "\ + "drop-select request" + roi = (roi > 0.5) + tokens_roi_recver = [input_tokens, roi] + + def is_buffer_available( + tokens_roi_recver: list[torch.Tensor], ) -> bool: + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + for _ in range(len(self.buffer)): + if self._matches(self.buffer[0], + tokens_roi_recver) > 0: + return True + # rotate the element we just accessed to the end + self.buffer.rotate(-1) + return False + + with self.buffer_cv: + while not is_buffer_available(tokens_roi_recver): + logger.debug( + "KV transfer buffer is not available. Waiting...") + self.buffer_cv.wait() + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + self.buffer_cv.notify() + + except RuntimeError as e: + if 'Connection closed by peer' not in str(e): + raise e + + logger.debug("Closing drop_select_handler") + + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: + + assert self.request_handling_thread is None, \ + "drop_select should be called by the KV cache consumer "\ + "(e.g. the decode vLLM instance)" + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone().float() + + self.signal_pipe.send_tensor(self.normal_signal) + self.data_pipe.send_tensor(input_tokens) + self.data_pipe.send_tensor(roi) + + input_tokens = self.data_pipe.recv_tensor() + roi = self.data_pipe.recv_tensor() + if roi is not None: + # convert from float tensor to bool tensor + # as PyNccl does not support sending bool tensor + roi = (roi > 0.5) + key = self.data_pipe.recv_tensor() + value = self.data_pipe.recv_tensor() + hidden = self.data_pipe.recv_tensor() + + return [input_tokens, roi, key, value, hidden] + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + self._add_to_buffer(input_tokens, roi, key, value, hidden) + + # when calling the insert, the current process is a sender + # need to launch the request handler and start listening to request. + if self.request_handling_thread is None: + self.request_handling_thread = threading.Thread( + target=self.drop_select_handler) + self.request_handling_thread.start() + + def close(self): + + if hasattr(self, "request_handling_thread" + ) and self.request_handling_thread is not None: + self.request_handling_thread.join() + + else: + # TODO: have a explicit close signal and have a explicit way to + # check if it's requester + self.signal_pipe.send_tensor(self.end_signal) diff --git a/distributed/kv_transfer/kv_pipe/__init__.py b/distributed/kv_transfer/kv_pipe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/distributed/kv_transfer/kv_pipe/base.py b/distributed/kv_transfer/kv_pipe/base.py new file mode 100644 index 0000000..1423fd0 --- /dev/null +++ b/distributed/kv_transfer/kv_pipe/base.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file defines an interface `KVPipeBase` +that provides an abstraction for sending and receiving tensors, or None, via +distributed communications. + +All classes instantiated from this interface are assumed to be a FIFO pipe. + +If your distributed communication platform already supports key-value lookup, +you can bypass this interface and directly start from `kv_lookup_buffer`. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class KVPipeBase(ABC): + """ + This class provides an interface for sending and receiving tensors, or + None, by distributed communications. + """ + + @abstractmethod + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Send a tensor, or None, via the pipe. + + Need to support sending None -- important for error handling. + + TODO: add a `key` argument so that we can use traditional + key-value database as the distributed communication mechanism behind + the pipe. + + Args: + tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receive a tensor (can be None) from the pipeline. + + Returns: + Optional[torch.Tensor]: The tensor received from the pipeline. Can + be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the pipeline and release resources. + + This method is responsible for closing the communication pipeline + and releasing any resources associated with it. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/distributed/kv_transfer/kv_pipe/mooncake_pipe.py new file mode 100644 index 0000000..9f3494b --- /dev/null +++ b/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import os +import struct +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import zmq +from safetensors.torch import load as safetensors_load +from safetensors.torch import save as safetensors_save + +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.logger import init_logger + +logger = init_logger(__name__) +NONE_INT = -150886311 + + +@dataclass +class MooncakeTransferEngineConfig: + prefill_url: str + decode_url: str + metadata_backend: Union[str, None] + metadata_server: str + protocol: str + device_name: str + + @staticmethod + def from_file(file_path: str) -> 'MooncakeTransferEngineConfig': + """Load the config from a JSON file.""" + with open(file_path) as fin: + config = json.load(fin) + return MooncakeTransferEngineConfig( + prefill_url=config.get("prefill_url"), + decode_url=config.get("decode_url"), + metadata_backend=config.get("metadata_backend", None), + metadata_server=config.get("metadata_server"), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + ) + + @staticmethod + def load_from_env() -> 'MooncakeTransferEngineConfig': + """Load config from a file specified in the environment variable.""" + config_file_path = os.getenv('MOONCAKE_CONFIG_PATH') + if config_file_path is None: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeTransferEngineConfig.from_file(config_file_path) + + +class MooncakeTransferEngine: + """Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ.""" + + def __init__(self, kv_rank: int, local_rank: int): + try: + from mooncake.engine import TransferEngine + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + + self.engine = TransferEngine() + self.local_rank = local_rank + + try: + self.config = MooncakeTransferEngineConfig.load_from_env() + logger.info("Mooncake Configuration loaded successfully.") + except ValueError as e: + logger.error(e) + raise + except Exception as exc: + logger.error( + "An error occurred while loading the configuration: %s", exc) + raise + prefill_host, base_prefill_port = self.config.prefill_url.split(':') + decode_host, base_decode_port = self.config.decode_url.split(':') + + # Avoid ports conflict when running prefill and decode on the same node + if prefill_host == decode_host and \ + base_prefill_port == base_decode_port: + base_decode_port = str(int(base_decode_port) + 100) + + prefill_port = int(base_prefill_port) + self.local_rank + decode_port = int(base_decode_port) + self.local_rank + self.prefill_url = ':'.join([prefill_host, str(prefill_port)]) + self.decode_url = ':'.join([decode_host, str(decode_port)]) + + self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url, + self.config.metadata_server, self.config.protocol, + self.config.device_name, self.config.metadata_backend) + + self.remote_url = (self.decode_url + if kv_rank == 0 else self.prefill_url) + + # Initialize ZeroMQ context and sockets + self.context = zmq.Context() # type: ignore[attr-defined] + self.sender_socket = self.context.socket(zmq.constants.PUSH) + self.receiver_socket = self.context.socket(zmq.constants.PULL) + self.sender_ack = self.context.socket(zmq.constants.PULL) + self.receiver_ack = self.context.socket(zmq.constants.PUSH) + + self.buffer_cleaner = ThreadPoolExecutor(max_workers=1) + self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port, + decode_host, base_decode_port) + + def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str, + d_host: str, d_port: str) -> None: + """Set up ZeroMQ sockets for sending and receiving data.""" + # Offsets < 8 are left for initialization in case tp and pp are enabled + p_rank_offset = int(p_port) + 8 + self.local_rank * 2 + d_rank_offset = int(d_port) + 8 + self.local_rank * 2 + if kv_rank == 0: + self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}") + self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}") + self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}") + self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}") + else: + self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}") + self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}") + self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}") + self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}") + + def initialize(self, local_hostname: str, metadata_server: str, + protocol: str, device_name: str, + metadata_backend: Union[str, None]) -> None: + """Initialize the mooncake instance.""" + if metadata_backend is None: + self.engine.initialize(local_hostname, metadata_server, protocol, + device_name) + else: + supported_backend = ["etcd", "redis"] + metadata_backend = metadata_backend.lower() + if metadata_backend not in supported_backend: + raise ValueError( + "Mooncake Configuration error. `metadata_backend`" + f" should be one of {supported_backend}.") + + self.engine.initialize_ext(local_hostname, metadata_server, + protocol, device_name, metadata_backend) + + def allocate_managed_buffer(self, length: int) -> int: + """Allocate a managed buffer of the specified length.""" + ret = self.engine.allocate_managed_buffer(length) + if ret <= 0: + logger.error("Allocation Return Error") + raise Exception("Allocation Return Error") + return ret + + def free_managed_buffer(self, buffer: int, length: int) -> int: + """Free a previously allocated managed buffer.""" + return self.engine.free_managed_buffer(buffer, length) + + def transfer_sync(self, buffer: int, peer_buffer_address: int, + length: int) -> int: + """Synchronously transfer data to the specified address.""" + ret = self.engine.transfer_sync_read(self.remote_url, buffer, + peer_buffer_address, length) + if ret < 0: + logger.error("Transfer Return Error") + raise Exception("Transfer Return Error") + return ret + + def write_bytes_to_buffer(self, buffer: int, user_data: bytes, + length: int) -> int: + """Write bytes to the allocated buffer.""" + return self.engine.write_bytes_to_buffer(buffer, user_data, length) + + def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes: + """Read bytes from the allocated buffer.""" + return self.engine.read_bytes_from_buffer(buffer, length) + + def wait_for_ack(self, src_ptr: int, length: int) -> None: + """Asynchronously wait for ACK from the receiver.""" + ack = self.sender_ack.recv() + if ack != b'ACK': + logger.error("Failed to receive ACK from the receiver") + + self.free_managed_buffer(src_ptr, length) + + def send_bytes(self, user_data: bytes) -> None: + """Send bytes to the remote process.""" + length = len(user_data) + src_ptr = self.allocate_managed_buffer(length) + self.write_bytes_to_buffer(src_ptr, user_data, length) + self.sender_socket.send_multipart( + [struct.pack("!Q", src_ptr), + struct.pack("!Q", length)]) + self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) + + def recv_bytes(self) -> bytes: + """Receive bytes from the remote process.""" + data = self.receiver_socket.recv_multipart() + src_ptr = struct.unpack("!Q", data[0])[0] + length = struct.unpack("!Q", data[1])[0] + dst_ptr = self.allocate_managed_buffer(length) + self.transfer_sync(dst_ptr, src_ptr, length) + ret = self.read_bytes_from_buffer(dst_ptr, length) + + # Buffer cleanup + self.receiver_ack.send(b'ACK') + self.free_managed_buffer(dst_ptr, length) + + return ret + + +class MooncakePipe(KVPipeBase): + """MooncakeTransferEngine based Pipe implementation.""" + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + device: Optional[str] = None): + """Initialize the mooncake pipe and set related parameters.""" + self.config = config + self.local_rank = local_rank + self.kv_rank = self.config.kv_rank + if device is None: + self.device = self._select_device(self.config.kv_buffer_device) + else: + self.device = self._select_device(device) + + self.transfer_engine = MooncakeTransferEngine(self.kv_rank, + self.local_rank) + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.none_tensor = torch.tensor([NONE_INT], device=self.device) + + def _select_device(self, device: str) -> torch.device: + """Select available device (CUDA or CPU).""" + logger.info("Selecting device: %s", device) + if device == "cuda": + return torch.device(f"cuda:{self.local_rank}") + else: + return torch.device("cpu") + + def tensor_hash(self, tensor: torch.Tensor) -> int: + """Calculate the hash value of the tensor.""" + return hash(tensor.data_ptr()) + + def _send_impl(self, tensor: torch.Tensor) -> None: + """Implement the tensor sending logic using safetensors.""" + self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor})) + + def _recv_impl(self) -> torch.Tensor: + """Implement the tensor receiving logic using safetensors.""" + data = self.transfer_engine.recv_bytes() + return safetensors_load(data)["tensor"].to(self.device) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Send tensor to the target process.""" + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + tensor = tensor if tensor is not None else self.none_tensor + assert (len(tensor.shape) > 0) + self.transport_thread.submit(self._send_impl, tensor) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receive tensor from other processes.""" + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + tensor = self.transport_thread.submit(self._recv_impl).result() + if tensor.numel() == 1 and tensor.item() == NONE_INT: + return None + else: + return tensor + + def close(self) -> None: + """Cleanup logic when closing the pipe.""" + self.transfer_engine.sender_socket.close() + self.transfer_engine.receiver_socket.close() + self.transfer_engine.sender_ack.close() + self.transfer_engine.receiver_ack.close() + self.transfer_engine.context.term() # Terminate the ZMQ context + logger.info("Closed the transfer engine and cleaned up resources.") diff --git a/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/distributed/kv_transfer/kv_pipe/pynccl_pipe.py new file mode 100644 index 0000000..09de0b6 --- /dev/null +++ b/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" + This module implements a PyNccl pipe for sending and receiving + Optional[torch.Tensor] between distributed ranks with advanced + communication features. + + Key Features: + - Supports sending and receiving tensors with metadata + - Handles both CUDA and CPU device communications + - Implements a non-blocking tensor transfer mechanism + - Manages buffer size and provides backpressure control + - Supports distributed process groups with configurable parameters +""" + +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Optional + +import torch + +from vllm.config import KVTransferConfig +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class BrokenPipeException(Exception): + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +Metadata = dict[str, Optional[torch.Tensor]] + + +class PyNcclPipe(KVPipeBase): + + METADATA_LENGTH = 16 + MAX_TENSOR_DIMENSIONS = 14 + METADATA_DTYPE = torch.int64 + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + device: Optional[str] = None, + port_offset: int = 0): + self.config = config + self.local_rank = local_rank + self.kv_rank = self.config.kv_rank + self.kv_parallel_size = self.config.kv_parallel_size + if device is None: + self.device = self._select_device(self.config.kv_buffer_device) + else: + self.device = self._select_device(device) + + # build distributed connection and send/recv implementation + store_timeout = self.config.get_from_extra_config("store_timeout", 300) + self.group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port + port_offset, + rank=self.kv_rank, + world_size=self.kv_parallel_size, + store_timeout=store_timeout, + ) + # add a barrier to make sure the connection is initiated properly + self.group.barrier() + impl = self._get_device_send_recv_impl(self.group) + self.device_send_func, self.device_recv_func = impl + # set target rank + self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size + self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size + + # transportation-related variables + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.buffer_size = 0 + self.buffer_size_lock = threading.Lock() + self.buffer_size_thresh = self.config.kv_buffer_size + + def _get_device_send_recv_impl( + self, group: StatelessProcessGroup + ) -> tuple[Callable[[torch.Tensor, int], None], Callable[ + [torch.Tensor, int], None]]: + + send: Callable[[torch.Tensor, int], None] + recv: Callable[[torch.Tensor, int], None] + if self.device.type == "cuda": + # use PyNCCL for send / recv + comm = PyNcclCommunicator(group, device=self.local_rank) + comm.disabled = False + send, recv = comm.send, comm.recv # type: ignore + else: + # This send / recv implementation here is NOT intended to transfer + # KV caches (and should NOT be repurposed to transfer KV caches). + # Currently it is only used to transmit control-plane messages + # for PyNcclBuffer. + send = group.send_obj + + def my_recv(x, src): + x[...] = group.recv_obj(src) + + recv = my_recv + + return send, recv + + def _select_device(self, device: str): + logger.info("Selecting device: %s", device) + if device == "cuda": + return torch.device(f"cuda:{self.local_rank}") + else: + return torch.device("cpu") + + def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: + """ + Create the metadata as a dictionary based on the input tensor. + + Args: + tensor: The input tensor or None if no tensor is provided. + + Returns: + metadata: A dictionary with the following keys: + - "dtype": The data type of the tensor or None. + - "shape": The shape of the tensor or None. + """ + if tensor is None: + return {"dtype": None, "shape": None} + else: + return {"dtype": tensor.dtype, "shape": tensor.shape} + + def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: + """ + Create a buffer to receive the tensor based on the provided metadata. + + Args: + metadata: A dictionary with keys "dtype" and "shape", + describing the tensor's data type and shape. + + Returns: + buffer: A tensor of the specified type and shape, + allocated on `self.device`. + """ + return torch.empty(metadata["shape"], + dtype=metadata["dtype"], + device=self.device) + + def _send_metadata(self, metadata: Metadata): + """ + Send the metadata dictionary to the target rank. + + Args: + metadata: A dictionary with keys "dtype" and "shape". + """ + self.group.send_obj(metadata, self.target_rank_for_send) + + def _recv_metadata(self) -> Metadata: + """ + Receive the metadata dictionary from the target rank. + + Returns: + metadata: A dictionary with keys "dtype" and "shape" + describing the tensor. + """ + return self.group.recv_obj(self.target_rank_for_recv) + + def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: + """ + The actual implementation of sending the tensor and its metadata to the + target rank. + + Args: + tensor: The input tensor to be sent, or `None` if no tensor is + being sent. + """ + metadata = self._make_metadata(tensor) + self._send_metadata(metadata) + if tensor is not None: + self.device_send_func(tensor.to(self.device), + self.target_rank_for_send) + + def _recv_impl(self) -> Optional[torch.Tensor]: + """ + The actual implementation of receiving a tensor and its metadata from + the target rank. + + Returns: + buffer: The received tensor, or `None` if no tensor is received. + """ + metadata = self._recv_metadata() + if metadata["dtype"] is None: + return None + buffer = self._prepare_recv_buffer(metadata) + self.device_recv_func(buffer, self.target_rank_for_recv) + + return buffer + + def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], + tensor_size: int) -> None: + """ + Wrapper for _send_impl to handle exceptions and update buffer size. + """ + try: + self._send_impl(tensor) + + with self.buffer_size_lock: + self.buffer_size -= tensor_size + except Exception as e: + logger.error("[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), str(tensor), str(e)) + import traceback + traceback.print_exc() + + def block_if_full(self): + """ + Block the current thread if the buffer size is larger than the + threshold. + """ + while self.buffer_size > self.buffer_size_thresh: + logger.debug("KV cache transfer pipe is full. Waiting...") + time.sleep(0.05) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """ + Sends a tensor and its metadata to the destination rank in a + non-blocking way. + + Args: + tensor: The tensor to send, or `None` if no tensor is being sent. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + if tensor is not None: + tensor_size = tensor.element_size() * tensor.numel() + else: + tensor_size = 0 + + self.block_if_full() + + with self.buffer_size_lock: + self.buffer_size += tensor_size + + self.transport_thread.submit(self.send_tensor_wrapper, tensor, + tensor_size) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """ + Receives a tensor and its metadata from the source rank. Blocking call. + + Args: + tensor: The received tensor, or `None` if no tensor is received. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + future = self.transport_thread.submit(self._recv_impl) + + try: + tensor = future.result() + except Exception as e: + logger.error("Encountering exception in KV receiving thread") + logger.error("%s", e) + logger.error("My device: %s", self.device) + import traceback + traceback.print_exc() + raise e + + return tensor + + def close(self): + """ + Close the pipe and release associated resources. + """ + if hasattr(self, + "transport_thread") and self.transport_thread is not None: + self.transport_thread.shutdown() diff --git a/distributed/kv_transfer/kv_transfer_state.py b/distributed/kv_transfer/kv_transfer_state.py new file mode 100644 index 0000000..60f1d5d --- /dev/null +++ b/distributed/kv_transfer/kv_transfer_state.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Optional + +from vllm import envs +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.distributed.parallel_state import get_world_group + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None + + +def get_kv_transfer_group() -> KVConnectorBaseType: + assert _KV_CONNECTOR_AGENT is not None, ( + "disaggregated KV cache transfer parallel group is not initialized") + return _KV_CONNECTOR_AGENT + + +def has_kv_transfer_group() -> bool: + return _KV_CONNECTOR_AGENT is not None + + +def is_v1_kv_transfer_group( + connector: Optional[KVConnectorBaseType] = None) -> bool: + """Check if the KV connector is the v1 connector. + If the argument is None, it will check the global KV connector + + Args: + connector: The KV connector to check. If None, it will check the + global KV connector. + + Note: + This function will no-longer be needed after the v1 KV connector + becomes the default. + """ + if connector is None: + connector = _KV_CONNECTOR_AGENT + + if connector is None: + return False + + return isinstance(connector, KVConnectorBase_V1) + + +def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: + """ + Initialize KV cache transfer parallel group. + """ + + global _KV_CONNECTOR_AGENT + + if vllm_config.kv_transfer_config is None: + return + + if (vllm_config.kv_transfer_config.is_kv_transfer_instance + and _KV_CONNECTOR_AGENT is None): + if envs.VLLM_USE_V1: + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( + config=vllm_config, role=KVConnectorRole.WORKER) + else: + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( + rank=get_world_group().rank, + local_rank=get_world_group().local_rank, + config=vllm_config, + ) diff --git a/distributed/parallel_state.py b/distributed/parallel_state.py new file mode 100644 index 0000000..c596f88 --- /dev/null +++ b/distributed/parallel_state.py @@ -0,0 +1,1297 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""vLLM distributed state. +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model/pipeline + parallelism, you can skip the model parallel initialization and destruction + steps. +""" +import contextlib +import gc +import pickle +import weakref +from collections import namedtuple +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import Any, Callable, Optional, Union +from unittest.mock import patch + +import torch +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +import vllm.envs as envs +from vllm.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase) +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger +from vllm.utils import (direct_register_custom_op, get_distributed_init_method, + resolve_obj_by_qualname, supports_custom_op) + + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: dict[str, Union[torch.Tensor, Any]] +) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list: list[tuple[str, Any]] = [] + tensor_list: list[torch.Tensor] = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size()))) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + +_group_name_counter: dict[str, int] = {} + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + _groups[group.unique_name] = weakref.ref(group) + + +def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce_out_place(tensor) + + +def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) + + +def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._reduce_scatter_out_place(tensor, dim) + + +def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + new_shape = list(tensor.shape) + new_shape[dim] = tensor.shape[dim] // world_size + return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) + + +def all_gather(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_gather_out_place(tensor, dim) + + +def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + new_shape = list(tensor.shape) + new_shape[dim] = tensor.shape[dim] * world_size + return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) + + +if supports_custom_op(): + from vllm.platforms import current_platform + direct_register_custom_op( + op_name="all_reduce", + op_func=all_reduce, + mutates_args=[], + fake_impl=all_reduce_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="reduce_scatter", + op_func=reduce_scatter, + mutates_args=[], + fake_impl=reduce_scatter_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="all_gather", + op_func=all_gather, + mutates_args=[], + fake_impl=all_gather_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It manages both CPU and device + communication. + """ + + # available attributes: + rank: int # global rank + ranks: list[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + use_device_communicator: bool # whether to use device communicator + device_communicator: DeviceCommunicatorBase # device communicator + mq_broadcaster: Optional[Any] # shared memory broadcaster + + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_device_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + from vllm.platforms import current_platform + + if current_platform.is_cuda_alike(): + self.device = torch.device(f"cuda:{local_rank}") + elif current_platform.is_out_of_tree(): + self.device = torch.device( + f"{current_platform.device_name}:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_device_communicator = use_device_communicator + + self.device_communicator: DeviceCommunicatorBase = None # type: ignore + if use_device_communicator and self.world_size > 1: + device_comm_cls = resolve_obj_by_qualname( + current_platform.get_device_communicator_cls()) + self.device_communicator = device_comm_cls( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + ) + + from vllm.distributed.device_communicators.shm_broadcast import ( + MessageQueue) + self.mq_broadcaster: Optional[MessageQueue] = None + if use_message_queue_broadcaster and self.world_size > 1: + self.mq_broadcaster = MessageQueue.create_from_process_group( + self.cpu_group, 1 << 22, 6) + + from vllm.platforms import current_platform + self.use_custom_op_call = (current_platform.is_cuda_alike() + or current_platform.is_tpu()) + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @contextmanager + def graph_capture( + self, graph_capture_context: Optional[GraphCaptureContext] = None): + if graph_capture_context is None: + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + # only cuda uses this function, + # so we don't abstract it into the base class + maybe_ca_context = nullcontext() + from vllm.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator) + if self.device_communicator is not None: + assert isinstance(self.device_communicator, CudaCommunicator) + ca_comm = self.device_communicator.ca_comm + if ca_comm is not None: + maybe_ca_context = ca_comm.capture() # type: ignore + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.cuda.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.cuda.stream(stream), maybe_ca_context: + yield graph_capture_context + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + User-facing all-reduce function before we actually call the + all-reduce operation. + + We need this because Dynamo does not support passing an arbitrary + object (`self` in this case) to a custom op. We need to pass the + group name as a string, and then look up the group coordinator from + the group name, dispatch the all-reduce operation to the group + coordinator. + + In addition, PyTorch custom ops do not support mutation or returning + a new tensor in the same op. So we always make the all-reduce operation + out-of-place. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + if self.use_custom_op_call: + return torch.ops.vllm.all_reduce(input_, + group_name=self.unique_name) + else: + return self._all_reduce_out_place(input_) + + def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + return self.device_communicator.all_reduce(input_) + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + if self.use_custom_op_call: + return torch.ops.vllm.all_gather(input_, + dim, + world_size, + group_name=self.unique_name) + else: + return self._all_gather_out_place(input_, dim) + + def _all_gather_out_place(self, input_: torch.Tensor, + dim: int) -> torch.Tensor: + return self.device_communicator.all_gather(input_, dim) + + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + if self.use_custom_op_call: + return torch.ops.vllm.reduce_scatter(input_, + dim, + world_size, + group_name=self.unique_name) + else: + return self._reduce_scatter_out_place(input_, dim) + + def _reduce_scatter_out_place(self, input_: torch.Tensor, + dim: int) -> torch.Tensor: + return self.device_communicator.reduce_scatter(input_, dim) + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + return self.device_communicator.gather(input_, dst, dim) + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast(input_, + src=self.ranks[src], + group=self.device_group) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.mq_broadcaster is not None: + assert src == 0, "Message queue broadcaster only supports src=0" + return self.mq_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list([obj], + src=self.ranks[src], + group=self.cpu_group) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=self.ranks[src], + group=self.cpu_group) + return recv[0] + + def broadcast_object_list(self, + obj_list: list[Any], + src: int = 0, + group: Optional[ProcessGroup] = None): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, + src=self.ranks[src], + group=self.device_group) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank_in_group, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank.") + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor([object_tensor.numel()], + dtype=torch.long, + device="cpu") + + # Send object size + + torch.distributed.send(size_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert src != self.rank_in_group, ( + "Invalid source rank. Source rank is the same as the current rank." + ) + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv(size_tensor, + src=self.ranks[src], + group=self.cpu_group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu") + + rank_object = torch.distributed.recv(object_tensor, + src=self.ranks[src], + group=self.cpu_group) + + assert rank_object == rank_size, ( + "Received object sender rank does not match the size sender rank.") + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + + # send-allgather: send only a slice, then do allgather. + if (all_gather_group is not None + and tensor.numel() % all_gather_size == 0): + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=group) + return None + + def recv_tensor_dict( + self, + src: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + + # send-allgather: send only a slice, then do allgather. + use_all_gather = (all_gather_group is not None + and tensor.numel() % all_gather_size == 0) + + if use_all_gather: + orig_shape = tensor.shape + tensor = tensor.reshape(all_gather_size, + -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=group) + if use_all_gather: + # do the allgather + tensor = all_gather_group.all_gather( # type: ignore + tensor, dim=0) + tensor = tensor.reshape(orig_shape) + + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + self.device_communicator.send(tensor, dst) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + return self.device_communicator.recv(size, dtype, src) + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + if self.device_communicator is not None: + self.device_communicator.destroy() + if self.mq_broadcaster is not None: + self.mq_broadcaster = None + + def prepare_communication_buffer_for_model(self, model: torch.nn.Module): + if self.device_communicator is not None: + self.device_communicator.prepare_communication_buffer_for_model( + model) + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if self.device_communicator is not None: + return self.device_communicator.dispatch(hidden_states, + router_logits) + else: + return hidden_states, router_logits + + def combine(self, hidden_states) -> torch.Tensor: + if self.device_communicator is not None: + return self.device_communicator.combine(hidden_states) + else: + return hidden_states + + +_WORLD: Optional[GroupCoordinator] = None + + +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, ("world group is not initialized") + return _WORLD + + +def init_world_group(ranks: list[int], local_rank: int, + backend: str) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=False, + group_name="world", + ) + + +def init_model_parallel_group( + group_ranks: list[list[int]], + local_rank: int, + backend: str, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, +) -> GroupCoordinator: + + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, + ) + + +_TP: Optional[GroupCoordinator] = None + + +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, ("tensor model parallel group is not initialized") + return _TP + + +# kept for backward compatibility +get_tensor_model_parallel_group = get_tp_group + +_PP: Optional[GroupCoordinator] = None + +_DP: Optional[GroupCoordinator] = None + + +def get_dp_group() -> GroupCoordinator: + assert _DP is not None, ("data parallel group is not initialized") + return _DP + + +_EP: Optional[GroupCoordinator] = None + + +def get_ep_group() -> GroupCoordinator: + assert _EP is not None, ("expert parallel group is not initialized") + return _EP + + +def get_pp_group() -> GroupCoordinator: + assert _PP is not None, ( + "pipeline model parallel group is not initialized") + return _PP + + +# kept for backward compatibility +get_pipeline_model_parallel_group = get_pp_group + + +@contextmanager +def graph_capture(device: torch.device): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + context = GraphCaptureContext(torch.cuda.Stream(device=device)) + with get_tp_group().graph_capture(context), get_pp_group().graph_capture( + context): + yield context + + +logger = init_logger(__name__) + +_ENABLE_CUSTOM_ALL_REDUCE = True + + +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", +): + logger.debug( + "world_size=%d rank=%d local_rank=%d " + "distributed_init_method=%s backend=%s", world_size, rank, local_rank, + distributed_init_method, backend) + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None and config.parallel_config.data_parallel_size > 1: + parallel_config = config.parallel_config + # adjust to take into account data parallelism + # offset the rank by the data parallel rank + rank = parallel_config.data_parallel_rank * world_size + rank + # adjust the world size to take into account data parallelism + world_size = parallel_config.world_size_across_dp + ip = parallel_config.data_parallel_master_ip + port = parallel_config.get_next_dp_init_port() + distributed_init_method = get_distributed_init_method(ip, port) + logger.info( + "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", + world_size, rank, distributed_init_method) + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment") + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert _WORLD.world_size == torch.distributed.get_world_size(), ( + "world group already initialized with a different world size") + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) + + data_parallel_size = 1 + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None: + data_parallel_size = config.parallel_config.data_parallel_size + + # the layout order is: ExternalDP x DP x PP x TP + # ExternalDP is the data parallel group that is not part of the model, + # every dp rank can generate independently (in verl integration). + # DP is the data parallel group that is part of the model, + # all the ranks in the same DP group should generate simultaneously, + # i.e. the `generate` call in the same DP group should be called together, + # otherwise it will cause deadlock. + # to get group_ranks for each dimension, transpose that dimension to the + # last dimension, then reshape to 2D, then unbind the last dimension + all_ranks = torch.arange(world_size).reshape( + -1, data_parallel_size, pipeline_model_parallel_size, + tensor_model_parallel_size) # noqa + + # Build the tensor model-parallel groups. + global _TP + assert _TP is None, ("tensor model parallel group is already initialized") + group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + + # message queue broadcaster is only used in tensor model parallel group + _TP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="tp") + + # Build the pipeline model-parallel groups. + global _PP + assert _PP is None, ( + "pipeline model parallel group is already initialized") + group_ranks = all_ranks.transpose(2, 3).reshape( + -1, pipeline_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _PP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="pp") + + global _DP + assert _DP is None, ("data parallel group is already initialized") + group_ranks = all_ranks.transpose(1, + 3).reshape(-1, + data_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _DP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="dp") + + global _EP + assert _EP is None, ("expert parallel group is already initialized") + group_ranks = all_ranks.transpose(1, 2).reshape( + -1, data_parallel_size * tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _EP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="ep") + + logger.info( + "rank %s in world size %s is assigned as " + "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size, + _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, + _EP.rank_in_group) + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, + pipeline_model_parallel_size, backend) + return + + assert ( + get_tensor_model_parallel_world_size() == tensor_model_parallel_size + ), ("tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. " + f"{tensor_model_parallel_size=}") + pp_world_size = get_pp_group().world_size + assert (pp_world_size == pipeline_model_parallel_size), ( + "pipeline parallel group already initialized, but of unexpected size: " + f"{pp_world_size=} vs. " + f"{pipeline_model_parallel_size=}") + + +def prepare_communication_buffer_for_model(model: torch.nn.Module): + """Prepare the communication buffer for the model. + Traditional communication libraries like NCCL are almost + model agnostic. However, emerging new communication libraries like + MoE all2all (DeepEP) usually allocate the communication buffer + based on the model shape for optimal performance. + """ + if _TP is not None: + _TP.prepare_communication_buffer_for_model(model) + if _PP is not None: + _PP.prepare_communication_buffer_for_model(model) + if _DP is not None: + _DP.prepare_communication_buffer_for_model(model) + if _EP is not None: + _EP.prepare_communication_buffer_for_model(model) + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return (_TP is not None and _PP is not None) + + +_TP_STATE_PATCHED = False + + +@contextmanager +def patch_tensor_parallel_group(tp_group: GroupCoordinator): + """Patch the tp group temporarily until this function ends. + + This method is for draft workers of speculative decoding to run draft model + with different tp degree from that of target model workers. + + Args: + tp_group (GroupCoordinator): the tp group coordinator + """ + global _TP_STATE_PATCHED + assert not _TP_STATE_PATCHED, "Should not call when it's already patched" + + _TP_STATE_PATCHED = True + old_tp_group = get_tp_group() + global _TP + _TP = tp_group + try: + yield + finally: + # restore the original state + _TP_STATE_PATCHED = False + _TP = old_tp_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _TP + + if _TP: + _TP.destroy() + _TP = None + + global _PP + if _PP: + _PP.destroy() + _PP = None + + global _DP + if _DP: + _DP.destroy() + _DP = None + + global _EP + if _EP: + _EP.destroy() + _EP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + if shutdown_ray: + import ray # Lazy import Ray + ray.shutdown() + gc.collect() + from vllm.platforms import current_platform + empty_cache = current_platform.empty_cache + if empty_cache is not None: + empty_cache() + """ + try: + if not current_platform.is_cpu(): + torch._C._host_emptyCache() + except AttributeError: + logger.warning( + "torch._C._host_emptyCache() only available in Pytorch >=2.5") + """ + +def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], + source_rank: int = 0) -> list[bool]: + """ + This is a collective operation that returns if each rank is in the same node + as the source rank. It tests if processes are attached to the same + memory system (shared access to shared memory). + """ + if isinstance(pg, ProcessGroup): + assert torch.distributed.get_backend( + pg) != torch.distributed.Backend.NCCL, ( + "in_the_same_node_as should be tested with a non-NCCL group.") + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + else: + rank = pg.rank + world_size = pg.world_size + ranks = list(range(world_size)) + + # local tensor in each process to store the result + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + + magic_message = b"magic_message" + shm = None + + try: + with contextlib.suppress(OSError): + if rank == source_rank: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[:len(magic_message)] = magic_message + if isinstance(pg, ProcessGroup): + torch.distributed.broadcast_object_list( + [shm.name], src=ranks[source_rank], group=pg) + else: + pg.broadcast_obj(shm.name, src=source_rank) + is_in_the_same_node[rank] = 1 + else: + # try to open the shared memory segment + if isinstance(pg, ProcessGroup): + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=ranks[source_rank], group=pg) + name = recv[0] + else: + name = pg.broadcast_obj(None, src=source_rank) + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + shm = shared_memory.SharedMemory(name=name) + if shm.buf[:len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + except Exception as e: + logger.error("Error ignored in is_in_the_same_node: %s", e) + finally: + if shm: + shm.close() + + if isinstance(pg, ProcessGroup): + torch.distributed.barrier(group=pg) + else: + pg.barrier() + + # clean up the shared memory segment + with contextlib.suppress(OSError): + if rank == source_rank and shm: + shm.unlink() + + if isinstance(pg, ProcessGroup): + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + aggregated_data = is_in_the_same_node + else: + aggregated_data = torch.zeros_like(is_in_the_same_node) + for i in range(world_size): + rank_data = pg.broadcast_obj(is_in_the_same_node, src=i) + aggregated_data += rank_data + + return [x == 1 for x in aggregated_data.tolist()] diff --git a/distributed/tpu_distributed_utils.py b/distributed/tpu_distributed_utils.py new file mode 100644 index 0000000..36ab2eb --- /dev/null +++ b/distributed/tpu_distributed_utils.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections import OrderedDict +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_xla.distributed.spmd as xs +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) + +logger = init_logger(__name__) + + +class XlaQKVParallelLinear(nn.Module): + + def __init__(self, + qkv_linear: nn.Module, + mesh: Optional["xs.Mesh"] = None): + super().__init__() + assert isinstance(qkv_linear, QKVParallelLinear) + self.skip_bias_add = qkv_linear.skip_bias_add + self.return_bias = qkv_linear.return_bias + assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD." + + self.q_weight: Parameter + self.k_weight: Parameter + self.v_weight: Parameter + self.q_bias: Optional[Parameter] + self.k_bias: Optional[Parameter] + self.v_bias: Optional[Parameter] + self._load_weights_from_qkv_linear(qkv_linear) + if mesh is not None: + self._shard_weight(mesh) + + def _shard_weight(self, mesh: "xs.Mesh"): + self.q_weight = Parameter(self.q_weight.to('xla'), requires_grad=False) + self.k_weight = Parameter(self.k_weight.to('xla'), requires_grad=False) + self.v_weight = Parameter(self.v_weight.to('xla'), requires_grad=False) + xs.mark_sharding(self.q_weight, mesh, ('x', None)) + xs.mark_sharding(self.k_weight, mesh, ('x', None)) + xs.mark_sharding(self.v_weight, mesh, ('x', None)) + if self.q_bias is not None: + assert self.k_bias is not None and self.v_bias is not None, \ + "QKVParallelLinear should have q, k, and v biases together." + self.q_bias = Parameter(self.q_bias.to('xla'), requires_grad=False) + xs.mark_sharding(self.q_bias, mesh, ('x', )) + self.k_bias = Parameter(self.k_bias.to('xla'), requires_grad=False) + xs.mark_sharding(self.k_bias, mesh, ('x', )) + self.v_bias = Parameter(self.v_bias.to('xla'), requires_grad=False) + xs.mark_sharding(self.v_bias, mesh, ('x', )) + + def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module): + q_proj_size, k_proj_size, _ = qkv_linear.output_sizes + # The weight of qkv linear is a concatenation of q, k, and v weights + # along the output dimension. + qkv_weight = qkv_linear.weight.data.cpu() + q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False) + k_weight = Parameter(qkv_weight[q_proj_size:q_proj_size + k_proj_size], + requires_grad=False) + v_weight = Parameter(qkv_weight[q_proj_size + k_proj_size:], + requires_grad=False) + self.register_parameter("q_weight", q_weight) + self.register_parameter("k_weight", k_weight) + self.register_parameter("v_weight", v_weight) + + if qkv_linear.bias is not None: + q_bias = Parameter(qkv_linear.bias[:q_proj_size], + requires_grad=False) + k_bias = Parameter(qkv_linear.bias[q_proj_size:q_proj_size + + k_proj_size], + requires_grad=False) + v_bias = Parameter(qkv_linear.bias[q_proj_size + k_proj_size:], + requires_grad=False) + self.register_parameter("q_bias", q_bias) + self.register_parameter("k_bias", k_bias) + self.register_parameter("v_bias", v_bias) + else: + self.register_parameter("q_bias", None) + self.register_parameter("k_bias", None) + self.register_parameter("v_bias", None) + + def forward(self, input): + # Same forward functionality as QKVParallelLinear, but doing qkv porj + # separately. + q_bias = self.q_bias if not self.skip_bias_add else None + k_bias = self.k_bias if not self.skip_bias_add else None + v_bias = self.v_bias if not self.skip_bias_add else None + q_proj = F.linear(input, self.q_weight, q_bias) + k_proj = F.linear(input, self.k_weight, k_bias) + v_proj = F.linear(input, self.v_weight, v_bias) + # The q/k/v projections will be split outside of the QKVParallelLinear. + # Because we are replacing XlaQKVParallelLinear with the + # QKVParallelLinear, we need to concatenate q, k, and v projections to + # match the output shape of the QKVParallelLinear implementation even if + # it seems to be redundant. + # The concat and the following split will be noop, and should be + # optimized away by the compiler. + qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1) + output_bias = torch.cat([q_bias, k_bias, v_bias], dim=-1) if \ + self.skip_bias_add else None + if not self.return_bias: + return qkv_proj + return qkv_proj, output_bias + + +def partition_column_parallel_linear(layer: torch.nn.Module, + mesh: xs.Mesh) -> torch.nn.Module: + assert isinstance(layer, ColumnParallelLinear) + xs.mark_sharding(layer.weight, mesh, ('x', None)) + logger.debug("Applied column-parallel sharding to %s", layer) + return layer + + +def partition_row_parallel_linear(layer: torch.nn.Module, + mesh: xs.Mesh) -> torch.nn.Module: + assert isinstance(layer, RowParallelLinear) + xs.mark_sharding(layer.weight, mesh, (None, 'x')) + logger.debug("Applied row-parallel sharding to %s", layer) + return layer + + +def partition_qkv_parallel_linear(layer: torch.nn.Module, + mesh: xs.Mesh) -> torch.nn.Module: + assert isinstance(layer, QKVParallelLinear) + xla_layer = XlaQKVParallelLinear(layer, mesh) + logger.debug("Applied qkv parallel sharding to %s", layer) + return xla_layer + + +MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict([ + ("QKVParallelLinear", partition_qkv_parallel_linear), + ("ColumnParallelLinear", partition_column_parallel_linear), + ("RowParallelLinear", partition_row_parallel_linear), +]) + + +def get_fqn(module): + # Get the fully qualified name of the module + return module.__class__.__qualname__ + + +def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None: + """ + Recursively check a PyTorch model and apply appropriate sharding based on + the MODULE_TYPE_TO_WRAPPING_FUNC mapping. + + Args: + model: torch.nn.Module to process + mesh: An XLA SPMD mesh object used for sharding + """ + + def _process_module(module, name=None, parent=None): + for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items(): + if get_fqn(module) == module_type: + wrapped_module = wrapping_func(module, mesh) + + assert parent is not None and name is not None, ( + "Top Level module is not expected to be wrapped.") + if wrapped_module is not module: + # Wrapped module and module are different py object. + # The original module should be replaced by the + # wrapped_module. + logger.debug("replace %s with %s", module, wrapped_module) + setattr(parent, name, wrapped_module) + + module = wrapped_module + break + + for child_name, child_module in list(module.named_children()): + _process_module(child_module, child_name, module) + + _process_module(model) diff --git a/distributed/utils.py b/distributed/utils.py new file mode 100644 index 0000000..67f7164 --- /dev/null +++ b/distributed/utils.py @@ -0,0 +1,536 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import dataclasses +import os +import pickle +import socket +import sys +import time +import uuid +from collections import deque +from collections.abc import Sequence +from datetime import timedelta +from typing import Any, Optional + +import torch +from torch.distributed import ProcessGroup, TCPStore +from torch.distributed.distributed_c10d import (Backend, PrefixStore, + _get_default_timeout, + _unregister_process_group) +from torch.distributed.rendezvous import rendezvous + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.utils import get_tcp_uri, is_torch_equal_or_newer + +logger = init_logger(__name__) + +# We prefer to use os.sched_yield as it results in tighter polling loops, +# measured to be around 3e-7 seconds. However on earlier versions of Python +# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) +USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) + or (sys.version_info[:2] == (3, 10) + and sys.version_info[2] >= 8)) + + +def sched_yield(): + if USE_SCHED_YIELD: + os.sched_yield() + else: + time.sleep(0) + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> Sequence[torch.Tensor]: + """ Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # NOTE: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +def get_pp_indices(num_hidden_layers: int, pp_rank: int, + pp_size: int) -> tuple[int, int]: + """Try to evenly distribute layers across partitions. + + If the number of layers is not divisible by the number of partitions, + the remaining layers are evenly distributed across all but the last + partition. The last partition is excluded because it often contains an + additional norm layer and we are attempting to balance compute. + + If `pp_size > 2` and the number of remaining layers is + `0 < x <= pp_size - 2` then the remaining layers are evenly distributed + across the middle partitions. The first and last partitions are excluded + because they contain the input and output embeddings respectively and we + are attempting to reduce maximum memory consumption across partitions. + """ + partition_list_str = envs.VLLM_PP_LAYER_PARTITION + if partition_list_str is not None: + try: + partitions = [ + int(layer) for layer in partition_list_str.split(",") + ] + except ValueError as err: + raise ValueError("Invalid partition string: {}".format( + partition_list_str)) from err + if len(partitions) != pp_size: + raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") + if sum(partitions) != num_hidden_layers: + raise ValueError( + f"{sum(partitions)=} does not match {num_hidden_layers=}.") + else: + layers_per_partition = num_hidden_layers // pp_size + partitions = [layers_per_partition for _ in range(pp_size)] + + if remaining_layers := num_hidden_layers % pp_size: + for i in range(2, remaining_layers + 2): + partitions[-i] += 1 + logger.info( + "Hidden layers were unevenly partitioned: [%s]. " + "This can be manually overridden using the " + "VLLM_PP_LAYER_PARTITION environment variable", + ",".join(str(p) for p in partitions)) + + start_layer = sum(partitions[:pp_rank]) + end_layer = start_layer + partitions[pp_rank] + + return (start_layer, end_layer) + + +@dataclasses.dataclass +class StatelessProcessGroup: + """A dataclass to hold a metadata store, and the rank, world_size of the + group. Only use it to communicate metadata between processes. + For data-plane communication, create NCCL-related objects. + """ + rank: int + world_size: int + store: torch._C._distributed_c10d.Store + + # stores a reference to the socket so that the file descriptor stays alive + socket: Optional[socket.socket] + + data_expiration_seconds: int = 3600 # 1 hour + + # dst rank -> counter + send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict) + # src rank -> counter + recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict) + broadcast_send_counter: int = 0 + broadcast_recv_src_counter: dict[int, int] = dataclasses.field( + default_factory=dict) + + # A deque to store the data entries, with key and timestamp. + entries: deque[tuple[str, + float]] = dataclasses.field(default_factory=deque) + + def __post_init__(self): + assert self.rank < self.world_size + self.send_dst_counter = {i: 0 for i in range(self.world_size)} + self.recv_src_counter = {i: 0 for i in range(self.world_size)} + self.broadcast_recv_src_counter = { + i: 0 + for i in range(self.world_size) + } + + def send_obj(self, obj: Any, dst: int): + """Send an object to a destination rank.""" + self.expire_data() + key = f"send_to/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, pickle.dumps(obj)) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.time())) + + def expire_data(self): + """Expire data that is older than `data_expiration_seconds` seconds.""" + while self.entries: + # check the oldest entry + key, timestamp = self.entries[0] + if time.time() - timestamp > self.data_expiration_seconds: + self.store.delete_key(key) + self.entries.popleft() + else: + break + + def recv_obj(self, src: int) -> Any: + """Receive an object from a source rank.""" + obj = pickle.loads( + self.store.get( + f"send_to/{self.rank}/{self.recv_src_counter[src]}")) + self.recv_src_counter[src] += 1 + return obj + + def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: + """Broadcast an object from a source rank to all other ranks. + It does not clean up after all ranks have received the object. + Use it for limited times, e.g., for initialization. + """ + if self.rank == src: + self.expire_data() + key = (f"broadcast_from/{src}/" + f"{self.broadcast_send_counter}") + self.store.set(key, pickle.dumps(obj)) + self.broadcast_send_counter += 1 + self.entries.append((key, time.time())) + return obj + else: + key = (f"broadcast_from/{src}/" + f"{self.broadcast_recv_src_counter[src]}") + recv_obj = pickle.loads(self.store.get(key)) + self.broadcast_recv_src_counter[src] += 1 + return recv_obj + + def all_gather_obj(self, obj: Any) -> list[Any]: + """All gather an object from all ranks.""" + gathered_objs = [] + for i in range(self.world_size): + if i == self.rank: + gathered_objs.append(obj) + self.broadcast_obj(obj, src=self.rank) + else: + recv_obj = self.broadcast_obj(None, src=i) + gathered_objs.append(recv_obj) + return gathered_objs + + def barrier(self, timeout: float = 30.0): + """A robust barrier to synchronize all ranks. + + + Uses a multi-phase approach to ensure all processes reach the barrier + before proceeding: + + 1. Each process signals it has reached the barrier + + 2. Each process signals that it has confirmed the arrival of all other + ranks. + + 3. Rank 0 waits for all other ranks to signal their departure to ensure + that all ranks have departed the barrier first. + + Args: + timeout: Maximum time in seconds to wait for each phase (in seconds) + + + Raises: + RuntimeError: If coordination fails or times out + """ + # Generate a barrier ID that is globally unique + try: + if self.rank == 0: + barrier_id = f"barrier_{uuid.uuid4()}" + self.broadcast_obj(barrier_id, src=0) + else: + barrier_id = self.broadcast_obj(None, src=0) + except Exception as e: + raise RuntimeError("Failed to broadcast barrier_id") from e + + # Phase 1: Signal arrival at barrier + # Wait for all processes to arrive + # We need all ranks to confirm the arrival of all other ranks. + # This is the key synchronization point. + arrival_key = f"arrival_{barrier_id}_{self.rank}" + try: + self.store.set(arrival_key, b"1") + except Exception as e: + raise RuntimeError("Failed to signal barrier arrival") from e + + start_time = time.time() + processes_arrived: set[int] = set() + + while len(processes_arrived) < self.world_size: + # Check for timeout + cur_time = time.time() + if cur_time - start_time > timeout: + raise RuntimeError("Barrier timed out after %f seconds", + timeout) + + # Check for each process + for i in range(self.world_size): + if i in processes_arrived: + continue + + key = f"arrival_{barrier_id}_{i}" + try: + # Try to get the key - if it exists, we'll get a value + # If it doesn't exist, it will throw an exception + self.store.get(key) + processes_arrived.add(i) + except KeyError: + # Key doesn't exist yet + pass + except Exception as check_e: + logger.debug("Error checking key existence: %s", check_e) + sched_yield() + + # Short sleep to avoid tight polling + if len(processes_arrived) < self.world_size: + sched_yield() + + # Phase 2: Signal departure from barrier + # We only care to block at this stage in rank 0, which runs the + # server side of the TCPStore. We want to make sure that all + # clients have departed the barrier before rank 0 in case the + # next thing after the barrier is a shutdown, including tearing + # down the TCPStore. Other ranks can exit the barrier immediately + # after signaling their departure. + departure_key = f"departure_{barrier_id}_{self.rank}" + try: + self.store.set(departure_key, b"1") + except Exception as e: + raise RuntimeError("Failed to signal barrier departure") from e + + if self.rank != 0: + return + + # Make rank 0 wait for all processes to signal departure + start_time = time.time() + processes_departed: set[int] = set() + + while len(processes_departed) < self.world_size: + # Check for timeout + if time.time() - start_time > timeout: + raise RuntimeError("Barrier departure timed out after %f s", + timeout) + + # Check for each process + for i in range(self.world_size): + if i in processes_departed: + continue + + key = f"departure_{barrier_id}_{i}" + try: + # Try to get the key - if it exists, we'll get a value + # If it doesn't exist, it will throw an exception + self.store.get(key) + processes_departed.add(i) + except KeyError: + # Key doesn't exist yet + pass + except Exception as check_e: + logger.debug("Error checking key existence: %s", check_e) + sched_yield() + + # Short sleep to avoid tight polling + if len(processes_departed) < self.world_size: + sched_yield() + + # Clean up keys to avoid leaking memory in the store + for i in range(self.world_size): + try: + self.store.delete_key(f"arrival_{barrier_id}_{i}") + except Exception: + logger.debug("Error deleting key: %s", + f'arrival_{barrier_id}_{i}') + + try: + self.store.delete_key(f"departure_{barrier_id}_{i}") + except Exception: + logger.debug("Error deleting key: %s", + f'departure_{barrier_id}_{i}') + + @staticmethod + def create( + host: str, + port: int, + rank: int, + world_size: int, + data_expiration_seconds: int = 3600, + store_timeout: int = 300, + ) -> "StatelessProcessGroup": + """A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. + + If we have process A and process B called `torch.distributed.init_process_group` + to form a group, and then we want to form another group with process A, B, C, + D, it is not possible in PyTorch, because process A and process B have already + formed a group, and process C and process D cannot join that group. This + function is a workaround for this issue. + + `torch.distributed.init_process_group` is a global call, while this function + is a stateless call. It will return a `StatelessProcessGroup` object that can be + used for exchanging metadata. With this function, process A and process B + can call `StatelessProcessGroup.create` to form a group, and then process A, B, + C, and D can call `StatelessProcessGroup.create` to form another group. + """ # noqa + launch_server = rank == 0 + if launch_server: + # listen on the specified interface (instead of 0.0.0.0) + listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listen_socket.bind((host, port)) + listen_socket.listen() + listen_fd = listen_socket.fileno() + else: + listen_socket = None + listen_fd = None + + store = TCPStore( + host_name=host, + port=port, + world_size=world_size, + is_master=launch_server, + timeout=timedelta(seconds=store_timeout), + use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 + master_listen_fd=listen_fd, + ) + + return StatelessProcessGroup( + rank=rank, + world_size=world_size, + store=store, + socket=listen_socket, + data_expiration_seconds=data_expiration_seconds) + + +def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore, + group_rank: int, group_size: int, + timeout: timedelta) -> ProcessGroup: + """ + Stateless init ProcessGroup with gloo backend compatible with + different torch versions. + """ + if is_torch_equal_or_newer("2.6"): + pg = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + else: + options = ProcessGroup.Options(backend=backend) + pg = ProcessGroup( + prefix_store, + group_rank, + group_size, + options, + ) + from torch.distributed.distributed_c10d import ProcessGroupGloo + backend_class = ProcessGroupGloo(prefix_store, + group_rank, + group_size, + timeout=timeout) + backend_type = ProcessGroup.BackendType.GLOO + device = torch.device("cpu") + if is_torch_equal_or_newer("2.6"): + # _set_default_backend is supported in torch >= 2.6 + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + return pg + + +def stateless_init_torch_distributed_process_group( + host: str, port: int, rank: int, world_size: int, + backend: str) -> ProcessGroup: + """ + A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. The created ProcessGroup object can be used for + some operations such as `allreduce`, because it does not depend on the + global rank. However, some operations such as `broadcast` cannot be used + because it depends on the global rank. + + # TODO: ask for help from PyTorch team if we need the `broadcast` operation. + + This function is useful when we are not sure about the total number of + processes in the process group. For example, we may have process + 1, 2, ..., 8 who want to communicate, and process 9 might be the same + process as process 1, or it might be a different process; process 10 + might be the same process as process 5, or it might be a different process. + In this case, how can we reliably form a communication channel within + process 9 and 10, without affecting the communication channel within + process 1, 2, ..., 8? + + One possible solution is to figure out if process 9 and 10 are the same + as process 1 and 5 beforehand, and then form a communication channel + based on the information, adjusting the ranks and world_size etc. However, + figuring out the information is not always easy, and it will interfere + with the main communication channel. + + Our solution is to always form a communication channel with process 1, 2, + ..., 8, and then use this function to form another communication channel + with process 9 and 10. This way, regardless of whether process 9 and 10 + are the same as process 1 and 5, the main communication channel is + always formed with process 1, 2, ..., 8, and the additional communication + channel is formed with process 9 and 10. + """ + init_method = get_tcp_uri(host, port) + backend = Backend(backend) # it is basically string + timeout = _get_default_timeout(backend) + + store, rank, world_size = next( + rendezvous(init_method, rank, world_size, timeout=timeout)) + store.set_timeout(timeout) + + group_rank = rank + group_size = world_size + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + prefix_store = PrefixStore(init_method, store) + + if backend == "gloo": + return init_gloo_process_group(backend=backend, + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout) + from vllm.platforms import current_platform + return current_platform.stateless_init_device_torch_dist_pg( + backend=backend, + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout) + + +def stateless_destroy_torch_distributed_process_group( + pg: ProcessGroup) -> None: + """ + Destroy ProcessGroup returned by + stateless_init_torch_distributed_process_group(). + """ + if is_torch_equal_or_newer("2.7"): + pg.shutdown() + else: + # Lazy import for non-CUDA backends. + from torch.distributed.distributed_c10d import _shutdown_backend + _shutdown_backend(pg) + + _unregister_process_group(pg.group_name) diff --git a/engine/__init__.py b/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/engine/arg_utils.py b/engine/arg_utils.py new file mode 100644 index 0000000..4ce1b41 --- /dev/null +++ b/engine/arg_utils.py @@ -0,0 +1,1708 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# yapf: disable +import argparse +import dataclasses +import json +import sys +import threading +import warnings +from dataclasses import MISSING, dataclass, fields, is_dataclass +from itertools import permutations +from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional, + Type, TypeVar, Union, cast, get_args, get_origin) + +import regex as re +import torch +from pydantic import TypeAdapter, ValidationError +from typing_extensions import TypeIs, deprecated + +import vllm.envs as envs +from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, + ConfigFormat, ConfigType, DecodingConfig, + DetailedTraceModules, Device, DeviceConfig, + DistributedExecutorBackend, GuidedDecodingBackend, + GuidedDecodingBackendV1, HfOverrides, KVEventsConfig, + KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, ModelDType, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PrefixCachingHashAlgo, PromptAdapterConfig, + SchedulerConfig, SchedulerPolicy, SpeculativeConfig, + TaskOption, TokenizerMode, TokenizerPoolConfig, + VllmConfig, get_attr_docs, get_field) +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.plugins import load_general_plugins +from vllm.reasoning import ReasoningParserManager +from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 +from vllm.transformers_utils.utils import check_gguf_file +from vllm.usage.usage_lib import UsageContext +from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, + GiB_bytes, get_ip, is_in_ray_actor) + +# yapf: enable + +logger = init_logger(__name__) + +# object is used to allow for special typing forms +T = TypeVar("T") +TypeHint = Union[type[Any], object] +TypeHintT = Union[type[T], object] + + +def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: + + def _parse_type(val: str) -> T: + try: + if return_type is json.loads and not re.match("^{.*}$", val): + return cast(T, nullable_kvs(val)) + return return_type(val) + except ValueError as e: + raise argparse.ArgumentTypeError( + f"Value {val} cannot be converted to {return_type}.") from e + + return _parse_type + + +def optional_type( + return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: + + def _optional_type(val: str) -> Optional[T]: + if val == "" or val == "None": + return None + return parse_type(return_type)(val) + + return _optional_type + + +def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: + if not re.match("^{.*}$", val): + return str(val) + return optional_type(json.loads)(val) + + +@deprecated( + "Passing a JSON argument as a string containing comma separated key=value " + "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON " + "string instead.") +def nullable_kvs(val: str) -> dict[str, int]: + """Parses a string containing comma separate key [str] to value [int] + pairs into a dictionary. + + Args: + val: String value to be parsed. + + Returns: + Dictionary with parsed values. + """ + out_dict: dict[str, int] = {} + for item in val.split(","): + kv_parts = [part.lower().strip() for part in item.split("=")] + if len(kv_parts) != 2: + raise argparse.ArgumentTypeError( + "Each item should be in the form KEY=VALUE") + key, value = kv_parts + + try: + parsed_value = int(value) + except ValueError as exc: + msg = f"Failed to parse value of item {key}={value}" + raise argparse.ArgumentTypeError(msg) from exc + + if key in out_dict and out_dict[key] != parsed_value: + raise argparse.ArgumentTypeError( + f"Conflicting values specified for key: {key}") + out_dict[key] = parsed_value + + return out_dict + + +def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]: + """Check if the type hint is a specific type.""" + return type_hint is type or get_origin(type_hint) is type + + +def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool: + """Check if the type hints contain a specific type.""" + return any(is_type(type_hint, type) for type_hint in type_hints) + + +def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT: + """Get the specific type from the type hints.""" + return next((th for th in type_hints if is_type(th, type)), None) + + +def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]: + """Convert Literal type hints to argparse kwargs.""" + type_hint = get_type(type_hints, Literal) + choices = get_args(type_hint) + choice_type = type(choices[0]) + if not all(isinstance(choice, choice_type) for choice in choices): + raise ValueError( + "All choices must be of the same type. " + f"Got {choices} with types {[type(c) for c in choices]}") + return {"type": choice_type, "choices": sorted(choices)} + + +def is_not_builtin(type_hint: TypeHint) -> bool: + """Check if the class is not a built-in type.""" + return type_hint.__module__ != "builtins" + + +def get_type_hints(type_hint: TypeHint) -> set[TypeHint]: + """Extract type hints from Annotated or Union type hints.""" + type_hints: set[TypeHint] = set() + origin = get_origin(type_hint) + args = get_args(type_hint) + + if origin is Annotated: + type_hints.update(get_type_hints(args[0])) + elif origin is Union: + for arg in args: + type_hints.update(get_type_hints(arg)) + else: + type_hints.add(type_hint) + + return type_hints + + +def get_kwargs(cls: ConfigType) -> dict[str, Any]: + cls_docs = get_attr_docs(cls) + kwargs = {} + for field in fields(cls): + # Get the set of possible types for the field + type_hints: set[TypeHint] = get_type_hints(field.type) + + # If the field is a dataclass, we can use the model_validate_json + generator = (th for th in type_hints if is_dataclass(th)) + dataclass_cls = next(generator, None) + + # Get the default value of the field + if field.default is not MISSING: + default = field.default + elif field.default_factory is not MISSING: + default = field.default_factory() + + # Get the help text for the field + name = field.name + help = cls_docs[name].strip() + # Escape % for argparse + help = help.replace("%", "%%") + + # Initialise the kwargs dictionary for the field + kwargs[name] = {"default": default, "help": help} + + # Set other kwargs based on the type hints + json_tip = """\n\nShould either be a valid JSON string or JSON keys + passed individually. For example, the following sets of arguments are + equivalent:\n\n + - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n + - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n""" + if dataclass_cls is not None: + + def parse_dataclass(val: str, cls=dataclass_cls) -> Any: + try: + if hasattr(cls, "from_cli"): + return cls.from_cli(val) + return TypeAdapter(cls).validate_json(val) + except ValidationError as e: + raise argparse.ArgumentTypeError(repr(e)) from e + + kwargs[name]["type"] = parse_dataclass + kwargs[name]["help"] += json_tip + elif contains_type(type_hints, bool): + # Creates --no- and -- flags + kwargs[name]["action"] = argparse.BooleanOptionalAction + elif contains_type(type_hints, Literal): + kwargs[name].update(literal_to_kwargs(type_hints)) + elif contains_type(type_hints, tuple): + type_hint = get_type(type_hints, tuple) + types = get_args(type_hint) + tuple_type = types[0] + assert all(t is tuple_type for t in types if t is not Ellipsis), ( + "All non-Ellipsis tuple elements must be of the same " + f"type. Got {types}.") + kwargs[name]["type"] = tuple_type + kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types) + elif contains_type(type_hints, list): + type_hint = get_type(type_hints, list) + types = get_args(type_hint) + assert len(types) == 1, ( + "List type must have exactly one type. Got " + f"{type_hint} with types {types}") + kwargs[name]["type"] = types[0] + kwargs[name]["nargs"] = "+" + elif contains_type(type_hints, int): + kwargs[name]["type"] = int + # Special case for large integers + if name in {"max_model_len", "max_num_batched_tokens"}: + kwargs[name]["type"] = human_readable_int + elif contains_type(type_hints, float): + kwargs[name]["type"] = float + elif (contains_type(type_hints, dict) + and (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints))): + kwargs[name]["type"] = union_dict_and_str + elif contains_type(type_hints, dict): + kwargs[name]["type"] = parse_type(json.loads) + kwargs[name]["help"] += json_tip + elif (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints)): + kwargs[name]["type"] = str + else: + raise ValueError( + f"Unsupported type {type_hints} for argument {name}.") + + # If the type hint was a sequence of literals, use the helper function + # to update the type and choices + if get_origin(kwargs[name].get("type")) is Literal: + kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]})) + + # If None is in type_hints, make the argument optional. + # But not if it's a bool, argparse will handle this better. + if type(None) in type_hints and not contains_type(type_hints, bool): + kwargs[name]["type"] = optional_type(kwargs[name]["type"]) + if kwargs[name].get("choices"): + kwargs[name]["choices"].append("None") + return kwargs + + +@dataclass +class EngineArgs: + """Arguments for vLLM engine.""" + model: str = ModelConfig.model + served_model_name: Optional[Union[ + str, List[str]]] = ModelConfig.served_model_name + tokenizer: Optional[str] = ModelConfig.tokenizer + hf_config_path: Optional[str] = ModelConfig.hf_config_path + task: TaskOption = ModelConfig.task + skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init + enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds + tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode + trust_remote_code: bool = ModelConfig.trust_remote_code + allowed_local_media_path: str = ModelConfig.allowed_local_media_path + download_dir: Optional[str] = LoadConfig.download_dir + load_format: str = LoadConfig.load_format + config_format: str = ModelConfig.config_format + dtype: ModelDType = ModelConfig.dtype + kv_cache_dtype: CacheDType = CacheConfig.cache_dtype + seed: Optional[int] = ModelConfig.seed + max_model_len: Optional[int] = ModelConfig.max_model_len + cuda_graph_sizes: list[int] = get_field(SchedulerConfig, + "cuda_graph_sizes") + # Note: Specifying a custom executor backend by passing a class + # is intended for expert use only. The API may change without + # notice. + distributed_executor_backend: Optional[Union[ + DistributedExecutorBackend, + Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend + # number of P/D disaggregation (or other disaggregation) workers + pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size + tensor_parallel_size: int = ParallelConfig.tensor_parallel_size + data_parallel_size: int = ParallelConfig.data_parallel_size + data_parallel_size_local: Optional[int] = None + data_parallel_address: Optional[str] = None + data_parallel_rpc_port: Optional[int] = None + data_parallel_backend: str = ParallelConfig.data_parallel_backend + enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + max_parallel_loading_workers: Optional[ + int] = ParallelConfig.max_parallel_loading_workers + block_size: Optional[BlockSize] = CacheConfig.block_size + enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching + prefix_caching_hash_algo: PrefixCachingHashAlgo = \ + CacheConfig.prefix_caching_hash_algo + disable_sliding_window: bool = ModelConfig.disable_sliding_window + disable_cascade_attn: bool = ModelConfig.disable_cascade_attn + use_v2_block_manager: bool = True + swap_space: float = CacheConfig.swap_space + cpu_offload_gb: float = CacheConfig.cpu_offload_gb + gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization + max_num_batched_tokens: Optional[ + int] = SchedulerConfig.max_num_batched_tokens + max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills + max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills + long_prefill_token_threshold: int = \ + SchedulerConfig.long_prefill_token_threshold + max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs + max_logprobs: int = ModelConfig.max_logprobs + disable_log_stats: bool = False + revision: Optional[str] = ModelConfig.revision + code_revision: Optional[str] = ModelConfig.code_revision + rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling") + rope_theta: Optional[float] = ModelConfig.rope_theta + hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token + hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides") + tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision + quantization: Optional[QuantizationMethods] = ModelConfig.quantization + enforce_eager: bool = ModelConfig.enforce_eager + max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture + disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce + # The following three fields are deprecated and will be removed in a future + # release. Setting them will have no effect. Please remove them from your + # configurations. + tokenizer_pool_size: int = TokenizerPoolConfig.pool_size + tokenizer_pool_type: str = TokenizerPoolConfig.pool_type + tokenizer_pool_extra_config: dict = \ + get_field(TokenizerPoolConfig, "extra_config") + limit_mm_per_prompt: dict[str, int] = \ + get_field(MultiModalConfig, "limit_per_prompt") + mm_processor_kwargs: Optional[Dict[str, Any]] = \ + MultiModalConfig.mm_processor_kwargs + disable_mm_preprocessor_cache: bool = \ + MultiModalConfig.disable_mm_preprocessor_cache + # LoRA fields + enable_lora: bool = False + enable_lora_bias: bool = LoRAConfig.bias_enabled + max_loras: int = LoRAConfig.max_loras + max_lora_rank: int = LoRAConfig.max_lora_rank + fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras + max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras + lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype + lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size + long_lora_scaling_factors: Optional[tuple[float, ...]] = \ + LoRAConfig.long_lora_scaling_factors + # PromptAdapter fields + enable_prompt_adapter: bool = False + max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters + max_prompt_adapter_token: int = \ + PromptAdapterConfig.max_prompt_adapter_token + + device: Device = DeviceConfig.device + num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps + multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs + ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight + num_gpu_blocks_override: Optional[ + int] = CacheConfig.num_gpu_blocks_override + num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots + model_loader_extra_config: dict = \ + get_field(LoadConfig, "model_loader_extra_config") + ignore_patterns: Optional[Union[str, + List[str]]] = LoadConfig.ignore_patterns + preemption_mode: Optional[str] = SchedulerConfig.preemption_mode + + scheduler_delay_factor: float = SchedulerConfig.delay_factor + enable_chunked_prefill: Optional[ + bool] = SchedulerConfig.enable_chunked_prefill + disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input + + disable_hybrid_kv_cache_manager: bool = ( + SchedulerConfig.disable_hybrid_kv_cache_manager) + + guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend + guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback + guided_decoding_disable_any_whitespace: bool = \ + DecodingConfig.disable_any_whitespace + guided_decoding_disable_additional_properties: bool = \ + DecodingConfig.disable_additional_properties + logits_processor_pattern: Optional[ + str] = ModelConfig.logits_processor_pattern + + speculative_config: Optional[Dict[str, Any]] = None + + qlora_adapter_name_or_path: Optional[str] = None + show_hidden_metrics_for_version: Optional[str] = \ + ObservabilityConfig.show_hidden_metrics_for_version + otlp_traces_endpoint: Optional[str] = \ + ObservabilityConfig.otlp_traces_endpoint + collect_detailed_traces: Optional[list[DetailedTraceModules]] = \ + ObservabilityConfig.collect_detailed_traces + disable_async_output_proc: bool = not ModelConfig.use_async_output_proc + scheduling_policy: SchedulerPolicy = SchedulerConfig.policy + scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls + + override_neuron_config: dict[str, Any] = \ + get_field(ModelConfig, "override_neuron_config") + override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ + ModelConfig.override_pooler_config + compilation_config: CompilationConfig = \ + get_field(VllmConfig, "compilation_config") + worker_cls: str = ParallelConfig.worker_cls + worker_extension_cls: str = ParallelConfig.worker_extension_cls + + kv_transfer_config: Optional[KVTransferConfig] = None + kv_events_config: Optional[KVEventsConfig] = None + + generation_config: str = ModelConfig.generation_config + enable_sleep_mode: bool = ModelConfig.enable_sleep_mode + override_generation_config: dict[str, Any] = \ + get_field(ModelConfig, "override_generation_config") + model_impl: str = ModelConfig.model_impl + + calculate_kv_scales: bool = CacheConfig.calculate_kv_scales + + additional_config: dict[str, Any] = \ + get_field(VllmConfig, "additional_config") + enable_reasoning: Optional[bool] = None # DEPRECATED + reasoning_parser: str = DecodingConfig.reasoning_backend + + use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load + pt_load_map_location: str = LoadConfig.pt_load_map_location + + enable_multimodal_encoder_data_parallel: bool = \ + ParallelConfig.enable_multimodal_encoder_data_parallel + + def __post_init__(self): + # support `EngineArgs(compilation_config={...})` + # without having to manually construct a + # CompilationConfig object + if isinstance(self.compilation_config, (int, dict)): + self.compilation_config = CompilationConfig.from_cli( + str(self.compilation_config)) + if self.qlora_adapter_name_or_path is not None: + warnings.warn( + "The `qlora_adapter_name_or_path` is deprecated " + "and will be removed in v0.10.0. ", + DeprecationWarning, + stacklevel=2, + ) + # Setup plugins + from vllm.plugins import load_general_plugins + load_general_plugins() + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Shared CLI arguments for vLLM engine.""" + + # Model arguments + model_kwargs = get_kwargs(ModelConfig) + model_group = parser.add_argument_group( + title="ModelConfig", + description=ModelConfig.__doc__, + ) + if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]): + model_group.add_argument("--model", **model_kwargs["model"]) + model_group.add_argument("--task", **model_kwargs["task"]) + model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"]) + model_group.add_argument("--tokenizer-mode", + **model_kwargs["tokenizer_mode"]) + model_group.add_argument("--trust-remote-code", + **model_kwargs["trust_remote_code"]) + model_group.add_argument("--dtype", **model_kwargs["dtype"]) + model_group.add_argument("--seed", **model_kwargs["seed"]) + model_group.add_argument("--hf-config-path", + **model_kwargs["hf_config_path"]) + model_group.add_argument("--allowed-local-media-path", + **model_kwargs["allowed_local_media_path"]) + model_group.add_argument("--revision", **model_kwargs["revision"]) + model_group.add_argument("--code-revision", + **model_kwargs["code_revision"]) + model_group.add_argument("--rope-scaling", + **model_kwargs["rope_scaling"]) + model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"]) + model_group.add_argument("--tokenizer-revision", + **model_kwargs["tokenizer_revision"]) + model_group.add_argument("--max-model-len", + **model_kwargs["max_model_len"]) + model_group.add_argument("--quantization", "-q", + **model_kwargs["quantization"]) + model_group.add_argument("--enforce-eager", + **model_kwargs["enforce_eager"]) + model_group.add_argument("--max-seq-len-to-capture", + **model_kwargs["max_seq_len_to_capture"]) + model_group.add_argument("--max-logprobs", + **model_kwargs["max_logprobs"]) + model_group.add_argument("--disable-sliding-window", + **model_kwargs["disable_sliding_window"]) + model_group.add_argument("--disable-cascade-attn", + **model_kwargs["disable_cascade_attn"]) + model_group.add_argument("--skip-tokenizer-init", + **model_kwargs["skip_tokenizer_init"]) + model_group.add_argument("--enable-prompt-embeds", + **model_kwargs["enable_prompt_embeds"]) + model_group.add_argument("--served-model-name", + **model_kwargs["served_model_name"]) + # This one is a special case because it is the + # opposite of ModelConfig.use_async_output_proc + model_group.add_argument( + "--disable-async-output-proc", + action="store_true", + default=EngineArgs.disable_async_output_proc, + help="Disable async output processing. This may result in " + "lower performance.") + model_group.add_argument("--config-format", + choices=[f.value for f in ConfigFormat], + **model_kwargs["config_format"]) + # This one is a special case because it can bool + # or str. TODO: Handle this in get_kwargs + model_group.add_argument("--hf-token", + type=str, + nargs="?", + const=True, + default=model_kwargs["hf_token"]["default"], + help=model_kwargs["hf_token"]["help"]) + model_group.add_argument("--hf-overrides", + **model_kwargs["hf_overrides"]) + model_group.add_argument("--override-neuron-config", + **model_kwargs["override_neuron_config"]) + model_group.add_argument("--override-pooler-config", + **model_kwargs["override_pooler_config"]) + model_group.add_argument("--logits-processor-pattern", + **model_kwargs["logits_processor_pattern"]) + model_group.add_argument("--generation-config", + **model_kwargs["generation_config"]) + model_group.add_argument("--override-generation-config", + **model_kwargs["override_generation_config"]) + model_group.add_argument("--enable-sleep-mode", + **model_kwargs["enable_sleep_mode"]) + model_group.add_argument("--model-impl", + choices=[f.value for f in ModelImpl], + **model_kwargs["model_impl"]) + + # Model loading arguments + load_kwargs = get_kwargs(LoadConfig) + load_group = parser.add_argument_group( + title="LoadConfig", + description=LoadConfig.__doc__, + ) + load_group.add_argument("--load-format", + choices=[f.value for f in LoadFormat], + **load_kwargs["load_format"]) + load_group.add_argument("--download-dir", + **load_kwargs["download_dir"]) + load_group.add_argument("--model-loader-extra-config", + **load_kwargs["model_loader_extra_config"]) + load_group.add_argument("--ignore-patterns", + **load_kwargs["ignore_patterns"]) + load_group.add_argument("--use-tqdm-on-load", + **load_kwargs["use_tqdm_on_load"]) + load_group.add_argument( + "--qlora-adapter-name-or-path", + type=str, + default=None, + help="The `--qlora-adapter-name-or-path` has no effect, do not set" + " it, and it will be removed in v0.10.0.", + deprecated=True, + ) + load_group.add_argument('--pt-load-map-location', + **load_kwargs["pt_load_map_location"]) + + # Guided decoding arguments + guided_decoding_kwargs = get_kwargs(DecodingConfig) + guided_decoding_group = parser.add_argument_group( + title="DecodingConfig", + description=DecodingConfig.__doc__, + ) + guided_decoding_group.add_argument("--guided-decoding-backend", + **guided_decoding_kwargs["backend"]) + guided_decoding_group.add_argument( + "--guided-decoding-disable-fallback", + **guided_decoding_kwargs["disable_fallback"]) + guided_decoding_group.add_argument( + "--guided-decoding-disable-any-whitespace", + **guided_decoding_kwargs["disable_any_whitespace"]) + guided_decoding_group.add_argument( + "--guided-decoding-disable-additional-properties", + **guided_decoding_kwargs["disable_additional_properties"]) + guided_decoding_group.add_argument( + "--enable-reasoning", + action=argparse.BooleanOptionalAction, + deprecated=True, + help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as " + "of v0.9.0. Use `--reasoning-parser` to specify the reasoning " + "parser backend instead. This flag (`--enable-reasoning`) will be " + "removed in v0.10.0. When `--reasoning-parser` is specified, " + "reasoning mode is automatically enabled.") + guided_decoding_group.add_argument( + "--reasoning-parser", + # This choices is a special case because it's not static + choices=list(ReasoningParserManager.reasoning_parsers), + **guided_decoding_kwargs["reasoning_backend"]) + + # Parallel arguments + parallel_kwargs = get_kwargs(ParallelConfig) + parallel_group = parser.add_argument_group( + title="ParallelConfig", + description=ParallelConfig.__doc__, + ) + parallel_group.add_argument( + "--distributed-executor-backend", + **parallel_kwargs["distributed_executor_backend"]) + parallel_group.add_argument( + "--pipeline-parallel-size", "-pp", + **parallel_kwargs["pipeline_parallel_size"]) + parallel_group.add_argument("--tensor-parallel-size", "-tp", + **parallel_kwargs["tensor_parallel_size"]) + parallel_group.add_argument("--data-parallel-size", "-dp", + **parallel_kwargs["data_parallel_size"]) + parallel_group.add_argument('--data-parallel-size-local', + '-dpl', + type=int, + help='Number of data parallel replicas ' + 'to run on this node.') + parallel_group.add_argument('--data-parallel-address', + '-dpa', + type=str, + help='Address of data parallel cluster ' + 'head-node.') + parallel_group.add_argument('--data-parallel-rpc-port', + '-dpp', + type=int, + help='Port for data parallel RPC ' + 'communication.') + parallel_group.add_argument('--data-parallel-backend', + '-dpb', + type=str, + default='mp', + help='Backend for data parallel, either ' + '"mp" or "ray".') + parallel_group.add_argument( + "--enable-expert-parallel", + **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument( + "--max-parallel-loading-workers", + **parallel_kwargs["max_parallel_loading_workers"]) + parallel_group.add_argument( + "--ray-workers-use-nsight", + **parallel_kwargs["ray_workers_use_nsight"]) + parallel_group.add_argument( + "--disable-custom-all-reduce", + **parallel_kwargs["disable_custom_all_reduce"]) + parallel_group.add_argument("--worker-cls", + **parallel_kwargs["worker_cls"]) + parallel_group.add_argument("--worker-extension-cls", + **parallel_kwargs["worker_extension_cls"]) + parallel_group.add_argument( + "--enable-multimodal-encoder-data-parallel", + **parallel_kwargs["enable_multimodal_encoder_data_parallel"]) + + # KV cache arguments + cache_kwargs = get_kwargs(CacheConfig) + cache_group = parser.add_argument_group( + title="CacheConfig", + description=CacheConfig.__doc__, + ) + cache_group.add_argument("--block-size", **cache_kwargs["block_size"]) + cache_group.add_argument("--gpu-memory-utilization", + **cache_kwargs["gpu_memory_utilization"]) + cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) + cache_group.add_argument("--kv-cache-dtype", + **cache_kwargs["cache_dtype"]) + cache_group.add_argument("--num-gpu-blocks-override", + **cache_kwargs["num_gpu_blocks_override"]) + cache_group.add_argument("--enable-prefix-caching", + **cache_kwargs["enable_prefix_caching"]) + cache_group.add_argument("--prefix-caching-hash-algo", + **cache_kwargs["prefix_caching_hash_algo"]) + cache_group.add_argument("--cpu-offload-gb", + **cache_kwargs["cpu_offload_gb"]) + cache_group.add_argument("--calculate-kv-scales", + **cache_kwargs["calculate_kv_scales"]) + + # Tokenizer arguments + tokenizer_kwargs = get_kwargs(TokenizerPoolConfig) + tokenizer_group = parser.add_argument_group( + title="TokenizerPoolConfig", + description=TokenizerPoolConfig.__doc__, + ) + tokenizer_group.add_argument("--tokenizer-pool-size", + **tokenizer_kwargs["pool_size"]) + tokenizer_group.add_argument("--tokenizer-pool-type", + **tokenizer_kwargs["pool_type"]) + tokenizer_group.add_argument("--tokenizer-pool-extra-config", + **tokenizer_kwargs["extra_config"]) + + # Multimodal related configs + multimodal_kwargs = get_kwargs(MultiModalConfig) + multimodal_group = parser.add_argument_group( + title="MultiModalConfig", + description=MultiModalConfig.__doc__, + ) + multimodal_group.add_argument("--limit-mm-per-prompt", + **multimodal_kwargs["limit_per_prompt"]) + multimodal_group.add_argument( + "--mm-processor-kwargs", + **multimodal_kwargs["mm_processor_kwargs"]) + multimodal_group.add_argument( + "--disable-mm-preprocessor-cache", + **multimodal_kwargs["disable_mm_preprocessor_cache"]) + + # LoRA related configs + lora_kwargs = get_kwargs(LoRAConfig) + lora_group = parser.add_argument_group( + title="LoRAConfig", + description=LoRAConfig.__doc__, + ) + lora_group.add_argument( + "--enable-lora", + action=argparse.BooleanOptionalAction, + help="If True, enable handling of LoRA adapters.") + lora_group.add_argument("--enable-lora-bias", + **lora_kwargs["bias_enabled"]) + lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) + lora_group.add_argument("--max-lora-rank", + **lora_kwargs["max_lora_rank"]) + lora_group.add_argument("--lora-extra-vocab-size", + **lora_kwargs["lora_extra_vocab_size"]) + lora_group.add_argument( + "--lora-dtype", + **lora_kwargs["lora_dtype"], + ) + lora_group.add_argument("--long-lora-scaling-factors", + **lora_kwargs["long_lora_scaling_factors"]) + lora_group.add_argument("--max-cpu-loras", + **lora_kwargs["max_cpu_loras"]) + lora_group.add_argument("--fully-sharded-loras", + **lora_kwargs["fully_sharded_loras"]) + + # PromptAdapter related configs + prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig) + prompt_adapter_group = parser.add_argument_group( + title="PromptAdapterConfig", + description=PromptAdapterConfig.__doc__, + ) + prompt_adapter_group.add_argument( + "--enable-prompt-adapter", + action=argparse.BooleanOptionalAction, + help="If True, enable handling of PromptAdapters.") + prompt_adapter_group.add_argument( + "--max-prompt-adapters", + **prompt_adapter_kwargs["max_prompt_adapters"]) + prompt_adapter_group.add_argument( + "--max-prompt-adapter-token", + **prompt_adapter_kwargs["max_prompt_adapter_token"]) + + # Device arguments + device_kwargs = get_kwargs(DeviceConfig) + device_group = parser.add_argument_group( + title="DeviceConfig", + description=DeviceConfig.__doc__, + ) + device_group.add_argument("--device", + **device_kwargs["device"], + deprecated=True) + + # Speculative arguments + speculative_group = parser.add_argument_group( + title="SpeculativeConfig", + description=SpeculativeConfig.__doc__, + ) + speculative_group.add_argument( + "--speculative-config", + type=json.loads, + default=None, + help="The configurations for speculative decoding. Should be a " + "JSON string.") + + # Observability arguments + observability_kwargs = get_kwargs(ObservabilityConfig) + observability_group = parser.add_argument_group( + title="ObservabilityConfig", + description=ObservabilityConfig.__doc__, + ) + observability_group.add_argument( + "--show-hidden-metrics-for-version", + **observability_kwargs["show_hidden_metrics_for_version"]) + observability_group.add_argument( + "--otlp-traces-endpoint", + **observability_kwargs["otlp_traces_endpoint"]) + # TODO: generalise this special case + choices = observability_kwargs["collect_detailed_traces"]["choices"] + metavar = f"{{{','.join(choices)}}}" + observability_kwargs["collect_detailed_traces"]["metavar"] = metavar + observability_kwargs["collect_detailed_traces"]["choices"] += [ + ",".join(p) + for p in permutations(get_args(DetailedTraceModules), r=2) + ] + observability_group.add_argument( + "--collect-detailed-traces", + **observability_kwargs["collect_detailed_traces"]) + + # Scheduler arguments + scheduler_kwargs = get_kwargs(SchedulerConfig) + scheduler_group = parser.add_argument_group( + title="SchedulerConfig", + description=SchedulerConfig.__doc__, + ) + scheduler_group.add_argument( + "--max-num-batched-tokens", + **scheduler_kwargs["max_num_batched_tokens"]) + scheduler_group.add_argument("--max-num-seqs", + **scheduler_kwargs["max_num_seqs"]) + scheduler_group.add_argument( + "--max-num-partial-prefills", + **scheduler_kwargs["max_num_partial_prefills"]) + scheduler_group.add_argument( + "--max-long-partial-prefills", + **scheduler_kwargs["max_long_partial_prefills"]) + scheduler_group.add_argument('--cuda-graph-sizes', + **scheduler_kwargs["cuda_graph_sizes"]) + scheduler_group.add_argument( + "--long-prefill-token-threshold", + **scheduler_kwargs["long_prefill_token_threshold"]) + scheduler_group.add_argument("--num-lookahead-slots", + **scheduler_kwargs["num_lookahead_slots"]) + scheduler_group.add_argument("--scheduler-delay-factor", + **scheduler_kwargs["delay_factor"]) + scheduler_group.add_argument("--preemption-mode", + **scheduler_kwargs["preemption_mode"]) + scheduler_group.add_argument("--num-scheduler-steps", + **scheduler_kwargs["num_scheduler_steps"]) + scheduler_group.add_argument( + "--multi-step-stream-outputs", + **scheduler_kwargs["multi_step_stream_outputs"]) + scheduler_group.add_argument("--scheduling-policy", + **scheduler_kwargs["policy"]) + scheduler_group.add_argument( + "--enable-chunked-prefill", + **scheduler_kwargs["enable_chunked_prefill"]) + scheduler_group.add_argument( + "--disable-chunked-mm-input", + **scheduler_kwargs["disable_chunked_mm_input"]) + scheduler_group.add_argument("--scheduler-cls", + **scheduler_kwargs["scheduler_cls"]) + scheduler_group.add_argument( + "--disable-hybrid-kv-cache-manager", + **scheduler_kwargs["disable_hybrid_kv_cache_manager"]) + + # vLLM arguments + vllm_kwargs = get_kwargs(VllmConfig) + vllm_group = parser.add_argument_group( + title="VllmConfig", + description=VllmConfig.__doc__, + ) + vllm_group.add_argument("--kv-transfer-config", + **vllm_kwargs["kv_transfer_config"]) + vllm_group.add_argument('--kv-events-config', + **vllm_kwargs["kv_events_config"]) + vllm_group.add_argument("--compilation-config", "-O", + **vllm_kwargs["compilation_config"]) + vllm_group.add_argument("--additional-config", + **vllm_kwargs["additional_config"]) + + # Other arguments + parser.add_argument('--use-v2-block-manager', + action='store_true', + default=True, + deprecated=True, + help='[DEPRECATED] block manager v1 has been ' + 'removed and SelfAttnBlockSpaceManager (i.e. ' + 'block manager v2) is now the default. ' + 'Setting this flag to True or False' + ' has no effect on vLLM behavior.') + parser.add_argument('--disable-log-stats', + action='store_true', + help='Disable logging statistics.') + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + return engine_args + + def create_model_config(self) -> ModelConfig: + # gguf file needs a specific model loader and doesn't use hf_repo + if check_gguf_file(self.model): + self.quantization = self.load_format = "gguf" + + # NOTE: This is to allow model loading from S3 in CI + if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 + and self.model in MODELS_ON_S3 + and self.load_format == LoadFormat.AUTO): # noqa: E501 + self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" + self.load_format = LoadFormat.RUNAI_STREAMER + + return ModelConfig( + model=self.model, + hf_config_path=self.hf_config_path, + task=self.task, + tokenizer=self.tokenizer, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + allowed_local_media_path=self.allowed_local_media_path, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + hf_token=self.hf_token, + hf_overrides=self.hf_overrides, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + enforce_eager=self.enforce_eager, + max_seq_len_to_capture=self.max_seq_len_to_capture, + max_logprobs=self.max_logprobs, + disable_sliding_window=self.disable_sliding_window, + disable_cascade_attn=self.disable_cascade_attn, + skip_tokenizer_init=self.skip_tokenizer_init, + enable_prompt_embeds=self.enable_prompt_embeds, + served_model_name=self.served_model_name, + limit_mm_per_prompt=self.limit_mm_per_prompt, + use_async_output_proc=not self.disable_async_output_proc, + config_format=self.config_format, + mm_processor_kwargs=self.mm_processor_kwargs, + disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache, + override_neuron_config=self.override_neuron_config, + override_pooler_config=self.override_pooler_config, + logits_processor_pattern=self.logits_processor_pattern, + generation_config=self.generation_config, + override_generation_config=self.override_generation_config, + enable_sleep_mode=self.enable_sleep_mode, + model_impl=self.model_impl, + ) + + def create_load_config(self) -> LoadConfig: + + if self.quantization == "bitsandbytes": + self.load_format = "bitsandbytes" + + return LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + use_tqdm_on_load=self.use_tqdm_on_load, + pt_load_map_location=self.pt_load_map_location, + ) + + def create_speculative_config( + self, + target_model_config: ModelConfig, + target_parallel_config: ParallelConfig, + enable_chunked_prefill: bool, + disable_log_stats: bool, + ) -> Optional["SpeculativeConfig"]: + """Initializes and returns a SpeculativeConfig object based on + `speculative_config`. + + This function utilizes `speculative_config` to create a + SpeculativeConfig object. The `speculative_config` can either be + provided as a JSON string input via CLI arguments or directly as a + dictionary from the engine. + """ + if self.speculative_config is None: + return None + + # Note(Shangming): These parameters are not obtained from the cli arg + # '--speculative-config' and must be passed in when creating the engine + # config. + self.speculative_config.update({ + "target_model_config": target_model_config, + "target_parallel_config": target_parallel_config, + "enable_chunked_prefill": enable_chunked_prefill, + "disable_log_stats": disable_log_stats, + }) + speculative_config = SpeculativeConfig.from_dict( + self.speculative_config) + + return speculative_config + + def create_engine_config( + self, + usage_context: Optional[UsageContext] = None, + ) -> VllmConfig: + """ + Create the VllmConfig. + + NOTE: for autoselection of V0 vs V1 engine, we need to + create the ModelConfig first, since ModelConfig's attrs + (e.g. the model arch) are needed to make the decision. + + This function set VLLM_USE_V1=X if VLLM_USE_V1 is + unspecified by the user. + + If VLLM_USE_V1 is specified by the user but the VllmConfig + is incompatible, we raise an error. + """ + from vllm.platforms import current_platform + current_platform.pre_register_and_update() + + device_config = DeviceConfig(device=current_platform.device_type) + model_config = self.create_model_config() + + # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" + # and fall back to V0 for experimental or unsupported features. + # * If VLLM_USE_V1=1, we enable V1 for supported + experimental + # features and raise error for unsupported features. + # * If VLLM_USE_V1=0, we disable V1. + use_v1 = False + try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1") + if try_v1 and self._is_v1_supported_oracle(model_config): + use_v1 = True + + # If user explicitly set VLLM_USE_V1, sanity check we respect it. + if envs.is_set("VLLM_USE_V1"): + assert use_v1 == envs.VLLM_USE_V1 + # Otherwise, set the VLLM_USE_V1 variable globally. + else: + envs.set_vllm_use_v1(use_v1) + + # Set default arguments for V0 or V1 Engine. + if use_v1: + self._set_default_args_v1(usage_context) + else: + self._set_default_args_v0(model_config) + + assert self.enable_chunked_prefill is not None + + if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]: + assert self.enforce_eager, ( + "Cuda graph is not supported with DualChunkFlashAttention. " + "To run the model in eager mode, set 'enforce_eager=True' " + "or use '--enforce-eager' in the CLI.") + assert current_platform.is_cuda(), ( + "DualChunkFlashAttention is only supported on CUDA platform.") + assert not use_v1, ( + "DualChunkFlashAttention is not supported on V1 engine. " + "To run the model in V0 engine, try set 'VLLM_USE_V1=0'") + + cache_config = CacheConfig( + block_size=self.block_size, + gpu_memory_utilization=self.gpu_memory_utilization, + swap_space=self.swap_space, + cache_dtype=self.kv_cache_dtype, + is_attention_free=model_config.is_attention_free, + num_gpu_blocks_override=self.num_gpu_blocks_override, + sliding_window=model_config.get_sliding_window(), + enable_prefix_caching=self.enable_prefix_caching, + prefix_caching_hash_algo=self.prefix_caching_hash_algo, + cpu_offload_gb=self.cpu_offload_gb, + calculate_kv_scales=self.calculate_kv_scales, + ) + + # Get the current placement group if Ray is initialized and + # we are in a Ray actor. If so, then the placement group will be + # passed to spawned processes. + placement_group = None + if is_in_ray_actor(): + import ray + + # This call initializes Ray automatically if it is not initialized, + # but we should not do this here. + placement_group = ray.util.get_current_placement_group() + + # Local DP size defaults to global DP size if not set. + data_parallel_size_local = self.data_parallel_size if ( + self.data_parallel_size_local + is None) else self.data_parallel_size_local + + # DP address, used in multi-node case for torch distributed group + # and ZMQ sockets. + if self.data_parallel_address is None: + if self.data_parallel_backend == "ray": + host_ip = get_ip() + logger.info( + "Using host IP %s as ray-based data parallel address", + host_ip) + data_parallel_address = host_ip + else: + assert self.data_parallel_backend == "mp", ( + "data_parallel_backend can only be ray or mp, got %s", + self.data_parallel_backend) + data_parallel_address = ParallelConfig.data_parallel_master_ip + else: + data_parallel_address = self.data_parallel_address + + # This port is only used when there are remote data parallel engines, + # otherwise the local IPC transport is used. + data_parallel_rpc_port = self.data_parallel_rpc_port if ( + self.data_parallel_rpc_port + is not None) else ParallelConfig.data_parallel_rpc_port + + data_parallel_backend = self.data_parallel_backend + + parallel_config = ParallelConfig( + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + data_parallel_size=self.data_parallel_size, + data_parallel_size_local=data_parallel_size_local, + data_parallel_master_ip=data_parallel_address, + data_parallel_rpc_port=data_parallel_rpc_port, + data_parallel_backend=data_parallel_backend, + enable_expert_parallel=self.enable_expert_parallel, + max_parallel_loading_workers=self.max_parallel_loading_workers, + disable_custom_all_reduce=self.disable_custom_all_reduce, + ray_workers_use_nsight=self.ray_workers_use_nsight, + placement_group=placement_group, + distributed_executor_backend=self.distributed_executor_backend, + worker_cls=self.worker_cls, + worker_extension_cls=self.worker_extension_cls, + enable_multimodal_encoder_data_parallel=self. + enable_multimodal_encoder_data_parallel, + ) + + speculative_config = self.create_speculative_config( + target_model_config=model_config, + target_parallel_config=parallel_config, + enable_chunked_prefill=self.enable_chunked_prefill, + disable_log_stats=self.disable_log_stats, + ) + + # Reminder: Please update docs/features/compatibility_matrix.md + # If the feature combo become valid + if self.num_scheduler_steps > 1: + if speculative_config is not None: + raise ValueError("Speculative decoding is not supported with " + "multi-step (--num-scheduler-steps > 1)") + if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: + raise ValueError("Multi-Step Chunked-Prefill is not supported " + "for pipeline-parallel-size > 1") + from vllm.platforms import current_platform + if current_platform.is_cpu(): + logger.warning("Multi-Step (--num-scheduler-steps > 1) is " + "currently not supported for CPUs and has been " + "disabled.") + self.num_scheduler_steps = 1 + + # make sure num_lookahead_slots is set the higher value depending on + # if we are using speculative decoding or multi-step + num_lookahead_slots = max(self.num_lookahead_slots, + self.num_scheduler_steps - 1) + num_lookahead_slots = num_lookahead_slots \ + if speculative_config is None \ + else speculative_config.num_lookahead_slots + + scheduler_config = SchedulerConfig( + runner_type=model_config.runner_type, + max_num_batched_tokens=self.max_num_batched_tokens, + max_num_seqs=self.max_num_seqs, + max_model_len=model_config.max_model_len, + cuda_graph_sizes=self.cuda_graph_sizes, + num_lookahead_slots=num_lookahead_slots, + delay_factor=self.scheduler_delay_factor, + enable_chunked_prefill=self.enable_chunked_prefill, + disable_chunked_mm_input=self.disable_chunked_mm_input, + is_multimodal_model=model_config.is_multimodal_model, + preemption_mode=self.preemption_mode, + num_scheduler_steps=self.num_scheduler_steps, + multi_step_stream_outputs=self.multi_step_stream_outputs, + send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER + and parallel_config.use_ray), + policy=self.scheduling_policy, + scheduler_cls=self.scheduler_cls, + max_num_partial_prefills=self.max_num_partial_prefills, + max_long_partial_prefills=self.max_long_partial_prefills, + long_prefill_token_threshold=self.long_prefill_token_threshold, + disable_hybrid_kv_cache_manager=self. + disable_hybrid_kv_cache_manager, + ) + + lora_config = LoRAConfig( + bias_enabled=self.enable_lora_bias, + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + fully_sharded_loras=self.fully_sharded_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + long_lora_scaling_factors=self.long_lora_scaling_factors, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras + and self.max_cpu_loras > 0 else None) if self.enable_lora else None + + # bitsandbytes pre-quantized model need a specific model loader + if model_config.quantization == "bitsandbytes": + self.quantization = self.load_format = "bitsandbytes" + + load_config = self.create_load_config() + + prompt_adapter_config = PromptAdapterConfig( + max_prompt_adapters=self.max_prompt_adapters, + max_prompt_adapter_token=self.max_prompt_adapter_token) \ + if self.enable_prompt_adapter else None + + decoding_config = DecodingConfig( + backend=self.guided_decoding_backend, + disable_fallback=self.guided_decoding_disable_fallback, + disable_any_whitespace=self.guided_decoding_disable_any_whitespace, + disable_additional_properties=\ + self.guided_decoding_disable_additional_properties, + reasoning_backend=self.reasoning_parser + ) + + observability_config = ObservabilityConfig( + show_hidden_metrics_for_version=self. + show_hidden_metrics_for_version, + otlp_traces_endpoint=self.otlp_traces_endpoint, + collect_detailed_traces=self.collect_detailed_traces, + ) + + config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + speculative_config=speculative_config, + load_config=load_config, + decoding_config=decoding_config, + observability_config=observability_config, + prompt_adapter_config=prompt_adapter_config, + compilation_config=self.compilation_config, + kv_transfer_config=self.kv_transfer_config, + kv_events_config=self.kv_events_config, + additional_config=self.additional_config, + ) + + return config + + def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: + """Oracle for whether to use V0 or V1 Engine by default.""" + + ############################################################# + # Unsupported Feature Flags on V1. + + if self.load_format == LoadFormat.SHARDED_STATE.value: + _raise_or_fallback( + feature_name=f"--load_format {self.load_format}", + recommend_to_remove=False) + return False + + if (self.logits_processor_pattern + != EngineArgs.logits_processor_pattern): + _raise_or_fallback(feature_name="--logits-processor-pattern", + recommend_to_remove=False) + return False + + if self.preemption_mode != SchedulerConfig.preemption_mode: + _raise_or_fallback(feature_name="--preemption-mode", + recommend_to_remove=True) + return False + + if (self.disable_async_output_proc + != EngineArgs.disable_async_output_proc): + _raise_or_fallback(feature_name="--disable-async-output-proc", + recommend_to_remove=True) + return False + + if self.scheduling_policy != SchedulerConfig.policy: + _raise_or_fallback(feature_name="--scheduling-policy", + recommend_to_remove=False) + return False + + if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps: + _raise_or_fallback(feature_name="--num-scheduler-steps", + recommend_to_remove=True) + return False + + if self.scheduler_delay_factor != SchedulerConfig.delay_factor: + _raise_or_fallback(feature_name="--scheduler-delay-factor", + recommend_to_remove=True) + return False + + if self.guided_decoding_backend not in get_args( + GuidedDecodingBackendV1): + _raise_or_fallback( + feature_name= + f"--guided-decoding-backend={self.guided_decoding_backend}", + recommend_to_remove=False) + return False + + # Need at least Ampere for now (FA support required). + # Skip this check if we are running on a non-GPU platform, + # or if the device capability is not available + # (e.g. in a Ray actor without GPUs). + from vllm.platforms import CpuArchEnum, current_platform + if (current_platform.is_cuda() + and current_platform.get_device_capability() + and current_platform.get_device_capability().major < 8): + _raise_or_fallback(feature_name="Compute Capability < 8.0", + recommend_to_remove=False) + return False + + # No Fp8 KV cache so far. + if self.kv_cache_dtype != "auto": + fp8_attention = self.kv_cache_dtype.startswith("fp8") + will_use_fa = ( + current_platform.is_cuda() + and not envs.is_set("VLLM_ATTENTION_BACKEND") + ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + supported = False + if current_platform.is_rocm(): + supported = True + elif fp8_attention and will_use_fa: + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_fp8) + supported = flash_attn_supports_fp8() + if not supported: + _raise_or_fallback(feature_name="--kv-cache-dtype", + recommend_to_remove=False) + return False + + # No Prompt Adapter so far. + if self.enable_prompt_adapter: + _raise_or_fallback(feature_name="--enable-prompt-adapter", + recommend_to_remove=False) + return False + + # No text embedding inputs so far. + if self.enable_prompt_embeds: + _raise_or_fallback(feature_name="--enable-prompt-embeds", + recommend_to_remove=False) + return False + + # Only Fp16 and Bf16 dtypes since we only support FA. + V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16] + if model_config.dtype not in V1_SUPPORTED_DTYPES: + _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}", + recommend_to_remove=False) + return False + + # No Embedding Models so far. + if model_config.task not in ["generate"]: + _raise_or_fallback(feature_name=f"--task {model_config.task}", + recommend_to_remove=False) + return False + + # No Mamba or Encoder-Decoder so far. + if not model_config.is_v1_compatible: + _raise_or_fallback(feature_name=model_config.architectures, + recommend_to_remove=False) + return False + + # No Concurrent Partial Prefills so far. + if (self.max_num_partial_prefills + != SchedulerConfig.max_num_partial_prefills + or self.max_long_partial_prefills + != SchedulerConfig.max_long_partial_prefills): + _raise_or_fallback(feature_name="Concurrent Partial Prefill", + recommend_to_remove=False) + return False + + # No OTLP observability so far. + if (self.otlp_traces_endpoint or self.collect_detailed_traces): + _raise_or_fallback(feature_name="--otlp-traces-endpoint", + recommend_to_remove=False) + return False + + # V1 supports N-gram, Medusa, and Eagle speculative decoding. + is_ngram_enabled = False + is_eagle_enabled = False + is_medusa_enabled = False + if self.speculative_config is not None: + # This is supported but experimental (handled below). + speculative_method = self.speculative_config.get("method") + if speculative_method: + if speculative_method in ("ngram", "[ngram]"): + is_ngram_enabled = True + elif speculative_method == "medusa": + is_medusa_enabled = True + elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"): + is_eagle_enabled = True + else: + speculative_model = self.speculative_config.get("model") + if speculative_model in ("ngram", "[ngram]"): + is_ngram_enabled = True + if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled): + # Other speculative decoding methods are not supported yet. + _raise_or_fallback(feature_name="Speculative Decoding", + recommend_to_remove=False) + return False + + # No XFormers so far. + V1_BACKENDS = [ + "FLASH_ATTN_VLLM_V1", + "FLASH_ATTN", + "PALLAS", + "PALLAS_VLLM_V1", + "TRITON_ATTN_VLLM_V1", + "TRITON_MLA", + "CUTLASS_MLA_VLLM_V1", + "FLASHMLA", + "FLASHINFER", + "FLASHINFER_VLLM_V1", + "ROCM_AITER_MLA", + "TORCH_SDPA_VLLM_V1", + "FLEX_ATTENTION", + ] + if (envs.is_set("VLLM_ATTENTION_BACKEND") + and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): + name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}" + _raise_or_fallback(feature_name=name, recommend_to_remove=True) + return False + + # Platforms must decide if they can support v1 for this model + if not current_platform.supports_v1(model_config=model_config): + _raise_or_fallback( + feature_name=f"device type={current_platform.device_type}", + recommend_to_remove=False) + return False + ############################################################# + # Experimental Features - allow users to opt in. + + # Signal Handlers requires running in main thread. + if (threading.current_thread() != threading.main_thread() + and _warn_or_fallback("Engine in background thread")): + return False + + if (self.pipeline_parallel_size > 1 + and self.distributed_executor_backend + not in (ParallelConfig.distributed_executor_backend, "ray", + "mp", "external_launcher")): + name = "Pipeline Parallelism without Ray distributed executor " \ + "or multiprocessing executor or external launcher" + _raise_or_fallback(feature_name=name, recommend_to_remove=False) + return False + + # Non-[CUDA, TPU] may be supported on V1, but off by default for now. + v0_hardware = not any( + (current_platform.is_cuda(), current_platform.is_tpu(), + (current_platform.is_cpu() + and current_platform.get_cpu_architecture() == CpuArchEnum.X86))) + if v0_hardware and _warn_or_fallback( # noqa: SIM103 + current_platform.device_name): + return False + ############################################################# + + return True + + def _set_default_args_v0(self, model_config: ModelConfig) -> None: + """Set Default Arguments for V0 Engine.""" + + max_model_len = model_config.max_model_len + use_long_context = max_model_len > 32768 + if self.enable_chunked_prefill is None: + # Chunked prefill not supported for Multimodal or MLA in V0. + if model_config.is_multimodal_model or model_config.use_mla: + self.enable_chunked_prefill = False + + # Enable chunked prefill by default for long context (> 32K) + # models to avoid OOM errors in initial memory profiling phase. + elif use_long_context: + from vllm.platforms import current_platform + is_gpu = current_platform.is_cuda() + use_sliding_window = (model_config.get_sliding_window() + is not None) + use_spec_decode = self.speculative_config is not None + + if (is_gpu and not use_sliding_window and not use_spec_decode + and not self.enable_lora + and not self.enable_prompt_adapter + and model_config.runner_type != "pooling"): + self.enable_chunked_prefill = True + logger.warning( + "Chunked prefill is enabled by default for models " + "with max_model_len > 32K. Chunked prefill might " + "not work with some features or models. If you " + "encounter any issues, please disable by launching " + "with --enable-chunked-prefill=False.") + + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = False + + if not self.enable_chunked_prefill and use_long_context: + logger.warning( + "The model has a long context length (%s). This may cause" + "OOM during the initial memory profiling phase, or result " + "in low performance due to small KV cache size. Consider " + "setting --max-model-len to a smaller value.", max_model_len) + elif (self.enable_chunked_prefill + and model_config.runner_type == "pooling"): + msg = "Chunked prefill is not supported for pooling models" + raise ValueError(msg) + + # if using prefix caching, we must set a hash algo + if self.enable_prefix_caching: + # Disable prefix caching for multimodal models for VLLM_V0. + if model_config.is_multimodal_model: + logger.warning( + "--enable-prefix-caching is not supported for multimodal " + "models in V0 and has been disabled.") + self.enable_prefix_caching = False + + # VLLM_V0 only supports builtin hash algo for prefix caching. + if self.prefix_caching_hash_algo == "sha256": + raise ValueError( + "sha256 is not supported for prefix caching in V0 engine. " + "Please use 'builtin'.") + + # Set max_num_seqs to 256 for VLLM_V0. + if self.max_num_seqs is None: + self.max_num_seqs = 256 + + def _set_default_args_v1(self, usage_context: UsageContext) -> None: + """Set Default Arguments for V1 Engine.""" + + # V1 always uses chunked prefills. + self.enable_chunked_prefill = True + + # V1 enables prefix caching by default. + if self.enable_prefix_caching is None: + self.enable_prefix_caching = True + + # V1 should use the new scheduler by default. + # Swap it only if this arg is set to the original V0 default + if self.scheduler_cls == EngineArgs.scheduler_cls: + self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler" + + # When no user override, set the default values based on the usage + # context. + # Use different default values for different hardware. + + # Try to query the device name on the current platform. If it fails, + # it may be because the platform that imports vLLM is not the same + # as the platform that vLLM is running on (e.g. the case of scaling + # vLLM with Ray) and has no GPUs. In this case we use the default + # values for non-H100/H200 GPUs. + from vllm.platforms import current_platform + try: + device_memory = current_platform.get_device_total_memory() + device_name = current_platform.get_device_name().lower() + except Exception: + # This is only used to set default_max_num_batched_tokens + device_memory = 0 + + # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces + # throughput, see PR #17885 for more details. + # So here we do an extra device name check to prevent such regression. + if device_memory >= 70 * GiB_bytes and "a100" not in device_name: + # For GPUs like H100 and MI300x, use larger default values. + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 16384, + UsageContext.OPENAI_API_SERVER: 8192, + } + default_max_num_seqs = 1024 + else: + # TODO(woosuk): Tune the default values for other hardware. + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 8192, + UsageContext.OPENAI_API_SERVER: 2048, + } + default_max_num_seqs = 256 + + # tpu specific default values. + if current_platform.is_tpu(): + default_max_num_batched_tokens_tpu = { + UsageContext.LLM_CLASS: { + 'V6E': 2048, + 'V5E': 1024, + 'V5P': 512, + }, + UsageContext.OPENAI_API_SERVER: { + 'V6E': 1024, + 'V5E': 512, + 'V5P': 256, + } + } + + use_context_value = usage_context.value if usage_context else None + if (self.max_num_batched_tokens is None + and usage_context in default_max_num_batched_tokens): + if current_platform.is_tpu(): + chip_name = current_platform.get_device_name() + if chip_name in default_max_num_batched_tokens_tpu[ + usage_context]: + self.max_num_batched_tokens = \ + default_max_num_batched_tokens_tpu[ + usage_context][chip_name] + else: + self.max_num_batched_tokens = \ + default_max_num_batched_tokens[usage_context] + else: + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context] + logger.debug( + "Setting max_num_batched_tokens to %d for %s usage context.", + self.max_num_batched_tokens, use_context_value) + + if self.max_num_seqs is None: + self.max_num_seqs = default_max_num_seqs + + logger.debug("Setting max_num_seqs to %d for %s usage context.", + self.max_num_seqs, use_context_value) + + +@dataclass +class AsyncEngineArgs(EngineArgs): + """Arguments for asynchronous vLLM engine.""" + disable_log_requests: bool = False + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser, + async_args_only: bool = False) -> FlexibleArgumentParser: + # Initialize plugin to update the parser, for example, The plugin may + # adding a new kind of quantization method to --quantization argument or + # a new device to --device argument. + load_general_plugins() + if not async_args_only: + parser = EngineArgs.add_cli_args(parser) + parser.add_argument('--disable-log-requests', + action='store_true', + help='Disable logging requests.') + from vllm.platforms import current_platform + current_platform.pre_register_and_update(parser) + return parser + + +def _raise_or_fallback(feature_name: str, recommend_to_remove: bool): + if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: + raise NotImplementedError( + f"VLLM_USE_V1=1 is not supported with {feature_name}.") + msg = f"{feature_name} is not supported by the V1 Engine. " + msg += "Falling back to V0. " + if recommend_to_remove: + msg += f"We recommend to remove {feature_name} from your config " + msg += "in favor of the V1 Engine." + logger.warning(msg) + + +def _warn_or_fallback(feature_name: str) -> bool: + if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: + logger.warning( + "Detected VLLM_USE_V1=1 with %s. Usage should " + "be considered experimental. Please report any " + "issues on Github.", feature_name) + should_exit = False + else: + logger.info( + "%s is experimental on VLLM_USE_V1=1. " + "Falling back to V0 Engine.", feature_name) + should_exit = True + return should_exit + + +def human_readable_int(value): + """Parse human-readable integers like '1k', '2M', etc. + Including decimal values with decimal multipliers. + + Examples: + - '1k' -> 1,000 + - '1K' -> 1,024 + - '25.6k' -> 25,600 + """ + value = value.strip() + match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value) + if match: + decimal_multiplier = { + 'k': 10**3, + 'm': 10**6, + 'g': 10**9, + } + binary_multiplier = { + 'K': 2**10, + 'M': 2**20, + 'G': 2**30, + } + + number, suffix = match.groups() + if suffix in decimal_multiplier: + mult = decimal_multiplier[suffix] + return int(float(number) * mult) + elif suffix in binary_multiplier: + mult = binary_multiplier[suffix] + # Do not allow decimals with binary multipliers + try: + return int(number) * mult + except ValueError as e: + raise argparse.ArgumentTypeError("Decimals are not allowed " \ + f"with binary suffixes like {suffix}. Did you mean to use " \ + f"{number}{suffix.lower()} instead?") from e + + # Regular plain number. + return int(value) + + +# These functions are used by sphinx to build the documentation +def _engine_args_parser(): + return EngineArgs.add_cli_args(FlexibleArgumentParser()) + + +def _async_engine_args_parser(): + return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(), + async_args_only=True) diff --git a/engine/async_llm_engine.py b/engine/async_llm_engine.py new file mode 100644 index 0000000..3d7d280 --- /dev/null +++ b/engine/async_llm_engine.py @@ -0,0 +1,1200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import copy +import time +import weakref +from functools import partial +from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, + Mapping, Optional, Set, Tuple, Type, Union) +from weakref import ReferenceType + +import vllm.envs as envs +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.core.scheduler import SchedulerOutputs +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_timeout import asyncio_timeout +from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState +from vllm.engine.metrics_types import StatLoggerBase +from vllm.engine.protocol import EngineClient +from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import ExecuteModelRequest +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Device, weak_bind + +logger = init_logger(__name__) +ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S + + +class AsyncEngineDeadError(RuntimeError): + pass + + +def _log_task_completion(task: asyncio.Task, + error_callback: Callable[[Exception], None]) -> None: + """This function is only intended for the `engine.run_engine_loop()` task. + + In particular, that task runs a `while True` loop that can only exit if + there is an exception. + """ + + exception = None + try: + return_value = task.result() + raise AssertionError( + f"The engine background task should never finish without an " + f"exception. {return_value}") + except asyncio.exceptions.CancelledError: + # We assume that if the task is cancelled, we are gracefully shutting + # down. This should only happen on program exit. + logger.info("Engine is gracefully shutting down.") + except Exception as e: + exception = e + logger.error("Engine background task failed", exc_info=e) + error_callback(exception) + raise AsyncEngineDeadError( + "Task finished unexpectedly. This should never happen! " + "Please open an issue on GitHub. See stack trace above for the " + "actual cause.") from e + + +STOP_ITERATION = Exception() # Sentinel + + +class AsyncStream: + """A stream of RequestOutputs or PoolingRequestOutputs for a request + that can be iterated over asynchronously via an async generator.""" + + def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: + self.request_id = request_id + self._cancel = cancel + self._queue: asyncio.Queue = asyncio.Queue() + self._finished = False + + def put(self, item: Union[RequestOutput, PoolingRequestOutput, + Exception]) -> None: + if not self._finished: + self._queue.put_nowait(item) + + def finish( + self, + exception: Optional[Union[BaseException, Type[BaseException]]] = None, + ) -> None: + if not self._finished: + self._finished = True + self._queue.put_nowait( + exception if self._is_raisable(exception) else STOP_ITERATION) + + @property + def finished(self) -> bool: + return self._finished + + async def generator( + self + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + try: + while True: + result = await self._queue.get() + if self._is_raisable(result): + if result == STOP_ITERATION: + return + raise result + yield result + except GeneratorExit: + self._cancel(self.request_id) + raise asyncio.CancelledError from None + + @staticmethod + def _is_raisable(value: Any): + return isinstance(value, BaseException) or \ + (isinstance(value, type) and \ + issubclass(value, BaseException)) + + +class RequestTracker: + """Synchronous abstraction for tracking requests.""" + + def __init__(self) -> None: + self._request_streams: Dict[str, AsyncStream] = {} + self._aborted_requests: asyncio.Queue[str] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[AsyncStream, + dict]] = asyncio.Queue() + self.new_requests_event = asyncio.Event() + + def __contains__(self, item): + return item in self._request_streams + + def __len__(self) -> int: + return len(self._request_streams) + + def propagate_exception(self, + exc: Exception, + request_id: Optional[str] = None) -> None: + """Propagate an exception to request streams + (all if request_id is None).""" + if request_id is not None: + self.abort_request(request_id, exception=exc) + else: + # NB: tuple() used here because self.abort_request pops the stream + # out of self._request_streams, so we can't iterate on it directly + for rid in tuple(self._request_streams.keys()): + self.abort_request(rid, exception=exc) + + def process_request_output(self, + request_output: Union[RequestOutput, + PoolingRequestOutput], + *, + verbose: bool = False) -> None: + """Process a request output from the engine.""" + request_id = request_output.request_id + finished = request_output.finished + + if finished: + stream = self._request_streams.pop(request_id, None) + else: + stream = self._request_streams.get(request_id) + # Guard against a KeyError which can occur if the request was aborted + # while the output was generated + if stream is not None: + stream.put(request_output) + if finished: + stream.finish() + + if verbose and finished: + logger.info("Finished request %s.", request_id) + + def process_exception(self, + request_id: str, + exception: BaseException, + *, + verbose: bool = False) -> None: + """Propagate an exception from the engine.""" + if verbose: + logger.info("Finished request %s.", request_id) + self.abort_request(request_id, exception=exception) + + def add_request(self, + request_id: str, + *, + verbose: bool = False, + **engine_add_request_kwargs) -> AsyncStream: + """Add a request to be sent to the engine on the next background + loop iteration.""" + if request_id in self._request_streams: + raise KeyError(f"Request {request_id} already exists.") + + abort_request = partial(self.abort_request, verbose=verbose) + stream = AsyncStream(request_id, abort_request) + self._new_requests.put_nowait((stream, { + "request_id": request_id, + **engine_add_request_kwargs + })) + + self.new_requests_event.set() + + if verbose: + logger.info("Added request %s.", request_id) + + return stream + + def abort_request(self, + request_id: str, + *, + exception: Optional[Union[BaseException, + Type[BaseException]]] = None, + verbose: bool = False) -> None: + """Abort a request during next background loop iteration.""" + if verbose: + logger.info("Aborted request %s.", request_id) + + self._aborted_requests.put_nowait(request_id) + + stream = self._request_streams.pop(request_id, None) + if stream is not None: + stream.finish(exception=exception) + + def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: + """Get the new requests and finished requests to be + sent to the engine.""" + new_requests: List[Dict] = [] + finished_requests: Set[str] = set() + + while not self._aborted_requests.empty(): + request_id = self._aborted_requests.get_nowait() + finished_requests.add(request_id) + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + request_id = stream.request_id + if request_id in finished_requests: + # The request has already been aborted. + stream.finish(asyncio.CancelledError) + finished_requests.discard(request_id) + else: + self._request_streams[request_id] = stream + new_requests.append(new_request) + + return new_requests, finished_requests + + async def wait_for_new_requests(self): + if not self.has_new_requests(): + await self.new_requests_event.wait() + self.new_requests_event.clear() + + def has_new_requests(self): + return not self._new_requests.empty() + + +class _AsyncLLMEngine(LLMEngine): + """Extension of LLMEngine to add async methods.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def step_async( + self, virtual_engine: int + ) -> List[Union[RequestOutput, PoolingRequestOutput]]: + """Performs one decoding iteration and returns newly generated results. + The workers are ran asynchronously if possible. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + # these are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc + + ctx = self.scheduler_contexts[virtual_engine] + + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + + # skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + if not self._has_remaining_steps(seq_group_metadata_list): + + # Schedule iteration + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc + ) = self.scheduler[virtual_engine].schedule() + + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs + + if not scheduler_outputs.is_empty(): + # this will cause mamba_cache/minimax_cache failed + # to release finished_requests_ids of the last steps + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() + + # Maybe switch from async mode to sync mode + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + else: + finished_requests_ids = list() + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None + + if not scheduler_outputs.is_empty(): + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + virtual_engine=virtual_engine, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + + if allow_async_output_proc: + execute_model_req.async_callback = self.async_callbacks[ + virtual_engine] + + # Execute the model. + outputs = await self.model_executor.execute_model_async( + execute_model_req) + + # we need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, outputs) + else: + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + outputs = [] + + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # Clear the cache if we have finished all the steps + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[ + virtual_engine] = SchedulerOutputState() + + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True, + is_first_step_output=is_first_step_output) + + if outputs and allow_async_output_proc: + assert len( + outputs + ) == 1, "Async postprocessor expects only a single output set" + self._advance_to_next_step( + outputs[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + if not allow_async_output_proc: + self._process_model_outputs(ctx=ctx) + + # Log stats. + self.do_log_stats(scheduler_outputs, outputs) + + # Tracing + self.do_tracing(scheduler_outputs) + + else: + # Multi-step case + return ctx.request_outputs + + if not self.has_unfinished_requests(): + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + assert len(ctx.output_queue) == 0 + + return ctx.request_outputs + + async def stop_remote_worker_execution_loop_async(self) -> None: + """Stop the remote worker execution loop.""" + await self.model_executor.stop_remote_worker_execution_loop_async() + + async def get_tokenizer_async(self, + lora_request: Optional[LoRARequest] = None + ) -> AnyTokenizer: + return await ( + self.get_tokenizer_group().get_lora_tokenizer_async(lora_request)) + + async def add_request_async( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + data_parallel_rank: Optional[int] = None, + ) -> None: + """ + Async version of + [`add_request`][vllm.engine.llm_engine.LLMEngine.add_request]. + """ + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if priority != 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + if arrival_time is None: + arrival_time = time.time() + + if data_parallel_rank is not None: + raise ValueError("Targeting data_parallel_rank only supported " + "in v1 client.") + + if (isinstance(prompt, dict) + and prompt.get("prompt_embeds", None) is not None + and not prompt.get("prompt_token_ids", None)): + # We use the -2 dimension (instead of 0) in case a batched input + # of batch size 1 is passed in. + prompt["prompt_token_ids"] = [0 + ] * prompt["prompt_embeds"].shape[-2] + + processed_inputs = await self.input_preprocessor.preprocess_async( + prompt, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: + # Guided decoding has an async implementation for building logits + # processors in a separate threadpool. + # We want to invoke that here instead of using the blocking + # implementation in the LLMEngine + params = await build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=await self.get_tokenizer_async(lora_request), + default_guided_backend=self.decoding_config.backend, + reasoning_backend=self.decoding_config.reasoning_backend, + model_config=self.model_config) + + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=priority, + ) + + async def check_health_async(self) -> None: + self.model_executor.check_health() + + async def collective_rpc_async(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + raise NotImplementedError + + +async def build_guided_decoding_logits_processor_async( + sampling_params: SamplingParams, tokenizer: AnyTokenizer, + default_guided_backend: str, reasoning_backend: Optional[str], + model_config: ModelConfig) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Modifies sampling params in-place and returns + the modified sampling params.""" + if sampling_params.guided_decoding is None: + return sampling_params + + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding + + logger.debug( + "Building guided decoding logits processor. " + "guided_decoding: %s%s", guided_decoding, + f", reasoning_backend: {reasoning_backend}" + if reasoning_backend is not None else "") + + guided_decoding.backend = guided_decoding.backend or default_guided_backend + + processor = await get_guided_decoding_logits_processor( + guided_params=guided_decoding, + tokenizer=tokenizer, + reasoning_backend=reasoning_backend, + model_config=model_config) + + if processor: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append(processor) + + # Unset guided decoding params after constructing the lp from them + sampling_params.guided_decoding = None + + return sampling_params + + +class AsyncLLMEngine(EngineClient): + """An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine]. + + This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to + make it asynchronous. It uses asyncio to create a background loop that keeps + processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked + by the generate method when there are requests in the waiting queue. The + generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine] + to the caller. + + Args: + log_requests: Whether to log the requests. + start_engine_loop: If True, the background task to run the engine + will be automatically started in the generate call. + *args: Arguments for [`LLMEngine`][vllm.LLMEngine]. + **kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine]. + """ + + _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine + + def __init__(self, + *args, + log_requests: bool = True, + start_engine_loop: bool = True, + **kwargs) -> None: + if envs.VLLM_USE_V1: + raise ValueError( + "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. " + "This should not happen. As a workaround, try using " + "AsyncLLMEngine.from_vllm_config(...) or explicitly set " + "VLLM_USE_V1=0 or 1 and report this issue on Github.") + + self.log_requests = log_requests + self.engine = self._engine_class(*args, **kwargs) + + # This ensures quick processing of request outputs + # so the append to asyncio queues is not delayed, + # especially for multi-step. + self.use_process_request_outputs_callback = ( + self.engine.model_config.use_async_output_proc) + + if self.use_process_request_outputs_callback: + self.engine.process_request_outputs_callback = \ + weak_bind(self.process_request_outputs) + + self.background_loop: Optional[asyncio.Future] = None + # We need to keep a reference to unshielded + # task as well to prevent it from being garbage + # collected + self._background_loop_unshielded: Optional[asyncio.Task] = None + self.start_engine_loop = start_engine_loop + self._errored_with: Optional[BaseException] = None + + # Lazy initialized fields + self._request_tracker: RequestTracker + + def __del__(self): + if rt := getattr(self, "request_tracker", None): + # Wake up engine loop so that it will exit cleanly + rt.new_requests_event.set() + + @classmethod + def _get_executor_cls(cls, + engine_config: VllmConfig) -> Type[ExecutorBase]: + return LLMEngine._get_executor_cls(engine_config) + + @classmethod + def from_vllm_config( + cls, + vllm_config: VllmConfig, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + disable_log_requests: bool = False, + disable_log_stats: bool = False, + ) -> "AsyncLLMEngine": + """Create an AsyncLLMEngine from the EngineArgs.""" + + return cls( + vllm_config=vllm_config, + executor_class=cls._get_executor_cls(vllm_config), + start_engine_loop=start_engine_loop, + log_requests=not disable_log_requests, + log_stats=not disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + + @classmethod + def from_engine_args( + cls, + engine_args: AsyncEngineArgs, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + + vllm_config = engine_args.create_engine_config(usage_context) + + async_engine_cls = cls + if envs.VLLM_USE_V1: + from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine + async_engine_cls = V1AsyncLLMEngine + + return async_engine_cls.from_vllm_config( + vllm_config=vllm_config, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + stat_loggers=stat_loggers, + disable_log_stats=engine_args.disable_log_stats, + disable_log_requests=engine_args.disable_log_requests, + ) + + @property + def is_running(self) -> bool: + return (self.background_loop is not None + and self._background_loop_unshielded is not None + and not self._background_loop_unshielded.done()) + + @property + def is_stopped(self) -> bool: + return self.errored or (self.background_loop is not None and + self._background_loop_unshielded is not None + and self._background_loop_unshielded.done()) + + @property + def errored(self) -> bool: + return self._errored_with is not None + + @property + def dead_error(self) -> BaseException: + return AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + def set_errored(self, exc: Exception) -> None: + self._errored_with = exc + + def _error_callback(self, exc: Exception) -> None: + self.set_errored(exc) + self._request_tracker.propagate_exception(exc) + + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.engine.input_preprocessor + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return await self.engine.get_tokenizer_async(lora_request) + + def start_background_loop(self) -> None: + """Start the background loop.""" + if self.errored: + raise AsyncEngineDeadError( + "Background loop has errored already.") from self._errored_with + if self.is_running: + raise RuntimeError("Background loop is already running.") + # Initialize the RequestTracker here so it uses the right event loop. + self._request_tracker = RequestTracker() + + self._background_loop_unshielded = asyncio.get_event_loop( + ).create_task(self.run_engine_loop(weakref.ref(self))) + self._background_loop_unshielded.add_done_callback( + partial(_log_task_completion, error_callback=self._error_callback)) + self.background_loop = asyncio.shield(self._background_loop_unshielded) + + def shutdown_background_loop(self) -> None: + """ + Shut down the background loop. + + This method needs to be called during cleanup to remove + references to `self` and properly GC the resources held + by the async LLM engine (e.g., the executors as well as + their resources). + """ + if self._background_loop_unshielded is not None: + self._background_loop_unshielded.cancel() + self._background_loop_unshielded = None + self.background_loop = None + + async def engine_step(self, virtual_engine: int) -> bool: + """Kick the engine to process the waiting requests. + + Returns True if there are in-progress requests.""" + + new_requests, aborted_requests = ( + self._request_tracker.get_new_and_aborted_requests()) + + for new_request in new_requests: + # Add the request into the vLLM engine's waiting queue. + try: + await self.engine.add_request_async(**new_request) + except ValueError as e: + # TODO: use a vLLM specific error for failed validation + self._request_tracker.process_exception( + new_request["request_id"], + e, + verbose=self.log_requests, + ) + + if aborted_requests: + await self._engine_abort(aborted_requests) + + request_outputs = await self.engine.step_async(virtual_engine) + + # Put the outputs into the corresponding streams. + # If used as a callback, then already invoked inside + # LLMEngine's _process_model_outputs + if not self.use_process_request_outputs_callback: + all_finished = self.process_request_outputs(request_outputs) + else: + # For callback case, we only need to detect when all + # requests are finished + all_finished = all(request_output.finished + for request_output in request_outputs) + + return not all_finished + + def process_request_outputs(self, request_outputs) -> bool: + # Put the outputs into the corresponding streams. + all_finished = True + for request_output in request_outputs: + self._request_tracker.process_request_output( + request_output, verbose=self.log_requests) + all_finished = all_finished and request_output.finished + + return all_finished + + async def _engine_abort(self, request_ids: Iterable[str]): + self.engine.abort_request(request_ids) + + @staticmethod + async def run_engine_loop(engine_ref: ReferenceType): + """We use a weakref to the engine so that the running loop + doesn't prevent the engine being garbage collected.""" + engine: Optional[AsyncLLMEngine] = engine_ref() + if not engine: + return + + pipeline_parallel_size = \ + engine.engine.parallel_config.pipeline_parallel_size + has_requests_in_progress = [False] * pipeline_parallel_size + while True: + if not any(has_requests_in_progress): + logger.debug("Waiting for new requests...") + # Stop the execute model loop in parallel workers until there + # are more requests to process. This avoids waiting + # indefinitely in torch.distributed ops which may otherwise + # timeout, and unblocks the RPC thread in the workers so that + # they can process any other queued control plane messages, + # such as add/remove lora adapters. + await engine.engine.stop_remote_worker_execution_loop_async() + request_tracker = engine._request_tracker + # Allow engine to be garbage collected while + # waiting for new requests + del engine + await asyncio.sleep(0) + if engine_ref() is None: + return + await request_tracker.wait_for_new_requests() + engine = engine_ref() + if not engine: + return + logger.debug("Got new requests!") + requests_in_progress = [ + asyncio.create_task(engine.engine_step(ve)) + for ve in range(pipeline_parallel_size) + ] + has_requests_in_progress = [True] * pipeline_parallel_size + + # Abort if iteration takes too long due to unrecoverable errors + # (eg. NCCL timeouts). + try: + async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S): + done, _ = await asyncio.wait( + requests_in_progress, + return_when=asyncio.FIRST_COMPLETED) + for _ in range(pipeline_parallel_size): + await asyncio.sleep(0) + for task in done: + result = task.result() + virtual_engine = requests_in_progress.index(task) + has_unfinished_requests = ( + engine.engine. + has_unfinished_requests_for_virtual_engine( + virtual_engine)) + if result or has_unfinished_requests: + requests_in_progress[virtual_engine] = ( + asyncio.create_task( + engine.engine_step(virtual_engine))) + has_requests_in_progress[virtual_engine] = True + else: + has_requests_in_progress[virtual_engine] = False + except asyncio.TimeoutError as exc: + logger.error( + "Engine iteration timed out. This should never happen!") + engine.set_errored(exc) + raise + await asyncio.sleep(0) + + async def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + data_parallel_rank: Optional[int] = None, + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + if not self.is_running: + if self.start_engine_loop: + self.start_background_loop() + else: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + if (priority != 0 + and not self.engine.scheduler_config.policy == "priority"): + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + + stream = self._request_tracker.add_request( + request_id, + verbose=self.log_requests, + prompt=prompt, + params=params, + arrival_time=arrival_time or time.time(), + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + data_parallel_rank=data_parallel_rank, + ) + + return stream.generator() + + async def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + data_parallel_rank: Optional[int] = None, + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See + [`PromptType`][vllm.inputs.PromptType] for more details about + the format of each input. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. + priority: The priority of the request. + Only applicable with priority scheduling. + data_parallel_rank: The (global) data parallel rank that must + handle this request. Only applicable if DP is enabled. + Yields: + The output `RequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + [`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step] + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> # note that engine_args here is AsyncEngineArgs instance + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "prompt": "What is LLM?", + >>> "stream": False, # assume the non-streaming case + >>> "temperature": 0.0, + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.generate( + >>> example_input["prompt"], + >>> SamplingParams(temperature=example_input["temperature"]), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... + """ + try: + async for output in await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + data_parallel_rank=data_parallel_rank, + ): + yield LLMEngine.validate_output(output, RequestOutput) + except asyncio.CancelledError: + await self.abort(request_id) + raise + + async def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + """Generate outputs for a request from a pooling model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See + [`PromptType`][vllm.inputs.PromptType] for more details about + the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + priority: The priority of the request. + Only applicable with priority scheduling. + + Yields: + The output `PoolingRequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + [`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][] + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + ``` + # Please refer to entrypoints/api_server.py for + # the complete example. + + # initialize the engine and the example input + # note that engine_args here is AsyncEngineArgs instance + engine = AsyncLLMEngine.from_engine_args(engine_args) + example_input = { + "input": "What is LLM?", + "request_id": 0, + } + + # start the generation + results_generator = engine.encode( + example_input["input"], + PoolingParams(), + example_input["request_id"]) + + # get the results + final_output = None + async for request_output in results_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + # Return or raise an error + ... + final_output = request_output + + # Process and return the final output + ... + ``` + """ + try: + async for output in await self.add_request( + request_id, + prompt, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=priority, + ): + yield LLMEngine.validate_output(output, PoolingRequestOutput) + except asyncio.CancelledError: + await self.abort(request_id) + raise + + async def abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + if not self.is_running: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + return self._abort(request_id) + + def _abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + self._request_tracker.abort_request(request_id, + exception=asyncio.CancelledError, + verbose=self.log_requests) + + async def get_vllm_config(self) -> VllmConfig: + """Get the vllm configuration of the vLLM engine.""" + return self.engine.get_vllm_config() + + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + return self.engine.get_model_config() + + async def get_parallel_config(self) -> ParallelConfig: + """Get the parallel configuration of the vLLM engine.""" + return self.engine.get_parallel_config() + + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + return self.engine.get_decoding_config() + + async def get_scheduler_config(self) -> SchedulerConfig: + """Get the scheduling configuration of the vLLM engine.""" + return self.engine.get_scheduler_config() + + async def get_lora_config(self) -> LoRAConfig: + """Get the lora configuration of the vLLM engine.""" + return self.engine.get_lora_config() + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None) -> None: + self.engine.do_log_stats() + + async def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + t = time.perf_counter() + logger.debug("Starting health check...") + if self.is_stopped: + raise AsyncEngineDeadError("Background loop is stopped.") + + await self.engine.check_health_async() + logger.debug("Health check took %fs", time.perf_counter() - t) + + async def is_tracing_enabled(self) -> bool: + return self.engine.is_tracing_enabled() + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + self.engine.add_logger(logger_name=logger_name, logger=logger) + + def remove_logger(self, logger_name: str) -> None: + self.engine.remove_logger(logger_name=logger_name) + + async def start_profile(self) -> None: + self.engine.start_profile() + + async def stop_profile(self) -> None: + self.engine.stop_profile() + + async def reset_mm_cache(self) -> None: + self.engine.reset_mm_cache() + + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + self.engine.reset_prefix_cache(device) + + async def sleep(self, level: int = 1) -> None: + self.engine.sleep(level) + + async def wake_up(self, tags: Optional[list[str]] = None) -> None: + self.engine.wake_up(tags) + + async def is_sleeping(self) -> bool: + return self.engine.is_sleeping() + + async def add_lora(self, lora_request: LoRARequest) -> None: + self.engine.add_lora(lora_request) + + async def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + """ + Perform a collective RPC call to the given path. + """ + return await self.engine.collective_rpc_async(method, timeout, args, + kwargs) + + +# TODO(v1): Remove this class proxy when V1 goes default. +if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: + from vllm.v1.engine.async_llm import AsyncLLM + + AsyncLLMEngine = AsyncLLM # type: ignore diff --git a/engine/async_timeout.py b/engine/async_timeout.py new file mode 100644 index 0000000..28a023a --- /dev/null +++ b/engine/async_timeout.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Workaround for https://github.com/python/cpython/issues/86296 +# +# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py +# Licensed under the Apache License (Apache-2.0) + +import asyncio +import enum +import sys +from types import TracebackType +from typing import Any, Optional, Type + +if sys.version_info[:2] >= (3, 11): + from asyncio import timeout as asyncio_timeout +else: + + def asyncio_timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + deadline = loop.time() + delay if delay is not None else None + return Timeout(deadline, loop) + + class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + + def __init__(self, deadline: Optional[float], + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._state = _State.INIT + + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + The delay can be negative. + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError( + "cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + deadline argument points on the time in the same clock system + as loop.time(). + If new deadline is in the past the timeout is raised immediately. + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError( + "cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon( + self._on_timeout, task) + else: + self._timeout_handler = self._loop.call_at( + deadline, self._on_timeout, task) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and \ + self._state == _State.TIMEOUT: + self._timeout_handler = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None: + if task: + task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None diff --git a/engine/llm_engine.py b/engine/llm_engine.py new file mode 100644 index 0000000..8fccf9b --- /dev/null +++ b/engine/llm_engine.py @@ -0,0 +1,2097 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +import time +from collections import Counter as collectionsCounter +from collections import deque +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, + Iterable, List, Literal, Mapping, NamedTuple, Optional) +from typing import Sequence as GenericSequence +from typing import Set, Type, Union, cast + +import torch +from typing_extensions import TypeVar + +import vllm.envs as envs +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, SchedulerConfig, + VllmConfig) +from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.metrics_types import StatLoggerBase, Stats +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.engine.output_processor.util import create_output_by_sequence_group +from vllm.entrypoints.openai.logits_processors import ( + get_logits_processors as get_openai_logits_processors) +from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs +from vllm.inputs.parse import split_enc_dec_inputs +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.logits_process import get_bad_words_logits_processors +from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_local_guided_decoding_logits_processor) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.processing import EncDecMultiModalProcessor +from vllm.outputs import (PoolingRequestOutput, RequestOutput, + RequestOutputFactory) +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, + PoolingSequenceGroupOutput, Sequence, SequenceGroup, + SequenceGroupBase, SequenceGroupMetadata, + SequenceGroupOutput, SequenceStatus) +from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, + init_tracer) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import ( + TokenizerGroup, init_tokenizer_from_configs) +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) +from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind +from vllm.version import __version__ as VLLM_VERSION +from vllm.worker.model_runner_base import InputProcessingError + +logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + +_O = TypeVar("_O", RequestOutput, PoolingRequestOutput) +_R = TypeVar("_R", default=Any) + + +@dataclass +class SchedulerOutputState: + """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + allow_async_output_proc: bool = False + last_output: Optional[SamplerOutput] = None + + +class OutputData(NamedTuple): + outputs: List[SamplerOutput] + seq_group_metadata_list: List[SequenceGroupMetadata] + scheduler_outputs: SchedulerOutputs + is_async: bool + is_last_step: bool + # Indicates if this output is from the first step of the + # multi-step. When multi-step is disabled, this is always + # set to True. + # is_first_step_output is invalid when `outputs` has + # outputs from multiple steps. + is_first_step_output: Optional[bool] + skip: List[int] + + +class SchedulerContext: + + def __init__(self, multi_step_stream_outputs: bool = False): + self.output_queue: Deque[OutputData] = deque() + self.request_outputs: List[Union[RequestOutput, + PoolingRequestOutput]] = [] + self.seq_group_metadata_list: Optional[ + List[SequenceGroupMetadata]] = None + self.scheduler_outputs: Optional[SchedulerOutputs] = None + + self.multi_step_stream_outputs: bool = multi_step_stream_outputs + + def append_output(self, outputs: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, is_async: bool, + is_last_step: bool, + is_first_step_output: Optional[bool]): + self.output_queue.append( + OutputData(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=is_async, + is_last_step=is_last_step, + is_first_step_output=is_first_step_output, + skip=[])) + + +class LLMEngine: + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The [`LLM`][vllm.LLM] class wraps this class for offline batched inference + and the [`AsyncLLMEngine`][vllm.engine.async_llm_engine.AsyncLLMEngine] + class wraps this class for online serving. + + The config arguments are derived from [`EngineArgs`][vllm.EngineArgs]. + + Args: + vllm_config: The configuration for initializing and running vLLM. + executor_class: The model executor class for managing distributed + execution. + log_stats: Whether to log statistics. + usage_context: Specified entry point, used for usage info collection. + """ + + DO_VALIDATE_OUTPUT: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + + @classmethod + @contextmanager + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True + + yield + + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return cast(_O, output) + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ + + tokenizer: Optional[TokenizerGroup] + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: Type[ExecutorBase], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + use_cached_outputs: bool = False, + ) -> None: + if envs.VLLM_USE_V1: + raise ValueError( + "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. " + "This should not happen. As a workaround, try using " + "LLMEngine.from_vllm_config(...) or explicitly set " + "VLLM_USE_V1=0 or 1 and report this issue on Github.") + + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config # noqa + self.load_config = vllm_config.load_config + self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + ) + self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa + self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa + ) + + logger.info( + "Initializing a V0 LLM engine (v%s) with config: %s, " + "use_cached_outputs=%s, ", + VLLM_VERSION, + vllm_config, + use_cached_outputs, + ) + + self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs + + if not self.model_config.skip_tokenizer_init: + self.tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) + tokenizer_group = self.get_tokenizer_group() + else: + self.tokenizer = None + self.detokenizer = None + tokenizer_group = None + + # Ensure that the function doesn't contain a reference to self, + # to avoid engine GC issues + def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: + assert tokenizer_group, ("tokenizer_group cannot be None, " + "make sure skip_tokenizer_init is False") + return tokenizer_group.get_lora_tokenizer(sequence.lora_request) + + self.seq_counter = Counter() + self.generation_config_fields = ( + self.model_config.try_get_generation_config()) + + self.input_preprocessor = InputPreprocessor(self.model_config, + self.tokenizer, + mm_registry) + + self.model_executor = executor_class(vllm_config=vllm_config) + + if self.model_config.runner_type != "pooling": + self._initialize_kv_caches() + + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import ( + get_architecture_class_name) + usage_message.report_usage( + get_architecture_class_name(self.model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(self.model_config.dtype), + "tensor_parallel_size": + self.parallel_config.tensor_parallel_size, + "block_size": + self.cache_config.block_size, + "gpu_memory_utilization": + self.cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + self.model_config.quantization, + "kv_cache_dtype": + str(self.cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(self.lora_config), + "enable_prompt_adapter": + bool(self.prompt_adapter_config), + "enable_prefix_caching": + self.cache_config.enable_prefix_caching, + "enforce_eager": + self.model_config.enforce_eager, + "disable_custom_all_reduce": + self.parallel_config.disable_custom_all_reduce, + }) + + self.cached_scheduler_outputs = [ + SchedulerOutputState() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.scheduler_contexts = [ + SchedulerContext(multi_step_stream_outputs=self.scheduler_config. + multi_step_stream_outputs) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + if self.model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [ + partial(process_model_outputs, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + else: + self.async_callbacks = [] + + # Currently used by AsyncLLMEngine to ensure quick append + # of request outputs to asyncio queues + self.process_request_outputs_callback: Optional[Callable] = None + + # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. + if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str): + Scheduler = resolve_obj_by_qualname( + self.vllm_config.scheduler_config.scheduler_cls) + else: + Scheduler = self.vllm_config.scheduler_config.scheduler_cls + self.scheduler = [ + Scheduler( + self.scheduler_config, self.cache_config, self.lora_config, + self.parallel_config.pipeline_parallel_size, + self.async_callbacks[v_id] + if self.model_config.use_async_output_proc else None) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + + # Metric Logging. + if self.log_stats: + if stat_loggers is not None: + self.stat_loggers = stat_loggers + else: + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from vllm.engine.metrics import (LoggingStatLogger, + PrometheusStatLogger) + + self.stat_loggers = { + "logging": + LoggingStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + vllm_config=vllm_config), + "prometheus": + PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict( + model_name=self.model_config.served_model_name), + vllm_config=vllm_config), + } + self.stat_loggers["prometheus"].info("cache_config", + self.cache_config) + + self.tracer = None + if self.observability_config.otlp_traces_endpoint: + self.tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = ( + SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + get_tokenizer_for_seq, + stop_checker=StopChecker(self.scheduler_config.max_model_len, + get_tokenizer_for_seq), + )) + + self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} + + # Flag to set when an input fails to process and the engine should run + # the next step without re-scheduling. + self._skip_scheduling_next_step = False + + # Don't keep the dummy data in memory + self.reset_mm_cache() + + def _initialize_kv_caches(self) -> None: + """Initialize the KV cache in the worker(s). + + The workers will determine the number of blocks in both the GPU cache + and the swap CPU cache. + """ + start = time.time() + num_gpu_blocks, num_cpu_blocks = ( + self.model_executor.determine_num_available_blocks()) + + if self.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_gpu_blocks, + num_gpu_blocks_override) + num_gpu_blocks = num_gpu_blocks_override + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + elapsed = time.time() - start + logger.info(("init engine (profile, create kv cache, " + "warmup model) took %.2f seconds"), elapsed) + + @classmethod + def _get_executor_cls(cls, + engine_config: VllmConfig) -> Type[ExecutorBase]: + # distributed_executor_backend must be set in VllmConfig.__post_init__ + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) + # Initialize the cluster and specify the executor class. + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {distributed_executor_backend}.") + executor_class = distributed_executor_backend + elif distributed_executor_backend == "ray": + from vllm.executor.ray_distributed_executor import ( + RayDistributedExecutor) + executor_class = RayDistributedExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.mp_distributed_executor import ( + MultiprocessingDistributedExecutor) + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") + executor_class = MultiprocessingDistributedExecutor + elif distributed_executor_backend == "uni": + # JAX-style, single-process, multi-device executor. + from vllm.executor.uniproc_executor import UniProcExecutor + executor_class = UniProcExecutor + elif distributed_executor_backend == "external_launcher": + # executor with external launcher + from vllm.executor.uniproc_executor import ( # noqa + ExecutorWithExternalLauncher) + executor_class = ExecutorWithExternalLauncher + else: + raise ValueError("unrecognized distributed_executor_backend: " + f"{distributed_executor_backend}") + return executor_class + + @classmethod + def from_vllm_config( + cls, + vllm_config: VllmConfig, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + disable_log_stats: bool = False, + ) -> "LLMEngine": + return cls( + vllm_config=vllm_config, + executor_class=cls._get_executor_cls(vllm_config), + log_stats=(not disable_log_stats), + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + + @classmethod + def from_engine_args( + cls, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + vllm_config = engine_args.create_engine_config(usage_context) + + engine_cls = cls + if envs.VLLM_USE_V1: + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + engine_cls = V1LLMEngine + + return engine_cls.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + stat_loggers=stat_loggers, + disable_log_stats=engine_args.disable_log_stats, + ) + + def __reduce__(self): + # This is to ensure that the LLMEngine is not referenced in + # the closure used to initialize Ray worker actors + raise RuntimeError("LLMEngine should not be pickled!") + + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if model_executor := getattr(self, "model_executor", None): + model_executor.shutdown() + + def get_tokenizer_group(self) -> TokenizerGroup: + if self.tokenizer is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + + return self.tokenizer + + def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return self.get_tokenizer_group().get_lora_tokenizer(lora_request) + + def _init_tokenizer(self) -> TokenizerGroup: + return init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + lora_config=self.lora_config) + + def _verify_args(self) -> None: + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) + + def _add_processed_request( + self, + request_id: str, + processed_inputs: ProcessorInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> Optional[SequenceGroup]: + """Add a processed request to the engine's request pool. + return the created sequence group. + """ + if isinstance(params, SamplingParams) and params.n > 1: + ParallelSampleSequenceGroup.add_request( + request_id, + self, + params, + processed_inputs=processed_inputs, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + return None + + self._validate_model_inputs(processed_inputs, lora_request) + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + + seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, + lora_request, prompt_adapter_request) + + encoder_seq = (None if encoder_inputs is None else Sequence( + seq_id, encoder_inputs, block_size, eos_token_id, lora_request, + prompt_adapter_request)) + + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") + + # Add the sequence group to the scheduler with least unfinished seqs. + costs = [ + scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler + ] + min_cost_scheduler = self.scheduler[costs.index(min(costs))] + min_cost_scheduler.add_seq_group(seq_group) + + return seq_group + + def stop_remote_worker_execution_loop(self) -> None: + self.model_executor.stop_remote_worker_execution_loop() + + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + """Add a request to the engine's request pool. + + The request is added to the request pool and will be processed by the + scheduler as `engine.step()` is called. The exact scheduling policy is + determined by the scheduler. + + Args: + request_id: The unique ID of the request. + prompt: The prompt to the LLM. See + [PromptType][vllm.inputs.PromptType] + for more details about the format of each input. + params: Parameters for sampling or pooling. + [SamplingParams][vllm.SamplingParams] for text generation. + [PoolingParams][vllm.PoolingParams] for pooling. + arrival_time: The arrival time of the request. If None, we use + the current monotonic time. + lora_request: The LoRA request to add. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: The prompt adapter request to add. + priority: The priority of the request. + Only applicable with priority scheduling. + + Details: + - Set arrival_time to the current time if it is None. + - Set prompt_token_ids to the encoded prompt if it is None. + - Create `n` number of [Sequence][vllm.Sequence] objects. + - Create a [SequenceGroup][vllm.SequenceGroup] object + from the list of [Sequence][vllm.Sequence]. + - Add the [SequenceGroup][vllm.SequenceGroup] object to the + scheduler. + + Example: + >>> # initialize engine + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> # set request arguments + >>> example_prompt = "Who is the president of the United States?" + >>> sampling_params = SamplingParams(temperature=0.0) + >>> request_id = 0 + >>> + >>> # add the request to the engine + >>> engine.add_request( + >>> str(request_id), + >>> example_prompt, + >>> SamplingParams(temperature=0.0)) + >>> # continue the request processing + >>> ... + """ + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + + if priority != 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + + if isinstance(params, SamplingParams) \ + and (params.guided_decoding or params.logits_processors) \ + and self.scheduler_config.num_scheduler_steps > 1: + raise ValueError( + "Guided decoding and logits processors are not supported " + "in multi-step decoding") + + if arrival_time is None: + arrival_time = time.time() + + if (isinstance(prompt, dict) + and prompt.get("prompt_embeds", None) is not None + and not prompt.get("prompt_token_ids", None)): + seq_len = prompt["prompt_embeds"].shape[0] + prompt["prompt_token_ids"] = [0] * seq_len + + processed_inputs = self.input_preprocessor.preprocess( + prompt, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=priority, + ) + + def _create_sequence_group_with_sampling( + self, + request_id: str, + seq: Sequence, + sampling_params: SamplingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + encoder_seq: Optional[Sequence] = None, + priority: int = 0, + ) -> SequenceGroup: + """Creates a SequenceGroup with SamplingParams.""" + max_logprobs = self.get_model_config().max_logprobs + if (sampling_params.logprobs + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): + raise ValueError(f"Cannot request more than " + f"{max_logprobs} logprobs.") + + sampling_params = self._build_logits_processors( + sampling_params, lora_request) + + # Defensive copy of SamplingParams, which are used by the sampler, + # this doesn't deep-copy LogitsProcessor objects + sampling_params = sampling_params.clone() + + sampling_params.update_from_generation_config( + self.generation_config_fields, seq.eos_token_id) + + # Create the sequence group. + draft_size = 1 + if self.vllm_config.speculative_config is not None: + draft_size = \ + self.vllm_config.speculative_config.num_speculative_tokens + 1 + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority, + draft_size=draft_size) + + return seq_group + + def _create_sequence_group_with_pooling( + self, + request_id: str, + seq: Sequence, + pooling_params: PoolingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + encoder_seq: Optional[Sequence] = None, + priority: int = 0, + ) -> SequenceGroup: + """Creates a SequenceGroup with PoolingParams.""" + # Defensive copy of PoolingParams, which are used by the pooler + pooling_params = pooling_params.clone() + # Create the sequence group. + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + pooling_params=pooling_params, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + return seq_group + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a request(s) with the given ID. + + Args: + request_id: The ID(s) of the request to abort. + + Details: + - Refer to [vllm.core.scheduler.Scheduler.abort_seq_group][]. + + Example: + >>> # initialize engine and add a request with request_id + >>> request_id = str(0) + >>> # abort the request + >>> engine.abort_request(request_id) + """ + for scheduler in self.scheduler: + scheduler.abort_seq_group( + request_id, seq_id_to_seq_group=self.seq_id_to_seq_group) + + def get_vllm_config(self) -> VllmConfig: + """Gets the vllm configuration.""" + return self.vllm_config + + def get_model_config(self) -> ModelConfig: + """Gets the model configuration.""" + return self.model_config + + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.parallel_config + + def get_decoding_config(self) -> DecodingConfig: + """Gets the decoding configuration.""" + return self.decoding_config + + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return sum(scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler) + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return any(scheduler.has_unfinished_seqs() + for scheduler in self.scheduler) + + def has_unfinished_requests_for_virtual_engine( + self, virtual_engine: int) -> bool: + """ + Returns True if there are unfinished requests for the virtual engine. + """ + return self.scheduler[virtual_engine].has_unfinished_seqs() + + def reset_mm_cache(self) -> bool: + """Reset the multi-modal cache.""" + return self.input_preprocessor.mm_registry.reset_processor_cache() + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache for all devices.""" + + success = True + for scheduler in self.scheduler: + success = success and scheduler.reset_prefix_cache(device) + return success + + @staticmethod + def _process_sequence_group_outputs( + seq_group: SequenceGroup, + outputs: List[PoolingSequenceGroupOutput], + ) -> None: + seq_group.pooled_data = outputs[0].data + + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_STOPPED + + return + + def _update_num_computed_tokens_for_multi_step_prefill( + self, seq_group: SequenceGroup, + seq_group_meta: SequenceGroupMetadata, + is_first_step_output: Optional[bool]): + """ + This function updates num_computed_tokens for prompt sequences + when Multi-Step is enabled. + + seq_group: SequenceGroup to update the num_computed_tokens for. + seq_group_meta: Metadata of the given SequenceGroup. + is_first_step_output: Optional[bool] - + When available, is_first_step_output indicates if the appended + output token is the output of the first-step in multi-step. + A value of None indicates that outputs from all steps in + in multi-step are submitted in a single burst. + """ + + assert self.scheduler_config.is_multi_step + + if not seq_group_meta.is_prompt: + # num_computed_token updates for multi-step decodes happen after + # the tokens are appended to the sequence. + return + + do_update: bool = False + if self.scheduler_config.chunked_prefill_enabled: + # In multi-step + chunked-prefill case, the prompt sequences + # that are scheduled are fully processed in the first step. + do_update = is_first_step_output is None or is_first_step_output + else: + # Normal multi-step decoding case. In this case prompt-sequences + # are actually single-stepped. Always update in this case. + assert seq_group.state.num_steps == 1 + do_update = True + + if do_update: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size) + + def _process_model_outputs(self, + ctx: SchedulerContext, + request_id: Optional[str] = None) -> None: + """Apply the model output to the sequences in the scheduled seq groups + and return responses. + + ctx: The virtual engine context to work on + request_id: If provided, then only this request is going to be processed + """ + + now = time.time() + + if len(ctx.output_queue) == 0: + return None + + # Get pending async postprocessor + if request_id: + # When we process only one request, no pop is required + # (since later we will process all of the rest) + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, is_first_step_output, skip) = ctx.output_queue[0] + else: + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, is_first_step_output, + skip) = ctx.output_queue.popleft() + + # Sanity check + assert len(seq_group_metadata_list) == len( + scheduler_outputs.scheduled_seq_groups) + + has_multiple_outputs: bool = len(outputs) > 1 + outputs_by_sequence_group: List[List[SequenceGroupOutput]] + if has_multiple_outputs: + assert self.scheduler_config.is_multi_step or \ + self.speculative_config + # Organize outputs by [step][sequence group] instead of + # [sequence group][step]. + if self.scheduler_config.is_multi_step: + outputs_by_sequence_group = create_output_by_sequence_group( + outputs, len(seq_group_metadata_list)) + elif self.speculative_config: + # Decodes are multi-steps while prefills are not, outputting at + # most 1 token. Separate them so that we can trigger chunk + # processing without having to pad or copy over prompts K times + # to match decodes structure (costly with prompt_logprobs). + num_prefills = sum(sg.is_prompt + for sg in seq_group_metadata_list) + prefills, decodes = outputs[:num_prefills], outputs[ + num_prefills:] + outputs_by_sequence_group = create_output_by_sequence_group( + decodes, + num_seq_groups=len(seq_group_metadata_list) - num_prefills) + outputs_by_sequence_group = [p.outputs for p in prefills + ] + outputs_by_sequence_group + # We have outputs for multiple steps submitted in a single burst, + # so invalidate is_first_step_output. + is_first_step_output = None + else: + outputs_by_sequence_group = outputs + + # Determine the requests we need to operate on + if request_id: + indices = [] + for i, seq_group_meta in enumerate(seq_group_metadata_list): + if seq_group_meta.request_id == request_id: + assert i not in skip # Cannot be called twice + indices.append(i) + break + + # If the request_id was not found, then it means that + # this is a new request that has no pending async + # postprocessor + if not indices: + return + else: + indices = range(len(seq_group_metadata_list)) # type: ignore + + finished_before: List[int] = [] + finished_now: List[int] = [] + for i in indices: + if i in skip: + continue + + seq_group_meta = seq_group_metadata_list[i] + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group: SequenceGroup = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + finished_before.append(i) + continue + + output: List[SequenceGroupOutput] + if has_multiple_outputs: + output = outputs_by_sequence_group[i] + else: + output = [outputs_by_sequence_group[0][i]] + + if not is_async: + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_meta, is_first_step_output) + else: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size or 0) + + if outputs: + for o in outputs: + if (isinstance(o, SamplerOutput) + and seq_group.metrics is not None): + if seq_group.metrics.model_forward_time is not None: + seq_group.metrics.model_forward_time += ( + o.model_forward_time or 0) + else: + seq_group.metrics.model_forward_time = ( + o.model_forward_time) + if seq_group.metrics.model_execute_time is not None: + seq_group.metrics.model_execute_time += ( + o.model_execute_time or 0) + else: + seq_group.metrics.model_execute_time = ( + o.model_execute_time) + + if self.model_config.runner_type == "pooling": + self._process_sequence_group_outputs(seq_group, output) + else: + self.output_processor.process_prompt_logprob(seq_group, output) + if seq_group_meta.do_sample: + self.output_processor.process_outputs( + seq_group, output, is_async) + + if seq_group.is_finished(): + finished_now.append(i) + + # Generate outputs for the requests that finished this iteration + for i in finished_now: + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + if not seq_group.is_prefill(): + seq_group.set_last_token_time(now) + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # When we process a single request, we skip it for the next time, + # and invoke the request output callback (if there was final output) + if request_id: + assert len(indices) == 1 + skip.append(indices[0]) + + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Free currently finished requests + if finished_now: + for scheduler in self.scheduler: + scheduler.free_finished_seq_groups() + + # For multi-step without streaming, don't create outputs each iteration + if not is_last_step and not ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Create the outputs + for i in indices: + if i in skip or i in finished_before or i in finished_now: + continue # Avoids double processing + + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + if not seq_group.is_prefill(): + seq_group.set_last_token_time(now) + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # For multi-step with streaming, create outputs each iteration + if not is_last_step and ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if self.process_request_outputs_callback is not None: + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + for seq_group in scheduler_outputs.ignored_seq_groups: + params = seq_group.sampling_params + if params is not None and params.output_kind == ( + RequestOutputKind.DELTA) and not seq_group.is_finished(): + continue + + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs, + ) + if request_output: + ctx.request_outputs.append(request_output) + + # Immediately process request outputs here (if callback is given) + if (ctx.request_outputs + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + + # For async case, we need to record the stats here. + # For non-async case, the stats are done in the + # LLMEngine/AsyncLLMEngine directly + if is_async: + # Log stats. + self.do_log_stats(scheduler_outputs, outputs, finished_before, + skip) + + # Tracing + self.do_tracing(scheduler_outputs, finished_before) + + return None + + def _advance_to_next_step( + self, output: SamplerOutput, + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done inside output processor, but it is + required if the worker is to perform async forward pass to next step. + """ + for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ + zip(seq_group_metadata_list, output, scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + continue + + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_metadata, + seq_group.state.num_steps == 1) + else: + token_chunk_size = (seq_group_metadata.token_chunk_size + if seq_group_metadata.token_chunk_size + is not None else 0) + seq_group.update_num_computed_tokens(token_chunk_size) + + if seq_group_metadata.do_sample: + assert len(sequence_group_outputs.samples) == 1, ( + "Async output processor expects a single sample" + " (i.e sampling_params.n == 1)") + sample = sequence_group_outputs.samples[0] + + assert len(seq_group.seqs) == 1 + seq = seq_group.seqs[0] + + if self.scheduler_config.is_multi_step: + is_prefill_append = seq.data.get_num_uncomputed_tokens( + ) == 0 + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) + if not is_prefill_append: + seq_group.update_num_computed_tokens(1) + else: + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) + + def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: + """Performs one decoding iteration and returns newly generated results. + +
+ ![Overview of the step function](https://i.imgur.com/sv2HssD.png) +
Overview of the step function
+
+ + Details: + - Step 1: Schedules the sequences to be executed in the next + iteration and the token blocks to be swapped in/out/copy. + + - Depending on the scheduling policy, + sequences may be `preempted/reordered`. + - A Sequence Group (SG) refer to a group of sequences + that are generated from the same prompt. + + - Step 2: Calls the distributed executor to execute the model. + - Step 3: Processes the model output. This mainly includes: + + - Decodes the relevant outputs. + - Updates the scheduled sequence groups with model outputs + based on its `sampling parameters` (`use_beam_search` or not). + - Frees the finished sequence groups. + + - Finally, it creates and returns the newly generated results. + + Example: + ``` + # Please see the example/ folder for more detailed examples. + + # initialize engine and request arguments + engine = LLMEngine.from_engine_args(engine_args) + example_inputs = [(0, "What is LLM?", + SamplingParams(temperature=0.0))] + + # Start the engine with an event loop + while True: + if example_inputs: + req_id, prompt, sampling_params = example_inputs.pop(0) + engine.add_request(str(req_id),prompt,sampling_params) + + # continue the request processing + request_outputs = engine.step() + for request_output in request_outputs: + if request_output.finished: + # return or show the request output + + if not (engine.has_unfinished_requests() or example_inputs): + break + ``` + """ + if self.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError( + "Pipeline parallelism is only supported through AsyncLLMEngine " + "as performance will be severely degraded otherwise.") + + # For llm_engine, there is no pipeline parallel support, so the engine + # used is always 0. + virtual_engine = 0 + + # These are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc + + ctx = self.scheduler_contexts[virtual_engine] + + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + + # Skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + # The scheduler is also skipped if a single request caused the last + # engine step to fail, and the previous schedule needs to be rerun. + if not self._has_remaining_steps( + seq_group_metadata_list + ) and not self._skip_scheduling_next_step: + # Schedule iteration + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc + ) = self.scheduler[virtual_engine].schedule() + + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs + + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() + # When n>1, elements in self.seq_id_to_seq_group should be deleted + # here, otherwise memory leaks. + for finished_request_id in finished_requests_ids: + if finished_request_id in self.seq_id_to_seq_group: + del self.seq_id_to_seq_group[finished_request_id] + + # Maybe switch from async mode to sync mode + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + else: + finished_requests_ids = list() + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None + + if not scheduler_outputs.is_empty(): + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + + if allow_async_output_proc: + execute_model_req.async_callback = self.async_callbacks[ + virtual_engine] + + try: + outputs = self.model_executor.execute_model( + execute_model_req=execute_model_req) + self._skip_scheduling_next_step = False + except InputProcessingError as e: + # The input for this request cannot be processed, so we must + # abort it. If there are remaining requests in the batch that + # have been scheduled, they will be retried on the next step. + invalid_request_id = e.request_id + self._abort_and_cache_schedule( + request_id=invalid_request_id, + virtual_engine=virtual_engine, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + allow_async_output_proc=allow_async_output_proc) + # Raise so the caller is notified that this request failed + raise + + # We need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, outputs) + else: + # Nothing scheduled => If there is pending async postprocessor, + # then finish it here. + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + # No outputs in this case + outputs = [] + + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # clear the cache if we have finished all the steps. + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[0] = SchedulerOutputState() + + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + + # Add results to the output_queue + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True, + is_first_step_output=is_first_step_output) + + if outputs and allow_async_output_proc: + assert len(outputs) == 1, ( + "Async postprocessor expects only a single output set") + + self._advance_to_next_step( + outputs[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + # Check if need to run the usual non-async path + if not allow_async_output_proc: + self._process_model_outputs(ctx=ctx) + + # Log stats. + self.do_log_stats(scheduler_outputs, outputs) + + # Tracing + self.do_tracing(scheduler_outputs) + else: + # Multi-step case + return ctx.request_outputs + + if not self.has_unfinished_requests(): + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + assert len(ctx.output_queue) == 0 + + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") + self.model_executor.stop_remote_worker_execution_loop() + + return ctx.request_outputs + + def _abort_and_cache_schedule( + self, request_id: str, virtual_engine: int, + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + """Aborts a single request, and caches the scheduler outputs minus that + request. This allows the next step to continue processing the remaining + requests without having to re-run the scheduler.""" + + # Abort the request and remove its sequence group from the current + # schedule + self.abort_request(request_id) + for i, metadata in enumerate(seq_group_metadata_list): + if metadata.request_id == request_id: + del seq_group_metadata_list[i] + break + for i, group in enumerate(scheduler_outputs.scheduled_seq_groups): + if group.seq_group.request_id == request_id: + del scheduler_outputs.scheduled_seq_groups[i] + break + + # If there are still other sequence groups left in the schedule, cache + # them and flag the engine to reuse the schedule. + if len(seq_group_metadata_list) > 0: + self._skip_scheduling_next_step = True + # Reuse multi-step caching logic + self._cache_scheduler_outputs_for_multi_step( + virtual_engine=virtual_engine, + scheduler_outputs=scheduler_outputs, + seq_group_metadata_list=seq_group_metadata_list, + allow_async_output_proc=allow_async_output_proc) + + def _has_remaining_steps( + self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> bool: + if (not self.scheduler_config.is_multi_step + or not seq_group_metadata_list): + return False + + # TODO(will) this is a sanity check for nowto make sure that all the + # seqs are on the same steps. Eventually we will want to do some sort of + # dynamic scheduling when doing multi-step decoding. + ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps + if any([ + seq_group.state.remaining_steps != ref_remaining_steps + for seq_group in seq_group_metadata_list[1:] + ]): + raise AssertionError("All running sequence groups should " + "have the same remaining steps.") + + return ref_remaining_steps > 0 + + def _cache_scheduler_outputs_for_multi_step( + self, virtual_engine: int, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + co = self.cached_scheduler_outputs[virtual_engine] + + co.seq_group_metadata_list = seq_group_metadata_list + co.scheduler_outputs = scheduler_outputs + co.allow_async_output_proc = allow_async_output_proc + co.last_output = None + + def _update_cached_scheduler_output( + self, virtual_engine: int, + output: List[Optional[SamplerOutput]]) -> None: + if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 + and output[0] is not None): + last_output = output[-1] + assert last_output is not None + assert last_output.sampled_token_ids_cpu is not None + assert last_output.sampled_token_ids is None + assert last_output.sampled_token_probs is None + self.cached_scheduler_outputs[ + virtual_engine].last_output = last_output + + def _get_last_sampled_token_ids( + self, virtual_engine: int) -> Optional[torch.Tensor]: + cached_last_output = self.cached_scheduler_outputs[ + virtual_engine].last_output + if (self.scheduler_config.is_multi_step + and self.parallel_config.pipeline_parallel_size > 1 + and cached_last_output is not None + and cached_last_output.sampled_token_ids_cpu is not None): + return cached_last_output.sampled_token_ids_cpu + return None + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") + if logger_name in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} already exists.") + self.stat_loggers[logger_name] = logger + + def remove_logger(self, logger_name: str) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") + if logger_name not in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} does not exist.") + del self.stat_loggers[logger_name] + + def do_log_stats(self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> None: + """Forced log when no requests active.""" + if self.log_stats: + stats = self._get_stats(scheduler_outputs, model_output, + finished_before, skip) + for logger in self.stat_loggers.values(): + logger.log(stats) + + def _get_stats(self, + scheduler_outputs: Optional[SchedulerOutputs], + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> Stats: + """Get Stats to be Logged to Prometheus. + + Args: + scheduler_outputs: Optional, used to populate metrics related to + the scheduled batch, + model_output: Optional, used to emit speculative decoding metrics + which are created by the workers. + finished_before: Optional, indices of sequences that were finished + before. These sequences will be ignored. + skip: Optional, indices of sequences that were preempted. These + sequences will be ignored. + """ + now = time.time() + + # System State + # Scheduler State + num_running_sys = sum( + len(scheduler.running) for scheduler in self.scheduler) + num_swapped_sys = sum( + len(scheduler.swapped) for scheduler in self.scheduler) + num_waiting_sys = sum( + len(scheduler.waiting) for scheduler in self.scheduler) + + # KV Cache Usage in % + num_total_gpu = self.cache_config.num_gpu_blocks + gpu_cache_usage_sys = 0. + if num_total_gpu: # Guard against both None and 0 + num_free_gpu = sum( + scheduler.block_manager.get_num_free_gpu_blocks() + for scheduler in self.scheduler) + gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) + + num_total_cpu = self.cache_config.num_cpu_blocks + cpu_cache_usage_sys = 0. + if num_total_cpu: # Guard against both None and 0 + num_free_cpu = sum( + scheduler.block_manager.get_num_free_cpu_blocks() + for scheduler in self.scheduler) + cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + + # Prefix Cache Hit Rate. Note that we always use + # the cache hit rate of the first virtual engine. + cpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.CPU) + gpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.GPU) + + # Exchange the uasge and cache hit stats between gpu and cpu when + # running on cpu because the cpu_worker.py intentionally reports the + # number of cpu blocks as gpu blocks in favor of cache management. + if self.device_config.device_type == "cpu": + num_total_gpu, num_total_cpu = num_total_cpu, num_total_gpu + gpu_cache_usage_sys, cpu_cache_usage_sys = ( + cpu_cache_usage_sys, + gpu_cache_usage_sys, + ) + gpu_prefix_cache_hit_rate, cpu_prefix_cache_hit_rate = ( + cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate, + ) + + # Iteration stats + num_prompt_tokens_iter = 0 + num_generation_tokens_iter = 0 + num_tokens_iter = 0 + time_to_first_tokens_iter: List[float] = [] + time_per_output_tokens_iter: List[float] = [] + num_preemption_iter = (0 if scheduler_outputs is None else + scheduler_outputs.preempted) + + # Request stats + # Latency + time_e2e_requests: List[float] = [] + time_queue_requests: List[float] = [] + time_inference_requests: List[float] = [] + time_prefill_requests: List[float] = [] + time_decode_requests: List[float] = [] + # Metadata + num_prompt_tokens_requests: List[int] = [] + num_generation_tokens_requests: List[int] = [] + n_requests: List[int] = [] + max_num_generation_tokens_requests: List[int] = [] + max_tokens_requests: List[int] = [] + finished_reason_requests: List[str] = [] + + # LoRA requests + running_lora_adapters = dict( + collectionsCounter([ + running_request.lora_request.lora_name + for scheduler in self.scheduler + for running_request in scheduler.running + if running_request.lora_request + ])) + waiting_lora_adapters = dict( + collectionsCounter([ + waiting_request.lora_request.lora_name + for scheduler in self.scheduler + for waiting_request in scheduler.waiting + if waiting_request.lora_request + ])) + max_lora_stat = "0" + if self.lora_config: + max_lora_stat = str(self.lora_config.max_loras) + + # NOTE: This loop assumes prefill seq_groups are before + # decode seq_groups in scheduled_seq_groups. + if scheduler_outputs is not None: + # For async postprocessor, already finished sequences need to be + # not counted (to avoid double counting) + actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore + + num_generation_tokens_from_prefill_groups = 0 + # NOTE: if scheduler_outputs.num_prefill_groups > 0 and + # the len of scheduler_outputs.scheduled_seq_groups is != + # scheduler_outputs.num_prefill_groups, this means that + # chunked prefills have been detected. + + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + # Skip double logging when using async output proc + if finished_before and idx in finished_before: + actual_num_batched_tokens -= 1 + continue + + # Currently, skip == preempted sequences, so we need to skip + # their log stats + if skip and idx in skip: + continue + + group_was_prefill = idx < scheduler_outputs.num_prefill_groups + seq_group = scheduled_seq_group.seq_group + + # NOTE: a seq_group that completed all of its prefill tokens + # in the last iteration will have seq_group.is_prefill() = False + # with group_was_prefill = True + if group_was_prefill: + # Number of prompt tokens. + num_prompt_tokens_iter += ( + scheduled_seq_group.token_chunk_size) + + # If the seq_group just finished the prefill state + # get TTFT. + if not seq_group.is_prefill(): + latency = seq_group.get_last_token_latency() + time_to_first_tokens_iter.append(latency) + + # One generation token per finished prefill. + num_generation_tokens_from_prefill_groups += ( + seq_group.num_seqs()) + else: + # TPOTs. + latency = seq_group.get_last_token_latency() + time_per_output_tokens_iter.append(latency) + if seq_group.state.current_step == 0: + # For async_output_proc, the do_log_stats() + # is called following init_multi_step(), which + # sets the current_step to zero. + actual_num_batched_tokens +=\ + seq_group.state.num_steps - 1 + else: + actual_num_batched_tokens +=\ + seq_group.state.current_step - 1 + + # Because of chunked prefill, we can have a single sequence + # group that does multiple prompt_runs. To prevent logging + # the same metadata more than once per request, we standardize + # on logging request level information for finished requests, + # which can only happen once. + if seq_group.is_finished(): + # Latency timings + time_e2e_requests.append(now - + seq_group.metrics.arrival_time) + if (seq_group.metrics.first_scheduled_time is not None and + seq_group.metrics.first_token_time is not None): + time_queue_requests.append( + seq_group.metrics.first_scheduled_time - + seq_group.metrics.arrival_time) + time_prefill_requests.append( + seq_group.metrics.first_token_time - + seq_group.metrics.first_scheduled_time) + time_decode_requests.append( + now - seq_group.metrics.first_token_time) + time_inference_requests.append( + now - seq_group.metrics.first_scheduled_time) + # Metadata + num_prompt_tokens_requests.append( + len(seq_group.prompt_token_ids)) + num_generation_tokens_requests.extend([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ]) + max_num_generation_tokens_requests.append( + max(seq.get_output_len() + for seq in seq_group.get_seqs())) + if seq_group.sampling_params is not None: + n_requests.append(seq_group.sampling_params.n) + max_tokens_requests.append( + seq_group.sampling_params.max_tokens) + finished_reason_requests.extend([ + SequenceStatus.get_finished_reason(seq.status) + for seq in seq_group.get_finished_seqs() + ]) + + # Number of generation tokens. + # num_batched_tokens equals the number of prompt_tokens plus the + # number of decode_tokens in a single iteration. So, + # num_generation_tokens = num_batched_tokens - num_prompt_tokens + # + num_generation_tokens_from_prefill_groups (since we generate + # one token on prefills on iters where the prefill finishes). + num_generation_tokens_iter = ( + actual_num_batched_tokens - num_prompt_tokens_iter + + num_generation_tokens_from_prefill_groups) + num_tokens_iter = (num_generation_tokens_iter + + num_prompt_tokens_iter) + # Spec decode, if enabled, emits specialized metrics from the worker in + # sampler output. + if model_output and isinstance(model_output[0], SamplerOutput) and ( + model_output[0].spec_decode_worker_metrics is not None): + spec_decode_metrics = model_output[0].spec_decode_worker_metrics + else: + spec_decode_metrics = None + + return Stats( + now=now, + # System stats + # Scheduler State + num_running_sys=num_running_sys, + num_swapped_sys=num_swapped_sys, + num_waiting_sys=num_waiting_sys, + # KV Cache Usage in % + gpu_cache_usage_sys=gpu_cache_usage_sys, + cpu_cache_usage_sys=cpu_cache_usage_sys, + # Prefix Cache Hit Rate + cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, + + # Iteration stats + num_prompt_tokens_iter=num_prompt_tokens_iter, + num_generation_tokens_iter=num_generation_tokens_iter, + num_tokens_iter=num_tokens_iter, + time_to_first_tokens_iter=time_to_first_tokens_iter, + time_per_output_tokens_iter=time_per_output_tokens_iter, + spec_decode_metrics=spec_decode_metrics, + num_preemption_iter=num_preemption_iter, + + # Request stats + # Latency + time_e2e_requests=time_e2e_requests, + time_queue_requests=time_queue_requests, + time_inference_requests=time_inference_requests, + time_prefill_requests=time_prefill_requests, + time_decode_requests=time_decode_requests, + # Metadata + num_prompt_tokens_requests=num_prompt_tokens_requests, + num_generation_tokens_requests=num_generation_tokens_requests, + max_num_generation_tokens_requests= + max_num_generation_tokens_requests, + n_requests=n_requests, + max_tokens_requests=max_tokens_requests, + finished_reason_requests=finished_reason_requests, + max_lora=str(max_lora_stat), + waiting_lora_adapters=list(waiting_lora_adapters.keys()), + running_lora_adapters=list(running_lora_adapters.keys())) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_executor.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_executor.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_executor.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_executor.pin_lora(lora_id) + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_executor.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_executor.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> List[int]: + return self.model_executor.list_prompt_adapters() + + def start_profile(self) -> None: + self.model_executor.start_profile() + + def stop_profile(self) -> None: + self.model_executor.stop_profile() + + def sleep(self, level: int = 1) -> None: + assert self.vllm_config.model_config.enable_sleep_mode, ( + "Sleep mode is not enabled in the model config") + self.model_executor.sleep(level=level) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + assert self.vllm_config.model_config.enable_sleep_mode, ( + "Sleep mode is not enabled in the model config") + self.model_executor.wake_up(tags) + + def is_sleeping(self) -> bool: + return self.model_executor.is_sleeping + + def check_health(self) -> None: + self.model_executor.check_health() + + def is_tracing_enabled(self) -> bool: + return self.tracer is not None + + def do_tracing(self, + scheduler_outputs: SchedulerOutputs, + finished_before: Optional[List[int]] = None) -> None: + if self.tracer is None: + return + + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + # Skip double tracing when using async output proc + if finished_before and idx in finished_before: + continue + + seq_group = scheduled_seq_group.seq_group + if seq_group.is_finished(): + self.create_trace_span(seq_group) + + def create_trace_span(self, seq_group: SequenceGroup) -> None: + if self.tracer is None or seq_group.sampling_params is None: + return + arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) + + trace_context = extract_trace_context(seq_group.trace_headers) + + with self.tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=trace_context, + start_time=arrival_time_nano_seconds) as seq_span: + metrics = seq_group.metrics + ttft = metrics.first_token_time - metrics.arrival_time + e2e_time = metrics.finished_time - metrics.arrival_time + seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL, + self.model_config.model) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, + seq_group.request_id) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, + seq_group.sampling_params.temperature) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, + seq_group.sampling_params.top_p) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, + seq_group.sampling_params.max_tokens) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, + seq_group.sampling_params.n) + seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES, + seq_group.num_seqs()) + seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, + len(seq_group.prompt_token_ids)) + seq_span.set_attribute( + SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, + sum([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ])) + seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, + metrics.time_in_queue) + seq_span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft) + seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time) + if metrics.scheduler_time is not None: + seq_span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER, + metrics.scheduler_time) + if metrics.model_forward_time is not None: + seq_span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD, + metrics.model_forward_time / 1000.0) + if metrics.model_execute_time is not None: + seq_span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE, + metrics.model_execute_time) + + def _validate_model_inputs(self, inputs: ProcessorInputs, + lora_request: Optional[LoRARequest]): + encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) + + if encoder_inputs is not None: + self._validate_model_input(encoder_inputs, + lora_request, + prompt_type="encoder") + + self._validate_model_input(decoder_inputs, + lora_request, + prompt_type="decoder") + + def _validate_model_input( + self, + prompt_inputs: SingletonInputs, + lora_request: Optional[LoRARequest], + *, + prompt_type: Literal["encoder", "decoder"], + ): + model_config = self.model_config + tokenizer = (None if self.tokenizer is None else + self.tokenizer.get_lora_tokenizer(lora_request)) + + prompt_ids = prompt_inputs.get("prompt_token_ids", []) + if not prompt_ids: + if prompt_type == "encoder" and model_config.is_multimodal_model: + pass # Mllama may have empty encoder inputs for text-only data + elif prompt_inputs["type"] == "embeds": + pass + else: + raise ValueError(f"The {prompt_type} prompt cannot be empty") + + if tokenizer is not None: + max_input_id = max(prompt_ids, default=0) + if max_input_id > tokenizer.max_token_id: + raise ValueError( + f"Token id {max_input_id} is out of vocabulary") + + max_prompt_len = self.model_config.max_model_len + if len(prompt_ids) > max_prompt_len: + if prompt_type == "encoder" and model_config.is_multimodal_model: + mm_registry = self.input_preprocessor.mm_registry + mm_processor = mm_registry.create_processor( + model_config, + tokenizer=tokenizer or object(), # Dummy if no tokenizer + ) + assert isinstance(mm_processor, EncDecMultiModalProcessor) + + if mm_processor.pad_dummy_encoder_prompt: + return # Skip encoder length check for Whisper + + if model_config.is_multimodal_model: + suggestion = ( + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well.") + else: + suggestion = ( + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens.") + + raise ValueError( + f"The {prompt_type} prompt (length {len(prompt_ids)}) is " + f"longer than the maximum model length of {max_prompt_len}. " + f"{suggestion}") + + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def _build_logits_processors( + self, sampling_params: SamplingParams, + lora_request: Optional[LoRARequest]) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Returns the modified sampling params.""" + + logits_processors = [] + + if sampling_params.guided_decoding is not None: + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding + + logger.debug( + "Building guided decoding logits processor in " + "LLMEngine. Params: %s", guided_decoding) + + tokenizer = self.get_tokenizer(lora_request=lora_request) + guided_decoding.backend = guided_decoding.backend or \ + self.decoding_config.backend + + if self.decoding_config.reasoning_backend: + logger.debug("Building with reasoning backend %s", + self.decoding_config.reasoning_backend) + + processor = get_local_guided_decoding_logits_processor( + guided_params=guided_decoding, + tokenizer=tokenizer, + model_config=self.model_config, + reasoning_backend=self.decoding_config.reasoning_backend, + ) + if processor: + logits_processors.append(processor) + + # Unset so this doesn't get passed down to the model + sampling_params.guided_decoding = None + + if (sampling_params.logit_bias or sampling_params.allowed_token_ids): + tokenizer = self.get_tokenizer(lora_request=lora_request) + + processors = get_openai_logits_processors( + logit_bias=sampling_params.logit_bias, + allowed_token_ids=sampling_params.allowed_token_ids, + tokenizer=tokenizer) + logits_processors.extend(processors) + + # Unset so these don't get passed down to the model + sampling_params.logit_bias = None + sampling_params.allowed_token_ids = None + + if len(sampling_params.bad_words) > 0: + tokenizer = self.get_tokenizer(lora_request) + processors = get_bad_words_logits_processors( + bad_words=sampling_params.bad_words, tokenizer=tokenizer) + logits_processors.extend(processors) + + if logits_processors: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = logits_processors + else: + sampling_params.logits_processors.extend(logits_processors) + + return sampling_params + + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return self.model_executor.collective_rpc(method, timeout, args, + kwargs) + + +if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + LLMEngine = V1LLMEngine # type: ignore diff --git a/engine/metrics.py b/engine/metrics.py new file mode 100644 index 0000000..8d51f04 --- /dev/null +++ b/engine/metrics.py @@ -0,0 +1,629 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time +from typing import TYPE_CHECKING +from typing import Counter as CollectionsCounter +from typing import Dict, List, Optional, Type, Union, cast + +import numpy as np +import prometheus_client + +from vllm.config import SupportsMetricsInfo, VllmConfig +from vllm.engine.metrics_types import StatLoggerBase, Stats +from vllm.executor.ray_utils import ray +from vllm.logger import init_logger + +if ray is not None: + from ray.util import metrics as ray_metrics +else: + ray_metrics = None + +if TYPE_CHECKING: + from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + +logger = init_logger(__name__) + +prometheus_client.disable_created_metrics() + +# The begin-* and end* here are used by the documentation generator +# to extract the metrics definitions. + + +# --8<-- [start:metrics-definitions] +class Metrics: + """ + vLLM uses a multiprocessing-based frontend for the OpenAI server. + This means that we need to run prometheus_client in multiprocessing mode + See https://prometheus.github.io/client_python/multiprocess/ for more + details on limitations. + """ + + labelname_finish_reason = "finished_reason" + labelname_waiting_lora_adapters = "waiting_lora_adapters" + labelname_running_lora_adapters = "running_lora_adapters" + labelname_max_lora = "max_lora" + _gauge_cls = prometheus_client.Gauge + _counter_cls = prometheus_client.Counter + _histogram_cls = prometheus_client.Histogram + + def __init__(self, labelnames: List[str], vllm_config: VllmConfig): + # Unregister any existing vLLM collectors (for CI/CD) + self._unregister_vllm_metrics() + + max_model_len = vllm_config.model_config.max_model_len + + # Use this flag to hide metrics that were deprecated in + # a previous release and which will be removed future + self.show_hidden_metrics = \ + vllm_config.observability_config.show_hidden_metrics + + # System stats + # Scheduler State + self.gauge_scheduler_running = self._gauge_cls( + name="vllm:num_requests_running", + documentation="Number of requests currently running on GPU.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_scheduler_waiting = self._gauge_cls( + name="vllm:num_requests_waiting", + documentation="Number of requests waiting to be processed.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_lora_info = self._gauge_cls( + name="vllm:lora_requests_info", + documentation="Running stats on lora requests.", + labelnames=[ + self.labelname_running_lora_adapters, + self.labelname_max_lora, + self.labelname_waiting_lora_adapters, + ], + multiprocess_mode="livemostrecent", + ) + + # KV Cache Usage in % + self.gauge_gpu_cache_usage = self._gauge_cls( + name="vllm:gpu_cache_usage_perc", + documentation="GPU KV-cache usage. 1 means 100 percent usage.", + labelnames=labelnames, + multiprocess_mode="sum") + + # Iteration stats + self.counter_num_preemption = self._counter_cls( + name="vllm:num_preemptions_total", + documentation="Cumulative number of preemption from the engine.", + labelnames=labelnames) + self.counter_prompt_tokens = self._counter_cls( + name="vllm:prompt_tokens_total", + documentation="Number of prefill tokens processed.", + labelnames=labelnames) + self.counter_generation_tokens = self._counter_cls( + name="vllm:generation_tokens_total", + documentation="Number of generation tokens processed.", + labelnames=labelnames) + self.histogram_iteration_tokens = self._histogram_cls( + name="vllm:iteration_tokens_total", + documentation="Histogram of number of tokens per engine_step.", + labelnames=labelnames, + buckets=[ + 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 + ]) + self.histogram_time_to_first_token = self._histogram_cls( + name="vllm:time_to_first_token_seconds", + documentation="Histogram of time to first token in seconds.", + labelnames=labelnames, + buckets=[ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, + 2560.0 + ]) + self.histogram_time_per_output_token = self._histogram_cls( + name="vllm:time_per_output_token_seconds", + documentation="Histogram of time per output token in seconds.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + ]) + + # Request stats + # Latency + request_latency_buckets = [ + 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, + 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 + ] + self.histogram_e2e_time_request = self._histogram_cls( + name="vllm:e2e_request_latency_seconds", + documentation="Histogram of end to end request latency in seconds.", + labelnames=labelnames, + buckets=request_latency_buckets) + self.histogram_queue_time_request = self._histogram_cls( + name="vllm:request_queue_time_seconds", + documentation= + "Histogram of time spent in WAITING phase for request.", + labelnames=labelnames, + buckets=request_latency_buckets) + self.histogram_inference_time_request = self._histogram_cls( + name="vllm:request_inference_time_seconds", + documentation= + "Histogram of time spent in RUNNING phase for request.", + labelnames=labelnames, + buckets=request_latency_buckets) + self.histogram_prefill_time_request = self._histogram_cls( + name="vllm:request_prefill_time_seconds", + documentation= + "Histogram of time spent in PREFILL phase for request.", + labelnames=labelnames, + buckets=request_latency_buckets) + self.histogram_decode_time_request = self._histogram_cls( + name="vllm:request_decode_time_seconds", + documentation= + "Histogram of time spent in DECODE phase for request.", + labelnames=labelnames, + buckets=request_latency_buckets) + + # Metadata + self.histogram_num_prompt_tokens_request = self._histogram_cls( + name="vllm:request_prompt_tokens", + documentation="Number of prefill tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_num_generation_tokens_request = \ + self._histogram_cls( + name="vllm:request_generation_tokens", + documentation="Number of generation tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_max_num_generation_tokens_request = self._histogram_cls( + name="vllm:request_max_num_generation_tokens", + documentation= + "Histogram of maximum number of requested generation tokens.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len)) + self.histogram_n_request = self._histogram_cls( + name="vllm:request_params_n", + documentation="Histogram of the n request parameter.", + labelnames=labelnames, + buckets=[1, 2, 5, 10, 20], + ) + self.histogram_max_tokens_request = self._histogram_cls( + name="vllm:request_params_max_tokens", + documentation="Histogram of the max_tokens request parameter.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.counter_request_success = self._counter_cls( + name="vllm:request_success_total", + documentation="Count of successfully processed requests.", + labelnames=labelnames + [Metrics.labelname_finish_reason]) + + # Speculative decoding stats + self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls( + name="vllm:spec_decode_draft_acceptance_rate", + documentation="Speulative token acceptance rate.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_spec_decode_efficiency = self._gauge_cls( + name="vllm:spec_decode_efficiency", + documentation="Speculative decoding system efficiency.", + labelnames=labelnames, + multiprocess_mode="sum") + self.counter_spec_decode_num_accepted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames)) + self.counter_spec_decode_num_draft_tokens = self._counter_cls( + name="vllm:spec_decode_num_draft_tokens_total", + documentation="Number of draft tokens.", + labelnames=labelnames) + self.counter_spec_decode_num_emitted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_emitted_tokens_total", + documentation="Number of emitted tokens.", + labelnames=labelnames)) + + +# --8<-- [end:metrics-definitions] + + def _unregister_vllm_metrics(self) -> None: + for collector in list(prometheus_client.REGISTRY._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + prometheus_client.REGISTRY.unregister(collector) + + +class _RayGaugeWrapper: + """Wraps around ray.util.metrics.Gauge to provide same API as + prometheus_client.Gauge""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None, + multiprocess_mode: str = ""): + del multiprocess_mode + labelnames_tuple = tuple(labelnames) if labelnames else None + self._gauge = ray_metrics.Gauge(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._gauge.set_default_tags(labels) + return self + + def set(self, value: Union[int, float]): + return self._gauge.set(value) + + def set_to_current_time(self): + # ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html + return self._gauge.set(time.time()) + + +class _RayCounterWrapper: + """Wraps around ray.util.metrics.Counter to provide same API as + prometheus_client.Counter""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._counter = ray_metrics.Counter(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._counter.set_default_tags(labels) + return self + + def inc(self, value: Union[int, float] = 1.0): + if value == 0: + return + return self._counter.inc(value) + + +class _RayHistogramWrapper: + """Wraps around ray.util.metrics.Histogram to provide same API as + prometheus_client.Histogram""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None, + buckets: Optional[List[float]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + boundaries = buckets if buckets else [] + self._histogram = ray_metrics.Histogram(name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=boundaries) + + def labels(self, **labels): + self._histogram.set_default_tags(labels) + return self + + def observe(self, value: Union[int, float]): + return self._histogram.observe(value) + + +class RayMetrics(Metrics): + """ + RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. + Provides the same metrics as Metrics but uses Ray's util.metrics library. + """ + _gauge_cls: Type[prometheus_client.Gauge] = cast( + Type[prometheus_client.Gauge], _RayGaugeWrapper) + _counter_cls: Type[prometheus_client.Counter] = cast( + Type[prometheus_client.Counter], _RayCounterWrapper) + _histogram_cls: Type[prometheus_client.Histogram] = cast( + Type[prometheus_client.Histogram], _RayHistogramWrapper) + + def __init__(self, labelnames: List[str], vllm_config: VllmConfig): + if ray_metrics is None: + raise ImportError("RayMetrics requires Ray to be installed.") + super().__init__(labelnames, vllm_config) + + def _unregister_vllm_metrics(self) -> None: + # No-op on purpose + pass + + +def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: + """ + Builds a list of buckets with increasing powers of 10 multiplied by + mantissa values until the value exceeds the specified maximum. + + """ + exponent = 0 + buckets: List[int] = [] + while True: + for m in mantissa_lst: + value = m * 10**exponent + if value <= max_value: + buckets.append(value) + else: + return buckets + exponent += 1 + + +def build_1_2_5_buckets(max_value: int) -> List[int]: + """ + Example: + >>> build_1_2_5_buckets(100) + [1, 2, 5, 10, 20, 50, 100] + """ + return build_buckets([1, 2, 5], max_value) + + +def build_1_2_3_5_8_buckets(max_value: int) -> List[int]: + """ + Example: + >>> build_1_2_3_5_8_buckets(100) + [1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100] + """ + return build_buckets([1, 2, 3, 5, 8], max_value) + + +def local_interval_elapsed(now: float, last_log: float, + local_interval: float) -> bool: + elapsed_time = now - last_log + return elapsed_time > local_interval + + +def get_throughput(tracked_stats: List[int], now: float, + last_log: float) -> float: + return float(np.sum(tracked_stats) / (now - last_log)) + + +class LoggingStatLogger(StatLoggerBase): + """LoggingStatLogger is used in LLMEngine to log to Stdout.""" + + def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: + super().__init__(local_interval, vllm_config) + self.last_prompt_throughput: Optional[float] = None + self.last_generation_throughput: Optional[float] = None + + def log(self, stats: Stats) -> None: + """Called by LLMEngine. + Logs to Stdout every self.local_interval seconds.""" + + # Save tracked stats for token counters. + self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) + self.num_generation_tokens.append(stats.num_generation_tokens_iter) + + # Update spec decode metrics + self.maybe_update_spec_decode_metrics(stats) + + # Log locally every local_interval seconds. + if local_interval_elapsed(stats.now, self.last_local_log, + self.local_interval): + # Compute summary metrics for tracked stats (and log them + # to promethus if applicable). + prompt_throughput = get_throughput(self.num_prompt_tokens, + now=stats.now, + last_log=self.last_local_log) + generation_throughput = get_throughput( + self.num_generation_tokens, + now=stats.now, + last_log=self.last_local_log) + + log_fn = logger.info + if not any((prompt_throughput, generation_throughput, + self.last_prompt_throughput, + self.last_generation_throughput)): + # Avoid log noise on an idle production system + log_fn = logger.debug + + log_fn( + "Avg prompt throughput: %.1f tokens/s, " + "Avg generation throughput: %.1f tokens/s, " + "Running: %d reqs, Swapped: %d reqs, " + "Pending: %d reqs, GPU KV cache usage: %.1f%%, " + "CPU KV cache usage: %.1f%%.", + prompt_throughput, + generation_throughput, + stats.num_running_sys, + stats.num_swapped_sys, + stats.num_waiting_sys, + stats.gpu_cache_usage_sys * 100, + stats.cpu_cache_usage_sys * 100, + ) + if (stats.cpu_prefix_cache_hit_rate >= 0 + or stats.gpu_prefix_cache_hit_rate >= 0): + log_fn( + "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", + stats.gpu_prefix_cache_hit_rate * 100, + stats.cpu_prefix_cache_hit_rate * 100, + ) + if self.spec_decode_metrics is not None: + log_fn( + self._format_spec_decode_metrics_str( + self.spec_decode_metrics)) + + self._reset(stats, prompt_throughput, generation_throughput) + + def _reset(self, stats, prompt_throughput, generation_throughput) -> None: + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now + self.spec_decode_metrics = None + self.last_prompt_throughput = prompt_throughput + self.last_generation_throughput = generation_throughput + + def _format_spec_decode_metrics_str( + self, metrics: "SpecDecodeWorkerMetrics") -> str: + + return ("Speculative metrics: " + f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " + f"System efficiency: {metrics.system_efficiency:.3f}, " + f"Number of speculative tokens: {metrics.num_spec_tokens}, " + f"Number of accepted tokens: {metrics.accepted_tokens}, " + f"Number of draft tokens: {metrics.draft_tokens}, " + f"Number of emitted tokens: {metrics.emitted_tokens}.") + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + + +class PrometheusStatLogger(StatLoggerBase): + """PrometheusStatLogger is used LLMEngine to log to Promethus.""" + _metrics_cls = Metrics + _gauge_cls = prometheus_client.Gauge + + def __init__(self, local_interval: float, labels: Dict[str, str], + vllm_config: VllmConfig) -> None: + super().__init__(local_interval, vllm_config) + # Prometheus metrics + self.labels = labels + self.metrics = self._metrics_cls(labelnames=list(labels.keys()), + vllm_config=vllm_config) + + def _log_gauge(self, gauge, data: Union[int, float]) -> None: + # Convenience function for logging to gauge. + gauge.labels(**self.labels).set(data) + + def _log_counter(self, counter, data: Union[int, float]) -> None: + # Convenience function for logging to counter. + # Prevent ValueError from negative increment + if data < 0: + logger.warning("Skipping negative increment of %g to %s", data, + counter) + return + counter.labels(**self.labels).inc(data) + + def _log_counter_labels(self, counter, data: CollectionsCounter, + label_key: str) -> None: + # Convenience function for collection counter of labels. + for label, count in data.items(): + counter.labels(**{**self.labels, label_key: label}).inc(count) + + def _log_histogram(self, histogram, data: Union[List[int], + List[float]]) -> None: + # Convenience function for logging list to histogram. + for datum in data: + histogram.labels(**self.labels).observe(datum) + + def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None: + gauge.labels(**data).set_to_current_time() + + def _log_prometheus(self, stats: Stats) -> None: + # System state data + self._log_gauge(self.metrics.gauge_scheduler_running, + stats.num_running_sys) + self._log_gauge(self.metrics.gauge_scheduler_waiting, + stats.num_waiting_sys) + self._log_gauge(self.metrics.gauge_gpu_cache_usage, + stats.gpu_cache_usage_sys) + # Including max-lora in metric, in future this property of lora + # config maybe extended to be dynamic. + lora_info = { + self.metrics.labelname_running_lora_adapters: + ",".join(stats.running_lora_adapters), + self.metrics.labelname_waiting_lora_adapters: + ",".join(stats.waiting_lora_adapters), + self.metrics.labelname_max_lora: + stats.max_lora, + } + self._log_gauge_string(self.metrics.gauge_lora_info, lora_info) + # Iteration level data + self._log_counter(self.metrics.counter_num_preemption, + stats.num_preemption_iter) + self._log_counter(self.metrics.counter_prompt_tokens, + stats.num_prompt_tokens_iter) + self._log_counter(self.metrics.counter_generation_tokens, + stats.num_generation_tokens_iter) + self._log_histogram(self.metrics.histogram_iteration_tokens, + [stats.num_tokens_iter]) + self._log_histogram(self.metrics.histogram_time_to_first_token, + stats.time_to_first_tokens_iter) + self._log_histogram(self.metrics.histogram_time_per_output_token, + stats.time_per_output_tokens_iter) + + # Request level data + # Latency + self._log_histogram(self.metrics.histogram_e2e_time_request, + stats.time_e2e_requests) + self._log_histogram(self.metrics.histogram_queue_time_request, + stats.time_queue_requests) + self._log_histogram(self.metrics.histogram_inference_time_request, + stats.time_inference_requests) + self._log_histogram(self.metrics.histogram_prefill_time_request, + stats.time_prefill_requests) + self._log_histogram(self.metrics.histogram_decode_time_request, + stats.time_decode_requests) + # Metadata + finished_reason_counter = CollectionsCounter( + stats.finished_reason_requests) + self._log_counter_labels(self.metrics.counter_request_success, + finished_reason_counter, + Metrics.labelname_finish_reason) + self._log_histogram(self.metrics.histogram_num_prompt_tokens_request, + stats.num_prompt_tokens_requests) + self._log_histogram( + self.metrics.histogram_num_generation_tokens_request, + stats.num_generation_tokens_requests) + self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) + self._log_histogram( + self.metrics.histogram_max_num_generation_tokens_request, + stats.max_num_generation_tokens_requests) + self._log_histogram(self.metrics.histogram_max_tokens_request, + stats.max_tokens_requests) + + def log(self, stats: Stats): + """Logs to prometheus and tracked stats every iteration.""" + # Log to prometheus. + self._log_prometheus(stats) + + # Save tracked stats for token counters. + self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) + self.num_generation_tokens.append(stats.num_generation_tokens_iter) + + # Update spec decode metrics + self.maybe_update_spec_decode_metrics(stats) + + # Log locally every local_interval seconds. + if local_interval_elapsed(stats.now, self.last_local_log, + self.local_interval): + if self.spec_decode_metrics is not None: + self._log_gauge( + self.metrics.gauge_spec_decode_draft_acceptance_rate, + self.spec_decode_metrics.draft_acceptance_rate) + self._log_gauge(self.metrics.gauge_spec_decode_efficiency, + self.spec_decode_metrics.system_efficiency) + self._log_counter( + self.metrics.counter_spec_decode_num_accepted_tokens, + self.spec_decode_metrics.accepted_tokens) + self._log_counter( + self.metrics.counter_spec_decode_num_draft_tokens, + self.spec_decode_metrics.draft_tokens) + self._log_counter( + self.metrics.counter_spec_decode_num_emitted_tokens, + self.spec_decode_metrics.emitted_tokens) + + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now + self.spec_decode_metrics = None + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + # Info type metrics are syntactic sugar for a gauge permanently set to 1 + # Since prometheus multiprocessing mode does not support Info, emulate + # info here with a gauge. + if type == "cache_config": + metrics_info = obj.metrics_info() + info_gauge = self._gauge_cls( + name="vllm:cache_config_info", + documentation="Information of the LLMEngine CacheConfig", + labelnames=metrics_info.keys(), + multiprocess_mode="mostrecent") + info_gauge.labels(**metrics_info).set(1) + + +class RayPrometheusStatLogger(PrometheusStatLogger): + """RayPrometheusStatLogger uses Ray metrics instead.""" + _metrics_cls = RayMetrics + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + return None diff --git a/engine/metrics_types.py b/engine/metrics_types.py new file mode 100644 index 0000000..9375dc4 --- /dev/null +++ b/engine/metrics_types.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +These types are defined in this file to avoid importing vllm.engine.metrics +and therefore importing prometheus_client. + +This is required due to usage of Prometheus multiprocess mode to enable +metrics after splitting out the uvicorn process from the engine process. + +Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR +before prometheus_client is imported. Typically, this is done by setting +the env variable before launch, but since we are a library, we need to +do this in Python code and lazily import prometheus_client. +""" + +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + +from vllm.config import SupportsMetricsInfo, VllmConfig +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + + +@dataclass +class Stats: + """Created by LLMEngine for use by StatLogger.""" + now: float + + # System stats (should have _sys suffix) + # Scheduler State + num_running_sys: int + num_waiting_sys: int + num_swapped_sys: int + # KV Cache Usage in % + gpu_cache_usage_sys: float + cpu_cache_usage_sys: float + # Prefix caching block hit rate + cpu_prefix_cache_hit_rate: float + gpu_prefix_cache_hit_rate: float + + # Iteration stats (should have _iter suffix) + num_prompt_tokens_iter: int + num_generation_tokens_iter: int + num_tokens_iter: int + time_to_first_tokens_iter: List[float] + time_per_output_tokens_iter: List[float] + num_preemption_iter: int + + # Request stats (should have _requests suffix) + # Latency + time_e2e_requests: List[float] + time_queue_requests: List[float] + time_inference_requests: List[float] + time_prefill_requests: List[float] + time_decode_requests: List[float] + # Metadata + num_prompt_tokens_requests: List[int] + num_generation_tokens_requests: List[int] + n_requests: List[int] + max_num_generation_tokens_requests: List[int] + max_tokens_requests: List[int] + finished_reason_requests: List[str] + waiting_lora_adapters: List[str] + running_lora_adapters: List[str] + max_lora: str + + spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + +class StatLoggerBase(ABC): + """Base class for StatLogger.""" + + def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: + # Tracked stats over current local logging interval. + self.num_prompt_tokens: List[int] = [] + self.num_generation_tokens: List[int] = [] + self.last_local_log = time.time() + self.local_interval = local_interval + self.spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None + + @abstractmethod + def log(self, stats: Stats) -> None: + raise NotImplementedError + + @abstractmethod + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + + def maybe_update_spec_decode_metrics(self, stats: Stats): + """Save spec decode metrics (since they are unlikely + to be emitted at same time as log interval).""" + if stats.spec_decode_metrics is not None: + self.spec_decode_metrics = stats.spec_decode_metrics diff --git a/engine/multiprocessing/__init__.py b/engine/multiprocessing/__init__.py new file mode 100644 index 0000000..db968cd --- /dev/null +++ b/engine/multiprocessing/__init__.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Mapping, Optional, Union + +from vllm import PoolingParams +from vllm.inputs import PromptType +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.utils import Device + +VLLM_RPC_SUCCESS_STR = "SUCCESS" + +IPC_INPUT_EXT = "_input_socket" +IPC_OUTPUT_EXT = "_output_socket" +IPC_HEALTH_EXT = "_health_socket" +IPC_DATA_EXT = "_data_socket" + + +class MQEngineDeadError(RuntimeError): + pass + + +@dataclass +class RPCProcessRequest: + prompt: PromptType + params: Union[SamplingParams, PoolingParams] + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + priority: int = 0 + + def __init__( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + super().__init__() + + self.prompt = prompt + self.params = params + self.request_id = request_id + self.lora_request = lora_request + self.trace_headers = trace_headers + self.prompt_adapter_request = prompt_adapter_request + self.priority = priority + + +@dataclass +class RPCError: + request_id: Optional[str] + is_engine_errored: bool + exception: BaseException + + +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCStartupRequest(Enum): + IS_SERVER_READY = 1 + + +@dataclass +class RPCStartupResponse: + tracing_enabled: bool + + +class RPCUProfileRequest(Enum): + START_PROFILE = 1 + STOP_PROFILE = 2 + + +class RPCResetMultiModalCacheRequest(Enum): + RESET = 1 + + +@dataclass +class RPCResetPrefixCacheRequest: + device: Device + + +class RPCSleepRequest(Enum): + SLEEP_LEVEL_1 = 1 + SLEEP_LEVEL_2 = 2 + + +@dataclass +class RPCWakeUpRequest: + tags: Optional[list[str]] = None + + +@dataclass +class RPCIsSleepingRequest: + # Set the default value of request_id to a new UUID + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + +@dataclass +class RPCIsSleepingResponse: + request_id: str + is_sleeping: bool + + +@dataclass +class RPCLoadAdapterRequest: + lora_request: LoRARequest + # Set the default value of request_id to a new UUID + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + +@dataclass +class RPCAdapterLoadedResponse: + request_id: str + + +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, + RPCUProfileRequest, RPCLoadAdapterRequest, + RPCResetMultiModalCacheRequest, + RPCResetPrefixCacheRequest, RPCSleepRequest, + RPCWakeUpRequest, RPCIsSleepingRequest] + +REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, + RPCIsSleepingResponse, RPCError] + + +def ENGINE_DEAD_ERROR( + error: Optional[BaseException] = None) -> MQEngineDeadError: + if error is None: + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + "find the original error") + + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + f"find the original error: {repr(error)}.") diff --git a/engine/multiprocessing/client.py b/engine/multiprocessing/client.py new file mode 100644 index 0000000..9e018ec --- /dev/null +++ b/engine/multiprocessing/client.py @@ -0,0 +1,681 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import copy +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, + Optional, Union, cast) + +import cloudpickle +import psutil +import zmq +import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from vllm import PoolingParams +from vllm.config import DecodingConfig, ModelConfig, VllmConfig +from vllm.core.scheduler import SchedulerOutputs +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.async_llm_engine import ( + build_guided_decoding_logits_processor_async) +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, RPC_REQUEST_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCAdapterLoadedResponse, RPCError, + RPCIsSleepingRequest, + RPCIsSleepingResponse, + RPCLoadAdapterRequest, + RPCProcessRequest, + RPCResetMultiModalCacheRequest, + RPCResetPrefixCacheRequest, + RPCSleepRequest, RPCStartupRequest, + RPCStartupResponse, + RPCUProfileRequest, RPCWakeUpRequest) +from vllm.engine.protocol import EngineClient +# yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import Device + +logger = init_logger(__name__) + + +class MQClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + +class MQLLMEngineClient(EngineClient): + """A client wrapper for MQLLMEngine that conforms to the + EngineClient protocol. + + MQLLMEngine and MQLLMEngineClient are intended to run in separate + processes communicating via zeromq ipc sockets. + + The entrypoint to MQLLMEngineClient is through the generate() + method. On generate() MQLLMEngine does three things: + - Creates an asyncio output queue + - Sends a RPCGenerateRequest to the MQLLMEngine via zmq + - Pulls RequestOutputs from its queue and yields them + + MQLLMEngine runs two background loops: + - output_loop: the output loop pulls List[RequestOutput] + from the MQLLMEngine via zmq (each list is the output + of one engine_step in the LLMEngine). It then parses + the list and pushes individual request_outputs into + the corresponding output_queue such that they can be + consumed by the .generate() method. + - health_loop: the health loop queries the health socket + every N seconds, confirming the engine is healthy + """ + + def __init__(self, ipc_path: str, engine_config: VllmConfig, + engine_pid: int): + self.context = zmq.asyncio.Context() + self._errored_with: Optional[BaseException] = None + + # Get the configs. + self.vllm_config = engine_config + self.model_config = engine_config.model_config + self.decoding_config = engine_config.decoding_config + + # Create the tokenizer group. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=engine_config.scheduler_config, + lora_config=engine_config.lora_config) + self.input_preprocessor = InputPreprocessor(self.model_config, + self.tokenizer) + + # Send RPCGenerateRequest to the MQLLMEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") + + # Receive streams of RequestOutput from the MQLLMEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # IPC path for acking heartbeats. + self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) + self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + + # Loop to handle output of the LLMEngine periodically. + # Started after the MQLLMEngine is ready so that we can + # build the Client in an executor to enable clean shutdown. + self.output_loop: Optional[asyncio.Task] = None + + # Loop to check health of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. + self.health_loop: Optional[asyncio.Task] = None + self._engine_process = psutil.Process(engine_pid) + + @staticmethod + def is_unsupported_config(vllm_config: VllmConfig): + # Pipeline parallel not yet supported + return vllm_config.parallel_config.pipeline_parallel_size > 1 + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_heartbeat_loop(self, timeout: int): + """Background loop that continually checks to ensure the engine process + is still alive. + """ + try: + while True: + # Check if the engine process is running: + if not self._engine_process.is_running() or ( + self._engine_process.status() == psutil.STATUS_ZOMBIE): + # NB: is_running() returns True for zombies + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) " + "died.")) + break + + if await self.heartbeat_socket.poll(timeout=timeout): + # Heartbeat received- check the message + await self._check_success( + error_message="Heartbeat failed.", + socket=self.heartbeat_socket) + + logger.debug("Heartbeat successful.") + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient check health loop.") + + except psutil.NoSuchProcess: + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) died.")) + + except Exception as e: + self._set_errored(e) + + async def run_output_handler_loop(self): + """Get RequestOutputs from Engine and stream to Request Queues""" + + try: + while True: + # Poll, checking for ENGINE_DEAD + while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT + ) == 0: + logger.debug("Waiting for output from MQLLMEngine.") + + # If errored, alert all running requests. + if self.errored: + for queue_j in tuple(self.output_queues.values()): + queue_j.put_nowait( + ENGINE_DEAD_ERROR(self._errored_with)) + return + + message: Frame = await self.output_socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + is_error = isinstance(request_outputs, + (BaseException, RPCError)) + if is_error: + if isinstance(request_outputs, RPCError): + rpc_error: RPCError = request_outputs + request_id = rpc_error.request_id + exception = rpc_error.exception + is_engine_errored = rpc_error.is_engine_errored + else: + # MPLLMEngine should always return an RPCError to + # the output_socket when an issue arises. + # If we are here, we are in a bad state and + # should shut down the server. + error: BaseException = request_outputs + logger.error( + "Received Exception %s rather than RPCError from " + "MPLLMEngine. This should never happen.", error) + request_id = None + exception = error + is_engine_errored = True + + # Set to error state only on engine critical error + # (and record only the first one) + if is_engine_errored and not self._errored_with: + self._errored_with = exception + # If engine is errored, no matter the type of exception + # it will no longer be able to receive new requests, + # therefore we have to inform that the current + # processed requests failed as well. Send back a dead + # engine error give this feedback and also give a + # 'hint' to the server to shutdown next. + exception = self.dead_error + + if request_id is None: + # If request_id is None, then the engine raised an + # exception for a batch, and we may not know the + # request that caused it, neither if it was actually + # caused by any of them (e.g. CUDA OOM). Therefore we + # broadcast the same exception for all requests. + for queue_i in tuple(self.output_queues.values()): + queue_i.put_nowait(exception) + else: + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(exception) + # Put each output into the appropriate queue. + elif isinstance( + request_outputs, + (RPCAdapterLoadedResponse, RPCIsSleepingResponse)): + self._add_output(request_outputs) + else: + for request_output in request_outputs: + self._add_output(request_output) + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient output handler.") + + def _add_output(self, request_output: Union[RequestOutput, + RPCAdapterLoadedResponse, + RPCIsSleepingResponse]): + queue = self.output_queues.get(request_output.request_id) + if queue is not None: + queue.put_nowait(request_output) + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + # Start output_loop + if self.output_loop is None: + # only generate once to avoid multiple concurrent output_loops + # this will lead to race conditions and wrong orders of tokens + # returned by the engine + # setup will be called multiple times during the startup of + # the engine + self.output_loop = asyncio.create_task( + self.run_output_handler_loop()) + + with self.get_data_socket() as socket: + # Wait until server is ready. + response = await self._wait_for_server_rpc(socket) + + self.tracing_flag = response.tracing_enabled + + # Start health_loop. + if self.health_loop is None: + self.health_loop = asyncio.create_task( + self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets and terminate the context. + self.context.destroy(linger=0) + + # Cancel background tasks. + if self.health_loop is not None: + self.health_loop.cancel() + if self.output_loop is not None: + self.output_loop.cancel() + + def _set_errored(self, e: BaseException): + logger.exception(repr(e)) + if self._errored_with is None: + self._errored_with = e + + @staticmethod + async def _send_get_data_rpc_request(request: RPCStartupRequest, + expected_type: Any, + error_message: str, + socket: Socket) -> Any: + """Send an RPC request that is expecting data back.""" + + # Ping RPCServer with a request. + await socket.send_multipart((pickle.dumps(request), ), copy=False) + + # Make sure the server responds in time. + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("RPCServer didn't reply within " + f"{VLLM_RPC_TIMEOUT} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, BaseException): + raise data + elif not isinstance(data, expected_type): + raise ValueError(error_message) + + return data + + @staticmethod + async def _send_one_way_rpc_request(request: RPC_REQUEST_T, + socket: Socket): + """Send one-way RPC request to trigger an action.""" + + if socket.closed: + raise MQClientClosedError() + + await socket.send_multipart((pickle.dumps(request), )) + + async def _await_ack(self, error_message: str, socket: Socket): + """Await acknowledgement that a request succeeded.""" + + if socket.closed: + raise MQClientClosedError() + + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("MQLLMEngine didn't reply within " + f"{VLLM_RPC_TIMEOUT}ms") + + await self._check_success(error_message, socket) + + @staticmethod + async def _check_success(error_message: str, socket: Socket): + """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" + + if socket.closed: + raise MQClientClosedError() + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) + + # Raise error if unsuccessful + if isinstance(response, BaseException): + raise response + elif (not isinstance(response, str) + or response != VLLM_RPC_SUCCESS_STR): + raise ValueError(error_message) + + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.input_preprocessor + + async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_vllm_config(self) -> VllmConfig: + return self.vllm_config + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: + """Wait for the RPCServer to start up.""" + + return await self._send_get_data_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, + expected_type=RPCStartupResponse, + error_message="Unable to start RPC Server", + socket=socket) + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + with suppress(MQClientClosedError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), socket=self.input_socket) + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + """ + Ignore do_log_stats (handled on MQLLMEngine polling) + """ + pass + + async def check_health(self): + """ + The check health loop probes the health status of the + Engine's health every N seconds and sets _errored_with + if the engine is unhealthy. + """ + if self._errored_with is not None: + raise self._errored_with + + @property + def is_running(self) -> bool: + return not self.errored + + @property + def is_stopped(self) -> bool: + return self.errored + + @property + def errored(self) -> bool: + return self._errored_with is not None + + @property + def dead_error(self) -> BaseException: + return ENGINE_DEAD_ERROR(self._errored_with) + + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See + [`PromptType`][vllm.inputs.PromptType] for more details about + the format of each input. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. + priority: Priority of the request (lower means earlier handling). + Any priority other than 0 will lead to an error if the + scheduling policy is not "priority". + """ + return cast( + AsyncGenerator[RequestOutput, None], + self._process_request(prompt, sampling_params, request_id, + lora_request, trace_headers, + prompt_adapter_request, priority)) + + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + """Generate outputs for a request from a pooling model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See + [`PromptType`][vllm.inputs.PromptType] for more details about + the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + + Yields: + The output `PoolingRequestOutput` objects from the LLMEngine + for the request. + """ + return cast( + AsyncGenerator[PoolingRequestOutput, None], + self._process_request(prompt, + pooling_params, + request_id, + lora_request, + trace_headers, + priority=priority)) + + async def _process_request( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ + PoolingRequestOutput, None]]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + # If already dead, error out. + if self._errored_with is not None: + raise ENGINE_DEAD_ERROR(self._errored_with) + + # Ensure the request id is unique among running requests + if request_id in self.output_queues: + raise ValueError(f"Request {request_id} already exists") + + # Constructing guided decoding logits processors is expensive, so we do + # it here to avoid contending with cpu resources and the GIL on the + # backend process. + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: + params = await \ + build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=await self.get_tokenizer(lora_request), + default_guided_backend=(self.decoding_config.backend + if self.decoding_config + else DecodingConfig.backend), + model_config=self.model_config, + reasoning_backend=self.decoding_config.reasoning_backend, + ) + + # 1) Create output queue for this requests. + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + + try: + # 2) Detach logits processors so that they can be pickled + # separately (may require cloudpickle which is slower) + if isinstance(params, SamplingParams) and params.logits_processors: + # Defensive shallow copy + params = copy.copy(params) + logits_processors = params.logits_processors + params.logits_processors = None + lp_bytes = cloudpickle.dumps(logits_processors) + else: + lp_bytes = None + + request_bytes = pickle.dumps( + RPCProcessRequest( + prompt=prompt, + params=params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + )) + + # 3) Send the RPCGenerateRequest to the MQLLMEngine. + parts = (request_bytes, + lp_bytes) if lp_bytes else (request_bytes, ) + await self.input_socket.send_multipart(parts, copy=False) + + # 4) Stream the RequestOutputs from the output queue. Note + # that the output_loop pushes RequestOutput objects to this + # queue after pulling them from the zmq socket. + finished = False + try: + while not finished: + request_output = await queue.get() + + if isinstance(request_output, BaseException): + raise request_output + + finished = request_output.finished + yield request_output + finally: + # Request was canceled by the client. + if not finished and not self.errored: + await self.abort(request_id) + finally: + self.output_queues.pop(request_id) + + async def start_profile(self) -> None: + """Start profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket) + + async def stop_profile(self) -> None: + """Stop profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) + + async def reset_mm_cache(self) -> None: + """Reset the multi-modal cache""" + + await self._send_one_way_rpc_request( + request=RPCResetMultiModalCacheRequest.RESET, + socket=self.input_socket) + + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + """Reset the prefix cache""" + + await self._send_one_way_rpc_request( + request=RPCResetPrefixCacheRequest(device), + socket=self.input_socket) + + async def sleep(self, level: int = 1) -> None: + """Sleep the engine for a given level""" + return await self._send_one_way_rpc_request( + request=RPCSleepRequest(level), socket=self.input_socket) + + async def wake_up(self, tags: Optional[list[str]] = None) -> None: + """Wake up the engine""" + return await self._send_one_way_rpc_request( + request=RPCWakeUpRequest(tags), socket=self.input_socket) + + async def is_sleeping(self) -> bool: + """Check whether the engine is sleeping""" + request = RPCIsSleepingRequest() + + queue: asyncio.Queue[Union[BaseException, + RPCIsSleepingResponse]] = asyncio.Queue() + self.output_queues[request.request_id] = queue + + request_bytes = pickle.dumps(request) + await self.input_socket.send_multipart((request_bytes, ), copy=False) + + request_output = await queue.get() + self.output_queues.pop(request.request_id) + + if isinstance(request_output, BaseException): + raise request_output + return request_output.is_sleeping + + async def add_lora(self, lora_request: LoRARequest) -> None: + """Load a new LoRA adapter into the engine for future requests.""" + # Uses the same I/O as generate requests + request = RPCLoadAdapterRequest(lora_request) + + # Create output queue for this requests. + queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue() + self.output_queues[request.request_id] = queue + + # Send the request + request_bytes = pickle.dumps(request) + await self.input_socket.send_multipart((request_bytes, ), copy=False) + + # Wait for the response + request_output = await queue.get() + self.output_queues.pop(request.request_id) + + # Raise on error, otherwise happily return None + if isinstance(request_output, BaseException): + raise request_output diff --git a/engine/multiprocessing/engine.py b/engine/multiprocessing/engine.py new file mode 100644 index 0000000..ef088bd --- /dev/null +++ b/engine/multiprocessing/engine.py @@ -0,0 +1,460 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pickle +import signal +from contextlib import contextmanager +from typing import Iterator, List, Optional, Union + +import cloudpickle +import zmq + +from vllm import AsyncEngineArgs, SamplingParams +from vllm.config import VllmConfig +from vllm.engine.llm_engine import LLMEngine +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCAdapterLoadedResponse, RPCError, + RPCIsSleepingRequest, + RPCIsSleepingResponse, + RPCLoadAdapterRequest, + RPCProcessRequest, + RPCResetMultiModalCacheRequest, + RPCResetPrefixCacheRequest, + RPCSleepRequest, RPCStartupRequest, + RPCStartupResponse, + RPCUProfileRequest, RPCWakeUpRequest) +# yapf: enable +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.transformers_utils.config import ( + maybe_register_config_serialize_by_value) +from vllm.usage.usage_lib import UsageContext +from vllm.worker.model_runner_base import InputProcessingError + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 10000 +HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) + + +class MQLLMEngine: + """A multiprocessing wrapper for + [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. + + This class is used to wrap the + [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use + in concurrnet manner. It runs a background loop and uses zeromq to + receive new requests and stream outputs incrementally via ipc. + + The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode + process is kicked off when a new RPCProcessRequest is received by the + input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the LLMEngine if there are any, calls the internal + [`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends + the RequestOutputs back over the output_socket. + + If use_async_sockets is set, the logic associated with reading new + requests from the socket and sending data to the socket is passed + as a callback to the llm_engine, which calls the logic asynchronously + such that the IPC can be overlapped with the GPU. + + Args: + ipc_path: Base path for zeromq interprocess messaging + use_async_sockets: Whether to make send/recv async with GPU + log_requests: Whether to log the requests. + *args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. + **kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. + """ + + def __init__(self, + ipc_path: str, + use_async_sockets: bool, + *args, + log_requests: bool = True, + **kwargs) -> None: + # For MQLLMEngine, we can use cached outputs, since each new request + # output is immediately pickled and send over the socket, which frees + # the python object to be reused again. + kwargs['use_cached_outputs'] = True + + self.engine = LLMEngine(*args, **kwargs) + self.log_requests = log_requests + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + self.engine.process_request_outputs_callback = \ + self._async_socket_engine_callback + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # Send heartbeats back to client. + self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) + self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Error state. + self._errored_with: Optional[BaseException] = None + + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + + @classmethod + def from_vllm_config(cls, vllm_config: VllmConfig, + usage_context: UsageContext, + disable_log_requests: bool, disable_log_stats: bool, + ipc_path: str) -> "MQLLMEngine": + # Setup plugins for each process + from vllm.plugins import load_general_plugins + load_general_plugins() + + use_async_sockets = vllm_config.model_config.use_async_output_proc + + return cls( + vllm_config=vllm_config, + executor_class=LLMEngine._get_executor_cls(vllm_config), + ipc_path=ipc_path, + usage_context=usage_context, + use_async_sockets=use_async_sockets, + log_requests=(not disable_log_requests), + log_stats=(not disable_log_stats), + ) + + @staticmethod + def from_engine_args(engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str): + """Creates an MQLLMEngine from the engine arguments.""" + + vllm_config = engine_args.create_engine_config(usage_context) + return MQLLMEngine.from_vllm_config( + ipc_path=ipc_path, + vllm_config=vllm_config, + usage_context=usage_context, + disable_log_requests=engine_args.disable_log_requests, + disable_log_stats=engine_args.disable_log_stats, + ) + + def start(self): + try: + try: + logger.debug("Starting Startup Loop.") + self.run_startup_loop() + logger.debug("Starting Engine Loop.") + self.run_engine_loop() + except Exception as e: + logger.exception(repr(e)) + except KeyboardInterrupt: + logger.debug("Shutting down MQLLMEngine.") + finally: + logger.debug("MQLLMEngine is shut down.") + self.cleanup() + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + # Closes all sockets and destroys context. + self.ctx.destroy(linger=0) + del self.engine + + @contextmanager + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Startup loop for sending data from Engine -> Client.""" + + with self.make_data_socket() as socket: + response: Union[RPCStartupResponse, BaseException] + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCStartupRequest = pickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCStartupRequest.IS_SERVER_READY: + tracing_enabled = self.engine.is_tracing_enabled() + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) + + except Exception as e: + response = e + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) + + def run_engine_loop(self): + """Core busy loop of the LLMEngine.""" + + while True: + if not self.engine.has_unfinished_requests(): + # Poll until there is work to do. + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + # When there's no work, check on engine health and send + # health status back to client + self._health_check() + self.engine.do_log_stats() + logger.debug("Waiting for new requests in engine loop.") + + # Handle any input from the client. + self.handle_new_input() + + # Engine step. + request_outputs = self.engine_step() + + # Send request outputs (if async, done in engine_step callback). + if not self.use_async_sockets: + self._send_outputs(request_outputs) + + def engine_step(self) -> List[RequestOutput]: + """Engine step wrapper with error handling.""" + try: + return self.engine.step() + except SystemExit: + raise + except InputProcessingError as e: + # Special case where we handle an error preparing the inputs for + # a single request in the batch + rpc_err = RPCError(request_id=e.request_id, + is_engine_errored=False, + exception=e.__cause__) + self._send_outputs(rpc_err) + return [] + except BaseException as e: + self._set_errored(e) + rpc_err = RPCError(request_id=None, + is_engine_errored=True, + exception=e) + self._send_outputs(rpc_err) + raise e + + def handle_new_input(self): + """Handle new input from the socket""" + try: + while self.input_socket.poll(timeout=0) != 0: + frames = self.input_socket.recv_multipart(copy=False) + request = pickle.loads(frames[0].buffer) + + if isinstance(request, RPCProcessRequest): + if len(frames) > 1: + # Use cloudpickle for logits processors + assert isinstance(request.params, SamplingParams) + lprocs = cloudpickle.loads(frames[1].buffer) + request.params.logits_processors = lprocs + self._handle_process_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCUProfileRequest): + if request == RPCUProfileRequest.START_PROFILE: + self.start_profile() + else: + self.stop_profile() + elif isinstance(request, RPCLoadAdapterRequest): + self._handle_load_adapter_request(request) + elif isinstance(request, RPCResetMultiModalCacheRequest): + self.reset_mm_cache() + elif isinstance(request, RPCResetPrefixCacheRequest): + self.reset_prefix_cache() + elif isinstance(request, RPCSleepRequest): + self.sleep(request.value) + elif isinstance(request, RPCWakeUpRequest): + self.wake_up(request.tags) + elif isinstance(request, RPCIsSleepingRequest): + self._handle_is_sleeping_request(request) + else: + raise ValueError("Unknown RPCRequest Type: " + f"{type(request)}") + + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + raise e from None + + def _handle_process_request(self, request: RPCProcessRequest): + """Handle RPCProcessRequest by adding it to the LLMEngine.""" + request_id = request.request_id + + if self._errored_with is not None: + rpc_err = RPCError(request_id=request_id, + is_engine_errored=True, + exception=ENGINE_DEAD_ERROR(self._errored_with)) + self._send_outputs(rpc_err) + + try: + self.engine.add_request( + request_id=request_id, + prompt=request.prompt, + params=request.params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request, + priority=request.priority) + + if self.log_requests: + logger.info("Added request %s.", request.request_id) + + except Exception as e: + # We do not set self._errored = True here, since the error + # is due to an issue adding this request to the engine, + # rather than an issue with the engine itself. + logger.debug("Failed to add request %s to engine. %s", + request.request_id, e) + is_errored = self._errored_with is not None + rpc_err = RPCError(request_id=request_id, + is_engine_errored=is_errored, + exception=e) + self._send_outputs(rpc_err) + + # Remove request from the engine. + self.engine.abort_request(request_id) + + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request(request.request_id) + if self.log_requests: + logger.info("Aborted request %s.", request.request_id) + + def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest): + try: + self.engine.add_lora(request.lora_request) + except BaseException as e: + # Send back an error if the adater fails to load + rpc_err = RPCError(request_id=request.request_id, + is_engine_errored=False, + exception=e) + self._send_outputs(rpc_err) + return + # Otherwise, send back the successful load message + self._send_outputs( + RPCAdapterLoadedResponse(request_id=request.request_id)) + + def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest): + is_sleeping = self.is_sleeping() + self._send_outputs( + RPCIsSleepingResponse(request_id=request.request_id, + is_sleeping=is_sleeping)) + + def _health_check(self): + # Send unhealthy if engine has already errored + if self._errored_with is not None: + self._send_unhealthy(self._errored_with) + try: + self.engine.check_health() + self._send_healthy() + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + + def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): + """Send outputs back to the engine client. These can be: + - Exceptions + - A list of generation outputs + - A response from loading a lora adapter + """ + if outputs: + try: + from ray.exceptions import RayTaskError + + # RayTaskError might not pickelable here. We need to unpack the + # underlying exception as the real exception in the output. + if (isinstance(outputs, RPCError) + and isinstance(outputs.exception, RayTaskError)): + outputs.exception = outputs.exception.cause + except ImportError: + pass + + output_bytes = pickle.dumps(outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) + + def _send_healthy(self): + """Send HEALTHY message to RPCClient.""" + if not self.heartbeat_socket.closed: + self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + + def _send_unhealthy(self, error: BaseException): + """Send UNHEALTHY message to RPCClient.""" + if not self.heartbeat_socket.closed: + error_bytes = pickle.dumps(error) + self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) + + def _async_socket_engine_callback(self, + request_outputs: REQUEST_OUTPUTS_T): + """Callback used by engine to make socket handling async with GPU.""" + self._send_outputs(request_outputs) + self.handle_new_input() + + def _set_errored(self, e: BaseException): + """Log and set errored status if this is the first issue.""" + if self._errored_with is None: + self._errored_with = e + + def start_profile(self) -> None: + self.engine.start_profile() + + def stop_profile(self) -> None: + self.engine.stop_profile() + + def reset_mm_cache(self) -> bool: + return self.engine.reset_mm_cache() + + def reset_prefix_cache(self) -> bool: + return self.engine.reset_prefix_cache() + + def sleep(self, level: int = 1) -> None: + self.engine.sleep(level) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + self.engine.wake_up(tags) + + def is_sleeping(self) -> bool: + return self.engine.is_sleeping() + + +def signal_handler(*_) -> None: + raise KeyboardInterrupt("MQLLMEngine terminated") + + +def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext, + ipc_path: str, disable_log_stats: bool, + disable_log_requests: bool, engine_alive): + try: + # Ensure we can serialize transformer config before spawning + maybe_register_config_serialize_by_value() + + engine = MQLLMEngine.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + disable_log_stats=disable_log_stats, + disable_log_requests=disable_log_requests, + ipc_path=ipc_path) + + signal.signal(signal.SIGTERM, signal_handler) + + engine.start() + + except BaseException as e: + logger.exception(e) + engine_alive.value = False + raise e from None diff --git a/engine/output_processor/__init__.py b/engine/output_processor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/engine/output_processor/interfaces.py b/engine/output_processor/interfaces.py new file mode 100644 index 0000000..19c5963 --- /dev/null +++ b/engine/output_processor/interfaces.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Callable, List + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import Counter + + +class SequenceGroupOutputProcessor(ABC): + """Interface for logic that processes new token ids in sequence groups, + managing detokenization, stop checking, and freeing/forking sequences with + the scheduler. + + This is highly coupled with the LLMEngine and should be seen as an extension + of it. The logic is separated to simplify the LLMEngine class and allow + separate implementations for single-step decoding (which supports beam + search sequence forking) and multi-step decoding (which does not support + beam search, but does support speculative decoding). + """ + + @staticmethod + def create_output_processor( + scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, + scheduler: List[Scheduler], + seq_counter: Counter, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], + stop_checker: "StopChecker", + ): + """Create an output processor. + + This returns a single-step output processor if num_lookahead_slots is + zero, else returns a multi-step output processor. + """ + if scheduler_config.num_lookahead_slots == 0: + # Importing here to avoid cycle. + from vllm.engine.output_processor.single_step import ( + SingleStepOutputProcessor) + return SingleStepOutputProcessor(scheduler_config, detokenizer, + scheduler, seq_counter, + stop_checker) + else: + # Importing here to avoid cycle. + from vllm.engine.output_processor.multi_step import ( + MultiStepOutputProcessor) + return MultiStepOutputProcessor( + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + stop_checker, + ) + + @abstractmethod + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: + """Process new token ids for the sequence group. Handles logic such as + detokenization, stop checking, and freeing/forking sequences in the + scheduler. + """ + pass + + @abstractmethod + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Update prompt logprobs received from outputs to seq_group.""" + pass diff --git a/engine/output_processor/multi_step.py b/engine/output_processor/multi_step.py new file mode 100644 index 0000000..e0fa6a0 --- /dev/null +++ b/engine/output_processor/multi_step.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import functools +from typing import Callable, List, cast + +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.single_step import ( + single_step_process_prompt_logprob) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Sequence, + SequenceGroup, SequenceGroupOutput, SequenceOutput, + SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import Counter + +logger = init_logger(__name__) + + +class MultiStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles logic related to + detokenization and stopping conditions. It specializes to "multi-step + decoding", where vLLM's worker may generate multiple tokens per invocation. + This is currently mutually exclusive with advanced sampling techniques like + beam search, which motivates the separation of this logic from the single + step output processor. + + This class is responsible for things such as correctly appending all new + token ids to their sequence, detokenizing new token ids, truncating new + output tokens after an eos token, and correctly handling the case where the + number of new output tokens per sequence differs in a single batch. + """ + + def __init__( + self, + detokenizer: Detokenizer, + scheduler: List[Scheduler], + seq_counter: Counter, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], + stop_checker: StopChecker, + ): + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.get_tokenizer_for_seq = get_tokenizer_for_seq + self.stop_checker = stop_checker + + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Process prompt logprobs associated with each step of a multi-step- + scheduled computation. + + Args: + seq_group: the outputs are associated with this + [`SequenceGroup`][vllm.sequence.SequenceGroup] + outputs: the + [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]s + for all scheduler steps + """ + for output in outputs: + # Concatenate single-step prompt logprob processing results. + assert isinstance(output, CompletionSequenceGroupOutput) + single_step_process_prompt_logprob(self, seq_group, output) + + @staticmethod + @functools.lru_cache + def _log_prompt_logprob_unsupported_warning_once(): + # Reminder: Please update docs/features/compatibility_matrix.md + # If the feature combo become valid + logger.warning( + "Prompt logprob is not supported by multi step workers. " + "(e.g., speculative decode uses multi step workers).") + + def process_outputs(self, + sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool = False) -> None: + """Append new tokens in the outputs to sequences in the sequence group. + + This only supports sequence groups of size 1. It supports greater than + one new token per sequence. + + This applies logic like stop condition checking and detokenization. + It also handles cases where there are tokens emitted after + the EOS token. + + is_async - Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) + """ + # Sequences can be in RUNNING or FINISHED_ABORTED state + # once scheduled, as a sequence is moved to FINISHED_ABORTED + # if a client disconnects from the api server. + seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) + if seqs is None: + seqs = sequence_group.get_seqs( + status=SequenceStatus.FINISHED_ABORTED) + + for output in outputs: + if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID: + sequence_group.metrics.spec_token_acceptance_counts[ + output.step_index] += 1 + + assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" + assert len(seqs) == 1, ( + "Beam search not supported in multi-step decoding.") + seq = seqs[0] + seq_id = seq.seq_id + # This method is defined in the more generic + # SequenceGroupOutputProcessor, but here we assume that the outputs are + # of a more specific type. + assert all([ + isinstance(output, CompletionSequenceGroupOutput) + for output in outputs + ]) + compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs) + assert all([ + seq_id == output.samples[0].parent_seq_id + for output in compl_outputs + ]) + + if is_async: + # Async case: We process tokens one by one. Here, we know the token + # was already appended, so we only need to do the rest of the + # postprocessor: Detokenization + stopping logic + self._process_decode_and_stop(seq, sequence_group.sampling_params) + else: + # Standard multi-step case + + # Since there's only one sequence per sequence group, + # we can take the first sample. + samples = [output.samples[0] for output in compl_outputs] + + # entries in sample tokens may be invalid (eg. due to spec decode + # rejecting tokens). + valid_samples = [ + sample for sample in samples + if sample.output_token != VLLM_INVALID_TOKEN_ID + ] + + # When both spec-decode and pre-fill chunking are enabled, we + # don't have guaranteed samples here (e.g. all -1s). + if valid_samples: + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) + + def _process_decode_and_stop(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + new_char_count = 0 + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + + # TODO(sang): Support lora. + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count=new_char_count, + sampling_params=sampling_params, + ) + + def _process_seq_outputs(self, seq: Sequence, + valid_samples: List[SequenceOutput], + sampling_params: SamplingParams) -> None: + output_token_ids = [sample.output_token for sample in valid_samples] + output_logprobs = [sample.logprobs for sample in valid_samples] + output_embeds = [sample.output_embed for sample in valid_samples] + + # Truncate to max_tokens if necessary. + remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + + len(output_token_ids)) + if remaining_tokens < 0: + output_token_ids = output_token_ids[:remaining_tokens] + + # Truncate any tokens after EOS. This is required as spec decode + # generates a fixed number of tokens without evaluating stopping + # conditions within the block. This can cause an eos token to be + # unintentionally ignored. + if not sampling_params.ignore_eos and self.detokenizer: + eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id + # Avoiding .index calls as exception throwing in the happy path + # is expensive. + for i in range(len(output_token_ids)): + if output_token_ids[i] == eos_token_id: + output_token_ids = output_token_ids[:i + 1] + break + + is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 + # Incrementally append tokens to the sequence, as if we had only one new + # token. + for output_token_id, output_logprob, output_embed in zip( + output_token_ids, output_logprobs, output_embeds): + seq.append_token_id( + token_id=output_token_id, + logprobs=output_logprob, + token_embed=output_embed, + ) + + if is_prefill_sampled_token: + is_prefill_sampled_token = False + else: + # Update num_computed_tokens iff the sampled token is not from + # a prefill step. + seq.data.update_num_computed_tokens(1) + + self._process_decode_and_stop(seq, sampling_params) + + if seq.is_finished(): + break diff --git a/engine/output_processor/single_step.py b/engine/output_processor/single_step.py new file mode 100644 index 0000000..dbf6a37 --- /dev/null +++ b/engine/output_processor/single_step.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import List + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup, + SequenceGroupOutput) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter + +logger = init_logger(__name__) + + +def single_step_process_prompt_logprob( + sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, + output: CompletionSequenceGroupOutput) -> None: + """Process prompt logprobs associated with the + [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] for a given step. + + Do nothing if the output has no prompt logprobs. + + Account for the fact that transformers do not compute first-token logprobs. + + Args: + sg_output_proc: + [`SequenceGroupOutputProcessor`][vllm.engine.output_processor.interfaces.SequenceGroupOutputProcessor] + instance + seq_group: the output is associated with this + [`SequenceGroup`][vllm.sequence.SequenceGroup] + output: the [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] + for a single scheduler step + """ + prompt_logprobs = output.prompt_logprobs + + # If this is the first (or only) "chunk" of the prefill, we need + # to prepend None to the list of prompt logprobs. The reason for this + # is that for N prompt tokens, the Sampler will generate N-1 total + # prompt logprobs during prefill since the token at idx 0 will not + # have a logprob associated with it. + if prompt_logprobs is not None: + if not seq_group.prompt_logprobs: + prompt_logprobs = [None] + prompt_logprobs + seq_group.prompt_logprobs = [] + + assert hasattr(sg_output_proc, 'detokenizer') + if (seq_group.sampling_params.detokenize + and sg_output_proc.detokenizer): + sg_output_proc.detokenizer.decode_prompt_logprobs_inplace( + seq_group, + prompt_logprobs, + position_offset=len(seq_group.prompt_logprobs)) + + seq_group.prompt_logprobs.extend(prompt_logprobs) + + +class SingleStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles "output processing" logic, + which happens after the model returns generated token ids and before + scheduling of the next batch. Output processing logic includes + detokenization, and determining if a sequence is finished (e.g. via max len + or eos token). + + The SingleStepOutputProcessor is specialized to the case where the model + emits at most a single token per invocation, which precludes configurations + such as speculative decoding or multi-step decoding. This enables beam + search sampling, which requires forking/finishing/freeing sequences in a way + that is currently difficult to schedule multiple steps ahead of time. + """ + + def __init__(self, scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, scheduler: List[Scheduler], + seq_counter: Counter, stop_checker: StopChecker): + self.scheduler_config = scheduler_config + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.stop_checker = stop_checker + + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: + """Append all new tokens to sequences in the sequence group. Fork any + surviving beam candidates; free any unsurviving ones. + + Invokes detokenizer to detokenize new tokens, and also marks sequences + as finished if they meet stop conditions. + + is_async - Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) + """ + assert (len(outputs) == 1 + ), f"{type(self)} does not support multiple outputs per step" + return self._process_sequence_group_outputs(sequence_group, outputs[0], + is_async) + + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Process prompt logprobs associated with one step of a single-step- + scheduled computation. + + Args: + seq_group: the output is associated with this + [`SequenceGroup`][vllm.sequence.SequenceGroup] + outputs: the + [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] + for a single scheduler step + """ + assert len(outputs) == 1, "Single step should only have 1 output." + output = outputs[0] + assert isinstance(output, CompletionSequenceGroupOutput) + single_step_process_prompt_logprob(self, seq_group, output) + + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutput, + is_async: bool) -> None: + sampling_params = seq_group.sampling_params + + sample = outputs.samples[0] + seq = seq_group.first_seq + if not is_async: + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + sampling_params, + lora_req=seq_group.lora_request, + ) + if seq.is_finished(): + for scheduler in self.scheduler: + scheduler.free_seq(seq) diff --git a/engine/output_processor/stop_checker.py b/engine/output_processor/stop_checker.py new file mode 100644 index 0000000..3fb2f71 --- /dev/null +++ b/engine/output_processor/stop_checker.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, List, Optional, Tuple + +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import Sequence, SequenceStatus +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class StopChecker: + """LLMEngine helper class which separates out the logic involving stop + checking. This checks things such as: whether the eos token was emitted, + whether the max_tokens has been consumed, whether a stop string has been + emitted, or if we have exceeded the max model len. + """ + + def __init__(self, max_model_len: int, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]): + # Do not use it directly, but use `self._get_max_model_len`. + self._max_model_len = max_model_len + self.get_tokenizer_for_seq = get_tokenizer_for_seq + + def _get_max_model_len(self, lora_req: Optional[LoRARequest]): + if lora_req and lora_req.long_lora_max_len: + return lora_req.long_lora_max_len + else: + return self._max_model_len + + def maybe_stop_sequence( + self, + seq: Sequence, + new_char_count: int, + sampling_params: SamplingParams, + lora_req: Optional[LoRARequest] = None, + ) -> None: + """Stop the finished sequences. + + new_char_count is the number of chars added to the + sequence's output text for the newly generated token + """ + + # Check if the minimum number of tokens has been generated yet; + # skip the stop string/token checks if not + if seq.get_output_len() < sampling_params.min_tokens: + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + # Remove the last EOS token unless explicitly specified + # This prevents unintended exposure of the EOS token + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if a stop token was encountered. + # This assumes a single token produced per step. + last_token_id = seq.get_last_token_id() + if last_token_id in (sampling_params.stop_token_ids or ()): + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + # Remove last token + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id + return + + # Check if any stop strings are matched. + stop = self.check_stop_strings( + seq.output_text, new_char_count, sampling_params.stop, + sampling_params.include_stop_str_in_output) + if stop is not None: + stop_str, truncate_to = stop + if truncate_to != -1: + seq.output_text = seq.output_text[:truncate_to] + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return + + # Check if the sequence has reached max_model_len. + if seq.get_len() >= self._get_max_model_len(lora_req): + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + @staticmethod + def check_stop_strings( + output_text: str, + new_char_count: int, + stop: List[str], + include_in_output: bool, + ) -> Optional[Tuple[str, int]]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns tuple (stop_string, offset) if matched or else None. + + Where stop_string is the matched stop string and offset is the + length to which output_text should be truncated, or -1 for no + truncation. + """ + if not new_char_count or not stop: + return None + + for stop_str in stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = output_text.find(stop_str, + 1 - new_char_count - stop_string_len) + if stop_index == -1: + continue + + if include_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(output_text): + # No truncation required. + return stop_str, -1 + + # Truncate the output text to either the beginning + # or end of the stop string. + return stop_str, stop_index + return None diff --git a/engine/output_processor/util.py b/engine/output_processor/util.py new file mode 100644 index 0000000..1e127eb --- /dev/null +++ b/engine/output_processor/util.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import List +from typing import Sequence as GenericSequence +from typing import cast + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput + + +def create_output_by_sequence_group( + outputs: GenericSequence[SamplerOutput], + num_seq_groups: int) -> List[List[SequenceGroupOutput]]: + """Helper method which transforms a 2d list organized by + [step][sequence group] into [sequence group][step]. + """ + output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [ + [] for _ in range(num_seq_groups) + ] + for step in outputs: + sequence_group_output: CompletionSequenceGroupOutput + for i, sequence_group_output in enumerate(step): + output_by_sequence_group[i].append(sequence_group_output) + + # Cast to the more generic type that CompletionSequenceGroupOutput + # inherits from. + return cast(List[List[SequenceGroupOutput]], output_by_sequence_group) diff --git a/engine/protocol.py b/engine/protocol.py new file mode 100644 index 0000000..727d592 --- /dev/null +++ b/engine/protocol.py @@ -0,0 +1,317 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from abc import ABC, abstractmethod +from typing import AsyncGenerator, Mapping, Optional + +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function +from vllm.config import DecodingConfig, ModelConfig, VllmConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs.parse import is_explicit_encoder_decoder_prompt +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import Device, collect_from_async_generator, random_uuid + +logger = init_logger(__name__) + + +class EngineClient(ABC): + """Protocol class for Clients to Engine""" + + @property + @abstractmethod + def is_running(self) -> bool: + ... + + @property + @abstractmethod + def is_stopped(self) -> bool: + ... + + @property + @abstractmethod + def errored(self) -> bool: + ... + + @property + @abstractmethod + def dead_error(self) -> BaseException: + ... + + @abstractmethod + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request.""" + ... + + async def beam_search( + self, + prompt: PromptType, + request_id: str, + params: BeamSearchParams, + lora_request: Optional[LoRARequest] = None, + ) -> AsyncGenerator[RequestOutput, None]: + + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + include_stop_str_in_output = params.include_stop_str_in_output + + preprocessor = await self.get_input_preprocessor() + tokenizer_group = preprocessor.get_tokenizer_group() + tokenizer = await tokenizer_group.get_lora_tokenizer_async() + + if is_explicit_encoder_decoder_prompt(prompt): + raise NotImplementedError + else: + processed_inputs = preprocessor._prompt_to_llm_inputs(prompt) + + if processed_inputs["type"] == "embeds": + raise NotImplementedError + + prompt_token_ids = processed_inputs["prompt_token_ids"] + prompt_text = processed_inputs.get("prompt") + multi_modal_data = processed_inputs.get("multi_modal_data") + mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs") + + tokenized_length = len(prompt_token_ids) + + sort_beams_key = create_sort_beams_key_function( + tokenizer.eos_token_id, length_penalty) + + beam_search_params = SamplingParams( + logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature, + ) + all_beams = [ + BeamSearchSequence(tokens=prompt_token_ids, + cum_logprob=0, + logprobs=[], + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + lora_request=lora_request) + ] + completed = [] + + for _ in range(max_tokens): + prompts_batch, lora_req_batch = zip(*[( + TokensPrompt(prompt_token_ids=beam.tokens, + multi_modal_data=beam.multi_modal_data, + mm_processor_kwargs=beam.mm_processor_kwargs), + beam.lora_request, + ) for beam in all_beams]) + + tasks = [] + + request_id = f"beam_search-{random_uuid()}" + for i, (individual_prompt, + lora_req) in enumerate(zip(prompts_batch, lora_req_batch)): + request_id_item = f"{request_id}-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.generate(individual_prompt, + beam_search_params, + request_id_item, + lora_request=lora_req))) + tasks.append(task) + + output = await asyncio.gather(*tasks) + + output = [x[0] for x in output] + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + completed.append( + BeamSearchSequence( + tokens=current_beam.tokens + + [token_id] if include_stop_str_in_output + else current_beam.tokens, + logprobs=current_beam.logprobs + + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + finish_reason="stop", + stop_reason=tokenizer.eos_token_id)) + else: + new_beams.append( + BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + + [logprobs], + lora_request=current_beam.lora_request, + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam. + multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs)) + + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos): + # Skip the eos token in the text. + tokens = beam.tokens[tokenized_length:-1] + else: + tokens = beam.tokens[tokenized_length:] + beam.text = tokenizer.decode(tokens) + + beam_search_output = RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput(text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.logprobs, + finish_reason=beam.finish_reason if + beam.finish_reason is not None else "length", + stop_reason=beam.stop_reason) + for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None) + + yield beam_search_output + + @abstractmethod + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + """Generate outputs for a request from a pooling model.""" + ... + + @abstractmethod + async def abort(self, request_id: str) -> None: + """Abort a request. + + Args: + request_id: The unique id of the request. + """ + ... + + @abstractmethod + async def get_vllm_config(self) -> VllmConfig: + """Get the vllm configuration of the vLLM engine.""" + ... + + @abstractmethod + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + ... + + @abstractmethod + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + ... + + @abstractmethod + async def get_input_preprocessor(self) -> InputPreprocessor: + """Get the input processor of the vLLM engine.""" + ... + + @abstractmethod + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + """Get the appropriate tokenizer for the request""" + ... + + @abstractmethod + async def is_tracing_enabled(self) -> bool: + ... + + @abstractmethod + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[list[SamplerOutput]] = None, + ) -> None: + ... + + @abstractmethod + async def check_health(self) -> None: + """Raise if unhealthy""" + ... + + @abstractmethod + async def start_profile(self) -> None: + """Start profiling the engine""" + ... + + @abstractmethod + async def stop_profile(self) -> None: + """Start profiling the engine""" + ... + + @abstractmethod + async def reset_mm_cache(self) -> None: + """Reset the multi-modal cache""" + ... + + @abstractmethod + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + """Reset the prefix cache""" + ... + + @abstractmethod + async def sleep(self, level: int = 1) -> None: + """Sleep the engine""" + ... + + @abstractmethod + async def wake_up(self, tags: Optional[list[str]] = None) -> None: + """Wake up the engine""" + ... + + @abstractmethod + async def is_sleeping(self) -> bool: + """Check whether the engine is sleeping""" + ... + + @abstractmethod + async def add_lora(self, lora_request: LoRARequest) -> None: + """Load a new LoRA adapter into the engine for future requests.""" + ... diff --git a/entrypoints/__init__.py b/entrypoints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/entrypoints/api_server.py b/entrypoints/api_server.py new file mode 100644 index 0000000..3d1e5dc --- /dev/null +++ b/entrypoints/api_server.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +NOTE: This API server is used only for demonstrating usage of AsyncEngine +and simple performance benchmarks. It is not intended for production use. +For production use, we recommend using our OpenAI compatible server. +We are also not going to accept PRs modifying this file, please +change `vllm/entrypoints/openai/api_server.py` instead. +""" +import asyncio +import json +import ssl +from argparse import Namespace +from collections.abc import AsyncGenerator +from typing import Any, Optional + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse + +import vllm.envs as envs +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.utils import with_cancellation +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger("vllm.entrypoints.api_server") + +app = FastAPI() +engine = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + return await _generate(request_dict, raw_request=request) + + +@with_cancellation +async def _generate(request_dict: dict, raw_request: Request) -> Response: + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + + assert engine is not None + results_generator = engine.generate(prompt, sampling_params, request_id) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + assert prompt is not None + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\n").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + try: + async for request_output in results_generator: + final_output = request_output + except asyncio.CancelledError: + return Response(status_code=499) + + assert final_output is not None + prompt = final_output.prompt + assert prompt is not None + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +def build_app(args: Namespace) -> FastAPI: + global app + + app.root_path = args.root_path + return app + + +async def init_app( + args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, +) -> FastAPI: + app = build_app(args) + + global engine + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER)) + app.state.engine_client = engine + return app + + +async def run_server(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs: Any) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + set_ulimit() + + app = await init_app(args, llm_engine) + assert engine is not None + + shutdown_task = await serve_http( + app, + sock=None, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + await shutdown_task + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=parser.check_port, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument("--ssl-ca-certs", + type=str, + default=None, + help="The CA certificates file") + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)" + ) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument("--log-level", type=str, default="debug") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + asyncio.run(run_server(args)) diff --git a/entrypoints/chat_utils.py b/entrypoints/chat_utils.py new file mode 100644 index 0000000..95c806c --- /dev/null +++ b/entrypoints/chat_utils.py @@ -0,0 +1,1299 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import json +from abc import ABC, abstractmethod +from collections import defaultdict, deque +from collections.abc import Awaitable, Iterable +from functools import cache, lru_cache, partial +from pathlib import Path +from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, + cast) + +import jinja2.nodes +import transformers.utils.chat_template_utils as hf_chat_utils +# yapf conflicts with isort for this block +# yapf: disable +from openai.types.chat import (ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartInputAudioParam) +from openai.types.chat import ( + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) +from openai.types.chat import (ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartTextParam) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) +from openai.types.chat import (ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam) +from openai.types.chat.chat_completion_content_part_input_audio_param import ( + InputAudio) +from pydantic import TypeAdapter +# yapf: enable +from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, + ProcessorMixin) +# pydantic needs the TypedDict from typing_extensions +from typing_extensions import Required, TypeAlias, TypedDict + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal.utils import MediaConnector +# yapf: disable +from vllm.transformers_utils.chat_templates import ( + get_chat_template_fallback_path) +# yapf: enable +from vllm.transformers_utils.processor import cached_get_processor +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import deprecate_kwargs, random_uuid + +logger = init_logger(__name__) + + +class AudioURL(TypedDict, total=False): + url: Required[str] + """ + Either a URL of the audio or a data URL with base64 encoded audio data. + """ + + +class ChatCompletionContentPartAudioParam(TypedDict, total=False): + audio_url: Required[AudioURL] + + type: Required[Literal["audio_url"]] + """The type of the content part.""" + + +class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): + image_embeds: Required[Union[str, dict[str, str]]] + """ + The image embeddings. It can be either: + - A single base64 string. + - A dictionary where each value is a base64 string. + """ + type: Required[Literal["image_embeds"]] + """The type of the content part.""" + + +class VideoURL(TypedDict, total=False): + url: Required[str] + """ + Either a URL of the video or a data URL with base64 encoded video data. + """ + + +class ChatCompletionContentPartVideoParam(TypedDict, total=False): + video_url: Required[VideoURL] + + type: Required[Literal["video_url"]] + """The type of the content part.""" + + +class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain image_url. + This is supported by OpenAI API, although it is not documented. + + Example: + { + "image_url": "https://example.com/image.jpg" + } + """ + image_url: Required[str] + + +class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain audio_url. + + Example: + { + "audio_url": "https://example.com/audio.mp3" + } + """ + audio_url: Required[str] + + +class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain audio_url. + + Example: + { + "video_url": "https://example.com/video.mp4" + } + """ + video_url: Required[str] + + +ChatCompletionContentPartParam: TypeAlias = Union[ + OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + ChatCompletionContentPartInputAudioParam, + ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, + CustomChatCompletionContentSimpleImageParam, + ChatCompletionContentPartImageEmbedsParam, + CustomChatCompletionContentSimpleAudioParam, + CustomChatCompletionContentSimpleVideoParam, str] + + +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, list[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + +ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, + CustomChatCompletionMessageParam] + + +# TODO: Make fields ReadOnly once mypy supports it +class ConversationMessage(TypedDict, total=False): + role: Required[str] + """The role of the message's author.""" + + content: Union[Optional[str], list[dict[str, str]]] + """The contents of the message""" + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + name: Optional[str] + """The name of the function to call""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + +# Passed in by user +ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] + +# Used internally +_ChatTemplateContentFormat = Literal["string", "openai"] + + +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + if isinstance(node, jinja2.nodes.Getitem): + return (_is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key) + + if isinstance(node, jinja2.nodes.Getattr): + return _is_var_access(node.node, varname) and node.attr == key + + return False + + +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: Optional[str] = None, +) -> bool: + if isinstance(node, jinja2.nodes.Filter): + return (node.node is not None + and _is_var_or_elems_access(node.node, varname, key)) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + if (isinstance(node, jinja2.nodes.Getitem) + and isinstance(node.arg, jinja2.nodes.Slice)): + return _is_var_or_elems_access(node.node, varname, key) + + # yapf: disable + return ( + _is_attr_access(node, varname, key) if key + else _is_var_access(node, varname) + ) # yapf: enable + + +def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): + # Global variable that is implicitly defined at the root + yield root, varname + + # Iterative BFS + related_varnames = deque([varname]) + while related_varnames: + related_varname = related_varnames.popleft() + + for assign_ast in root.find_all(jinja2.nodes.Assign): + lhs = assign_ast.target + rhs = assign_ast.node + + if _is_var_or_elems_access(rhs, related_varname): + assert isinstance(lhs, jinja2.nodes.Name) + yield assign_ast, lhs.name + + # Avoid infinite looping for self-assignment + if lhs.name != related_varname: + related_varnames.append(lhs.name) + + +# NOTE: The proper way to handle this is to build a CFG so that we can handle +# the scope in which each variable is defined, but that is too complicated +def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): + messages_varnames = [ + varname + for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") + ] + + # Search for {%- for message in messages -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in messages_varnames: + if _is_var_or_elems_access(loop_iter, varname): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): + message_varnames = [ + varname for _, varname in _iter_nodes_assign_messages_item(root) + ] + + # Search for {%- for content in message['content'] -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in message_varnames: + if _is_var_or_elems_access(loop_iter, varname, "content"): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]: + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception: + logger.exception("Error when compiling Jinja template") + return None + + +def _detect_content_format( + chat_template: str, + *, + default: _ChatTemplateContentFormat, +) -> _ChatTemplateContentFormat: + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: + return default + + try: + next(_iter_nodes_assign_content_item(jinja_ast)) + except StopIteration: + return "string" + except Exception: + logger.exception("Error when parsing AST of Jinja template") + return default + else: + return "openai" + + +def resolve_mistral_chat_template( + chat_template: Optional[str], + **kwargs: Any, +) -> Optional[str]: + if chat_template is not None: + logger.warning_once( + "'chat_template' cannot be overridden for mistral tokenizer.") + if "add_generation_prompt" in kwargs: + logger.warning_once( + "'add_generation_prompt' is not supported for mistral tokenizer, " + "so it will be ignored.") + if "continue_final_message" in kwargs: + logger.warning_once( + "'continue_final_message' is not supported for mistral tokenizer, " + "so it will be ignored.") + return None + +@deprecate_kwargs( + "trust_remote_code", + additional_message="Please use `model_config.trust_remote_code` instead.", +) +def resolve_hf_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + *, + model_config: ModelConfig, + trust_remote_code: Optional[bool] = None, +) -> Optional[str]: + # 1st priority: The given chat template + if chat_template is not None: + return chat_template + + # 2nd priority: AutoProcessor chat template, unless tool calling is enabled + if tools is None: + try: + processor = cached_get_processor( + tokenizer.name_or_path, + processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, + ProcessorMixin), + trust_remote_code=model_config.trust_remote_code, + ) + if isinstance(processor, ProcessorMixin) and \ + hasattr(processor, 'chat_template') and \ + processor.chat_template is not None: + return processor.chat_template + except Exception: + logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501 + + # 3rd priority: AutoTokenizer chat template + try: + return tokenizer.get_chat_template(chat_template, tools=tools) + except Exception: + logger.debug("Failed to load AutoTokenizer chat template for %s", + tokenizer.name_or_path, exc_info=True) + + # 4th priority: Predefined fallbacks + path = get_chat_template_fallback_path( + model_type=model_config.hf_config.model_type, + tokenizer_name_or_path=model_config.tokenizer, + ) + if path is not None: + logger.info("Loading chat template fallback for %s as there isn't one " + "defined on HF Hub.", tokenizer.name_or_path) + chat_template = load_chat_template(path) + else: + logger.debug("There is no chat template fallback for %s", + tokenizer.name_or_path) + + return chat_template + + +def _resolve_chat_template_content_format( + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + tokenizer: AnyTokenizer, + *, + model_config: ModelConfig, +) -> _ChatTemplateContentFormat: + if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + hf_chat_template = resolve_hf_chat_template( + tokenizer, + chat_template=chat_template, + tools=tools, + model_config=model_config, + ) + else: + hf_chat_template = None + + jinja_text = (hf_chat_template if isinstance(hf_chat_template, str) + else load_chat_template(chat_template, is_literal=True)) + + detected_format = ("string" if jinja_text is None else + _detect_content_format(jinja_text, default="string")) + + return detected_format + + +@lru_cache +def _log_chat_template_content_format( + chat_template: Optional[str], + given_format: ChatTemplateContentFormatOption, + detected_format: ChatTemplateContentFormatOption, +): + logger.info( + "Detected the chat template content format to be '%s'. " + "You can set `--chat-template-content-format` to override this.", + detected_format, + ) + + if given_format != "auto" and given_format != detected_format: + logger.warning( + "You specified `--chat-template-content-format %s` " + "which is different from the detected format '%s'. " + "If our automatic detection is incorrect, please consider " + "opening a GitHub issue so that we can improve it: " + "https://github.com/vllm-project/vllm/issues/new/choose", + given_format, + detected_format, + ) + + +@deprecate_kwargs( + "trust_remote_code", + additional_message="Please use `model_config.trust_remote_code` instead.", +) +def resolve_chat_template_content_format( + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, + *, + model_config: ModelConfig, + trust_remote_code: Optional[bool] = None, +) -> _ChatTemplateContentFormat: + detected_format = _resolve_chat_template_content_format( + chat_template, + tools, + tokenizer, + model_config=model_config, + ) + + _log_chat_template_content_format( + chat_template, + given_format=given_format, + detected_format=detected_format, + ) + + return detected_format if given_format == "auto" else given_format + + + +ModalityStr = Literal["image", "audio", "video", "image_embeds"] +_T = TypeVar("_T") + + +class BaseMultiModalItemTracker(ABC, Generic[_T]): + """ + Tracks multi-modal items in a given request and ensures that the number + of multi-modal items in a given request does not exceed the configured + maximum per prompt. + """ + + def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): + super().__init__() + + self._model_config = model_config + self._tokenizer = tokenizer + + self._items_by_modality = defaultdict[str, list[_T]](list) + + @property + def model_config(self) -> ModelConfig: + return self._model_config + + @property + def allowed_local_media_path(self): + return self._model_config.allowed_local_media_path + + @property + def mm_registry(self): + return MULTIMODAL_REGISTRY + + @staticmethod + @cache + def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: + return tokenizer.decode(token_index) + + def _placeholder_str(self, modality: ModalityStr, + current_count: int) -> Optional[str]: + # TODO: Let user specify how to insert image tokens into prompt + # (similar to chat template) + hf_config = self._model_config.hf_config + model_type = hf_config.model_type + + if modality in ("image", "image_embeds"): + if model_type == "chatglm": + return "<|begin_of_image|><|endoftext|><|end_of_image|>" + if model_type in ("phi3_v", "phi4mm"): + return f"<|image_{current_count}|>" + if model_type in ("minicpmo", "minicpmv"): + return "(./)" + if model_type in ("blip-2", "florence2", "fuyu", "paligemma", + "pixtral", "mistral3"): + # These models do not use image tokens in the prompt + return None + if model_type == "qwen": + return f"Picture {current_count}: " + if model_type.startswith("llava"): + return self._cached_token_str(self._tokenizer, + hf_config.image_token_index) + + if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2", + "internvl_chat", "ovis", "skywork_chat", + "NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"): + return "" + if model_type in ("mllama", "llama4"): + return "<|image|>" + if model_type in ("qwen2_vl", "qwen2_5_vl"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if model_type == "qwen2_5_omni": + return "<|vision_start|><|IMAGE|><|vision_end|>" + if model_type == "molmo": + return "" + if model_type == "aria": + return "<|fim_prefix|><|img|><|fim_suffix|>" + if model_type == "gemma3": + return "" + if model_type == "kimi_vl": + return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501 + + raise TypeError(f"Unknown {modality} model type: {model_type}") + elif modality == "audio": + if model_type in ("ultravox", "granite_speech"): + return "<|audio|>" + if model_type == "phi4mm": + return f"<|audio_{current_count}|>" + if model_type in ("qwen2_audio", "qwen2_5_omni"): + return (f"Audio {current_count}: " + f"<|audio_bos|><|AUDIO|><|audio_eos|>") + if model_type == "minicpmo": + return "()" + raise TypeError(f"Unknown model type: {model_type}") + elif modality == "video": + if model_type == "internvl_chat": + return "