提交vllm0.11.0开发分支

This commit is contained in:
chenyili
2025-12-10 17:51:24 +08:00
parent deab7dd0b6
commit 7c22d621fb
175 changed files with 31856 additions and 8683 deletions

View File

@@ -1,8 +1,5 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Dong Xinyu, Bao Qian, Chen Zhennan, Ma Tianyu, Wang Haowen
# Email: dongxinyu03@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-kunlun project.
#
from vllm.config import VllmConfig, get_layers_from_vllm_config
import xtorch_ops
from dataclasses import dataclass
@@ -24,8 +23,8 @@ import torch
import numpy as np
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
# from vllm.attention.backends.utils import CommonAttentionState
# from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
from vllm_kunlun.ops.paged_attn import (PagedAttention, PagedAttentionMetadata)
from vllm_kunlun.ops._kunlun_ops import KunlunOps
@@ -45,6 +44,7 @@ from vllm.v1.worker.block_table import BlockTable
from vllm.config import VllmConfig, get_layers_from_vllm_config
class KunlunAttentionBackend(AttentionBackend):
"""KunlunAttentionBackend"""
# crucial to cuda graph
@@ -70,10 +70,10 @@ class KunlunAttentionBackend(AttentionBackend):
"""get_builder_cls"""
return KunlunAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
"""get_state_cls"""
return CommonAttentionState
# @staticmethod
# def get_state_cls() -> Type["CommonAttentionState"]:
# """get_state_cls"""
# return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
@@ -81,6 +81,7 @@ class KunlunAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto"
) -> Tuple[int, ...]:
"""get_kv_cache_shape"""
# return (2, num_blocks, block_size, num_kv_heads * head_size)
@@ -132,7 +133,11 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
slot_mapping: torch.Tensor
block_tables: torch.Tensor
multi_modal_placeholder_index_maps: Optional[torch.Tensor] = None
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
@@ -143,6 +148,11 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# Prefix cache loc
kv_lod_cpu: Optional[torch.Tensor] = None
kv_lod_xpu: Optional[torch.Tensor] = None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor] = None
@@ -181,6 +191,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
enable_kv_scales_calculation: Optional[bool] = False
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
@@ -193,6 +204,11 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
use_cascade: Optional[bool] = False
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
def __post_init__(self):
"""__post_init__"""
@@ -253,6 +269,19 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
input_positions = (None if self.input_positions is None else
self.input_positions[-self.num_prefills:])
if self.kv_lod_cpu is None:
kv_lod_cpu = None
kv_lod_xpu = None
else:
start = -(self.num_prefills + 1)
base_cpu = self.kv_lod_cpu[start]
kv_lod_cpu = self.kv_lod_cpu[start:] - base_cpu
base_xpu = self.kv_lod_xpu[start]
kv_lod_xpu = self.kv_lod_xpu[start:] - base_xpu
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = KunlunMetadata(
num_actual_tokens=self.num_actual_tokens,
@@ -264,7 +293,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
seq_start_loc=None,
seq_start_loc = None,
kv_lod_cpu=kv_lod_cpu,
kv_lod_xpu=kv_lod_xpu,
max_query_len=self.max_query_len,
max_kv_len=self.max_kv_len,
max_prefill_seq_len=self.max_prefill_seq_len,
@@ -413,18 +444,28 @@ class KunlunAttentionMetadataBuilder:
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> KunlunMetadata:
attn_metadata = self.build(0, common_attn_metadata)
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens_tensor.fill_(1)
return attn_metadata
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
"""build"""
num_reqs=common_attn_metadata.num_reqs
num_actual_tokens=common_attn_metadata.num_actual_tokens
max_query_len=common_attn_metadata.max_query_len
common_prefix_len=common_prefix_len
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
common_prefix_len = common_prefix_len
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
query_start_loc = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1].to(
@@ -432,18 +473,18 @@ class KunlunAttentionMetadataBuilder:
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_start_loc = list(accumulate(seq_lens, initial=0))
if len(seq_start_loc) != num_reqs + 1:
seq_start_loc = query_start_loc_host.tolist()
if seq_start_loc[-1] != num_actual_tokens:
seq_start_loc = query_start_loc_host.tolist()
seq_start_loc_tensor = torch.empty(len(seq_start_loc), dtype=torch.int32, device=self.device)
seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32))
kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu")
kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0)
kv_lod_xpu = kv_lod_cpu.to(self.device)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
@@ -456,6 +497,7 @@ class KunlunAttentionMetadataBuilder:
max_decode_seq_len = np.max(tmp_decode_scheduled_tokens)
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs]
if num_prefill_tokens == 0:
max_prefill_seq_len = 0
else:
@@ -473,6 +515,8 @@ class KunlunAttentionMetadataBuilder:
num_decode_tokens=num_decode_tokens,
seq_lens_tensor=seq_lens,
seq_lens_tensor_cpu=seq_lens_cpu,
kv_lod_xpu=kv_lod_xpu,
kv_lod_cpu=kv_lod_cpu,
max_query_len=max_prefill_seq_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
@@ -483,7 +527,6 @@ class KunlunAttentionMetadataBuilder:
use_cuda_graph=False,
use_cascade=use_cascade,
)
return attn_metadata
def can_run_in_cudagraph(
@@ -514,11 +557,15 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False,
sinks:Optional[torch.Tensor]= None,
multi_modal_placeholder_index_maps:Optional[torch.Tensor]= None,
) -> None:
"""__init__"""
if blocksparse_params is not None:
raise ValueError(
"kunlunAttention does not support block-sparse attention.")
# if logits_soft_cap is not None:
# raise ValueError(
# "kunlunAttention does not support attention logits soft capping.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -547,6 +594,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps
def forward(
self,
@@ -560,7 +608,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""forward"""
query = query.view(-1, self.num_heads, self.head_size)
@@ -597,12 +646,22 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
value = value.contiguous()
xtorch_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
updated_slot_mapping)
if key_cache.is_contiguous():
xtorch_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
updated_slot_mapping)
else:
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
xtorch_ops.reshape_and_cache_flash(
key,
value,
cast_key_cache,
cast_value_cache,
updated_slot_mapping)
assert attn_type == AttentionType.DECODER
# Decoder self-attention supports chunked prefill.
@@ -614,22 +673,38 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
assert prefill_query.shape[0] == num_prefill_tokens
output[num_decode_tokens:attn_metadata.num_actual_tokens] = KunlunOps.multi_query_kv_attention(
prefill_meta.query_start_loc,prefill_meta.query_start_loc_host, prefill_query, prefill_key, prefill_value,
alibi_slopes=self.alibi_slopes).view_as(prefill_query)
xtorch_ops.prefill_attention(
q=prefill_query,
k=key_cache, # Key Cache (block_num, head, block_size, dim)
v=value_cache,
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
is_causal=True,
is_prefix_cache=True,
block_table=prefill_meta.block_tables,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc,
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
alibi_slopes=self.alibi_slopes,
softmax_lse=None,
sink=self.sinks
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
decode_query = query[:num_decode_tokens]
if key_cache.is_contiguous():
tmp_block_tables = decode_meta.block_tables
else:
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
xtorch_ops.paged_attention(
x=decode_query,
k_cache=key_cache,
v_cache=value_cache,
block_tables=decode_meta.block_tables,
block_tables=tmp_block_tables,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
is_context=False,
@@ -639,7 +714,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def use_cascade_attention(
common_prefix_len: int,
query_lens: np.ndarray,
@@ -650,13 +724,18 @@ def use_cascade_attention(
num_sms: int,
use_local_attention: bool = False,
) -> bool:
"""
TODO: Not Yet Supported on Kunlun platform
"""Decide whether to use cascade attention.
This function 1) checks whether cascade attention is supported with the
given configuration, and 2) heuristically decides whether using cascade
attention can improve performance.
"""
# Too short common prefix. Probably not worth using cascade attention.
# We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
# NOTE(woosuk): This is the common case. We should return False as soon as
# possible to avoid any unnecessary computation.
return False
if common_prefix_len < 256:
return False
# Cascade attention is currently not supported with these variants.