support glm4.7 mtp (#187)

Signed-off-by: chengxiaokang <chengxiaokang@baidu.com>
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-02-11 18:32:30 +08:00
committed by GitHub
parent bd8c999335
commit fc48b79ae9

View File

@@ -14,39 +14,53 @@
# limitations under the License.
# This file is a part of the vllm-kunlun project.
#
from vllm.config import VllmConfig, get_layers_from_vllm_config
import xtorch_ops
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, ClassVar, Tuple, Type, TYPE_CHECKING
from itertools import accumulate
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
)
import torch
import numpy as np
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionLayer, AttentionType)
import torch
import xtorch_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionType,
)
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
CommonAttentionMetadata,
split_decodes_and_prefills,
)
# from vllm.attention.backends.utils import CommonAttentionState
# from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
from vllm_kunlun.ops.paged_attn import (PagedAttention, PagedAttentionMetadata)
from vllm_kunlun.ops._kunlun_ops import KunlunOps
from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
AttentionCGSupport,
split_decodes_and_prefills)
from vllm.forward_context import ForwardContext, get_forward_context
from itertools import accumulate
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
import inspect
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.config import VllmConfig, get_layers_from_vllm_config
import inspect
class KunlunAttentionBackend(AttentionBackend):
"""KunlunAttentionBackend"""
# crucial to cuda graph
accept_output_buffer = True
@@ -81,12 +95,13 @@ class KunlunAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto"
cache_dtype_str: str = "auto",
) -> Tuple[int, ...]:
"""get_kv_cache_shape"""
# return (2, num_blocks, block_size, num_kv_heads * head_size)
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
return PagedAttention.get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size
)
@staticmethod
def swap_blocks(
@@ -104,13 +119,12 @@ class KunlunAttentionBackend(AttentionBackend):
) -> None:
"""copy_blocks"""
raise NotImplementedError
@dataclass
class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
"""KunlunMetadata"""
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
@@ -133,7 +147,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
slot_mapping: torch.Tensor
block_tables: torch.Tensor
@@ -203,11 +217,13 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
use_cascade: Optional[bool] = False
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
is_speculative: Optional[bool] = False
max_model_len: int = 0
def __post_init__(self):
"""__post_init__"""
@@ -218,16 +234,20 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
@property
def is_all_encoder_attn_metadata_set(self):
"""is_all_encoder_attn_metadata_set"""
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
return (
(self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None)
)
@property
def is_all_cross_attn_metadata_set(self):
"""is_all_cross_attn_metadata_set"""
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
return (
self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None)
)
@property
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
@@ -240,35 +260,60 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# metadata structure
return self._cached_prefill_metadata
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
assert (self.seq_lens_tensor is not None) or (
self.encoder_seq_lens_tensor is not None
)
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[-(self.num_prefills + 1):] - self.query_start_loc[-(self.num_prefills + 1)])
query_start_loc = (
None
if self.query_start_loc is None
else self.query_start_loc[-(self.num_prefills + 1) :]
- self.query_start_loc[-(self.num_prefills + 1)]
)
# flash attention needs both lod information on host and device
query_start_loc_host = (None if self.query_start_loc_host is None else
self.query_start_loc_host[-(self.num_prefills + 1):] - self.query_start_loc_host[-(self.num_prefills + 1)])
query_start_loc_host = (
None
if self.query_start_loc_host is None
else self.query_start_loc_host[-(self.num_prefills + 1) :]
- self.query_start_loc_host[-(self.num_prefills + 1)]
)
# TODO(chengruichang):how to support prefix cache
kv_prefix_start_loc_host = None
kv_prefix_start_loc = None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[-self.num_prefill_tokens:])
slot_mapping = (
None
if self.slot_mapping is None
else self.slot_mapping[-self.num_prefill_tokens :]
)
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[-self.num_prefills:])
seq_lens = (None if self.seq_lens is None else self.seq_lens[-self.num_prefills:])
seq_lens_tensor = (
None
if self.seq_lens_tensor is None
else self.seq_lens_tensor[-self.num_prefills :]
)
seq_lens = (
None if self.seq_lens is None else self.seq_lens[-self.num_prefills :]
)
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[-self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[-self.num_prefills:])
input_positions = (None if self.input_positions is None else
self.input_positions[-self.num_prefills:])
context_lens_tensor = (
None
if self.context_lens_tensor is None
else self.context_lens_tensor[-self.num_prefills :]
)
block_tables = (
None
if self.block_tables is None
else self.block_tables[-self.num_prefills :]
)
input_positions = (
None
if self.input_positions is None
else self.input_positions[-self.num_prefills :]
)
if self.kv_lod_cpu is None:
kv_lod_cpu = None
kv_lod_xpu = None
@@ -280,19 +325,17 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
base_xpu = self.kv_lod_xpu[start]
kv_lod_xpu = self.kv_lod_xpu[start:] - base_xpu
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = KunlunMetadata(
num_actual_tokens=self.num_actual_tokens,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
seq_start_loc = None,
seq_start_loc=None,
kv_lod_cpu=kv_lod_cpu,
kv_lod_xpu=kv_lod_xpu,
max_query_len=self.max_query_len,
@@ -314,7 +357,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False,
use_cascade=self.use_cascade)
use_cascade=self.use_cascade,
is_speculative=self.is_speculative,
)
return self._cached_prefill_metadata
@property
@@ -327,40 +372,47 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
assert (self.seq_lens_tensor is not None) or (
self.encoder_seq_lens_tensor is not None
)
if self.num_prefills != 0:
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:-self.num_prefill_tokens])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:-self.num_prefills])
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
self.seq_lens_tensor_cpu[:-self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:-self.num_prefills])
slot_mapping = (
None
if self.slot_mapping is None
else self.slot_mapping[: -self.num_prefill_tokens]
)
seq_lens_tensor = (
None
if self.seq_lens_tensor is None
else self.seq_lens_tensor[: -self.num_prefills]
)
seq_lens_tensor_cpu = (
None
if self.seq_lens_tensor_cpu is None
else self.seq_lens_tensor_cpu[: -self.num_prefills]
)
block_tables = (
None
if self.block_tables is None
else self.block_tables[: -self.num_prefills]
)
else:
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping)
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor)
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
self.seq_lens_tensor_cpu)
block_tables = (None if self.block_tables is None else
self.block_tables)
slot_mapping = None if self.slot_mapping is None else self.slot_mapping
seq_lens_tensor = (
None if self.seq_lens_tensor is None else self.seq_lens_tensor
)
seq_lens_tensor_cpu = (
None if self.seq_lens_tensor_cpu is None else self.seq_lens_tensor_cpu
)
block_tables = None if self.block_tables is None else self.block_tables
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = KunlunMetadata(
num_actual_tokens=self.num_actual_tokens,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
@@ -378,19 +430,29 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False,
use_cascade=self.use_cascade)
use_cascade=self.use_cascade,
is_speculative=self.is_speculative,
max_model_len=self.max_model_len,
)
return self._cached_decode_metadata
M = TypeVar("M")
class KunlunAttentionMetadataBuilder:
"""KunlunAttentionMetadataBuilder"""
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: ClassVar[Optional[int]] = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
"""__init__"""
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
@@ -398,17 +460,45 @@ class KunlunAttentionMetadataBuilder:
self.compilation_config = vllm_config.compilation_config
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
self.parallel_config)
self.parallel_config
)
self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.device = device
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int | None = 1,
supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False,
) -> None:
self.reorder_batch_threshold = reorder_batch_threshold
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
# If the backend supports spec-as-decode kernels, then we can set
# the reorder_batch_threshold based on the number of speculative
# tokens from the config.
speculative_config = self.vllm_config.speculative_config
if (
speculative_config is not None
and speculative_config.num_speculative_tokens is not None
):
self.reorder_batch_threshold = max(
self.reorder_batch_threshold,
1 + speculative_config.num_speculative_tokens,
)
if (
self.vllm_config.parallel_config.decode_context_parallel_size > 1
and not supports_dcp_with_varlen
):
self.reorder_batch_threshold = 1
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
"""reorder_batch"""
decodes = []
prefills = []
@@ -432,8 +522,9 @@ class KunlunAttentionMetadataBuilder:
for i in range(1, min(num_decodes, num_prefills) + 1):
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i])
input_batch.swap_states(
prefills[first_prefill], decodes[num_decodes - i]
)
first_prefill += 1
modified_batch = True
else:
@@ -443,7 +534,7 @@ class KunlunAttentionMetadataBuilder:
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> KunlunMetadata:
@@ -454,8 +545,30 @@ class KunlunAttentionMetadataBuilder:
attn_metadata.seq_lens_tensor.fill_(1)
return attn_metadata
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> M:
"""
Build attention metadata for draft model. Uses build by default.
Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
When speculating a chain of tokens, this index refers to the
draft attempt for the i-th token.
For tree-based attention, this index instead refers to the
draft attempt for the i-th level in the tree of tokens.
"""
return self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
def build(
self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata
):
"""build"""
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
@@ -464,30 +577,38 @@ class KunlunAttentionMetadataBuilder:
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
query_start_loc = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to(
self.device, non_blocking=True
)
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_start_loc = list(accumulate(seq_lens, initial=0))
seq_start_loc_tensor = torch.empty(len(seq_start_loc), dtype=torch.int32, device=self.device)
seq_start_loc_tensor = torch.empty(
len(seq_start_loc), dtype=torch.int32, device=self.device
)
seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32))
kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu")
kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0)
kv_lod_xpu = kv_lod_cpu.to(self.device)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
num_scheduled_tokens = np.diff(common_attn_metadata.query_start_loc_cpu[:num_reqs + 1])
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold or 1,
require_uniform=True,
)
)
num_scheduled_tokens = np.diff(
common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
)
tmp_decode_scheduled_tokens = num_scheduled_tokens[:num_decodes]
if num_decode_tokens == 0:
@@ -495,18 +616,19 @@ class KunlunAttentionMetadataBuilder:
else:
max_decode_seq_len = np.max(tmp_decode_scheduled_tokens)
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs]
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes:num_reqs]
if num_prefill_tokens == 0:
max_prefill_seq_len = 0
else:
max_prefill_seq_len = np.max(tmp_prefill_scheduled_tokens)
use_cascade = False
attn_metadata = KunlunMetadata(
num_actual_tokens=num_actual_tokens,
num_prefills=num_prefills,
num_decodes=num_decodes,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
@@ -525,11 +647,14 @@ class KunlunAttentionMetadataBuilder:
block_tables=block_table_tensor,
use_cuda_graph=False,
use_cascade=use_cascade,
is_speculative=self.reorder_batch_threshold > 1,
max_model_len=self.vllm_config.model_config.max_model_len,
)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
self, common_attn_metadata: CommonAttentionMetadata
) -> bool:
"""can_run_in_cudagraph"""
# Full CUDA Graph always supported (FA2 support checked separately)
return True
@@ -538,6 +663,7 @@ class KunlunAttentionMetadataBuilder:
"""use_cascade_attention"""
return use_cascade_attention(*args, **kwargs)
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
"""KunlunAttentionImpl"""
@@ -555,13 +681,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
kv_sharing_target_layer_name: Optional[str] = None,
attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False,
sinks:Optional[torch.Tensor]= None,
multi_modal_placeholder_index_maps:Optional[torch.Tensor]= None,
sinks: Optional[torch.Tensor] = None,
multi_modal_placeholder_index_maps: Optional[torch.Tensor] = None,
) -> None:
"""__init__"""
if blocksparse_params is not None:
raise ValueError(
"kunlunAttention does not support block-sparse attention.")
raise ValueError("kunlunAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -582,15 +707,17 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
f"Supported head sizes are: {suppored_head_sizes}."
)
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps
f"num_heads: {num_heads}."
)
self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps
def forward(
self,
@@ -605,7 +732,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
attn_type: AttentionType = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""forward"""
query = query.view(-1, self.num_heads, self.head_size)
@@ -624,7 +751,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# Self-attention vs. cross-attention will impact
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
@@ -633,7 +760,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
kv_cache, self.num_kv_heads, self.head_size
)
if (key is not None) and (value is not None):
updated_slot_mapping = attn_metadata.slot_mapping
@@ -644,11 +772,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
value = value.contiguous()
if key_cache.is_contiguous():
xtorch_ops.reshape_and_cache(
key,
value,
key[: attn_metadata.num_actual_tokens],
value[: attn_metadata.num_actual_tokens],
key_cache,
value_cache,
updated_slot_mapping)
updated_slot_mapping,
)
else:
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
@@ -657,7 +786,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
value,
cast_key_cache,
cast_value_cache,
updated_slot_mapping)
updated_slot_mapping,
)
assert attn_type == AttentionType.DECODER
# Decoder self-attention supports chunked prefill.
@@ -668,88 +798,98 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_query = query[num_decode_tokens : attn_metadata.num_actual_tokens]
prefill_key = key[num_decode_tokens : attn_metadata.num_actual_tokens]
prefill_value = value[num_decode_tokens : attn_metadata.num_actual_tokens]
# For hybrid Attention (Qwen3-Next.)
if key_cache.is_contiguous():
tmp_block_tables = prefill_meta.block_tables
else:
# For hybrid Attention (Qwen3-Next)
tmp_block_tables = prefill_meta.block_tables * 2
tmp_block_tables = prefill_meta.block_tables * 2
# Prefix cache
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
xtorch_ops.prefill_attention(
q=prefill_query,
k=key_cache, # Key Cache [block_num, head, block_size, dim]
k=key_cache, # Key Cache [block_num, head, block_size, dim]
v=value_cache,
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
out=output[num_decode_tokens : attn_metadata.num_actual_tokens],
is_causal=True,
is_prefix_cache=True,
block_table=tmp_block_tables,
is_prefix_cache=True,
block_table=tmp_block_tables,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc,
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
alibi_slopes=self.alibi_slopes,
softmax_lse=None
softmax_lse=None,
)
else:
xtorch_ops.prefill_attention(
q=prefill_query,
k=prefill_key,
v=prefill_value,
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
out=output[num_decode_tokens : attn_metadata.num_actual_tokens],
is_causal=True,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc,
alibi_slopes=self.alibi_slopes,
softmax_lse=None,
swa_left = self.sliding_window if self.sliding_window is not None else -1,
swa_right = 0 if self.sliding_window is not None else -1,
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
softmax_lse=None,
swa_left=(
self.sliding_window if self.sliding_window is not None else -1
),
swa_right=0 if self.sliding_window is not None else -1,
sink=(
self.sinks.to(torch.float32) if self.sinks is not None else None
),
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
if decode_meta := attn_metadata.decode_metadata:
assert (
attn_type != AttentionType.ENCODER_ONLY
), "Encoder-only models should not have decode metadata."
decode_query = query[:num_decode_tokens]
# For hybrid Attention (Qwen3-Next
if key_cache.is_contiguous():
tmp_block_tables = decode_meta.block_tables
else:
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
tmp_block_tables = (
decode_meta.block_tables * 2
) # only test in Qwen3-Next
sig = inspect.signature(xtorch_ops.speculative_attention)
if "max_window_size" in sig.parameters:
xtorch_ops.speculative_attention(
out=output[:num_decode_tokens],
# Only MLA support q len > 1 right now
q=decode_query.unsqueeze(0),
k_cache=key_cache,
v_cache=value_cache,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
batch_num=decode_meta.block_tables.shape[0],
# TODO (@xyDong23): Support MTP(q lens >1)
qlen=1,
# TODO (@xyDong23): Support max_context_len to (262144)
max_context_len=131072,
head_num=self.num_heads,
head_dim=self.head_size,
scale=0.0,
kv_head_num=self.num_kv_heads,
block_size=key_cache.shape[2],
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
max_window_size=self.sliding_window if self.sliding_window is not None else -1,
block_tables=tmp_block_tables,
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
# Only MLA support q len > 1 right now
q=decode_query.unsqueeze(0),
k_cache=key_cache,
v_cache=value_cache,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
batch_num=decode_meta.block_tables.shape[0],
# TODO (@xyDong23): Support MTP(q lens >1)
qlen=1,
# TODO (@xyDong23): Support max_context_len to (262144)
max_context_len=131072,
head_num=self.num_heads,
head_dim=self.head_size,
scale=0.0,
kv_head_num=self.num_kv_heads,
block_size=key_cache.shape[2],
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
max_window_size=(
self.sliding_window if self.sliding_window is not None else -1
),
block_tables=tmp_block_tables,
sink=(
self.sinks.to(torch.float32) if self.sinks is not None else None
),
)
else:
elif not attn_metadata.is_speculative:
xtorch_ops.paged_attention(
x=decode_query,
k_cache=key_cache,
@@ -760,10 +900,38 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
is_context=False,
is_causal=True,
out=output[:num_decode_tokens],
vo_head_dim=self.head_size
)
vo_head_dim=self.head_size,
)
else:
batch_size = attn_metadata.num_decodes
query_seq_len, head_num, head_dim = decode_query.shape
assert query_seq_len % batch_size == 0
qlen = query_seq_len // batch_size
out = output[:num_decode_tokens]
assert out.is_contiguous()
xtorch_ops.speculative_attention(
out=out.view(batch_size, qlen, head_num, self.head_size),
q=decode_query.view(batch_size, qlen, head_num, head_dim),
k_cache=key_cache,
v_cache=value_cache,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
batch_num=batch_size,
qlen=qlen,
max_context_len=decode_meta.max_model_len,
head_num=self.num_heads,
head_dim=self.head_size,
scale=0.0,
kv_head_num=self.num_kv_heads,
block_size=key_cache.shape[2],
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
block_tables=tmp_block_tables,
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def use_cascade_attention(
common_prefix_len: int,
query_lens: np.ndarray,
@@ -785,7 +953,7 @@ def use_cascade_attention(
# NOTE(woosuk): This is the common case. We should return False as soon as
# possible to avoid any unnecessary computation.
return False
if common_prefix_len < 256:
return False
# Cascade attention is currently not supported with these variants.
@@ -803,8 +971,12 @@ def use_cascade_attention(
num_queries_per_kv = num_query_heads // num_kv_heads
# The criteria for using FlashDecoding can be found in the following link:
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
and not use_alibi and np.all(query_lens == 1))
use_flash_decoding = (
num_queries_per_kv > 1
and not use_sliding_window
and not use_alibi
and np.all(query_lens == 1)
)
if not use_flash_decoding:
# Use cascade attention.
return True
@@ -826,10 +998,11 @@ def use_cascade_attention(
cascade_waves = cdiv(cascade_ctas, num_sms)
cascade_time = cascade_waves * num_prefix_tiles
flash_decoding_ctas = (num_reqs * num_kv_heads *
cdiv(num_queries_per_kv, q_tile_size))
flash_decoding_ctas = (
num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size)
)
flash_decoding_ctas *= num_prefix_tiles
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
# Use cascade attention if it is faster than FlashDecoding.
return cascade_time < flash_decoding_time
return cascade_time < flash_decoding_time