Initial commit for vLLM-Kunlun Plugin

This commit is contained in:
dongxinyu03
2025-12-10 12:05:39 +08:00
commit c728e52505
131 changed files with 28816 additions and 0 deletions

View File

View File

@@ -0,0 +1,3 @@
# from .backends import KunlunMetadata
# __all__ = ['KunlunMetadata']

View File

@@ -0,0 +1,3 @@
from .kunlun_attn import KunlunMetadata
__all__ = ['KunlunMetadata']

View File

@@ -0,0 +1,706 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Dong Xinyu, Bao Qian, Chen Zhennan, Ma Tianyu, Wang Haowen
# Email: dongxinyu03@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm.config import VllmConfig, get_layers_from_vllm_config
import xtorch_ops
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, ClassVar, Tuple, Type, TYPE_CHECKING
import torch
import numpy as np
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
from vllm_kunlun.ops.paged_attn import (PagedAttention, PagedAttentionMetadata)
from vllm_kunlun.ops._kunlun_ops import KunlunOps
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
AttentionCGSupport,
split_decodes_and_prefills)
from vllm.forward_context import ForwardContext, get_forward_context
from itertools import accumulate
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.config import VllmConfig, get_layers_from_vllm_config
class KunlunAttentionBackend(AttentionBackend):
"""KunlunAttentionBackend"""
# crucial to cuda graph
accept_output_buffer = True
@staticmethod
def get_name() -> str:
"""get_name"""
return "Kunlun_v1"
@staticmethod
def get_impl_cls() -> Type["KunlunAttentionImpl"]:
"""get_impl_cls"""
return KunlunAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["KunlunMetadata"]:
"""get_metadata_cls"""
return KunlunMetadata
@staticmethod
def get_builder_cls() -> Type["KunlunAttentionMetadataBuilder"]:
"""get_builder_cls"""
return KunlunAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
"""get_state_cls"""
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
"""get_kv_cache_shape"""
# return (2, num_blocks, block_size, num_kv_heads * head_size)
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: List[torch.Tensor],
dst_kv_cache: List[torch.Tensor],
src_to_dst: torch.Tensor,
) -> None:
"""swap_blocks"""
raise NotImplementedError
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
"""copy_blocks"""
raise NotImplementedError
@dataclass
class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
"""KunlunMetadata"""
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
num_actual_tokens: int
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor] = None
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# Max number of key/value length in the batch, especially for prefix cache
max_kv_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
query_start_loc_host: Optional[torch.Tensor] = None
# serve only for prefix cache
kv_prefix_start_loc_host: Optional[torch.Tensor] = None
kv_prefix_start_loc: Optional[torch.Tensor] = None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata: Optional["KunlunMetadata"] = None
_cached_decode_metadata: Optional["KunlunMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: Optional[torch.Tensor] = None
use_cascade: Optional[bool] = False
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
def __post_init__(self):
"""__post_init__"""
self.attn_bias: Optional[List[AttentionBias]] = None
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
self.cross_attn_bias: Optional[List[AttentionBias]] = None
@property
def is_all_encoder_attn_metadata_set(self):
"""is_all_encoder_attn_metadata_set"""
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
"""is_all_cross_attn_metadata_set"""
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
"""prefill_metadata"""
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[-(self.num_prefills + 1):] - self.query_start_loc[-(self.num_prefills + 1)])
# flash attention needs both lod information on host and device
query_start_loc_host = (None if self.query_start_loc_host is None else
self.query_start_loc_host[-(self.num_prefills + 1):] - self.query_start_loc_host[-(self.num_prefills + 1)])
# TODO(chengruichang):how to support prefix cache
kv_prefix_start_loc_host = None
kv_prefix_start_loc = None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[-self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[-self.num_prefills:])
seq_lens = (None if self.seq_lens is None else self.seq_lens[-self.num_prefills:])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[-self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[-self.num_prefills:])
input_positions = (None if self.input_positions is None else
self.input_positions[-self.num_prefills:])
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = KunlunMetadata(
num_actual_tokens=self.num_actual_tokens,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
seq_start_loc=None,
max_query_len=self.max_query_len,
max_kv_len=self.max_kv_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
query_start_loc_host=query_start_loc_host,
input_positions=input_positions,
kv_prefix_start_loc=kv_prefix_start_loc,
kv_prefix_start_loc_host=kv_prefix_start_loc_host,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False,
use_cascade=self.use_cascade)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["KunlunMetadata"]:
"""decode_metadata"""
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
if self.num_prefills != 0:
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:-self.num_prefill_tokens])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:-self.num_prefills])
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
self.seq_lens_tensor_cpu[:-self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:-self.num_prefills])
else:
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping)
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor)
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
self.seq_lens_tensor_cpu)
block_tables = (None if self.block_tables is None else
self.block_tables)
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = KunlunMetadata(
num_actual_tokens=self.num_actual_tokens,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens_tensor=seq_lens_tensor,
seq_lens_tensor_cpu=seq_lens_tensor_cpu,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False,
use_cascade=self.use_cascade)
return self._cached_decode_metadata
class KunlunAttentionMetadataBuilder:
"""KunlunAttentionMetadataBuilder"""
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: ClassVar[Optional[int]] = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
"""__init__"""
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.compilation_config = vllm_config.compilation_config
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
self.parallel_config)
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.device = device
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
"""reorder_batch"""
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# TODO: how if a prefilled sequence has only one token
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
num_decodes = len(decodes)
num_prefills = len(prefills)
first_prefill = 0
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i])
first_prefill += 1
modified_batch = True
else:
break
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
"""build"""
num_reqs=common_attn_metadata.num_reqs
num_actual_tokens=common_attn_metadata.num_actual_tokens
max_query_len=common_attn_metadata.max_query_len
common_prefix_len=common_prefix_len
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
query_start_loc = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_start_loc = list(accumulate(seq_lens, initial=0))
if len(seq_start_loc) != num_reqs + 1:
seq_start_loc = query_start_loc_host.tolist()
if seq_start_loc[-1] != num_actual_tokens:
seq_start_loc = query_start_loc_host.tolist()
seq_start_loc_tensor = torch.empty(len(seq_start_loc), dtype=torch.int32, device=self.device)
seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32))
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
num_scheduled_tokens = np.diff(common_attn_metadata.query_start_loc_cpu[:num_reqs + 1])
tmp_decode_scheduled_tokens = num_scheduled_tokens[:num_decodes]
if num_decode_tokens == 0:
max_decode_seq_len = 0
else:
max_decode_seq_len = np.max(tmp_decode_scheduled_tokens)
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs]
if num_prefill_tokens == 0:
max_prefill_seq_len = 0
else:
max_prefill_seq_len = np.max(tmp_prefill_scheduled_tokens)
use_cascade = False
attn_metadata = KunlunMetadata(
num_actual_tokens=num_actual_tokens,
num_prefills=num_prefills,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens_tensor=seq_lens,
seq_lens_tensor_cpu=seq_lens_cpu,
max_query_len=max_prefill_seq_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
query_start_loc_host=query_start_loc_host,
context_lens_tensor=None,
block_tables=block_table_tensor,
use_cuda_graph=False,
use_cascade=use_cascade,
)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
"""can_run_in_cudagraph"""
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool:
"""use_cascade_attention"""
return use_cascade_attention(*args, **kwargs)
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
"""KunlunAttentionImpl"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
kv_sharing_target_layer_name: Optional[str] = None,
attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False,
sinks:Optional[torch.Tensor]= None,
) -> None:
"""__init__"""
if blocksparse_params is not None:
raise ValueError(
"kunlunAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.use_irope = use_irope
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: Optional[torch.Tensor],
value: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: Optional[KunlunMetadata],
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""forward"""
query = query.view(-1, self.num_heads, self.head_size)
if output is None:
output = torch.empty_like(query)
if attn_metadata is None:
# Profiling run.
return output.view(-1, self.num_heads * self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
# Self-attention vs. cross-attention will impact
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
if (key is not None) and (value is not None):
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
value = value.contiguous()
xtorch_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
updated_slot_mapping)
assert attn_type == AttentionType.DECODER
# Decoder self-attention supports chunked prefill.
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
# Only enforce this shape-constraint for decoder
# self-attention
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
assert prefill_query.shape[0] == num_prefill_tokens
output[num_decode_tokens:attn_metadata.num_actual_tokens] = KunlunOps.multi_query_kv_attention(
prefill_meta.query_start_loc,prefill_meta.query_start_loc_host, prefill_query, prefill_key, prefill_value,
alibi_slopes=self.alibi_slopes).view_as(prefill_query)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
decode_query = query[:num_decode_tokens]
xtorch_ops.paged_attention(
x=decode_query,
k_cache=key_cache,
v_cache=value_cache,
block_tables=decode_meta.block_tables,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
is_context=False,
is_causal=True,
out=output[:num_decode_tokens],
vo_head_dim=self.head_size
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def use_cascade_attention(
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
num_sms: int,
use_local_attention: bool = False,
) -> bool:
"""
TODO: Not Yet Supported on Kunlun platform
"""
# Too short common prefix. Probably not worth using cascade attention.
# We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
# NOTE(woosuk): This is the common case. We should return False as soon as
# possible to avoid any unnecessary computation.
if common_prefix_len < 256:
return False
# Cascade attention is currently not supported with these variants.
if use_alibi or use_sliding_window or use_local_attention:
return False
# Too few queries. Probably not worth using cascade attention.
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
num_reqs = len(query_lens)
if num_reqs < 8:
return False
# Heuristics to decide whether using cascade attention is beneficial.
# 1. When FlashDecoding is not used for normal attention, cascade attention
# is likely to be faster since it saves memory bandwidth.
num_queries_per_kv = num_query_heads // num_kv_heads
# The criteria for using FlashDecoding can be found in the following link:
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
and not use_alibi and np.all(query_lens == 1))
if not use_flash_decoding:
# Use cascade attention.
return True
# 2. When FlashDecoding is used for normal attention, it is not clear
# whether cascade attention is beneficial, because FlashDecoding can
# launch more CTAs than cascade attention.
# We use a simple performance model to compare the two methods.
# NOTE(woosuk): The performance model is very rough and may not be
# accurate.
num_tokens = num_reqs
# NOTE(woosuk): These are default tile sizes. flash-attn might use
# different tile sizes (e.g., 64 or 256) depending on the configuration.
q_tile_size = 128
kv_tile_size = 128
num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)
cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
cascade_waves = cdiv(cascade_ctas, num_sms)
cascade_time = cascade_waves * num_prefix_tiles
flash_decoding_ctas = (num_reqs * num_kv_heads *
cdiv(num_queries_per_kv, q_tile_size))
flash_decoding_ctas *= num_prefix_tiles
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
# Use cascade attention if it is faster than FlashDecoding.
return cascade_time < flash_decoding_time

