support glm4.7 mtp (#187)
Signed-off-by: chengxiaokang <chengxiaokang@baidu.com> Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -14,39 +14,53 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
import xtorch_ops
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, ClassVar, Tuple, Type, TYPE_CHECKING
|
||||
from itertools import accumulate
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionLayer, AttentionType)
|
||||
import torch
|
||||
import xtorch_ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
|
||||
# from vllm.attention.backends.utils import CommonAttentionState
|
||||
# from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
|
||||
from vllm_kunlun.ops.paged_attn import (PagedAttention, PagedAttentionMetadata)
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps
|
||||
from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata
|
||||
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
AttentionCGSupport,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from itertools import accumulate
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
import inspect
|
||||
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
import inspect
|
||||
|
||||
class KunlunAttentionBackend(AttentionBackend):
|
||||
"""KunlunAttentionBackend"""
|
||||
|
||||
# crucial to cuda graph
|
||||
accept_output_buffer = True
|
||||
|
||||
@@ -81,12 +95,13 @@ class KunlunAttentionBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto"
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> Tuple[int, ...]:
|
||||
"""get_kv_cache_shape"""
|
||||
# return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
return PagedAttention.get_kv_cache_shape(
|
||||
num_blocks, block_size, num_kv_heads, head_size
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@@ -104,13 +119,12 @@ class KunlunAttentionBackend(AttentionBackend):
|
||||
) -> None:
|
||||
"""copy_blocks"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""KunlunMetadata"""
|
||||
|
||||
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
@@ -133,7 +147,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
|
||||
slot_mapping: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
|
||||
@@ -203,11 +217,13 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
use_cascade: Optional[bool] = False
|
||||
|
||||
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
is_speculative: Optional[bool] = False
|
||||
max_model_len: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
"""__post_init__"""
|
||||
@@ -218,16 +234,20 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
"""is_all_encoder_attn_metadata_set"""
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
return (
|
||||
(self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None)
|
||||
)
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
"""is_all_cross_attn_metadata_set"""
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
return (
|
||||
self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None)
|
||||
)
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
|
||||
@@ -240,35 +260,60 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
assert (self.seq_lens_tensor is not None) or (
|
||||
self.encoder_seq_lens_tensor is not None
|
||||
)
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[-(self.num_prefills + 1):] - self.query_start_loc[-(self.num_prefills + 1)])
|
||||
query_start_loc = (
|
||||
None
|
||||
if self.query_start_loc is None
|
||||
else self.query_start_loc[-(self.num_prefills + 1) :]
|
||||
- self.query_start_loc[-(self.num_prefills + 1)]
|
||||
)
|
||||
# flash attention needs both lod information on host and device
|
||||
query_start_loc_host = (None if self.query_start_loc_host is None else
|
||||
self.query_start_loc_host[-(self.num_prefills + 1):] - self.query_start_loc_host[-(self.num_prefills + 1)])
|
||||
|
||||
query_start_loc_host = (
|
||||
None
|
||||
if self.query_start_loc_host is None
|
||||
else self.query_start_loc_host[-(self.num_prefills + 1) :]
|
||||
- self.query_start_loc_host[-(self.num_prefills + 1)]
|
||||
)
|
||||
|
||||
# TODO(chengruichang):how to support prefix cache
|
||||
kv_prefix_start_loc_host = None
|
||||
kv_prefix_start_loc = None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[-self.num_prefill_tokens:])
|
||||
slot_mapping = (
|
||||
None
|
||||
if self.slot_mapping is None
|
||||
else self.slot_mapping[-self.num_prefill_tokens :]
|
||||
)
|
||||
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[-self.num_prefills:])
|
||||
seq_lens = (None if self.seq_lens is None else self.seq_lens[-self.num_prefills:])
|
||||
seq_lens_tensor = (
|
||||
None
|
||||
if self.seq_lens_tensor is None
|
||||
else self.seq_lens_tensor[-self.num_prefills :]
|
||||
)
|
||||
seq_lens = (
|
||||
None if self.seq_lens is None else self.seq_lens[-self.num_prefills :]
|
||||
)
|
||||
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[-self.num_prefills:])
|
||||
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[-self.num_prefills:])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[-self.num_prefills:])
|
||||
context_lens_tensor = (
|
||||
None
|
||||
if self.context_lens_tensor is None
|
||||
else self.context_lens_tensor[-self.num_prefills :]
|
||||
)
|
||||
|
||||
block_tables = (
|
||||
None
|
||||
if self.block_tables is None
|
||||
else self.block_tables[-self.num_prefills :]
|
||||
)
|
||||
input_positions = (
|
||||
None
|
||||
if self.input_positions is None
|
||||
else self.input_positions[-self.num_prefills :]
|
||||
)
|
||||
|
||||
|
||||
if self.kv_lod_cpu is None:
|
||||
kv_lod_cpu = None
|
||||
kv_lod_xpu = None
|
||||
@@ -280,19 +325,17 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
base_xpu = self.kv_lod_xpu[start]
|
||||
kv_lod_xpu = self.kv_lod_xpu[start:] - base_xpu
|
||||
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = KunlunMetadata(
|
||||
num_actual_tokens=self.num_actual_tokens,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
seq_start_loc = None,
|
||||
seq_start_loc=None,
|
||||
kv_lod_cpu=kv_lod_cpu,
|
||||
kv_lod_xpu=kv_lod_xpu,
|
||||
max_query_len=self.max_query_len,
|
||||
@@ -314,7 +357,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False,
|
||||
use_cascade=self.use_cascade)
|
||||
use_cascade=self.use_cascade,
|
||||
is_speculative=self.is_speculative,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
@@ -327,40 +372,47 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
assert (self.seq_lens_tensor is not None) or (
|
||||
self.encoder_seq_lens_tensor is not None
|
||||
)
|
||||
|
||||
if self.num_prefills != 0:
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:-self.num_prefill_tokens])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:-self.num_prefills])
|
||||
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
|
||||
self.seq_lens_tensor_cpu[:-self.num_prefills])
|
||||
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:-self.num_prefills])
|
||||
slot_mapping = (
|
||||
None
|
||||
if self.slot_mapping is None
|
||||
else self.slot_mapping[: -self.num_prefill_tokens]
|
||||
)
|
||||
seq_lens_tensor = (
|
||||
None
|
||||
if self.seq_lens_tensor is None
|
||||
else self.seq_lens_tensor[: -self.num_prefills]
|
||||
)
|
||||
seq_lens_tensor_cpu = (
|
||||
None
|
||||
if self.seq_lens_tensor_cpu is None
|
||||
else self.seq_lens_tensor_cpu[: -self.num_prefills]
|
||||
)
|
||||
block_tables = (
|
||||
None
|
||||
if self.block_tables is None
|
||||
else self.block_tables[: -self.num_prefills]
|
||||
)
|
||||
else:
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping)
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor)
|
||||
|
||||
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
|
||||
self.seq_lens_tensor_cpu)
|
||||
|
||||
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables)
|
||||
|
||||
slot_mapping = None if self.slot_mapping is None else self.slot_mapping
|
||||
seq_lens_tensor = (
|
||||
None if self.seq_lens_tensor is None else self.seq_lens_tensor
|
||||
)
|
||||
seq_lens_tensor_cpu = (
|
||||
None if self.seq_lens_tensor_cpu is None else self.seq_lens_tensor_cpu
|
||||
)
|
||||
block_tables = None if self.block_tables is None else self.block_tables
|
||||
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = KunlunMetadata(
|
||||
num_actual_tokens=self.num_actual_tokens,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
@@ -378,19 +430,29 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False,
|
||||
use_cascade=self.use_cascade)
|
||||
use_cascade=self.use_cascade,
|
||||
is_speculative=self.is_speculative,
|
||||
max_model_len=self.max_model_len,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
class KunlunAttentionMetadataBuilder:
|
||||
"""KunlunAttentionMetadataBuilder"""
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
reorder_batch_threshold: ClassVar[Optional[int]] = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
"""__init__"""
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
@@ -398,17 +460,45 @@ class KunlunAttentionMetadataBuilder:
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
self.num_heads_kv = self.model_config.get_num_kv_heads(
|
||||
self.parallel_config)
|
||||
self.parallel_config
|
||||
)
|
||||
self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
self.headdim = self.model_config.get_head_size()
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.device = device
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
def _init_reorder_batch_threshold(
|
||||
self,
|
||||
reorder_batch_threshold: int | None = 1,
|
||||
supports_spec_as_decode: bool = False,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
) -> None:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold
|
||||
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
|
||||
# If the backend supports spec-as-decode kernels, then we can set
|
||||
# the reorder_batch_threshold based on the number of speculative
|
||||
# tokens from the config.
|
||||
speculative_config = self.vllm_config.speculative_config
|
||||
if (
|
||||
speculative_config is not None
|
||||
and speculative_config.num_speculative_tokens is not None
|
||||
):
|
||||
self.reorder_batch_threshold = max(
|
||||
self.reorder_batch_threshold,
|
||||
1 + speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
if (
|
||||
self.vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||
and not supports_dcp_with_varlen
|
||||
):
|
||||
self.reorder_batch_threshold = 1
|
||||
|
||||
def reorder_batch(
|
||||
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
|
||||
) -> bool:
|
||||
"""reorder_batch"""
|
||||
decodes = []
|
||||
prefills = []
|
||||
@@ -432,8 +522,9 @@ class KunlunAttentionMetadataBuilder:
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
if decodes[num_decodes - i] >= num_decodes:
|
||||
input_batch.swap_states(prefills[first_prefill],
|
||||
decodes[num_decodes - i])
|
||||
input_batch.swap_states(
|
||||
prefills[first_prefill], decodes[num_decodes - i]
|
||||
)
|
||||
first_prefill += 1
|
||||
modified_batch = True
|
||||
else:
|
||||
@@ -443,7 +534,7 @@ class KunlunAttentionMetadataBuilder:
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
return modified_batch
|
||||
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> KunlunMetadata:
|
||||
@@ -454,8 +545,30 @@ class KunlunAttentionMetadataBuilder:
|
||||
attn_metadata.seq_lens_tensor.fill_(1)
|
||||
return attn_metadata
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
draft_index: int,
|
||||
) -> M:
|
||||
"""
|
||||
Build attention metadata for draft model. Uses build by default.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: The common attention metadata.
|
||||
draft_index: The index of the current draft operation.
|
||||
When speculating a chain of tokens, this index refers to the
|
||||
draft attempt for the i-th token.
|
||||
For tree-based attention, this index instead refers to the
|
||||
draft attempt for the i-th level in the tree of tokens.
|
||||
"""
|
||||
return self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
def build(
|
||||
self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata
|
||||
):
|
||||
"""build"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
@@ -464,30 +577,38 @@ class KunlunAttentionMetadataBuilder:
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
|
||||
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
|
||||
|
||||
seq_start_loc_tensor = torch.empty(len(seq_start_loc), dtype=torch.int32, device=self.device)
|
||||
seq_start_loc_tensor = torch.empty(
|
||||
len(seq_start_loc), dtype=torch.int32, device=self.device
|
||||
)
|
||||
seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32))
|
||||
|
||||
kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu")
|
||||
kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0)
|
||||
kv_lod_xpu = kv_lod_cpu.to(self.device)
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
|
||||
num_scheduled_tokens = np.diff(common_attn_metadata.query_start_loc_cpu[:num_reqs + 1])
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold or 1,
|
||||
require_uniform=True,
|
||||
)
|
||||
)
|
||||
|
||||
num_scheduled_tokens = np.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
|
||||
)
|
||||
tmp_decode_scheduled_tokens = num_scheduled_tokens[:num_decodes]
|
||||
|
||||
if num_decode_tokens == 0:
|
||||
@@ -495,18 +616,19 @@ class KunlunAttentionMetadataBuilder:
|
||||
else:
|
||||
max_decode_seq_len = np.max(tmp_decode_scheduled_tokens)
|
||||
|
||||
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs]
|
||||
|
||||
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes:num_reqs]
|
||||
|
||||
if num_prefill_tokens == 0:
|
||||
max_prefill_seq_len = 0
|
||||
else:
|
||||
max_prefill_seq_len = np.max(tmp_prefill_scheduled_tokens)
|
||||
|
||||
|
||||
use_cascade = False
|
||||
|
||||
attn_metadata = KunlunMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
@@ -525,11 +647,14 @@ class KunlunAttentionMetadataBuilder:
|
||||
block_tables=block_table_tensor,
|
||||
use_cuda_graph=False,
|
||||
use_cascade=use_cascade,
|
||||
is_speculative=self.reorder_batch_threshold > 1,
|
||||
max_model_len=self.vllm_config.model_config.max_model_len,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> bool:
|
||||
"""can_run_in_cudagraph"""
|
||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||
return True
|
||||
@@ -538,6 +663,7 @@ class KunlunAttentionMetadataBuilder:
|
||||
"""use_cascade_attention"""
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
"""KunlunAttentionImpl"""
|
||||
|
||||
@@ -555,13 +681,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_irope: bool = False,
|
||||
sinks:Optional[torch.Tensor]= None,
|
||||
multi_modal_placeholder_index_maps:Optional[torch.Tensor]= None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
multi_modal_placeholder_index_maps: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
"""__init__"""
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"kunlunAttention does not support block-sparse attention.")
|
||||
raise ValueError("kunlunAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -582,15 +707,17 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||
f"Supported head sizes are: {suppored_head_sizes}."
|
||||
)
|
||||
|
||||
self.sinks = sinks
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}.")
|
||||
self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -605,7 +732,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""forward"""
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
@@ -624,7 +751,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# Self-attention vs. cross-attention will impact
|
||||
# which KV cache memory-mapping & which
|
||||
# seqlen datastructures we utilize
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
@@ -633,7 +760,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
@@ -644,11 +772,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
value = value.contiguous()
|
||||
if key_cache.is_contiguous():
|
||||
xtorch_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key[: attn_metadata.num_actual_tokens],
|
||||
value[: attn_metadata.num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping)
|
||||
updated_slot_mapping,
|
||||
)
|
||||
else:
|
||||
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
|
||||
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
|
||||
@@ -657,7 +786,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
value,
|
||||
cast_key_cache,
|
||||
cast_value_cache,
|
||||
updated_slot_mapping)
|
||||
updated_slot_mapping,
|
||||
)
|
||||
|
||||
assert attn_type == AttentionType.DECODER
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
@@ -668,88 +798,98 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
prefill_query = query[num_decode_tokens : attn_metadata.num_actual_tokens]
|
||||
prefill_key = key[num_decode_tokens : attn_metadata.num_actual_tokens]
|
||||
prefill_value = value[num_decode_tokens : attn_metadata.num_actual_tokens]
|
||||
|
||||
# For hybrid Attention (Qwen3-Next.)
|
||||
if key_cache.is_contiguous():
|
||||
tmp_block_tables = prefill_meta.block_tables
|
||||
else:
|
||||
# For hybrid Attention (Qwen3-Next)
|
||||
tmp_block_tables = prefill_meta.block_tables * 2
|
||||
|
||||
tmp_block_tables = prefill_meta.block_tables * 2
|
||||
|
||||
# Prefix cache
|
||||
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
|
||||
xtorch_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=key_cache, # Key Cache [block_num, head, block_size, dim]
|
||||
k=key_cache, # Key Cache [block_num, head, block_size, dim]
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||
out=output[num_decode_tokens : attn_metadata.num_actual_tokens],
|
||||
is_causal=True,
|
||||
is_prefix_cache=True,
|
||||
block_table=tmp_block_tables,
|
||||
is_prefix_cache=True,
|
||||
block_table=tmp_block_tables,
|
||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
||||
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_lse=None
|
||||
softmax_lse=None,
|
||||
)
|
||||
else:
|
||||
xtorch_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=prefill_key,
|
||||
v=prefill_value,
|
||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||
out=output[num_decode_tokens : attn_metadata.num_actual_tokens],
|
||||
is_causal=True,
|
||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_lse=None,
|
||||
swa_left = self.sliding_window if self.sliding_window is not None else -1,
|
||||
swa_right = 0 if self.sliding_window is not None else -1,
|
||||
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||
softmax_lse=None,
|
||||
swa_left=(
|
||||
self.sliding_window if self.sliding_window is not None else -1
|
||||
),
|
||||
swa_right=0 if self.sliding_window is not None else -1,
|
||||
sink=(
|
||||
self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert (
|
||||
attn_type != AttentionType.ENCODER_ONLY
|
||||
), "Encoder-only models should not have decode metadata."
|
||||
decode_query = query[:num_decode_tokens]
|
||||
|
||||
# For hybrid Attention (Qwen3-Next
|
||||
if key_cache.is_contiguous():
|
||||
tmp_block_tables = decode_meta.block_tables
|
||||
else:
|
||||
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
|
||||
|
||||
tmp_block_tables = (
|
||||
decode_meta.block_tables * 2
|
||||
) # only test in Qwen3-Next
|
||||
|
||||
sig = inspect.signature(xtorch_ops.speculative_attention)
|
||||
if "max_window_size" in sig.parameters:
|
||||
xtorch_ops.speculative_attention(
|
||||
out=output[:num_decode_tokens],
|
||||
# Only MLA support q len > 1 right now
|
||||
q=decode_query.unsqueeze(0),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||
batch_num=decode_meta.block_tables.shape[0],
|
||||
# TODO (@xyDong23): Support MTP(q lens >1)
|
||||
qlen=1,
|
||||
# TODO (@xyDong23): Support max_context_len to (262144)
|
||||
max_context_len=131072,
|
||||
head_num=self.num_heads,
|
||||
head_dim=self.head_size,
|
||||
scale=0.0,
|
||||
kv_head_num=self.num_kv_heads,
|
||||
block_size=key_cache.shape[2],
|
||||
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
|
||||
max_window_size=self.sliding_window if self.sliding_window is not None else -1,
|
||||
block_tables=tmp_block_tables,
|
||||
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||
# Only MLA support q len > 1 right now
|
||||
q=decode_query.unsqueeze(0),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||
batch_num=decode_meta.block_tables.shape[0],
|
||||
# TODO (@xyDong23): Support MTP(q lens >1)
|
||||
qlen=1,
|
||||
# TODO (@xyDong23): Support max_context_len to (262144)
|
||||
max_context_len=131072,
|
||||
head_num=self.num_heads,
|
||||
head_dim=self.head_size,
|
||||
scale=0.0,
|
||||
kv_head_num=self.num_kv_heads,
|
||||
block_size=key_cache.shape[2],
|
||||
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
|
||||
max_window_size=(
|
||||
self.sliding_window if self.sliding_window is not None else -1
|
||||
),
|
||||
block_tables=tmp_block_tables,
|
||||
sink=(
|
||||
self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
elif not attn_metadata.is_speculative:
|
||||
xtorch_ops.paged_attention(
|
||||
x=decode_query,
|
||||
k_cache=key_cache,
|
||||
@@ -760,10 +900,38 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
is_context=False,
|
||||
is_causal=True,
|
||||
out=output[:num_decode_tokens],
|
||||
vo_head_dim=self.head_size
|
||||
)
|
||||
vo_head_dim=self.head_size,
|
||||
)
|
||||
else:
|
||||
batch_size = attn_metadata.num_decodes
|
||||
query_seq_len, head_num, head_dim = decode_query.shape
|
||||
assert query_seq_len % batch_size == 0
|
||||
qlen = query_seq_len // batch_size
|
||||
out = output[:num_decode_tokens]
|
||||
assert out.is_contiguous()
|
||||
|
||||
xtorch_ops.speculative_attention(
|
||||
out=out.view(batch_size, qlen, head_num, self.head_size),
|
||||
q=decode_query.view(batch_size, qlen, head_num, head_dim),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||
batch_num=batch_size,
|
||||
qlen=qlen,
|
||||
max_context_len=decode_meta.max_model_len,
|
||||
head_num=self.num_heads,
|
||||
head_dim=self.head_size,
|
||||
scale=0.0,
|
||||
kv_head_num=self.num_kv_heads,
|
||||
block_size=key_cache.shape[2],
|
||||
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
|
||||
block_tables=tmp_block_tables,
|
||||
)
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
|
||||
def use_cascade_attention(
|
||||
common_prefix_len: int,
|
||||
query_lens: np.ndarray,
|
||||
@@ -785,7 +953,7 @@ def use_cascade_attention(
|
||||
# NOTE(woosuk): This is the common case. We should return False as soon as
|
||||
# possible to avoid any unnecessary computation.
|
||||
return False
|
||||
|
||||
|
||||
if common_prefix_len < 256:
|
||||
return False
|
||||
# Cascade attention is currently not supported with these variants.
|
||||
@@ -803,8 +971,12 @@ def use_cascade_attention(
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
# The criteria for using FlashDecoding can be found in the following link:
|
||||
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
|
||||
use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
|
||||
and not use_alibi and np.all(query_lens == 1))
|
||||
use_flash_decoding = (
|
||||
num_queries_per_kv > 1
|
||||
and not use_sliding_window
|
||||
and not use_alibi
|
||||
and np.all(query_lens == 1)
|
||||
)
|
||||
if not use_flash_decoding:
|
||||
# Use cascade attention.
|
||||
return True
|
||||
@@ -826,10 +998,11 @@ def use_cascade_attention(
|
||||
cascade_waves = cdiv(cascade_ctas, num_sms)
|
||||
cascade_time = cascade_waves * num_prefix_tiles
|
||||
|
||||
flash_decoding_ctas = (num_reqs * num_kv_heads *
|
||||
cdiv(num_queries_per_kv, q_tile_size))
|
||||
flash_decoding_ctas = (
|
||||
num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size)
|
||||
)
|
||||
flash_decoding_ctas *= num_prefix_tiles
|
||||
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
||||
|
||||
# Use cascade attention if it is faster than FlashDecoding.
|
||||
return cascade_time < flash_decoding_time
|
||||
return cascade_time < flash_decoding_time
|
||||
|
||||
Reference in New Issue
Block a user