提交vllm0.11.0开发分支
This commit is contained in:
@@ -1,8 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Dong Xinyu, Bao Qian, Chen Zhennan, Ma Tianyu, Wang Haowen
|
||||
# Email: dongxinyu03@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,6 +12,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
import xtorch_ops
|
||||
from dataclasses import dataclass
|
||||
@@ -24,8 +23,8 @@ import torch
|
||||
import numpy as np
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
|
||||
# from vllm.attention.backends.utils import CommonAttentionState
|
||||
# from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
|
||||
from vllm_kunlun.ops.paged_attn import (PagedAttention, PagedAttentionMetadata)
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps
|
||||
|
||||
@@ -45,6 +44,7 @@ from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
|
||||
|
||||
class KunlunAttentionBackend(AttentionBackend):
|
||||
"""KunlunAttentionBackend"""
|
||||
# crucial to cuda graph
|
||||
@@ -70,10 +70,10 @@ class KunlunAttentionBackend(AttentionBackend):
|
||||
"""get_builder_cls"""
|
||||
return KunlunAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
"""get_state_cls"""
|
||||
return CommonAttentionState
|
||||
# @staticmethod
|
||||
# def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
# """get_state_cls"""
|
||||
# return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
@@ -81,6 +81,7 @@ class KunlunAttentionBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto"
|
||||
) -> Tuple[int, ...]:
|
||||
"""get_kv_cache_shape"""
|
||||
# return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
@@ -132,7 +133,11 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
slot_mapping: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
|
||||
multi_modal_placeholder_index_maps: Optional[torch.Tensor] = None
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
@@ -143,6 +148,11 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
# Prefix cache loc
|
||||
kv_lod_cpu: Optional[torch.Tensor] = None
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None
|
||||
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor] = None
|
||||
@@ -181,6 +191,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
enable_kv_scales_calculation: Optional[bool] = False
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
@@ -193,6 +204,11 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
use_cascade: Optional[bool] = False
|
||||
|
||||
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
"""__post_init__"""
|
||||
@@ -253,6 +269,19 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[-self.num_prefills:])
|
||||
|
||||
|
||||
if self.kv_lod_cpu is None:
|
||||
kv_lod_cpu = None
|
||||
kv_lod_xpu = None
|
||||
else:
|
||||
start = -(self.num_prefills + 1)
|
||||
base_cpu = self.kv_lod_cpu[start]
|
||||
kv_lod_cpu = self.kv_lod_cpu[start:] - base_cpu
|
||||
|
||||
base_xpu = self.kv_lod_xpu[start]
|
||||
kv_lod_xpu = self.kv_lod_xpu[start:] - base_xpu
|
||||
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = KunlunMetadata(
|
||||
num_actual_tokens=self.num_actual_tokens,
|
||||
@@ -264,7 +293,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
seq_start_loc=None,
|
||||
seq_start_loc = None,
|
||||
kv_lod_cpu=kv_lod_cpu,
|
||||
kv_lod_xpu=kv_lod_xpu,
|
||||
max_query_len=self.max_query_len,
|
||||
max_kv_len=self.max_kv_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
@@ -413,18 +444,28 @@ class KunlunAttentionMetadataBuilder:
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
return modified_batch
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> KunlunMetadata:
|
||||
attn_metadata = self.build(0, common_attn_metadata)
|
||||
# When doing full graph capture, setting seq_lens to
|
||||
# max_model_len will cause graph capture to be extremely
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens_tensor.fill_(1)
|
||||
return attn_metadata
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
"""build"""
|
||||
num_reqs=common_attn_metadata.num_reqs
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens
|
||||
max_query_len=common_attn_metadata.max_query_len
|
||||
common_prefix_len=common_prefix_len
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
common_prefix_len = common_prefix_len
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
|
||||
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
@@ -432,18 +473,18 @@ class KunlunAttentionMetadataBuilder:
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
|
||||
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
if len(seq_start_loc) != num_reqs + 1:
|
||||
seq_start_loc = query_start_loc_host.tolist()
|
||||
|
||||
if seq_start_loc[-1] != num_actual_tokens:
|
||||
seq_start_loc = query_start_loc_host.tolist()
|
||||
|
||||
|
||||
|
||||
|
||||
seq_start_loc_tensor = torch.empty(len(seq_start_loc), dtype=torch.int32, device=self.device)
|
||||
seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32))
|
||||
|
||||
kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu")
|
||||
kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0)
|
||||
kv_lod_xpu = kv_lod_cpu.to(self.device)
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
|
||||
@@ -456,6 +497,7 @@ class KunlunAttentionMetadataBuilder:
|
||||
max_decode_seq_len = np.max(tmp_decode_scheduled_tokens)
|
||||
|
||||
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs]
|
||||
|
||||
if num_prefill_tokens == 0:
|
||||
max_prefill_seq_len = 0
|
||||
else:
|
||||
@@ -473,6 +515,8 @@ class KunlunAttentionMetadataBuilder:
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens_tensor=seq_lens,
|
||||
seq_lens_tensor_cpu=seq_lens_cpu,
|
||||
kv_lod_xpu=kv_lod_xpu,
|
||||
kv_lod_cpu=kv_lod_cpu,
|
||||
max_query_len=max_prefill_seq_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
@@ -483,7 +527,6 @@ class KunlunAttentionMetadataBuilder:
|
||||
use_cuda_graph=False,
|
||||
use_cascade=use_cascade,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
@@ -514,11 +557,15 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_irope: bool = False,
|
||||
sinks:Optional[torch.Tensor]= None,
|
||||
multi_modal_placeholder_index_maps:Optional[torch.Tensor]= None,
|
||||
) -> None:
|
||||
"""__init__"""
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"kunlunAttention does not support block-sparse attention.")
|
||||
# if logits_soft_cap is not None:
|
||||
# raise ValueError(
|
||||
# "kunlunAttention does not support attention logits soft capping.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -547,6 +594,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}.")
|
||||
self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -560,7 +608,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""forward"""
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
@@ -597,12 +646,22 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory
|
||||
value = value.contiguous()
|
||||
xtorch_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping)
|
||||
if key_cache.is_contiguous():
|
||||
xtorch_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping)
|
||||
else:
|
||||
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
|
||||
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
|
||||
xtorch_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
cast_key_cache,
|
||||
cast_value_cache,
|
||||
updated_slot_mapping)
|
||||
|
||||
assert attn_type == AttentionType.DECODER
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
@@ -614,22 +673,38 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
assert prefill_query.shape[0] == num_prefill_tokens
|
||||
output[num_decode_tokens:attn_metadata.num_actual_tokens] = KunlunOps.multi_query_kv_attention(
|
||||
prefill_meta.query_start_loc,prefill_meta.query_start_loc_host, prefill_query, prefill_key, prefill_value,
|
||||
alibi_slopes=self.alibi_slopes).view_as(prefill_query)
|
||||
xtorch_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=key_cache, # Key Cache (block_num, head, block_size, dim)
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||
is_causal=True,
|
||||
is_prefix_cache=True,
|
||||
block_table=prefill_meta.block_tables,
|
||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
||||
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_lse=None,
|
||||
sink=self.sinks
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
decode_query = query[:num_decode_tokens]
|
||||
|
||||
|
||||
if key_cache.is_contiguous():
|
||||
tmp_block_tables = decode_meta.block_tables
|
||||
else:
|
||||
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
|
||||
|
||||
xtorch_ops.paged_attention(
|
||||
x=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_tables=decode_meta.block_tables,
|
||||
block_tables=tmp_block_tables,
|
||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||
is_context=False,
|
||||
@@ -639,7 +714,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
)
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def use_cascade_attention(
|
||||
common_prefix_len: int,
|
||||
query_lens: np.ndarray,
|
||||
@@ -650,13 +724,18 @@ def use_cascade_attention(
|
||||
num_sms: int,
|
||||
use_local_attention: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
TODO: Not Yet Supported on Kunlun platform
|
||||
"""Decide whether to use cascade attention.
|
||||
|
||||
This function 1) checks whether cascade attention is supported with the
|
||||
given configuration, and 2) heuristically decides whether using cascade
|
||||
attention can improve performance.
|
||||
"""
|
||||
# Too short common prefix. Probably not worth using cascade attention.
|
||||
# We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
|
||||
# NOTE(woosuk): This is the common case. We should return False as soon as
|
||||
# possible to avoid any unnecessary computation.
|
||||
return False
|
||||
|
||||
if common_prefix_len < 256:
|
||||
return False
|
||||
# Cascade attention is currently not supported with these variants.
|
||||
|
||||
Reference in New Issue
Block a user