update
This commit is contained in:
0
vllm_old/v1/attention/backends/mla/__init__.py
Normal file
0
vllm_old/v1/attention/backends/mla/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2200
vllm_old/v1/attention/backends/mla/common.py
Normal file
2200
vllm_old/v1/attention/backends/mla/common.py
Normal file
File diff suppressed because it is too large
Load Diff
275
vllm_old/v1/attention/backends/mla/cutlass_mla.py
Normal file
275
vllm_old/v1/attention/backends/mla/cutlass_mla.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
# enable full CUDA Graph support for decode-only capture
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CUTLASS_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||
return CutlassMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
|
||||
return CutlassMLAMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 10
|
||||
|
||||
|
||||
class SM100Workspace:
|
||||
def __init__(self, initial_workspace_size):
|
||||
self._workspace_buf = torch.empty(
|
||||
initial_workspace_size, device="cuda", dtype=torch.uint8
|
||||
)
|
||||
|
||||
self._block_size = 128 # Forced to 128
|
||||
|
||||
# Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy
|
||||
# (assumes all devices are similar)
|
||||
properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
|
||||
self._sm_count = properties.multi_processor_count
|
||||
|
||||
def get_buf(self):
|
||||
return self._workspace_buf
|
||||
|
||||
def ensure_size(self, attn_metadata: MLACommonMetadata, num_kv_splits: int):
|
||||
batch_size = attn_metadata.num_reqs
|
||||
max_seq_len = attn_metadata.max_query_len
|
||||
|
||||
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
|
||||
max_seq_len * self._block_size,
|
||||
batch_size,
|
||||
self._sm_count,
|
||||
num_kv_splits=num_kv_splits,
|
||||
)
|
||||
|
||||
if self._workspace_buf.shape[0] < workspace_size:
|
||||
self._workspace_buf.resize_(workspace_size)
|
||||
|
||||
|
||||
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
|
||||
|
||||
MAX_HEADS = 128
|
||||
|
||||
|
||||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
q_pad_num_heads=MAX_HEADS,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl"
|
||||
)
|
||||
|
||||
# TODO: Currently, num_kv_splits is limited to 16 to avoid hanging
|
||||
# issues. In case the code hangs, use:
|
||||
# FORCE_NUM_KV_SPLITS=1
|
||||
force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None)
|
||||
if force_num_kv_splits:
|
||||
logger.debug_once("Forcing num_kv_splits to %d", int(force_num_kv_splits))
|
||||
self._num_kv_splits = int(force_num_kv_splits)
|
||||
else:
|
||||
self._num_kv_splits = -1 # => Auto-detect
|
||||
|
||||
# Share workspace buffer across all executions
|
||||
self._workspace = g_sm100_workspace
|
||||
|
||||
def _sm100_cutlass_mla_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
sm_scale: float,
|
||||
num_kv_splits: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
|
||||
assert kv_c_and_k_pe_cache.ndim == 3, (
|
||||
"kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format(
|
||||
kv_c_and_k_pe_cache.ndim
|
||||
)
|
||||
)
|
||||
|
||||
B_q, H, D_q_nope = q_nope.shape
|
||||
B_q_2, H_2, D_q_pe = q_pe.shape
|
||||
assert (B_q == B_q_2) and (H == H_2)
|
||||
|
||||
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
|
||||
|
||||
D_latent = 512
|
||||
D_rope = 64
|
||||
assert D_q_nope == D_latent
|
||||
assert D_q_pe == D_rope
|
||||
assert D_ckv == D_latent + D_rope
|
||||
|
||||
MAX_HEADS = 128
|
||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||
|
||||
assert len(page_table.shape) == 2
|
||||
B_block_table, block_num = page_table.shape
|
||||
assert B_block_table == B_q
|
||||
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
|
||||
assert block_num % (128 / PAGE_SIZE) == 0
|
||||
|
||||
assert q_nope.dtype in (torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
|
||||
f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got {q_nope.dtype}."
|
||||
)
|
||||
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
|
||||
assert seq_lens.dtype == torch.int32, (
|
||||
f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
|
||||
)
|
||||
assert page_table.dtype == torch.int32, (
|
||||
f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
||||
)
|
||||
|
||||
dtype = (
|
||||
torch.bfloat16
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
else q_nope.dtype
|
||||
)
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
||||
lse = (
|
||||
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
|
||||
if self.need_to_return_lse_for_decode
|
||||
else torch.Tensor()
|
||||
)
|
||||
|
||||
ops.sm100_cutlass_mla_decode(
|
||||
out,
|
||||
lse,
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
seq_lens,
|
||||
page_table,
|
||||
workspace,
|
||||
sm_scale,
|
||||
num_kv_splits,
|
||||
)
|
||||
|
||||
if H < MAX_HEADS:
|
||||
# Extract the subsets of the outputs
|
||||
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
|
||||
out = out[:, :H]
|
||||
|
||||
return out, lse
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
# Adjust workspace size (if necessary)
|
||||
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
|
||||
|
||||
# Run MLA
|
||||
o, lse = self._sm100_cutlass_mla_decode(
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table,
|
||||
self._workspace.get_buf(),
|
||||
self.scale,
|
||||
self._num_kv_splits,
|
||||
)
|
||||
|
||||
return o, (lse if self.need_to_return_lse_for_decode else None)
|
||||
337
vllm_old/v1/attention/backends/mla/flashattn_mla.py
Normal file
337
vllm_old/v1/attention/backends/mla/flashattn_mla.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_mla,
|
||||
get_flash_attn_version,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashAttnMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
|
||||
return FlashAttnMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
|
||||
return FlashAttnMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 9
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
if not flash_attn_supports_mla():
|
||||
return "FlashAttention MLA not supported on this device"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
query_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
|
||||
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config,
|
||||
device,
|
||||
FlashAttnMLAMetadata,
|
||||
supports_dcp_with_varlen=True,
|
||||
)
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
self.fa_aot_schedule = get_flash_attn_version() == 3
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||
|
||||
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
||||
self.scheduler_metadata = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# When using cuda graph, we need to set the upper bound of the
|
||||
# number of splits so that large enough intermediate buffers are
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
self.max_num_splits = 1
|
||||
|
||||
def _schedule_decode(
|
||||
self,
|
||||
num_reqs,
|
||||
cu_query_lens,
|
||||
max_query_len,
|
||||
seqlens,
|
||||
max_seq_len,
|
||||
causal,
|
||||
max_num_splits,
|
||||
):
|
||||
if self.fa_aot_schedule:
|
||||
return get_scheduler_metadata(
|
||||
batch_size=num_reqs,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_seq_len,
|
||||
num_heads_q=self.num_heads * self.dcp_world_size,
|
||||
num_heads_kv=1,
|
||||
headdim=self.mla_dims.qk_rope_head_dim,
|
||||
cache_seqlens=seqlens,
|
||||
qkv_dtype=self.kv_cache_spec.dtype,
|
||||
headdim_v=self.mla_dims.kv_lora_rank,
|
||||
page_size=self.page_size,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
num_splits=max_num_splits,
|
||||
)
|
||||
return None
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
max_seq_len = seq_lens_cpu.max().item()
|
||||
|
||||
# For Flash Attention MLA + full cudagraph
|
||||
max_num_splits = 0
|
||||
if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size:
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
max_num_splits = 1
|
||||
|
||||
scheduler_metadata = self._schedule_decode(
|
||||
num_reqs=seq_lens_cpu.numel(),
|
||||
cu_query_lens=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens_device,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True,
|
||||
max_num_splits=max_num_splits,
|
||||
)
|
||||
|
||||
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
||||
n = scheduler_metadata.shape[0]
|
||||
# Ensure the persistent buffer is large enough
|
||||
assert n <= self.scheduler_metadata.shape[0], (
|
||||
f"Scheduler metadata size {n} exceeds buffer size "
|
||||
+ f"{self.scheduler_metadata.shape[0]}"
|
||||
)
|
||||
self.scheduler_metadata[:n] = scheduler_metadata
|
||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||
# metadata to guarantee the correctness. Otherwise, some thread
|
||||
# blocks may use the invalid scheduler metadata and overwrite the
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
metadata = FlashAttnMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
query_start_loc=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
assert flash_attn_supports_mla(), "FlashAttnMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashAttnMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttnMLAImpl"
|
||||
)
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashAttnMLA V1 with FP8 KV cache not yet supported"
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttnMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
|
||||
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
|
||||
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
|
||||
|
||||
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
|
||||
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
|
||||
# to prevent invalid grid configuration during graph capture.
|
||||
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
|
||||
|
||||
attn_out = flash_attn_varlen_func(
|
||||
q=q_pe,
|
||||
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
|
||||
q_v=q_nope,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
cu_seqlens_q=attn_metadata.decode.query_start_loc,
|
||||
max_seqlen_k=attn_metadata.decode.max_seq_len,
|
||||
seqused_k=attn_metadata.decode.seq_lens,
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_softmax_lse=self.need_to_return_lse_for_decode,
|
||||
fa_version=3, # only version 3 is supported
|
||||
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.max_num_splits,
|
||||
cp_world_size=self.dcp_world_size,
|
||||
cp_rank=self.dcp_rank,
|
||||
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
|
||||
)
|
||||
|
||||
if self.need_to_return_lse_for_decode:
|
||||
o, lse = attn_out
|
||||
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
|
||||
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
|
||||
else:
|
||||
o = attn_out
|
||||
return o, None
|
||||
171
vllm_old/v1/attention/backends/mla/flashinfer_mla.py
Normal file
171
vllm_old/v1/attention/backends/mla/flashinfer_mla.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
|
||||
|
||||
class FlashInferMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
||||
return FlashInferMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
||||
return FlashInferMLAMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 10
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return "HND"
|
||||
|
||||
|
||||
g_fi_workspace = torch.zeros(
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
|
||||
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferMLAImpl"
|
||||
)
|
||||
|
||||
self._workspace_buffer = g_fi_workspace
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if isinstance(q, tuple):
|
||||
q_nope, q_pe = q
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
# trtllm API requires extra dimension q_len_per_request for MTP
|
||||
if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0:
|
||||
logger.warning_once(
|
||||
"""FlashInferMLAImpl got a query of uneven length.
|
||||
This usually indicates an issue in batch reordering
|
||||
or incorrect setup in dummy_run."""
|
||||
)
|
||||
q = q.unsqueeze(1)
|
||||
else:
|
||||
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
|
||||
# Flatten the output for consistent shape
|
||||
o = o.view(-1, o.shape[-2], o.shape[-1])
|
||||
|
||||
# TODO: Return LSE pending support from Flashinfer API:
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/1566
|
||||
return o, None
|
||||
314
vllm_old/v1/attention/backends/mla/flashmla.py
Normal file
314
vllm_old/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major in [9, 10]
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
if use_sparse:
|
||||
from vllm.attention.ops.flashmla import is_flashmla_sparse_supported
|
||||
|
||||
return is_flashmla_sparse_supported()[1]
|
||||
else:
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
return is_flashmla_dense_supported()[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
|
||||
# ^ TODO(matt): tune this
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
|
||||
)
|
||||
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(self.device)
|
||||
num_sms = device_properties.multi_processor_count
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.cg_buf_tile_scheduler_metadata = torch.zeros(
|
||||
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
(num_sms, 8),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.cg_buf_num_splits = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_seqs + 1),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> FlashMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
# we use the max but all should be the same due to uniform length requirement
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens_device,
|
||||
num_q_tokens_per_head_k,
|
||||
1, # MQA for the decode path
|
||||
is_fp8_kvcache=self.is_fp8_kvcache,
|
||||
)
|
||||
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
|
||||
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
|
||||
tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[
|
||||
:sm_parts
|
||||
]
|
||||
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = tile_scheduler_metadata_view
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
# Num splits needs to monotonically increasing
|
||||
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
|
||||
# it needs to monotonically increasing by 1)
|
||||
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
|
||||
num_splits = num_splits_view
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
is_supported, reason = is_flashmla_dense_supported()
|
||||
assert is_supported, reason
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl"
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# TODO: (zyongye) decode function for mla here
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
# mypy assertion: q is now always a tensor
|
||||
assert isinstance(q, torch.Tensor)
|
||||
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||
|
||||
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
|
||||
num_splits = attn_metadata.decode.num_splits
|
||||
if vllm_is_batch_invariant():
|
||||
device = q.device
|
||||
dtype = torch.int32
|
||||
|
||||
B = q.shape[0]
|
||||
# block_table shape: [batch_size, max_num_blocks_per_seq]
|
||||
# The number of blocks per sequence is in the second dimension
|
||||
topk = attn_metadata.decode.block_table.shape[-1]
|
||||
B_TOPK = 64
|
||||
assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
|
||||
end_block_idx = topk // B_TOPK
|
||||
|
||||
# Single partition => num_sm_parts = 1
|
||||
# TileSchedulerMetaDataSize = 8, layout:
|
||||
# [begin_idx, begin_block_idx, end_idx, end_block_idx,
|
||||
# begin_n_split_idx, _, _, _]
|
||||
tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
|
||||
tile_scheduler_metadata[0, 0] = 0 # begin_idx
|
||||
tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx
|
||||
tile_scheduler_metadata[0, 2] = B - 1 # end_idx
|
||||
tile_scheduler_metadata[0, 3] = end_block_idx
|
||||
tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx
|
||||
# fields [5..7] stay 0
|
||||
|
||||
# Non-split path ignores num_splits, but the API requires it:
|
||||
# zeros of length B+1
|
||||
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
|
||||
|
||||
o, lse = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
descale_q=layer._q_scale.reshape(1),
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
)
|
||||
|
||||
o = reshape_attn_output_for_spec_decode(o)
|
||||
|
||||
return o, lse
|
||||
560
vllm_old/v1/attention/backends/mla/flashmla_sparse.py
Normal file
560
vllm_old/v1/attention/backends/mla/flashmla_sparse.py
Normal file
@@ -0,0 +1,560 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_sparse_prefill,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
"""
|
||||
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
|
||||
|
||||
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
|
||||
structured as:
|
||||
- **First 512 bytes:** The "quantized NoPE" part, containing 512
|
||||
`float8_e4m3` values.
|
||||
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
|
||||
The first `float32` is the scale for the first 128 `float8_e4m3` values,
|
||||
the second for the next 128, and so on.
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
part is not quantized for accuracy.
|
||||
"""
|
||||
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_SPARSE"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
|
||||
return FlashMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLASparseImpl"]:
|
||||
return FlashMLASparseImpl
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major in [9, 10]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if cache_dtype_str == "fp8_ds_mla":
|
||||
# custom storage fromat is 656 bytes
|
||||
# see FlashMLA readme.md for details
|
||||
return (num_blocks, block_size, 656)
|
||||
else:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadata:
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
block_size: int = 64
|
||||
topk_tokens: int = 2048
|
||||
|
||||
@dataclass
|
||||
class FP8KernelMetadata:
|
||||
scheduler_metadata: torch.Tensor | None
|
||||
num_splits: torch.Tensor
|
||||
dummy_block_table: torch.Tensor
|
||||
cache_lens: torch.Tensor
|
||||
|
||||
fp8_extra_metadata: FP8KernelMetadata | None = None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _convert_req_index_to_global_index_kernel(
|
||||
req_id_ptr, # int32 [num_tokens]
|
||||
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
# shapes (compile-time where possible)
|
||||
max_num_blocks_per_req: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, # tile width along columns
|
||||
# strides (in elements)
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
):
|
||||
# program_id(0) -> token_id (row)
|
||||
# program_id(1) -> tile index along columns
|
||||
token_id = tl.program_id(0)
|
||||
tile_id = tl.program_id(1)
|
||||
|
||||
# Each program covers BLOCK_N consecutive columns
|
||||
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Load request id for this token (no mask: grid is exact)
|
||||
req = tl.load(req_id_ptr + token_id)
|
||||
|
||||
# Load token indices for this tile
|
||||
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
|
||||
tok = tl.load(ti_ptr) # int32
|
||||
|
||||
# Only token == -1 should propagate as -1
|
||||
is_invalid_tok = tok < 0
|
||||
|
||||
# Compute block id and in-block offset
|
||||
block_id = tok // BLOCK_SIZE
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
|
||||
# Guard block_table access
|
||||
valid_block = block_id < max_num_blocks_per_req
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
base = tl.load(bt_ptr, mask=valid_block, other=0)
|
||||
|
||||
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
|
||||
out_val = tl.where(
|
||||
is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off
|
||||
)
|
||||
|
||||
# Store results
|
||||
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
|
||||
tl.store(out_ptr_ij, out_val)
|
||||
|
||||
|
||||
def triton_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
BLOCK_N: int = 128, # tile width along columns
|
||||
):
|
||||
"""
|
||||
out[token_id, indice_id] =
|
||||
block_table[req_id[token_id],
|
||||
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
|
||||
+ token_indices[token_id, indice_id] % BLOCK_SIZE
|
||||
|
||||
Only when token_indices[token_id, indice_id] == -1 do we output -1.
|
||||
For safety, we also output -1 if the derived block_id would be
|
||||
out-of-bounds.
|
||||
"""
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})"
|
||||
)
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
num_requests, max_num_blocks_per_req = block_table.shape
|
||||
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
|
||||
|
||||
# Ensure contiguous tensors on the same device
|
||||
req_id_c = req_id.contiguous()
|
||||
block_table_c = block_table.contiguous()
|
||||
token_indices_c = token_indices.contiguous()
|
||||
out = torch.empty_like(token_indices_c)
|
||||
|
||||
# Strides in elements
|
||||
bt_stride0, bt_stride1 = block_table_c.stride()
|
||||
ti_stride0, ti_stride1 = token_indices_c.stride()
|
||||
out_stride0, out_stride1 = out.stride()
|
||||
|
||||
# Exact 2D grid: tokens × column tiles
|
||||
grid = (num_tokens, tiles_per_row)
|
||||
|
||||
_convert_req_index_to_global_index_kernel[grid](
|
||||
req_id_c,
|
||||
block_table_c,
|
||||
token_indices_c,
|
||||
out,
|
||||
# shapes / constexprs
|
||||
max_num_blocks_per_req,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_N,
|
||||
# strides
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.device = device
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
sm_count = props.multi_processor_count
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
|
||||
self.topk_tokens_tensor = torch.tensor(
|
||||
[self.topk_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
self.max_model_len_tensor = torch.tensor(
|
||||
[self.model_config.max_model_len], device=device, dtype=torch.int32
|
||||
)
|
||||
# this is ignored by `flash_mla_with_kvcache` if indices not None
|
||||
self.dummy_block_table = torch.empty(
|
||||
(1, 1), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Equation taken from FlashMLA/csrc/pybind.cpp
|
||||
h_q, h_k = self.num_heads, 1
|
||||
s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
|
||||
max_num_sm_parts = int(
|
||||
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)
|
||||
)
|
||||
if current_platform.is_device_capability(100):
|
||||
max_num_sm_parts *= 2
|
||||
self.tile_scheduler_metadata_buffer = torch.empty(
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
# see: FlashMLA/csrc/params.h
|
||||
(max_num_sm_parts, 8),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.num_splits_buffer = torch.empty(
|
||||
# We pack all the tokens into one batch for sparse attention.
|
||||
# Otherwise, we can exceed the sm of `get_mla_metadata`.
|
||||
(2,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.req_id_per_token_buffer = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> FlashMLASparseMetadata:
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
||||
)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
fp8_extra_metadata = None
|
||||
if self.use_fp8_kv_cache:
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor,
|
||||
num_q_tokens_per_head_k=num_tokens * self.num_heads,
|
||||
topk=self.topk_tokens,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=1,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
num_sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Copy to persistent buffer for full-CG support
|
||||
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
|
||||
:num_sm_parts
|
||||
]
|
||||
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
||||
self.num_splits_buffer.copy_(num_splits)
|
||||
|
||||
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=tile_scheduler_metadata_buffer,
|
||||
num_splits=self.num_splits_buffer,
|
||||
# cache_lens and block_table are basically unused in sparse case
|
||||
# but the decode kernel will treat -1 and indices >= cache_lens
|
||||
# as invalid so we make sure cache_lens is large enough to not
|
||||
# accidentally mark indices invalid, we will use -1 exclusively
|
||||
# to mark invalid indices
|
||||
cache_lens=self.max_model_len_tensor,
|
||||
dummy_block_table=self.dummy_block_table,
|
||||
)
|
||||
|
||||
metadata = FlashMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
fp8_extra_metadata=fp8_extra_metadata,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: torch.Tensor | None = None,
|
||||
indexer: Optional["Indexer"] = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer = indexer.topk_indices_buffer
|
||||
self.padding = 128 if current_platform.is_device_capability(100) else 64
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = q.shape[0]
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1, 1, kv_c_and_k_pe_cache.shape[-1]
|
||||
)
|
||||
|
||||
# NOTE(Chen): kernel requires num_local_head to be a multiple of
|
||||
# 64 on hopper and 128 on blackwell
|
||||
if self.num_heads % self.padding != 0:
|
||||
assert self.padding % self.num_heads == 0
|
||||
logger.warning_once(
|
||||
f"padding num_heads to {self.padding} \
|
||||
due to sparse attn kernel requirement"
|
||||
)
|
||||
q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2]))
|
||||
q_padded[:, : self.num_heads, :] = q
|
||||
q = q_padded
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
output = flash_mla_sparse_prefill(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
|
||||
)
|
||||
output = output[:, : self.num_heads, :]
|
||||
return output
|
||||
|
||||
def _forward_fp8_kv(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.fp8_extra_metadata is not None
|
||||
extra_metadata = attn_metadata.fp8_extra_metadata
|
||||
|
||||
_attn_out, _ = flash_mla_with_kvcache(
|
||||
q=q.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
|
||||
block_table=extra_metadata.dummy_block_table,
|
||||
head_dim_v=512,
|
||||
cache_seqlens=extra_metadata.cache_lens,
|
||||
tile_scheduler_metadata=extra_metadata.scheduler_metadata,
|
||||
num_splits=extra_metadata.num_splits,
|
||||
is_fp8_kvcache=True,
|
||||
indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
|
||||
return _attn_out
|
||||
|
||||
def forward_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
) -> None:
|
||||
self.positions = positions
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
kv_cache_scale: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for MLACommonImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
output = torch.empty(output.shape[0], self.v_head_dim * self.num_heads, device=q.device,
|
||||
dtype=q.dtype)
|
||||
return output
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
q_pe, k_pe = self.rotary_emb(self.positions[:num_actual_toks], q_pe, k_pe)
|
||||
|
||||
q_nope = self._k_up_proj(q_nope)
|
||||
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
# TODO: handle index / kv_cache correctly
|
||||
topk_indices_global = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
|
||||
)
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype != "fp8_ds_mla":
|
||||
attn_out = self._forward_bf16_kv(
|
||||
q, kv_cache, topk_indices_global, attn_metadata
|
||||
)
|
||||
else:
|
||||
attn_out = self._forward_fp8_kv(
|
||||
q, kv_cache, topk_indices_global, attn_metadata
|
||||
)
|
||||
output = torch.empty(output.shape[0],
|
||||
self.num_heads, self.v_head_dim,
|
||||
device=q.device,
|
||||
dtype=q.dtype)
|
||||
|
||||
output[:num_actual_toks] = self._v_up_proj(attn_out)
|
||||
return output.view(output.shape[0], self.v_head_dim * self.num_heads)
|
||||
362
vllm_old/v1/attention/backends/mla/indexer.py
Normal file
362
vllm_old/v1/attention/backends/mla/indexer.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 128]
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
|
||||
return DeepseekV32IndexerMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
assert num_kv_heads == 1
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
return (0, 1, 2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
block_table: torch.Tensor
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
total_seq_lens: int
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillMetadata:
|
||||
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSeekV32IndexerDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
decode_lens: torch.Tensor
|
||||
requires_padding: bool
|
||||
# schedule_metadata: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerMetadata:
|
||||
# FIXME (zyongye)
|
||||
# hacky way to access the data now, need to be in chunked meta
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
# The dimension of the attention heads
|
||||
head_dim: int
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
decode: DeepSeekV32IndexerDecodeMetadata | None = None
|
||||
prefill: DeepseekV32IndexerPrefillMetadata | None = None
|
||||
|
||||
|
||||
# TODO (zyongye) optimize this, this is now vibe coded
|
||||
def kv_spans_from_batches(
|
||||
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
start_seq_loc: 1D long tensor [B+1], cumulative counts of
|
||||
selected tokens per batch.
|
||||
Example: [0, 2, 4, 7] ->
|
||||
batch sizes (selected) [2, 2, 3], N=7 tokens total.
|
||||
seq_len_per_batch: 1D long tensor [B],
|
||||
full sequence length (KV length) of each batch.
|
||||
Example: [5, 9, 4].
|
||||
|
||||
Returns:
|
||||
start_tensor: 1D long tensor [N], start offset in the
|
||||
concatenated KV cache for each token's batch.
|
||||
end_location: 1D long tensor [N],
|
||||
**exclusive** end = start + token's local position.
|
||||
(So the attended KV slice is kv[start:end].)
|
||||
|
||||
Assumes each batch contributes its full `seq_len_per_batch[i]`
|
||||
keys to the KV cache, andthe selected tokens within a batch
|
||||
are the **last** `counts[i]` positions of that sequence.
|
||||
"""
|
||||
q = start_seq_loc.to(dtype=torch.long)
|
||||
L = seq_len_per_batch.to(dtype=torch.long)
|
||||
assert q.dim() == 1 and L.dim() == 1
|
||||
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
|
||||
|
||||
# Selected tokens per batch and totals
|
||||
counts = q[1:] - q[:-1] # [B]
|
||||
N = int(q[-1].item()) # total selected tokens
|
||||
B = L.numel()
|
||||
|
||||
if N == 0:
|
||||
return (
|
||||
torch.empty(0, dtype=torch.long, device=device),
|
||||
torch.empty(0, dtype=torch.long, device=device),
|
||||
)
|
||||
|
||||
# KV start offsets per batch in the concatenated KV cache
|
||||
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
|
||||
|
||||
# For each selected token, which batch does it belong to?
|
||||
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
|
||||
|
||||
# Map batch KV start to each token
|
||||
start_tensor = kv_starts_per_batch[batch_id] # [N]
|
||||
|
||||
# End-align local positions inside each batch:
|
||||
# local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b
|
||||
L_expand = torch.repeat_interleave(L, counts) # [N]
|
||||
m_expand = torch.repeat_interleave(counts, counts) # [N]
|
||||
# position within the selected block: 1..counts[b]
|
||||
pos_within = (
|
||||
torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1
|
||||
)
|
||||
|
||||
local_pos = L_expand - m_expand + pos_within # [N], 1-based
|
||||
end_location = start_tensor + local_pos # exclusive end
|
||||
|
||||
return start_tensor.int().to(device), end_location.int().to(device)
|
||||
|
||||
|
||||
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
|
||||
# May be tuned later.
|
||||
return max_model_len * 2
|
||||
|
||||
|
||||
def split_prefill_chunks(
|
||||
seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
|
||||
such that the total sequence length of each chunk is less than the
|
||||
maximum prefill buffer size.
|
||||
|
||||
Args:
|
||||
seq_lens_cpu: The sequence lengths of the prefill requests.
|
||||
max_prefill_buffer_size: The maximum prefill buffer size.
|
||||
reqs_start: The start index of the prefill requests.
|
||||
|
||||
Returns:
|
||||
A list of tuples of (reqs_start, reqs_end).
|
||||
"""
|
||||
chunk_seq_ids = []
|
||||
total_seq_lens = 0
|
||||
for i in range(reqs_start, len(seq_lens_cpu)):
|
||||
cur_seq_len = seq_lens_cpu[i].item()
|
||||
assert cur_seq_len <= max_prefill_buffer_size
|
||||
total_seq_lens += cur_seq_len
|
||||
if total_seq_lens > max_prefill_buffer_size:
|
||||
chunk_seq_ids.append((reqs_start, i))
|
||||
reqs_start = i
|
||||
total_seq_lens = cur_seq_len
|
||||
if total_seq_lens > 0:
|
||||
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
|
||||
return chunk_seq_ids
|
||||
|
||||
|
||||
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
scheduler_config = self.vllm_config.scheduler_config
|
||||
# NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
|
||||
self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config)
|
||||
self.num_speculative_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
|
||||
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
|
||||
|
||||
props = torch.cuda.get_device_properties(self.device)
|
||||
sm_count = props.multi_processor_count
|
||||
self.num_sms = sm_count
|
||||
|
||||
self.decode_lens_buffer = torch.empty(
|
||||
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# See: DeepGMM/csrc/apis/attention.hpp
|
||||
self.scheduler_metadata_buffer = torch.empty(
|
||||
(self.num_sms + 1, 2), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
def build_one_prefill_chunk(
|
||||
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
|
||||
):
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc_cpu[reqs_start : reqs_end + 1]
|
||||
- query_start_loc_cpu[reqs_start]
|
||||
)
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
|
||||
]
|
||||
)
|
||||
.to(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
total_seq_lens=total_seq_lens,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> DeepseekV32IndexerMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start,
|
||||
reqs_end,
|
||||
query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor,
|
||||
)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
torch.diff(
|
||||
common_attn_metadata.query_start_loc[: num_decodes + 1],
|
||||
out=self.decode_lens_buffer[:num_decodes],
|
||||
)
|
||||
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
# self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
# seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
# )
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
# schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
head_dim=128,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
294
vllm_old/v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
294
vllm_old/v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterMLAImpl"]:
|
||||
return AiterMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
|
||||
return AiterMLAMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: torch.Tensor | None = None
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: torch.Tensor | None = None
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_len: torch.Tensor | None = None
|
||||
# The query indptr, shape : [num_decode + 1]
|
||||
qo_indptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
# TODO(luka, lucas): audit this as part of:
|
||||
# https://github.com/vllm-project/vllm/issues/22945
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
|
||||
)
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
max_num_pages_per_req = cdiv(
|
||||
vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size
|
||||
)
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
|
||||
# Preparing persistent buffers
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.block_table_remapping = torch.zeros(
|
||||
[max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.paged_kv_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
max_num_pages, dtype=torch.int32, device=device
|
||||
)
|
||||
self.paged_kv_last_page_len = torch.zeros(
|
||||
max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
self.qo_indptr = torch.arange(
|
||||
0, max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> AiterMLADecodeMetadata:
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
device = self.device
|
||||
num_reqs = seq_lens_device.size(0)
|
||||
bs, _ = block_table_tensor.shape
|
||||
block_table_tensor = (
|
||||
block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
|
||||
)
|
||||
block_table_tensor = (
|
||||
block_table_tensor
|
||||
+ torch.arange(
|
||||
0,
|
||||
page_size,
|
||||
device=block_table_tensor.device,
|
||||
dtype=block_table_tensor.dtype,
|
||||
)[None, None, :]
|
||||
)
|
||||
block_table_tensor = block_table_tensor.view(bs, -1)
|
||||
|
||||
# after remapping, we assume the block size already equals to 1
|
||||
|
||||
max_blk_size_per_req = block_table_tensor.shape[-1]
|
||||
mask = torch.arange(
|
||||
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
|
||||
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_last_page_len = seq_lens_device % page_size
|
||||
paged_kv_last_page_len = torch.where(
|
||||
paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len
|
||||
)
|
||||
|
||||
paged_kv_indptr = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
|
||||
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
|
||||
block_table_tensor, non_blocking=True
|
||||
)
|
||||
block_table_tensor = self.block_table_remapping[
|
||||
:num_reqs, :max_blk_size_per_req
|
||||
]
|
||||
|
||||
self.paged_kv_indices[:num_actual_pages].copy_(
|
||||
paged_kv_indices, non_blocking=True
|
||||
)
|
||||
self.paged_kv_indices[num_actual_pages:].fill_(-1)
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
|
||||
self.paged_kv_indptr[: 1 + num_reqs].copy_(
|
||||
paged_kv_indptr, non_blocking=True
|
||||
)
|
||||
self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1])
|
||||
paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs]
|
||||
|
||||
self.paged_kv_last_page_len[:num_reqs].copy_(
|
||||
paged_kv_last_page_len, non_blocking=True
|
||||
)
|
||||
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||
|
||||
qo_indptr = self.qo_indptr[: 1 + num_reqs]
|
||||
|
||||
else:
|
||||
qo_indptr = torch.arange(
|
||||
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
assert num_heads == 16 or num_heads == 128, (
|
||||
f"Aiter MLA only supports 16 or 128 number of heads.\n"
|
||||
f"Provided {num_heads} number of heads.\n"
|
||||
"Try adjusting tensor_parallel_size value."
|
||||
)
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
assert isinstance(q, torch.Tensor)
|
||||
B = q.shape[0]
|
||||
o = torch.zeros(
|
||||
B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
|
||||
)
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
# max_seqlen_qo must be 1 except for MTP
|
||||
# TODO: Find the best value for MTP
|
||||
max_seqlen_qo = 1
|
||||
rocm_aiter_ops.mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer,
|
||||
o,
|
||||
self.scale,
|
||||
attn_metadata.decode.qo_indptr,
|
||||
max_seqlen_qo,
|
||||
attn_metadata.decode.paged_kv_indptr,
|
||||
attn_metadata.decode.paged_kv_indices,
|
||||
attn_metadata.decode.paged_kv_last_page_len,
|
||||
)
|
||||
|
||||
return o, None
|
||||
206
vllm_old/v1/attention/backends/mla/triton_mla.py
Normal file
206
vllm_old/v1/attention/backends/mla/triton_mla.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
)
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl"
|
||||
)
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported"
|
||||
)
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
return super()._flash_attn_varlen_diff_headdims(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_c_normed: torch.Tensor | None,
|
||||
k_pe: torch.Tensor | None,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode
|
||||
q_nope = self._k_up_proj(q_nope)
|
||||
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
q = get_dcp_group().all_gather(q, dim=1)
|
||||
o = torch.empty(B,
|
||||
q.shape[1],
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_and_k_pe_cache_scale,
|
||||
self.scale,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.max_decode_seq_len,
|
||||
return_softmax_lse=True
|
||||
)
|
||||
else:
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla(
|
||||
output=o,
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
return_softmax_lse=True)
|
||||
return attn_out, softmax_lse
|
||||
|
||||
o = torch.empty(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
ixf_ops.vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_and_k_pe_cache_scale,
|
||||
self.scale,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.max_decode_seq_len,
|
||||
attn_metadata.decode.use_cuda_graph
|
||||
)
|
||||
else:
|
||||
# fused q concat & cache write
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope,
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph
|
||||
)
|
||||
return self._v_up_proj(o), None
|
||||
Reference in New Issue
Block a user