From fc48b79ae9f9a969e3902dd196d566a51ee62d77 Mon Sep 17 00:00:00 2001 From: fromck <74886593+fromck@users.noreply.github.com> Date: Wed, 11 Feb 2026 18:32:30 +0800 Subject: [PATCH] support glm4.7 mtp (#187) Signed-off-by: chengxiaokang Co-authored-by: chengxiaokang --- .../v1/attention/backends/kunlun_attn.py | 527 ++++++++++++------ 1 file changed, 350 insertions(+), 177 deletions(-) diff --git a/vllm_kunlun/v1/attention/backends/kunlun_attn.py b/vllm_kunlun/v1/attention/backends/kunlun_attn.py index 95ef1bb..edf3935 100644 --- a/vllm_kunlun/v1/attention/backends/kunlun_attn.py +++ b/vllm_kunlun/v1/attention/backends/kunlun_attn.py @@ -14,39 +14,53 @@ # limitations under the License. # This file is a part of the vllm-kunlun project. # -from vllm.config import VllmConfig, get_layers_from_vllm_config -import xtorch_ops from dataclasses import dataclass -from typing import Any, Dict, List, Optional, ClassVar, Tuple, Type, TYPE_CHECKING +from itertools import accumulate +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, +) -import torch import numpy as np -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionLayer, AttentionType) +import torch +import xtorch_ops +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionType, +) +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + CommonAttentionMetadata, + split_decodes_and_prefills, +) + # from vllm.attention.backends.utils import CommonAttentionState # from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping -from vllm_kunlun.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) -from vllm_kunlun.ops._kunlun_ops import KunlunOps +from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - AttentionCGSupport, - split_decodes_and_prefills) -from vllm.forward_context import ForwardContext, get_forward_context -from itertools import accumulate -from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch - from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +import inspect from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable -from vllm.config import VllmConfig, get_layers_from_vllm_config -import inspect class KunlunAttentionBackend(AttentionBackend): """KunlunAttentionBackend""" + # crucial to cuda graph accept_output_buffer = True @@ -81,12 +95,13 @@ class KunlunAttentionBackend(AttentionBackend): block_size: int, num_kv_heads: int, head_size: int, - cache_dtype_str: str = "auto" + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: """get_kv_cache_shape""" # return (2, num_blocks, block_size, num_kv_heads * head_size) - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + return PagedAttention.get_kv_cache_shape( + num_blocks, block_size, num_kv_heads, head_size + ) @staticmethod def swap_blocks( @@ -104,13 +119,12 @@ class KunlunAttentionBackend(AttentionBackend): ) -> None: """copy_blocks""" raise NotImplementedError - + @dataclass class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): """KunlunMetadata""" - # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| @@ -133,7 +147,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - + slot_mapping: torch.Tensor block_tables: torch.Tensor @@ -203,11 +217,13 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): use_cascade: Optional[bool] = False seq_lens_tensor_cpu: Optional[torch.Tensor] = None - + num_prefill_tokens: int = 0 num_decode_tokens: int = 0 num_prefills: int = 0 num_decodes: int = 0 + is_speculative: Optional[bool] = False + max_model_len: int = 0 def __post_init__(self): """__post_init__""" @@ -218,16 +234,20 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): @property def is_all_encoder_attn_metadata_set(self): """is_all_encoder_attn_metadata_set""" - return ((self.encoder_seq_lens is not None) - and (self.encoder_seq_lens_tensor is not None) - and (self.max_encoder_seq_len is not None)) + return ( + (self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None) + ) @property def is_all_cross_attn_metadata_set(self): """is_all_cross_attn_metadata_set""" - return (self.is_all_encoder_attn_metadata_set - and (self.cross_slot_mapping is not None) - and (self.cross_block_tables is not None)) + return ( + self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None) + ) @property def prefill_metadata(self) -> Optional["KunlunMetadata"]: @@ -240,35 +260,60 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): # metadata structure return self._cached_prefill_metadata - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) + assert (self.seq_lens_tensor is not None) or ( + self.encoder_seq_lens_tensor is not None + ) # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[-(self.num_prefills + 1):] - self.query_start_loc[-(self.num_prefills + 1)]) + query_start_loc = ( + None + if self.query_start_loc is None + else self.query_start_loc[-(self.num_prefills + 1) :] + - self.query_start_loc[-(self.num_prefills + 1)] + ) # flash attention needs both lod information on host and device - query_start_loc_host = (None if self.query_start_loc_host is None else - self.query_start_loc_host[-(self.num_prefills + 1):] - self.query_start_loc_host[-(self.num_prefills + 1)]) - + query_start_loc_host = ( + None + if self.query_start_loc_host is None + else self.query_start_loc_host[-(self.num_prefills + 1) :] + - self.query_start_loc_host[-(self.num_prefills + 1)] + ) + # TODO(chengruichang):how to support prefix cache kv_prefix_start_loc_host = None kv_prefix_start_loc = None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[-self.num_prefill_tokens:]) + slot_mapping = ( + None + if self.slot_mapping is None + else self.slot_mapping[-self.num_prefill_tokens :] + ) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[-self.num_prefills:]) - seq_lens = (None if self.seq_lens is None else self.seq_lens[-self.num_prefills:]) + seq_lens_tensor = ( + None + if self.seq_lens_tensor is None + else self.seq_lens_tensor[-self.num_prefills :] + ) + seq_lens = ( + None if self.seq_lens is None else self.seq_lens[-self.num_prefills :] + ) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[-self.num_prefills:]) - - block_tables = (None if self.block_tables is None else - self.block_tables[-self.num_prefills:]) - input_positions = (None if self.input_positions is None else - self.input_positions[-self.num_prefills:]) + context_lens_tensor = ( + None + if self.context_lens_tensor is None + else self.context_lens_tensor[-self.num_prefills :] + ) + + block_tables = ( + None + if self.block_tables is None + else self.block_tables[-self.num_prefills :] + ) + input_positions = ( + None + if self.input_positions is None + else self.input_positions[-self.num_prefills :] + ) - if self.kv_lod_cpu is None: kv_lod_cpu = None kv_lod_xpu = None @@ -280,19 +325,17 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): base_xpu = self.kv_lod_xpu[start] kv_lod_xpu = self.kv_lod_xpu[start:] - base_xpu - # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = KunlunMetadata( num_actual_tokens=self.num_actual_tokens, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, + multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps, num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - seq_start_loc = None, + seq_start_loc=None, kv_lod_cpu=kv_lod_cpu, kv_lod_xpu=kv_lod_xpu, max_query_len=self.max_query_len, @@ -314,7 +357,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables, enable_kv_scales_calculation=False, - use_cascade=self.use_cascade) + use_cascade=self.use_cascade, + is_speculative=self.is_speculative, + ) return self._cached_prefill_metadata @property @@ -327,40 +372,47 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): # Recover cached decode-phase attention # metadata structure return self._cached_decode_metadata - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) + assert (self.seq_lens_tensor is not None) or ( + self.encoder_seq_lens_tensor is not None + ) if self.num_prefills != 0: # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:-self.num_prefill_tokens]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:-self.num_prefills]) - seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else - self.seq_lens_tensor_cpu[:-self.num_prefills]) - - block_tables = (None if self.block_tables is None else - self.block_tables[:-self.num_prefills]) + slot_mapping = ( + None + if self.slot_mapping is None + else self.slot_mapping[: -self.num_prefill_tokens] + ) + seq_lens_tensor = ( + None + if self.seq_lens_tensor is None + else self.seq_lens_tensor[: -self.num_prefills] + ) + seq_lens_tensor_cpu = ( + None + if self.seq_lens_tensor_cpu is None + else self.seq_lens_tensor_cpu[: -self.num_prefills] + ) + block_tables = ( + None + if self.block_tables is None + else self.block_tables[: -self.num_prefills] + ) else: # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor) - - seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else - self.seq_lens_tensor_cpu) - - - block_tables = (None if self.block_tables is None else - self.block_tables) - + slot_mapping = None if self.slot_mapping is None else self.slot_mapping + seq_lens_tensor = ( + None if self.seq_lens_tensor is None else self.seq_lens_tensor + ) + seq_lens_tensor_cpu = ( + None if self.seq_lens_tensor_cpu is None else self.seq_lens_tensor_cpu + ) + block_tables = None if self.block_tables is None else self.block_tables # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = KunlunMetadata( num_actual_tokens=self.num_actual_tokens, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, + multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, @@ -378,19 +430,29 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables, enable_kv_scales_calculation=False, - use_cascade=self.use_cascade) + use_cascade=self.use_cascade, + is_speculative=self.is_speculative, + max_model_len=self.max_model_len, + ) return self._cached_decode_metadata +M = TypeVar("M") + class KunlunAttentionMetadataBuilder: """KunlunAttentionMetadataBuilder""" - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH reorder_batch_threshold: ClassVar[Optional[int]] = 1 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): """__init__""" self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -398,17 +460,45 @@ class KunlunAttentionMetadataBuilder: self.compilation_config = vllm_config.compilation_config self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config) - self.num_heads_kv = self.model_config.get_num_kv_heads( - self.parallel_config) + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec self.device = device - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def _init_reorder_batch_threshold( + self, + reorder_batch_threshold: int | None = 1, + supports_spec_as_decode: bool = False, + supports_dcp_with_varlen: bool = False, + ) -> None: + self.reorder_batch_threshold = reorder_batch_threshold + if self.reorder_batch_threshold is not None and supports_spec_as_decode: + # If the backend supports spec-as-decode kernels, then we can set + # the reorder_batch_threshold based on the number of speculative + # tokens from the config. + speculative_config = self.vllm_config.speculative_config + if ( + speculative_config is not None + and speculative_config.num_speculative_tokens is not None + ): + self.reorder_batch_threshold = max( + self.reorder_batch_threshold, + 1 + speculative_config.num_speculative_tokens, + ) + + if ( + self.vllm_config.parallel_config.decode_context_parallel_size > 1 + and not supports_dcp_with_varlen + ): + self.reorder_batch_threshold = 1 + + def reorder_batch( + self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" + ) -> bool: """reorder_batch""" decodes = [] prefills = [] @@ -432,8 +522,9 @@ class KunlunAttentionMetadataBuilder: for i in range(1, min(num_decodes, num_prefills) + 1): if decodes[num_decodes - i] >= num_decodes: - input_batch.swap_states(prefills[first_prefill], - decodes[num_decodes - i]) + input_batch.swap_states( + prefills[first_prefill], decodes[num_decodes - i] + ) first_prefill += 1 modified_batch = True else: @@ -443,7 +534,7 @@ class KunlunAttentionMetadataBuilder: self._num_decode_tokens = num_decode_tokens self._num_prefill_tokens = num_prefill_tokens return modified_batch - + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> KunlunMetadata: @@ -454,8 +545,30 @@ class KunlunAttentionMetadataBuilder: attn_metadata.seq_lens_tensor.fill_(1) return attn_metadata - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + def build_for_drafting( + self, + common_attn_metadata: CommonAttentionMetadata, + draft_index: int, + ) -> M: + """ + Build attention metadata for draft model. Uses build by default. + + Args: + common_attn_metadata: The common attention metadata. + draft_index: The index of the current draft operation. + When speculating a chain of tokens, this index refers to the + draft attempt for the i-th token. + For tree-based attention, this index instead refers to the + draft attempt for the i-th level in the tree of tokens. + """ + return self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + def build( + self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata + ): """build""" num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens @@ -464,30 +577,38 @@ class KunlunAttentionMetadataBuilder: block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) - query_start_loc_host = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1].to( - self.device, non_blocking=True) - + query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] + query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to( + self.device, non_blocking=True + ) + seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu - + seq_start_loc = list(accumulate(seq_lens, initial=0)) - - - seq_start_loc_tensor = torch.empty(len(seq_start_loc), dtype=torch.int32, device=self.device) + seq_start_loc_tensor = torch.empty( + len(seq_start_loc), dtype=torch.int32, device=self.device + ) seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32)) kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu") kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0) kv_lod_xpu = kv_lod_cpu.to(self.device) - - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata) - num_scheduled_tokens = np.diff(common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]) + self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold or 1, + require_uniform=True, + ) + ) + + num_scheduled_tokens = np.diff( + common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] + ) tmp_decode_scheduled_tokens = num_scheduled_tokens[:num_decodes] if num_decode_tokens == 0: @@ -495,18 +616,19 @@ class KunlunAttentionMetadataBuilder: else: max_decode_seq_len = np.max(tmp_decode_scheduled_tokens) - tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs] - + tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes:num_reqs] + if num_prefill_tokens == 0: max_prefill_seq_len = 0 else: max_prefill_seq_len = np.max(tmp_prefill_scheduled_tokens) - + use_cascade = False attn_metadata = KunlunMetadata( num_actual_tokens=num_actual_tokens, num_prefills=num_prefills, + num_decodes=num_decodes, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, @@ -525,11 +647,14 @@ class KunlunAttentionMetadataBuilder: block_tables=block_table_tensor, use_cuda_graph=False, use_cascade=use_cascade, + is_speculative=self.reorder_batch_threshold > 1, + max_model_len=self.vllm_config.model_config.max_model_len, ) return attn_metadata def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: + self, common_attn_metadata: CommonAttentionMetadata + ) -> bool: """can_run_in_cudagraph""" # Full CUDA Graph always supported (FA2 support checked separately) return True @@ -538,6 +663,7 @@ class KunlunAttentionMetadataBuilder: """use_cascade_attention""" return use_cascade_attention(*args, **kwargs) + class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): """KunlunAttentionImpl""" @@ -555,13 +681,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): kv_sharing_target_layer_name: Optional[str] = None, attn_type: AttentionType = AttentionType.DECODER, use_irope: bool = False, - sinks:Optional[torch.Tensor]= None, - multi_modal_placeholder_index_maps:Optional[torch.Tensor]= None, + sinks: Optional[torch.Tensor] = None, + multi_modal_placeholder_index_maps: Optional[torch.Tensor] = None, ) -> None: """__init__""" if blocksparse_params is not None: - raise ValueError( - "kunlunAttention does not support block-sparse attention.") + raise ValueError("kunlunAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -582,15 +707,17 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): if head_size not in suppored_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {suppored_head_sizes}." + ) self.sinks = sinks if sinks is not None: assert sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " f"heads in the layer. Sinks shape: {sinks.shape}, " - f"num_heads: {num_heads}.") - self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps + f"num_heads: {num_heads}." + ) + self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps def forward( self, @@ -605,7 +732,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): attn_type: AttentionType = AttentionType.DECODER, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """forward""" query = query.view(-1, self.num_heads, self.head_size) @@ -624,7 +751,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): # Self-attention vs. cross-attention will impact # which KV cache memory-mapping & which # seqlen datastructures we utilize - if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: # KV-cache during decoder-self- or # encoder-decoder-cross-attention, but not # during encoder attention. @@ -633,7 +760,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): # we still need to break out key_cache and value_cache # i.e. for later use by paged attention key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + kv_cache, self.num_kv_heads, self.head_size + ) if (key is not None) and (value is not None): updated_slot_mapping = attn_metadata.slot_mapping @@ -644,11 +772,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): value = value.contiguous() if key_cache.is_contiguous(): xtorch_ops.reshape_and_cache( - key, - value, + key[: attn_metadata.num_actual_tokens], + value[: attn_metadata.num_actual_tokens], key_cache, value_cache, - updated_slot_mapping) + updated_slot_mapping, + ) else: cast_key_cache = key_cache.squeeze(1).unsqueeze(-2) cast_value_cache = value_cache.squeeze(1).unsqueeze(-2) @@ -657,7 +786,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): value, cast_key_cache, cast_value_cache, - updated_slot_mapping) + updated_slot_mapping, + ) assert attn_type == AttentionType.DECODER # Decoder self-attention supports chunked prefill. @@ -668,88 +798,98 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens] - prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens] - prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens] + prefill_query = query[num_decode_tokens : attn_metadata.num_actual_tokens] + prefill_key = key[num_decode_tokens : attn_metadata.num_actual_tokens] + prefill_value = value[num_decode_tokens : attn_metadata.num_actual_tokens] # For hybrid Attention (Qwen3-Next.) if key_cache.is_contiguous(): tmp_block_tables = prefill_meta.block_tables else: # For hybrid Attention (Qwen3-Next) - tmp_block_tables = prefill_meta.block_tables * 2 - + tmp_block_tables = prefill_meta.block_tables * 2 + # Prefix cache if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]: xtorch_ops.prefill_attention( q=prefill_query, - k=key_cache, # Key Cache [block_num, head, block_size, dim] + k=key_cache, # Key Cache [block_num, head, block_size, dim] v=value_cache, - out=output[num_decode_tokens:attn_metadata.num_actual_tokens], + out=output[num_decode_tokens : attn_metadata.num_actual_tokens], is_causal=True, - is_prefix_cache=True, - block_table=tmp_block_tables, + is_prefix_cache=True, + block_table=tmp_block_tables, context_qlen_lod_cpu=prefill_meta.query_start_loc_host, context_qlen_lod_xpu=prefill_meta.query_start_loc, context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu, context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu, alibi_slopes=self.alibi_slopes, - softmax_lse=None + softmax_lse=None, ) else: xtorch_ops.prefill_attention( q=prefill_query, k=prefill_key, v=prefill_value, - out=output[num_decode_tokens:attn_metadata.num_actual_tokens], + out=output[num_decode_tokens : attn_metadata.num_actual_tokens], is_causal=True, context_qlen_lod_cpu=prefill_meta.query_start_loc_host, context_qlen_lod_xpu=prefill_meta.query_start_loc, alibi_slopes=self.alibi_slopes, - softmax_lse=None, - swa_left = self.sliding_window if self.sliding_window is not None else -1, - swa_right = 0 if self.sliding_window is not None else -1, - sink = self.sinks.to(torch.float32) if self.sinks is not None else None + softmax_lse=None, + swa_left=( + self.sliding_window if self.sliding_window is not None else -1 + ), + swa_right=0 if self.sliding_window is not None else -1, + sink=( + self.sinks.to(torch.float32) if self.sinks is not None else None + ), ) - - if decode_meta := attn_metadata.decode_metadata: - assert attn_type != AttentionType.ENCODER_ONLY, ( - "Encoder-only models should not have decode metadata.") + if decode_meta := attn_metadata.decode_metadata: + assert ( + attn_type != AttentionType.ENCODER_ONLY + ), "Encoder-only models should not have decode metadata." decode_query = query[:num_decode_tokens] # For hybrid Attention (Qwen3-Next if key_cache.is_contiguous(): tmp_block_tables = decode_meta.block_tables else: - tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next - + tmp_block_tables = ( + decode_meta.block_tables * 2 + ) # only test in Qwen3-Next + sig = inspect.signature(xtorch_ops.speculative_attention) if "max_window_size" in sig.parameters: xtorch_ops.speculative_attention( out=output[:num_decode_tokens], - # Only MLA support q len > 1 right now - q=decode_query.unsqueeze(0), - k_cache=key_cache, - v_cache=value_cache, - context_lens_cpu=decode_meta.seq_lens_tensor_cpu, - context_lens_xpu=decode_meta.seq_lens_tensor, - batch_num=decode_meta.block_tables.shape[0], - # TODO (@xyDong23): Support MTP(q lens >1) - qlen=1, - # TODO (@xyDong23): Support max_context_len to (262144) - max_context_len=131072, - head_num=self.num_heads, - head_dim=self.head_size, - scale=0.0, - kv_head_num=self.num_kv_heads, - block_size=key_cache.shape[2], - max_num_blocks_per_seq=decode_meta.block_tables.shape[1], - max_window_size=self.sliding_window if self.sliding_window is not None else -1, - block_tables=tmp_block_tables, - sink = self.sinks.to(torch.float32) if self.sinks is not None else None + # Only MLA support q len > 1 right now + q=decode_query.unsqueeze(0), + k_cache=key_cache, + v_cache=value_cache, + context_lens_cpu=decode_meta.seq_lens_tensor_cpu, + context_lens_xpu=decode_meta.seq_lens_tensor, + batch_num=decode_meta.block_tables.shape[0], + # TODO (@xyDong23): Support MTP(q lens >1) + qlen=1, + # TODO (@xyDong23): Support max_context_len to (262144) + max_context_len=131072, + head_num=self.num_heads, + head_dim=self.head_size, + scale=0.0, + kv_head_num=self.num_kv_heads, + block_size=key_cache.shape[2], + max_num_blocks_per_seq=decode_meta.block_tables.shape[1], + max_window_size=( + self.sliding_window if self.sliding_window is not None else -1 + ), + block_tables=tmp_block_tables, + sink=( + self.sinks.to(torch.float32) if self.sinks is not None else None + ), ) - else: + elif not attn_metadata.is_speculative: xtorch_ops.paged_attention( x=decode_query, k_cache=key_cache, @@ -760,10 +900,38 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): is_context=False, is_causal=True, out=output[:num_decode_tokens], - vo_head_dim=self.head_size - ) + vo_head_dim=self.head_size, + ) + else: + batch_size = attn_metadata.num_decodes + query_seq_len, head_num, head_dim = decode_query.shape + assert query_seq_len % batch_size == 0 + qlen = query_seq_len // batch_size + out = output[:num_decode_tokens] + assert out.is_contiguous() + + xtorch_ops.speculative_attention( + out=out.view(batch_size, qlen, head_num, self.head_size), + q=decode_query.view(batch_size, qlen, head_num, head_dim), + k_cache=key_cache, + v_cache=value_cache, + context_lens_cpu=decode_meta.seq_lens_tensor_cpu, + context_lens_xpu=decode_meta.seq_lens_tensor, + batch_num=batch_size, + qlen=qlen, + max_context_len=decode_meta.max_model_len, + head_num=self.num_heads, + head_dim=self.head_size, + scale=0.0, + kv_head_num=self.num_kv_heads, + block_size=key_cache.shape[2], + max_num_blocks_per_seq=decode_meta.block_tables.shape[1], + block_tables=tmp_block_tables, + ) # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) + + def use_cascade_attention( common_prefix_len: int, query_lens: np.ndarray, @@ -785,7 +953,7 @@ def use_cascade_attention( # NOTE(woosuk): This is the common case. We should return False as soon as # possible to avoid any unnecessary computation. return False - + if common_prefix_len < 256: return False # Cascade attention is currently not supported with these variants. @@ -803,8 +971,12 @@ def use_cascade_attention( num_queries_per_kv = num_query_heads // num_kv_heads # The criteria for using FlashDecoding can be found in the following link: # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 - use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window - and not use_alibi and np.all(query_lens == 1)) + use_flash_decoding = ( + num_queries_per_kv > 1 + and not use_sliding_window + and not use_alibi + and np.all(query_lens == 1) + ) if not use_flash_decoding: # Use cascade attention. return True @@ -826,10 +998,11 @@ def use_cascade_attention( cascade_waves = cdiv(cascade_ctas, num_sms) cascade_time = cascade_waves * num_prefix_tiles - flash_decoding_ctas = (num_reqs * num_kv_heads * - cdiv(num_queries_per_kv, q_tile_size)) + flash_decoding_ctas = ( + num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size) + ) flash_decoding_ctas *= num_prefix_tiles flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) # Use cascade attention if it is faster than FlashDecoding. - return cascade_time < flash_decoding_time \ No newline at end of file + return cascade_time < flash_decoding_time