View File

View File

View File

@@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
vocab_size: int,
num_seqs: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=tokens.device)
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0
return bin_counts, mask
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor) -> torch.Tensor:
"""
Applies penalties in place to the logits tensor
logits : The input logits tensor of shape [num_seqs, vocab_size]
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
are padded to the maximum prompt length within the batch using
`vocab_size` as the padding value. The value `vocab_size` is used
for padding because it does not correspond to any valid token ID
in the vocabulary.
output_tokens_tensor: The output tokens tensor.
presence_penalties: The presence penalties of shape (num_seqs, )
frequency_penalties: The frequency penalties of shape (num_seqs, )
repetition_penalties: The repetition penalties of shape (num_seqs, )
"""
num_seqs, vocab_size = logits.shape
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
# Apply repetition penalties as a custom op
from vllm._custom_ops import apply_repetition_penalties_torch
apply_repetition_penalties_torch(logits, prompt_mask, output_mask,
repetition_penalties)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits
def apply_all_penalties(
logits: torch.Tensor,
prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
logits.device)
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
presence_penalties, frequency_penalties,
repetition_penalties)
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)

View File

@@ -0,0 +1,198 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.nn as nn
from packaging import version
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
import xtorch_ops
logger = init_logger(__name__)
class TopKTopPSampler(nn.Module):
"""
Module that performs optional top-k and top-p filtering followed by
weighted random sampling of logits.
Implementations may update the logits tensor in-place.
"""
def __init__(self):
super().__init__()
logger.info_once(
"Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_kunlun
def forward_native(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def forward_kunlun(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
if k is None and p is None:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
if generators:
logger.warning_once("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
# because of slicing operation in logits_processor.
return flashinfer_sample(logits.contiguous(), k, p, generators)
def apply_top_k_top_p(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
if p is None:
if k is None:
return logits
# Avoid sorting vocab for top-k only case.
return apply_top_k_only(logits, k)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def apply_top_k_only(
logits: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
"""
no_top_k_mask = k == logits.shape[1]
# Set non-top-k rows to 1 so that we can gather.
k = k.masked_fill(no_top_k_mask, 1)
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))
return logits
def random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def flashinfer_sample(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the logits using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
assert not (k is None and p is None)
probs = logits.softmax(dim=-1, dtype=torch.float32)
if k is None:
# Top-p only.
next_token_ids = xtorch_ops.top_p_sampling_from_probs(
probs,top_p=p, deterministic=True)
elif p is None:
# Top-k only.
next_token_ids = xtorch_ops.top_k_sampling_from_probs(
probs, top_k=k, deterministic=True)
else:
# Both top-k and top-p.
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
probs, top_k=k, top_p=p, deterministic=True)
return next_token_ids.view(-1)

View File

View File

@@ -0,0 +1,174 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.logger import init_logger
from vllm.utils import cdiv
logger = init_logger(__name__)
class BlockTable:
def __init__(
self,
block_size: int,
max_num_reqs: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
):
self.block_size = block_size
self.max_num_reqs = max_num_reqs
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
self.pin_memory = pin_memory
self.device = device
self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int32,
device=self.device)
def append_row(
self,
block_ids: list[int],
row_idx: int,
) -> None:
if not block_ids:
return
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
def swap_row(self, src: int, tgt: int) -> None:
num_blocks_src = self.num_blocks_per_row[src]
num_blocks_tgt = self.num_blocks_per_row[tgt]
self.num_blocks_per_row[src] = num_blocks_tgt
self.num_blocks_per_row[tgt] = num_blocks_src
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // self.block_size)
block_numbers = self.block_table_np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:req_indices.shape[0]])
def commit_block_table(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)
def commit_slot_mapping(self, num_tokens: int) -> None:
self.slot_mapping[:num_tokens].copy_(
self.slot_mapping_cpu[:num_tokens], non_blocking=True)
def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)
def get_device_tensor(self) -> torch.Tensor:
"""Ruturns the device tensor of the block table."""
return self.block_table
def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table_cpu
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, block_sizes: list[int]) -> None:
self.block_tables = [
BlockTable(block_size, max_num_reqs, cdiv(max_model_len,
block_size),
max_num_batched_tokens, pin_memory, device)
for block_size in block_sizes
]
def append_row(self, block_ids: tuple[list[int], ...],
row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)
def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.move_row(src, tgt)
def swap_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions)
def commit_block_table(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit_block_table(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
for block_table in self.block_tables:
block_table.commit_slot_mapping(num_tokens)
def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]