Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -191,6 +191,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
max_decode_seq_len: int = 0,
|
||||
use_cuda_graph: bool = False,
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
@@ -239,12 +241,14 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
metadata = FlashAttnMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
query_start_loc=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
@@ -156,6 +156,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
max_decode_seq_len: int = 0,
|
||||
use_cuda_graph: bool = False,
|
||||
) -> FlashMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
# we use the max but all should be the same due to uniform length requirement
|
||||
@@ -179,8 +181,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,11 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
@@ -37,13 +42,17 @@ from vllm.v1.attention.backends.utils import (
|
||||
)
|
||||
from vllm.v1.attention.ops.flashmla import (
|
||||
FlashMLASchedMeta,
|
||||
flash_mla_sparse_fwd,
|
||||
flash_mla_sparse_prefill,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
import functools
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import numpy as np
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
@@ -74,7 +83,15 @@ structured as:
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
part is not quantized for accuracy.
|
||||
"""
|
||||
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
):
|
||||
DTYPE_MAX = torch.finfo(dtype).max
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
|
||||
scale = DTYPE_MAX / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
@@ -558,6 +575,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
||||
self.qk_nope_head_dim = mla_args["qk_nope_head_dim"]
|
||||
self.qk_rope_head_dim = mla_args["qk_rope_head_dim"]
|
||||
self.qk_head_dim = mla_args["qk_head_dim"]
|
||||
self.v_head_dim = mla_args["v_head_dim"]
|
||||
self.kv_b_proj = mla_args["kv_b_proj"]
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
@@ -580,6 +602,65 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
(self.prefill_workspace_shape, torch.bfloat16)
|
||||
)
|
||||
)
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
|
||||
)
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if layer.quant_method is not None and not isinstance(
|
||||
layer.quant_method, UnquantizedLinearMethod
|
||||
):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(
|
||||
layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device,
|
||||
)
|
||||
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}"
|
||||
)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
)
|
||||
self.W_UV = W_UV
|
||||
self.W_UK = W_UK
|
||||
# self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _v_up_proj(self, x: torch.Tensor):
|
||||
|
||||
return torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
||||
def _k_up_proj(self, q_nope):
|
||||
|
||||
return torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK).view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
@@ -590,12 +671,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
) -> torch.Tensor:
|
||||
# Convert per-request indices to global slots (decode) or workspace
|
||||
# offsets (prefill).
|
||||
topk_indices = triton_convert_req_index_to_global_index(
|
||||
topk_indices = ops.dsa_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
attn_metadata.block_size,
|
||||
)
|
||||
|
||||
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
|
||||
@@ -790,22 +870,10 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
-1, 1, kv_c_and_k_pe_cache.shape[-1]
|
||||
)
|
||||
|
||||
# NOTE(Chen): kernel requires num_local_head to be a multiple of
|
||||
# 64 on hopper and 128 on blackwell
|
||||
if self.num_heads % self.prefill_padding != 0:
|
||||
assert self.prefill_padding % self.num_heads == 0
|
||||
logger.warning_once(
|
||||
f"Padding num_heads from {self.num_heads} to "
|
||||
f"{self.prefill_padding} for BF16 sparse prefill kernel"
|
||||
)
|
||||
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
|
||||
q_padded[:, : self.num_heads, :] = q
|
||||
q = q_padded
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
output = flash_mla_sparse_fwd(
|
||||
output = flash_mla_sparse_prefill(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
|
||||
)[0]
|
||||
)
|
||||
output = output[:, : self.num_heads, :]
|
||||
return output
|
||||
|
||||
@@ -843,5 +911,5 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
attn_out = self._forward_fp8_kv_separate_prefill_decode(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
||||
)
|
||||
|
||||
return attn_out, None
|
||||
|
||||
return attn_out
|
||||
|
||||
@@ -8,7 +8,11 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_paged_mqa_logits_metadata,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
@@ -21,6 +25,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -68,11 +73,15 @@ class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
cu_seqlens_q: torch.Tensor
|
||||
token_to_seq: torch.Tensor
|
||||
total_seq_lens: int
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
max_context_len: int
|
||||
max_q_len: int # Maximum query length for dsa_indexer_mqa_logits_with_blocks
|
||||
max_kv_len: int # Maximum key-value length for dsa_indexer_mqa_logits_with_blocks
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,9 +95,16 @@ class DeepSeekV32IndexerDecodeMetadata:
|
||||
seq_lens: torch.Tensor
|
||||
decode_lens: torch.Tensor
|
||||
requires_padding: bool
|
||||
schedule_metadata: torch.Tensor
|
||||
# schedule_metadata: torch.Tensor
|
||||
use_large_context_topk: bool
|
||||
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seqlens_kv: torch.Tensor
|
||||
cu_seqlens_q: torch.Tensor
|
||||
max_context_len: int
|
||||
max_q_len: int
|
||||
max_kv_len: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -211,20 +227,39 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
if self.vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
if self.num_speculative_tokens > 1:
|
||||
raise ValueError(
|
||||
"Sparse MLA only supports "
|
||||
"num_speculative_tokens <= 1 because the DeepGEMM "
|
||||
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
|
||||
f"Got num_speculative_tokens={self.num_speculative_tokens}."
|
||||
)
|
||||
self.reorder_batch_threshold += self.num_speculative_tokens
|
||||
|
||||
sm_count = num_compute_units(self.device.index)
|
||||
self.num_sms = sm_count
|
||||
|
||||
self.decode_lens_buffer = torch.empty(
|
||||
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Pre-allocated buffers for flattening (spec decode).
|
||||
self.arange_buffer = torch.arange(
|
||||
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.expanded_seq_lens_buffer = torch.zeros(
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
max_num_blocks_per_req = cdiv(
|
||||
self.vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size * get_total_cp_world_size(),
|
||||
)
|
||||
self.expanded_block_table_buffer = torch.zeros(
|
||||
(
|
||||
scheduler_config.max_num_batched_tokens,
|
||||
max_num_blocks_per_req,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# See: DeepGMM/csrc/apis/attention.hpp
|
||||
@@ -260,18 +295,88 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
.to(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
cu_seqlens_q = prefill_query_start_loc.to(torch.int32).to(self.device)
|
||||
max_context_len = seq_lens_cpu[reqs_start:reqs_end].max().item()
|
||||
# max_q_len is the maximum query length among all batches in this chunk
|
||||
# prefill_query_start_loc is cumsum of lengths with shape [batch+1]
|
||||
max_q_len = (prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]).max().item()
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
token_to_seq=token_to_seq,
|
||||
total_seq_lens=total_seq_lens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
max_context_len=max_context_len,
|
||||
max_q_len=max_q_len,
|
||||
max_kv_len=max_context_len
|
||||
)
|
||||
|
||||
def build_decode_metadata(
|
||||
self, common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
|
||||
):
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
assert (
|
||||
decode_lens_cpu.max().item()
|
||||
== decode_lens_cpu.min().item()
|
||||
== 1
|
||||
), "Only support single token decode in dsa_indexer backend"
|
||||
|
||||
# Calculate decode metadata parameters
|
||||
seq_lens_decode = common_attn_metadata.seq_lens_cpu[:num_decodes]
|
||||
max_context_len = seq_lens_decode.max().item()
|
||||
max_kv_len = max_context_len
|
||||
max_q_len = 1 # Single token decode
|
||||
|
||||
# Create cu_seqlens_q: cumulative sum of query lengths (all 1s)
|
||||
cu_seqlens_q = torch.arange(
|
||||
num_decodes + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Create cu_seqlens_kv and related tensors using kv_spans_from_batches
|
||||
decode_query_start_loc = torch.arange(
|
||||
num_decodes + 1, dtype=torch.long
|
||||
)
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
decode_query_start_loc, seq_lens_decode, self.device
|
||||
)
|
||||
|
||||
cu_seqlens_kv = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32, device=self.device),
|
||||
torch.cumsum(seq_lens_decode.to(self.device), dim=0)
|
||||
.to(torch.int32),
|
||||
]
|
||||
)
|
||||
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.block_table_tensor[
|
||||
:num_decodes, ...
|
||||
],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=(
|
||||
decode_lens_cpu.max() > decode_lens_cpu.min()
|
||||
).item(),
|
||||
use_large_context_topk=use_large_context_topk,
|
||||
offsets=offsets,
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seqlens_kv=cu_seqlens_kv,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_context_len=max_context_len,
|
||||
max_q_len=max_q_len,
|
||||
max_kv_len=max_kv_len,
|
||||
# schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
return decode_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
@@ -323,45 +428,103 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
|
||||
|
||||
# Decide which top-k kernel to use based on batch size and sequence length
|
||||
batch_size = num_decodes
|
||||
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
||||
|
||||
# Decision logic based on micro-benchmark results:
|
||||
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
||||
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
||||
use_large_context_topk = batch_size <= 128 and _is_large_context
|
||||
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
if next_n > 1:
|
||||
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32)
|
||||
else:
|
||||
offsets = None
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and has_deep_gemm():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
)
|
||||
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
|
||||
|
||||
# Padded CUDA graph requests have block_table entries of -1.
|
||||
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
|
||||
# This is safe because padded requests have seq_lens=0, so the
|
||||
# kernel produces no meaningful output for those rows.
|
||||
block_table.clamp_(min=0)
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=block_table,
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
use_large_context_topk=use_large_context_topk,
|
||||
offsets=offsets,
|
||||
|
||||
max_decode_len = int(decode_lens_cpu.max().item())
|
||||
if max_decode_len > 1:
|
||||
# Flatten multi-token decode requests into single-token
|
||||
# batch entries, expanding seq_lens and block tables so
|
||||
# the kernel always sees next_n=1.
|
||||
|
||||
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
|
||||
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
|
||||
# The context lengths are therefore
|
||||
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
|
||||
|
||||
# 3 + 1 + 4 + 0 = 8
|
||||
actual_expanded = int(decode_lens_cpu.sum().item())
|
||||
|
||||
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
|
||||
expanded_base = torch.repeat_interleave(
|
||||
seq_lens - decode_lens, decode_lens
|
||||
)
|
||||
|
||||
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
|
||||
expanded_starts = torch.repeat_interleave(
|
||||
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
|
||||
)
|
||||
|
||||
# [0, 1, 2, 0, 0, 1, 2, 3]
|
||||
positions_within = (
|
||||
self.arange_buffer[:actual_expanded] - expanded_starts
|
||||
)
|
||||
|
||||
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
|
||||
self.expanded_seq_lens_buffer[:actual_expanded] = (
|
||||
expanded_base + positions_within + 1
|
||||
)
|
||||
self.expanded_seq_lens_buffer[actual_expanded:] = 0
|
||||
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
|
||||
|
||||
# Give each of the flattened entries the same block table row as the
|
||||
# original request.
|
||||
self.expanded_block_table_buffer[:actual_expanded] = (
|
||||
torch.repeat_interleave(block_table, decode_lens, dim=0)
|
||||
)
|
||||
if actual_expanded < num_decode_tokens:
|
||||
self.expanded_block_table_buffer[
|
||||
actual_expanded:num_decode_tokens, 0
|
||||
] = 0
|
||||
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
|
||||
|
||||
# All reqs now have decode_len=1
|
||||
self.decode_lens_buffer[:num_decode_tokens] = 1
|
||||
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
|
||||
offsets = None
|
||||
batch_size = num_decode_tokens
|
||||
else:
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
if next_n > 1:
|
||||
offsets = torch.arange(
|
||||
next_n, device=self.device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
offsets = None
|
||||
batch_size = num_decodes
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and is_deep_gemm_supported():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens,
|
||||
self.kv_cache_spec.block_size,
|
||||
self.num_sms,
|
||||
)
|
||||
|
||||
# Decide which top-k kernel to use based on batch size and sequence length
|
||||
# Decision logic based on micro-benchmark results:
|
||||
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
||||
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
||||
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
||||
use_large_context_topk = batch_size <= 128 and _is_large_context
|
||||
|
||||
# decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
# block_table=block_table,
|
||||
# seq_lens=seq_lens,
|
||||
# decode_lens=decode_lens,
|
||||
# requires_padding=False,
|
||||
# # schedule_metadata=self.scheduler_metadata_buffer,
|
||||
# use_large_context_topk=use_large_context_topk,
|
||||
# offsets=offsets,
|
||||
# )
|
||||
decode_metadata = self.build_decode_metadata(
|
||||
common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
|
||||
@@ -115,6 +115,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
max_decode_seq_len: int = 0,
|
||||
use_cuda_graph: bool = False,
|
||||
) -> AiterMLADecodeMetadata:
|
||||
# kernel block size is always 1, although the kv block size is not 1.
|
||||
device = self.device
|
||||
@@ -170,11 +172,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_qo_len=max_qo_len,
|
||||
attn_out_dtype=self.decode_attn_out_dtype,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
@@ -22,20 +23,19 @@ from vllm.v1.attention.backend import (
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
# supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
# supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
# "auto",
|
||||
# "bfloat16",
|
||||
# ]
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -120,10 +120,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
# layer: AttentionLayer,
|
||||
k_c_normed: torch.Tensor |None = None,
|
||||
k_pe: torch.Tensor |None = None,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor |None = None,
|
||||
k_c_normed: torch.Tensor | None,
|
||||
k_pe: torch.Tensor | None,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
@@ -136,7 +135,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
q = get_dcp_group().all_gather(q, dim=1)
|
||||
@@ -147,7 +146,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
device=q_nope.device)
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla_int8(
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
@@ -160,7 +159,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
return_softmax_lse=True
|
||||
)
|
||||
else:
|
||||
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla(
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla(
|
||||
output=o,
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
@@ -170,12 +169,12 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
return_softmax_lse=True)
|
||||
return attn_out, softmax_lse
|
||||
|
||||
|
||||
o = torch.empty(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
device=q_nope.device)
|
||||
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
@@ -193,18 +192,30 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
attn_metadata.decode.use_cuda_graph
|
||||
)
|
||||
else:
|
||||
# fused q concat & cache write
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope,
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph
|
||||
)
|
||||
if k_c_normed is None:
|
||||
q = torch.cat([q_nope, q_pe.contiguous()], dim=-1)
|
||||
ixf_ops.vllm_paged_attention_mla(
|
||||
output=o,
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph,
|
||||
)
|
||||
else:
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope.contiguous(),
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph,
|
||||
)
|
||||
return self._v_up_proj(o), None
|
||||
|
||||
Reference in New Issue
Block a user