848 lines
33 KiB
Python
848 lines
33 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, ClassVar
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.config import VllmConfig, get_current_vllm_config
|
|
from vllm.config.cache import CacheDType
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.attention.mla_attention import (
|
|
get_mla_dims,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
from vllm.platforms.interface import DeviceCapability
|
|
from vllm.utils.platform_utils import num_compute_units
|
|
from vllm.v1.attention.backend import (
|
|
AttentionBackend,
|
|
AttentionCGSupport,
|
|
AttentionLayer,
|
|
AttentionMetadata,
|
|
AttentionMetadataBuilder,
|
|
CommonAttentionMetadata,
|
|
MultipleOf,
|
|
SparseMLAAttentionImpl,
|
|
)
|
|
from vllm.v1.attention.backends.mla.sparse_utils import (
|
|
triton_convert_req_index_to_global_index,
|
|
)
|
|
from vllm.v1.attention.backends.utils import (
|
|
reshape_attn_output_for_spec_decode,
|
|
reshape_query_for_spec_decode,
|
|
split_decodes_and_prefills,
|
|
split_prefill_chunks,
|
|
)
|
|
from vllm.v1.attention.ops.flashmla import (
|
|
FlashMLASchedMeta,
|
|
flash_mla_sparse_fwd,
|
|
flash_mla_with_kvcache,
|
|
get_mla_metadata,
|
|
)
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
from vllm.v1.worker.workspace import current_workspace_manager
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.model_executor.models.deepseek_v2 import Indexer
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# For FP8 sparse attention we have two impelementations:
|
|
# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is
|
|
# done by treating all tokens as single batch.
|
|
# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill
|
|
# (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using
|
|
# the FP8 decode kernel for decode.
|
|
# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16
|
|
# prefill kernel requires padding the numer of heads to 128 while the decode does not
|
|
# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed
|
|
# batch mode (#2).
|
|
MIN_HEADS_FOR_BF16_PREFILL = 32
|
|
|
|
"""
|
|
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
|
|
|
|
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
|
|
structured as:
|
|
- **First 512 bytes:** The "quantized NoPE" part, containing 512
|
|
`float8_e4m3` values.
|
|
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
|
|
The first `float32` is the scale for the first 128 `float8_e4m3` values,
|
|
the second for the next 128, and so on.
|
|
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
|
part is not quantized for accuracy.
|
|
"""
|
|
|
|
|
|
class FlashMLASparseBackend(AttentionBackend):
|
|
accept_output_buffer: bool = True
|
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
|
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
|
"auto",
|
|
"bfloat16",
|
|
"fp8_ds_mla",
|
|
]
|
|
|
|
@staticmethod
|
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
|
return [64]
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "FLASHMLA_SPARSE"
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
|
|
return FlashMLASparseMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> type["FlashMLASparseImpl"]:
|
|
return FlashMLASparseImpl
|
|
|
|
@classmethod
|
|
def get_supported_head_sizes(cls) -> list[int]:
|
|
return [576]
|
|
|
|
@classmethod
|
|
def is_mla(cls) -> bool:
|
|
return True
|
|
|
|
@classmethod
|
|
def is_sparse(cls) -> bool:
|
|
return True
|
|
|
|
@classmethod
|
|
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
|
return capability.major in [9, 10]
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int, # assumed to be 1 for MLA
|
|
head_size: int,
|
|
cache_dtype_str: str = "auto",
|
|
) -> tuple[int, ...]:
|
|
if cache_dtype_str == "fp8_ds_mla":
|
|
# custom storage fromat is 656 bytes
|
|
# see FlashMLA readme.md for details
|
|
return (num_blocks, block_size, 656)
|
|
else:
|
|
return (num_blocks, block_size, head_size)
|
|
|
|
|
|
@dataclass
|
|
class FlashMLASparseMetadata(AttentionMetadata):
|
|
num_reqs: int
|
|
max_query_len: int
|
|
max_seq_len: int
|
|
|
|
num_actual_tokens: int # Number of tokens excluding padding.
|
|
query_start_loc: torch.Tensor
|
|
slot_mapping: torch.Tensor
|
|
|
|
block_table: torch.Tensor
|
|
req_id_per_token: torch.Tensor
|
|
block_size: int = 64
|
|
topk_tokens: int = 2048
|
|
|
|
@dataclass
|
|
class FP8KernelMetadata:
|
|
scheduler_metadata: FlashMLASchedMeta
|
|
dummy_block_table: torch.Tensor
|
|
cache_lens: torch.Tensor
|
|
|
|
@dataclass
|
|
class FP8SeparatePrefillDecode:
|
|
@dataclass
|
|
class Decode:
|
|
kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
|
|
decode_query_len: int # needed for reshape in spec decode
|
|
|
|
@dataclass
|
|
class Prefill:
|
|
# Sequence lengths (context + query) for prefill requests
|
|
# Shape: [num_prefill_reqs]
|
|
seq_lens: torch.Tensor
|
|
|
|
# Request ID for each token: -1 for decode tokens, request index
|
|
# (0, 1, 2, ...) for prefill tokens.
|
|
# Shape: [num_actual_tokens]
|
|
request_ids: torch.Tensor
|
|
|
|
# Workspace start offsets for all prefill requests
|
|
# Shape: [num_prefill_reqs], adjusted in-place per chunk to be
|
|
# 0-indexed within each chunk. Used to map prefill tokens to workspace
|
|
# offsets in convert_logical_index_to_physical_index
|
|
workspace_starts: torch.Tensor
|
|
|
|
@dataclass
|
|
class Chunk:
|
|
"""Metadata for a chunk of prefill requests.
|
|
|
|
Prefill requests may be chunked to fit within the fixed workspace size.
|
|
"""
|
|
|
|
seq_lens: torch.Tensor
|
|
tokens_slice: slice
|
|
block_table: torch.Tensor
|
|
req_start_idx: int
|
|
workspace_starts: torch.Tensor
|
|
chunk_tot_seqlen: int
|
|
|
|
chunks: list[Chunk]
|
|
|
|
num_prefills: int = 0
|
|
num_decodes: int = 0
|
|
num_prefill_tokens: int = 0
|
|
num_decode_tokens: int = 0
|
|
|
|
decode: Decode | None = None
|
|
prefill: Prefill | None = None
|
|
|
|
fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None
|
|
fp8_use_mixed_batch: bool = False
|
|
|
|
|
|
def get_prefill_workspace_size(max_model_len: int):
|
|
# NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
|
|
# May be tuned later.
|
|
# Memory usage: 5 * max_model_len * 576 * 2 bytes
|
|
# Example: DeepSeek-V3.2 with max_model_len=163840 ->
|
|
# 5 * 163840 * 576 * 2 = ~900 MB
|
|
# This fits nicely below the typical MoE workspace size of >2GB so this is "free"
|
|
return max_model_len * 5
|
|
|
|
|
|
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
|
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
|
|
|
def __init__(
|
|
self,
|
|
kv_cache_spec: AttentionSpec,
|
|
layer_names: list[str],
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
) -> None:
|
|
self.vllm_config = vllm_config
|
|
self.layer_names = layer_names
|
|
cache_config = vllm_config.cache_config
|
|
self.kv_cache_spec = kv_cache_spec
|
|
self.model_config = vllm_config.model_config
|
|
parallel_config = vllm_config.parallel_config
|
|
self.device = device
|
|
|
|
# Treat requests with query length <= 1 as decodes to match the
|
|
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
|
|
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
|
|
|
sm_count = num_compute_units(device.index)
|
|
|
|
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
|
self.mla_dims = get_mla_dims(self.model_config)
|
|
# FP8 decode kernel only supports h_q = 64 or 128, so we need to pad
|
|
self.fp8_decode_padded_heads = (
|
|
FlashMLASparseImpl._compute_fp8_decode_padded_heads(self.num_heads)
|
|
)
|
|
|
|
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
|
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
|
|
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
|
# Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
|
|
self.topk_tokens_tensor = torch.full(
|
|
(max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32
|
|
)
|
|
# Shape: [max_num_seqs], all elements = max_model_len
|
|
self.max_model_len_tensor = torch.full(
|
|
(max_num_seqs,),
|
|
self.model_config.max_model_len,
|
|
device=device,
|
|
dtype=torch.int32,
|
|
)
|
|
# this is ignored by `flash_mla_with_kvcache` if indices not None
|
|
self.dummy_block_table = torch.empty(
|
|
(max_num_seqs, 1), dtype=torch.int32, device=self.device
|
|
)
|
|
|
|
# Equation taken from FlashMLA/csrc/api/sparse_decode.h
|
|
# For sparse FP8 decode, the formula depends on architecture:
|
|
# - SM90 (Hopper): num_sm_parts = num_sms / s_q / (h_q/64)
|
|
# - SM100 (Blackwell head64/head64x2): num_sm_parts = num_sms / s_q
|
|
# - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2
|
|
# For max buffer size, use s_q = 1 (the case that produces largest output)
|
|
# Use padded head count since that's what will be passed to the kernel
|
|
h_q = self.fp8_decode_padded_heads
|
|
if current_platform.is_device_capability_family(100):
|
|
# SM100 head64 or head64x2 uses full SM count
|
|
max_num_sm_parts = sm_count
|
|
else:
|
|
# SM90 uses h_q/64 divisor
|
|
max_num_sm_parts = sm_count // max(1, h_q // 64)
|
|
self.tile_scheduler_metadata_buffer = torch.empty(
|
|
# TileSchedulerMetaDataSize = 8
|
|
# see: FlashMLA/csrc/params.h
|
|
(max_num_sm_parts, 8),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
# Sized for per-request batching (num_decodes + 1)
|
|
self.num_splits_buffer = torch.empty(
|
|
(max_num_seqs + 1,),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
self.req_id_per_token_buffer = torch.empty(
|
|
(vllm_config.scheduler_config.max_num_batched_tokens,),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
|
|
def _build_fp8_mixed_decode_prefill(
|
|
self,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
) -> "FlashMLASparseMetadata.FP8KernelMetadata":
|
|
"""Build FP8 metadata treating all tokens as one mixed batch.
|
|
|
|
This matches main branch's approach and avoids the BF16 prefill kernel
|
|
which has head padding overhead when num_heads is small (high TP case).
|
|
"""
|
|
num_tokens = common_attn_metadata.num_actual_tokens
|
|
|
|
# Use padded head count since that's what the kernel will see
|
|
padded_heads = self.fp8_decode_padded_heads
|
|
|
|
# Build metadata for all tokens as a single batch
|
|
scheduler_metadata, _ = get_mla_metadata(
|
|
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
|
|
num_q_tokens_per_head_k=num_tokens * padded_heads,
|
|
topk=self.topk_tokens,
|
|
num_heads_q=padded_heads,
|
|
num_heads_k=1,
|
|
is_fp8_kvcache=True,
|
|
)
|
|
|
|
fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
|
scheduler_metadata=scheduler_metadata,
|
|
cache_lens=self.max_model_len_tensor[:1],
|
|
dummy_block_table=self.dummy_block_table[:1],
|
|
)
|
|
|
|
return fp8_metadata
|
|
|
|
def _build_fp8_separate_prefill_decode(
|
|
self,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode":
|
|
num_tokens = common_attn_metadata.num_actual_tokens
|
|
|
|
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
|
|
split_decodes_and_prefills(
|
|
common_attn_metadata,
|
|
decode_threshold=self.reorder_batch_threshold or 1,
|
|
require_uniform=True,
|
|
)
|
|
)
|
|
|
|
FP8Meta = FlashMLASparseMetadata.FP8SeparatePrefillDecode
|
|
fp8_metadata = FP8Meta(
|
|
num_decodes=num_decodes,
|
|
num_prefills=num_prefills,
|
|
num_decode_tokens=num_decode_tokens,
|
|
num_prefill_tokens=num_prefill_tokens,
|
|
)
|
|
|
|
# Extract prefill sequence lengths (context + query, not just query)
|
|
# Decode requests come first in the batch, prefill requests follow
|
|
prefill_seq_lens = None
|
|
prefill_request_id = None
|
|
prefill_workspace_starts = None
|
|
prefill_chunks = None
|
|
|
|
# For pure decode batches, prefill_request_id will be None
|
|
# For mixed batches, it will have -1 for decode and request_id for prefill
|
|
if num_prefills > 0:
|
|
seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
|
|
seq_lens = common_attn_metadata.seq_lens
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
|
|
prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:]
|
|
prefill_seq_lens = seq_lens[num_decodes:]
|
|
|
|
# Build prefill_request_id: -1 for decode, request index for
|
|
# prefill. This enables a single
|
|
# convert_logical_index_to_physical_index call for all tokens
|
|
prefill_request_id = torch.full(
|
|
(num_tokens,), -1, dtype=torch.int32, device=self.device
|
|
)
|
|
# Map prefill tokens to their request IDs (0, 1, 2, ...)
|
|
for req_idx in range(num_prefills):
|
|
# Get query token range for this prefill request
|
|
global_req_idx = num_decodes + req_idx
|
|
req_query_start = query_start_loc_cpu[global_req_idx]
|
|
req_query_end = query_start_loc_cpu[global_req_idx + 1]
|
|
prefill_request_id[req_query_start:req_query_end] = req_idx
|
|
|
|
# will be adjusted by chunk loop
|
|
prefill_workspace_starts_cpu = torch.zeros(
|
|
num_prefills, dtype=torch.int32, pin_memory=True
|
|
)
|
|
prefill_workspace_starts_cpu[1:] = torch.cumsum(
|
|
prefill_seq_lens_cpu[:-1], dim=0
|
|
)
|
|
# populated by non-blocking copy after prefill_workspace_starts_cpu is
|
|
# updated by each chunk
|
|
prefill_workspace_starts = torch.empty(
|
|
num_prefills, dtype=torch.int32, device=self.device
|
|
)
|
|
|
|
# Chunk prefill requests to fit within workspace size
|
|
max_prefill_buffer_size = get_prefill_workspace_size(
|
|
self.vllm_config.model_config.max_model_len
|
|
)
|
|
chunk_bounds = split_prefill_chunks(
|
|
prefill_seq_lens_cpu, max_prefill_buffer_size
|
|
)
|
|
|
|
prefill_chunks = []
|
|
for chunk_start, chunk_end in chunk_bounds:
|
|
# Adjust workspace_starts in-place per chunk to be
|
|
# 0-indexed within each chunk
|
|
# Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]]
|
|
# Initial: workspace_starts=[0,10,25,45]
|
|
# After: workspace_starts=[0,10,0,20]
|
|
# (chunk 0 starts at 0, chunk 1 starts at 0)
|
|
offset = prefill_workspace_starts_cpu[chunk_start].item()
|
|
prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset
|
|
|
|
chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end]
|
|
chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum()
|
|
token_start = query_start_loc_cpu[num_decodes + chunk_start].item()
|
|
token_end = query_start_loc_cpu[num_decodes + chunk_end].item()
|
|
tokens_slice = slice(token_start, token_end)
|
|
|
|
# Create chunk view of gpu tensor
|
|
chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end]
|
|
chunk_block_table = common_attn_metadata.block_table_tensor[
|
|
num_decodes + chunk_start : num_decodes + chunk_end
|
|
]
|
|
|
|
prefill_chunks.append(
|
|
FP8Meta.Prefill.Chunk(
|
|
seq_lens=chunk_seq_lens,
|
|
tokens_slice=tokens_slice,
|
|
block_table=chunk_block_table,
|
|
req_start_idx=chunk_start,
|
|
workspace_starts=chunk_workspace_starts,
|
|
chunk_tot_seqlen=chunk_tot_seqlen,
|
|
)
|
|
)
|
|
|
|
prefill_workspace_starts.copy_(
|
|
prefill_workspace_starts_cpu, non_blocking=True
|
|
)
|
|
|
|
fp8_metadata.prefill = FP8Meta.Prefill(
|
|
seq_lens=prefill_seq_lens,
|
|
request_ids=prefill_request_id,
|
|
workspace_starts=prefill_workspace_starts,
|
|
chunks=prefill_chunks,
|
|
)
|
|
|
|
if num_decodes > 0:
|
|
# Compute decode_query_len for spec decode (uniform due to require_uniform)
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
|
|
|
|
# Use padded head count since that's what the kernel will see
|
|
padded_heads = self.fp8_decode_padded_heads
|
|
scheduler_metadata, _ = get_mla_metadata(
|
|
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
|
|
num_q_tokens_per_head_k=decode_query_len * padded_heads,
|
|
topk=self.topk_tokens,
|
|
num_heads_q=padded_heads,
|
|
num_heads_k=1,
|
|
is_fp8_kvcache=True,
|
|
)
|
|
|
|
kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
|
|
scheduler_metadata=scheduler_metadata,
|
|
dummy_block_table=self.dummy_block_table[:num_decodes],
|
|
cache_lens=self.max_model_len_tensor[:num_decodes],
|
|
)
|
|
fp8_metadata.decode = FP8Meta.Decode(
|
|
kernel_metadata=kernel_meta,
|
|
decode_query_len=decode_query_len,
|
|
)
|
|
|
|
return fp8_metadata
|
|
|
|
def build(
|
|
self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False,
|
|
) -> FlashMLASparseMetadata:
|
|
cm = common_attn_metadata
|
|
num_tokens = cm.num_actual_tokens
|
|
starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
|
|
seg_lengths = np.diff(starts)
|
|
req_id_per_token = np.repeat(
|
|
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
|
)
|
|
# Zero-fill for cudagraphs
|
|
self.req_id_per_token_buffer.fill_(0)
|
|
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
|
torch.from_numpy(req_id_per_token), non_blocking=True
|
|
)
|
|
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
|
|
|
fp8_extra_metadata: (
|
|
FlashMLASparseMetadata.FP8SeparatePrefillDecode
|
|
| FlashMLASparseMetadata.FP8KernelMetadata
|
|
| None
|
|
) = None
|
|
fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL
|
|
if self.use_fp8_kv_cache:
|
|
if fp8_use_mixed_batch:
|
|
fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
|
|
else:
|
|
fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm)
|
|
|
|
metadata = FlashMLASparseMetadata(
|
|
num_reqs=cm.num_reqs,
|
|
max_query_len=cm.max_query_len,
|
|
max_seq_len=cm.max_seq_len,
|
|
num_actual_tokens=cm.num_actual_tokens,
|
|
query_start_loc=cm.query_start_loc,
|
|
slot_mapping=cm.slot_mapping,
|
|
block_table=cm.block_table_tensor,
|
|
req_id_per_token=req_id_per_token,
|
|
block_size=self.kv_cache_spec.block_size,
|
|
topk_tokens=self.topk_tokens,
|
|
fp8_extra_metadata=fp8_extra_metadata,
|
|
fp8_use_mixed_batch=fp8_use_mixed_batch,
|
|
)
|
|
|
|
return metadata
|
|
|
|
|
|
class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
|
@staticmethod
|
|
def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
|
|
# FP8 decode kernel only supports h_q = 64 or 128
|
|
# Compute padded head count for decode
|
|
return 64 if num_heads <= 64 else 128
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: list[float] | None,
|
|
sliding_window: int | None,
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: float | None,
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: str | None,
|
|
# MLA Specific Arguments
|
|
topk_indice_buffer: torch.Tensor | None = None,
|
|
indexer: "Indexer | None" = None,
|
|
**mla_args,
|
|
) -> None:
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_kv_heads
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
|
self.softmax_scale = scale
|
|
assert indexer is not None
|
|
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
|
# Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell
|
|
self.prefill_padding = (
|
|
128 if current_platform.is_device_capability_family(100) else 64
|
|
)
|
|
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
|
|
|
|
if kv_cache_dtype == "fp8_ds_mla":
|
|
# Reserve workspace during initialization
|
|
vllm_config = get_current_vllm_config()
|
|
assert vllm_config is not None and vllm_config.model_config is not None
|
|
prefill_workspace_size = get_prefill_workspace_size(
|
|
vllm_config.model_config.max_model_len
|
|
)
|
|
self.prefill_workspace_shape = (prefill_workspace_size, head_size)
|
|
(self.prefill_bf16_workspace,) = (
|
|
current_workspace_manager().get_simultaneous(
|
|
(self.prefill_workspace_shape, torch.bfloat16)
|
|
)
|
|
)
|
|
|
|
def _forward_bf16_kv(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
topk_indices: torch.Tensor,
|
|
attn_metadata: FlashMLASparseMetadata,
|
|
) -> torch.Tensor:
|
|
# Convert per-request indices to global slots (decode) or workspace
|
|
# offsets (prefill).
|
|
topk_indices = triton_convert_req_index_to_global_index(
|
|
attn_metadata.req_id_per_token,
|
|
attn_metadata.block_table,
|
|
topk_indices,
|
|
BLOCK_SIZE=attn_metadata.block_size,
|
|
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
|
)
|
|
|
|
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
|
|
|
|
def _forward_fp8_kv_separate_prefill_decode(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
topk_indices: torch.Tensor,
|
|
attn_metadata: FlashMLASparseMetadata,
|
|
) -> torch.Tensor:
|
|
fp8_metadata = attn_metadata.fp8_extra_metadata
|
|
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
|
|
num_decodes = fp8_metadata.num_decodes
|
|
|
|
prefill_request_ids = None
|
|
prefill_workspace_starts = None
|
|
has_prefill_workspace = False
|
|
if fp8_metadata.prefill is not None:
|
|
prefill_request_ids = fp8_metadata.prefill.request_ids
|
|
prefill_workspace_starts = fp8_metadata.prefill.workspace_starts
|
|
has_prefill_workspace = True
|
|
|
|
# Convert per-request indices to global slots (decode) or workspace
|
|
# offsets (prefill).
|
|
# For FP8 cache: prefill uses workspace mapping (upconverted to BF16)
|
|
# For BF16 cache: always use global cache slots (no workspace)
|
|
# prefill_workspace_starts has been adjusted in-place per chunk so
|
|
# prefill indices automatically come out chunk-local
|
|
topk_indices = triton_convert_req_index_to_global_index(
|
|
attn_metadata.req_id_per_token,
|
|
attn_metadata.block_table,
|
|
topk_indices,
|
|
BLOCK_SIZE=attn_metadata.block_size,
|
|
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
|
HAS_PREFILL_WORKSPACE=has_prefill_workspace,
|
|
prefill_workspace_request_ids=prefill_request_ids,
|
|
prefill_workspace_starts=prefill_workspace_starts,
|
|
)
|
|
|
|
fp8_metadata = attn_metadata.fp8_extra_metadata
|
|
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
|
|
|
|
def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
|
|
# Reshape q: (num_decode_tokens, num_heads, head_dim)
|
|
# -> (num_decodes, seq_len, num_heads, head_dim)
|
|
q = reshape_query_for_spec_decode(q, num_decodes)
|
|
seq_len = q.shape[1]
|
|
# Reshape topk_indices: (num_decode_tokens, topk)
|
|
# -> (num_decodes, seq_len, topk)
|
|
topk_indices = topk_indices.view(num_decodes, seq_len, -1)
|
|
assert fp8_metadata.decode is not None
|
|
attn_out, _ = self._fp8_flash_mla_kernel(
|
|
q=q,
|
|
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
|
|
topk_indices=topk_indices,
|
|
kernel_metadata=fp8_metadata.decode.kernel_metadata,
|
|
)
|
|
# Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
|
|
# -> (num_decode_tokens, num_heads, head_dim_v)
|
|
return reshape_attn_output_for_spec_decode(attn_out)
|
|
|
|
num_decode_tokens = fp8_metadata.num_decode_tokens
|
|
num_prefill_tokens = fp8_metadata.num_prefill_tokens
|
|
|
|
# Pure decode: direct call without allocation
|
|
if num_decode_tokens > 0 and num_prefill_tokens == 0:
|
|
assert fp8_metadata.decode is not None
|
|
attn_out = _fp8_decode(q, topk_indices)
|
|
else:
|
|
# Mixed or pure prefill: allocate output tensor
|
|
attn_out = q.new_empty(
|
|
(attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank),
|
|
dtype=q.dtype,
|
|
device=q.device,
|
|
)
|
|
|
|
if num_decode_tokens > 0:
|
|
attn_out[:num_decode_tokens] = _fp8_decode(
|
|
q[:num_decode_tokens], topk_indices[:num_decode_tokens]
|
|
)
|
|
|
|
assert fp8_metadata.prefill is not None
|
|
for chunk in fp8_metadata.prefill.chunks:
|
|
chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen]
|
|
ops.cp_gather_and_upconvert_fp8_kv_cache(
|
|
kv_c_and_k_pe_cache,
|
|
chunk_workspace,
|
|
chunk.block_table,
|
|
chunk.seq_lens,
|
|
chunk.workspace_starts,
|
|
len(chunk.block_table),
|
|
)
|
|
|
|
chunk_q = q[chunk.tokens_slice]
|
|
chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice]
|
|
|
|
attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel(
|
|
chunk_q,
|
|
chunk_workspace,
|
|
chunk_topk_indices_workspace,
|
|
)
|
|
|
|
return attn_out
|
|
|
|
def _forward_fp8_kv_mixed_batch(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
topk_indices: torch.Tensor,
|
|
attn_metadata: FlashMLASparseMetadata,
|
|
) -> torch.Tensor:
|
|
"""Mixed batch FP8 forward path that treats all tokens as one batch.
|
|
|
|
This is equivalent to main branch's approach and avoids the BF16
|
|
prefill kernel which has head padding overhead when num_heads is small.
|
|
Used when use_mixed_batch is True.
|
|
"""
|
|
# Convert per-request indices to global slots (decode) or workspace
|
|
# offsets (prefill).
|
|
topk_indices = triton_convert_req_index_to_global_index(
|
|
attn_metadata.req_id_per_token,
|
|
attn_metadata.block_table,
|
|
topk_indices,
|
|
BLOCK_SIZE=attn_metadata.block_size,
|
|
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
|
)
|
|
|
|
assert attn_metadata.fp8_extra_metadata is not None
|
|
assert isinstance(
|
|
attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata
|
|
)
|
|
fp8_metadata = attn_metadata.fp8_extra_metadata
|
|
|
|
_attn_out, _ = self._fp8_flash_mla_kernel(
|
|
q=q.unsqueeze(0), # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
|
|
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
|
|
topk_indices=topk_indices.unsqueeze(0), # (T, topk) -> (1, T, topk)
|
|
kernel_metadata=fp8_metadata,
|
|
)
|
|
|
|
# Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
|
|
return _attn_out.squeeze(0)
|
|
|
|
def _fp8_flash_mla_kernel(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
topk_indices: torch.Tensor,
|
|
kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# q shape: (batch, seq_len, num_heads, head_dim)
|
|
actual_num_heads = q.size(2)
|
|
padded_num_heads = self.fp8_decode_padded_heads
|
|
|
|
# Pad query if needed (kernel only supports h_q = 64 or 128)
|
|
if actual_num_heads < padded_num_heads:
|
|
logger.warning_once(
|
|
f"Padding num_heads from {actual_num_heads} to "
|
|
f"{padded_num_heads} for FP8 sparse decode kernel"
|
|
)
|
|
q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
|
|
q_padded[:, :, :actual_num_heads, :] = q
|
|
q = q_padded
|
|
|
|
out, lse = flash_mla_with_kvcache(
|
|
q=q,
|
|
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
|
|
block_table=kernel_metadata.dummy_block_table,
|
|
head_dim_v=512,
|
|
cache_seqlens=kernel_metadata.cache_lens,
|
|
tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
|
|
is_fp8_kvcache=True,
|
|
indices=topk_indices,
|
|
softmax_scale=self.softmax_scale,
|
|
)
|
|
|
|
# Slice output back to actual head count if we padded
|
|
if actual_num_heads < padded_num_heads:
|
|
out = out[:, :, :actual_num_heads, :]
|
|
|
|
return out, lse
|
|
|
|
def _bf16_flash_mla_kernel(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
topk_indices: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
num_tokens = q.shape[0]
|
|
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
|
-1, 1, kv_c_and_k_pe_cache.shape[-1]
|
|
)
|
|
|
|
# NOTE(Chen): kernel requires num_local_head to be a multiple of
|
|
# 64 on hopper and 128 on blackwell
|
|
if self.num_heads % self.prefill_padding != 0:
|
|
assert self.prefill_padding % self.num_heads == 0
|
|
logger.warning_once(
|
|
f"Padding num_heads from {self.num_heads} to "
|
|
f"{self.prefill_padding} for BF16 sparse prefill kernel"
|
|
)
|
|
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
|
|
q_padded[:, : self.num_heads, :] = q
|
|
q = q_padded
|
|
|
|
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
|
output = flash_mla_sparse_fwd(
|
|
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
|
|
)[0]
|
|
output = output[:, : self.num_heads, :]
|
|
return output
|
|
|
|
def forward_mqa(
|
|
self,
|
|
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: FlashMLASparseMetadata,
|
|
layer: AttentionLayer,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
|
# MQA 576/512 approach for both prefill and decode
|
|
|
|
# Concatenate q if it's a tuple (ql_nope, q_pe)
|
|
if isinstance(q, tuple):
|
|
q = torch.cat(q, dim=-1)
|
|
|
|
num_actual_toks = q.shape[0]
|
|
|
|
# Get topk indices
|
|
assert self.topk_indices_buffer is not None
|
|
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
|
|
|
use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"
|
|
|
|
if not use_fp8_cache:
|
|
attn_out = self._forward_bf16_kv(
|
|
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
|
)
|
|
elif attn_metadata.fp8_use_mixed_batch:
|
|
attn_out = self._forward_fp8_kv_mixed_batch(
|
|
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
|
)
|
|
else:
|
|
attn_out = self._forward_fp8_kv_separate_prefill_decode(
|
|
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
|
)
|
|
|
|
return attn_out, None
|