595 lines
23 KiB
Python
Executable File
595 lines
23 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
"""
|
|
Support attention backend for TRTLLM MLA kernels from flashinfer.
|
|
"""
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Optional, Union
|
|
|
|
import torch
|
|
import triton
|
|
|
|
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
|
FlashInferMLAAttnBackend,
|
|
FlashInferMLAMultiStepDraftBackend,
|
|
)
|
|
from sglang.srt.layers.attention.utils import (
|
|
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
|
create_flashmla_kv_indices_triton,
|
|
)
|
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
from sglang.srt.utils import is_flashinfer_available
|
|
|
|
if is_flashinfer_available():
|
|
import flashinfer
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
from sglang.srt.speculative.spec_info import SpecInfo
|
|
|
|
# Constants
|
|
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
|
|
|
# Block constraint from flashinfer requirements
|
|
# From flashinfer.decode._check_trtllm_gen_mla_shape:
|
|
# block_num % (128 / block_size) == 0
|
|
# This imposes that the total number of blocks must be divisible by
|
|
# (128 / block_size). We capture the 128 constant here so we can
|
|
# compute the LCM with other padding constraints.
|
|
TRTLLM_BLOCK_CONSTRAINT = 128
|
|
|
|
global_zero_init_workspace_buffer = None
|
|
|
|
|
|
@dataclass
|
|
class TRTLLMMLAPrefillMetadata:
|
|
"""Metadata for TRTLLM MLA prefill operations."""
|
|
|
|
max_seq_len: int
|
|
cum_seq_lens: torch.Tensor
|
|
seq_lens: torch.Tensor
|
|
|
|
|
|
@dataclass
|
|
class TRTLLMMLADecodeMetadata:
|
|
"""Metadata for TRTLLM MLA decode operations."""
|
|
|
|
workspace: Optional[torch.Tensor] = None
|
|
block_kv_indices: Optional[torch.Tensor] = None
|
|
max_seq_len: Optional[int] = None
|
|
|
|
|
|
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
"""TRTLLM MLA attention kernel from flashinfer."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_runner: ModelRunner,
|
|
skip_prefill: bool = False,
|
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
|
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
|
):
|
|
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
|
|
|
|
config = model_runner.model_config
|
|
|
|
# Model parameters
|
|
self.num_q_heads = config.num_attention_heads // get_attention_tp_size()
|
|
self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())
|
|
self.num_local_heads = config.num_attention_heads // get_attention_tp_size()
|
|
|
|
# MLA-specific dimensions
|
|
self.kv_lora_rank = config.kv_lora_rank
|
|
self.qk_nope_head_dim = config.qk_nope_head_dim
|
|
self.qk_rope_head_dim = config.qk_rope_head_dim
|
|
self.v_head_dim = config.v_head_dim
|
|
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
|
|
|
# Runtime parameters
|
|
self.scaling = config.scaling
|
|
self.data_type = model_runner.kv_cache_dtype
|
|
self.q_data_type = model_runner.dtype
|
|
self.page_size = model_runner.page_size
|
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
|
|
|
# Workspace allocation
|
|
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
|
global global_zero_init_workspace_buffer
|
|
if global_zero_init_workspace_buffer is None:
|
|
global_zero_init_workspace_buffer = torch.zeros(
|
|
self.workspace_size,
|
|
dtype=torch.uint8,
|
|
device=model_runner.device,
|
|
)
|
|
self.workspace_buffer = global_zero_init_workspace_buffer
|
|
|
|
# CUDA graph state
|
|
self.decode_cuda_graph_metadata = {}
|
|
self.decode_cuda_graph_kv_indices = None
|
|
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
|
|
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
|
|
|
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
|
"""
|
|
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
|
|
|
Args:
|
|
max_seq_len: Maximum sequence length in tokens
|
|
|
|
Returns:
|
|
Number of blocks padded to satisfy all constraints
|
|
"""
|
|
blocks = triton.cdiv(max_seq_len, self.page_size)
|
|
|
|
# Apply dual constraints (take LCM to satisfy both):
|
|
# 1. TRT-LLM: block_num % (128 / page_size) == 0
|
|
# 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
|
|
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
|
|
constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
|
|
|
|
if blocks % constraint_lcm != 0:
|
|
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
|
|
return blocks
|
|
|
|
def _create_block_kv_indices(
|
|
self,
|
|
batch_size: int,
|
|
max_blocks: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
device: torch.device,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Create block KV indices tensor using Triton kernel.
|
|
|
|
Args:
|
|
batch_size: Batch size
|
|
max_blocks: Maximum number of blocks per sequence
|
|
req_pool_indices: Request pool indices
|
|
seq_lens: Sequence lengths
|
|
device: Target device
|
|
|
|
Returns:
|
|
Block KV indices tensor
|
|
"""
|
|
block_kv_indices = torch.full(
|
|
(batch_size, max_blocks), -1, dtype=torch.int32, device=device
|
|
)
|
|
|
|
create_flashmla_kv_indices_triton[(batch_size,)](
|
|
self.req_to_token,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
None,
|
|
block_kv_indices,
|
|
self.req_to_token.stride(0),
|
|
max_blocks,
|
|
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
|
PAGED_SIZE=self.page_size,
|
|
)
|
|
|
|
return block_kv_indices
|
|
|
|
def init_cuda_graph_state(
|
|
self,
|
|
max_bs: int,
|
|
max_num_tokens: int,
|
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
|
):
|
|
"""Initialize CUDA graph state for TRTLLM MLA."""
|
|
|
|
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
|
|
|
self.decode_cuda_graph_kv_indices = torch.full(
|
|
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
|
)
|
|
self.decode_cuda_graph_workspace = torch.empty(
|
|
self.workspace_size, dtype=torch.int8, device=self.device
|
|
)
|
|
|
|
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
|
|
|
def init_forward_metadata_capture_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
num_tokens: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[SpecInfo],
|
|
):
|
|
"""Initialize metadata for CUDA graph capture."""
|
|
|
|
# Delegate to parent for non-decode modes.
|
|
if not forward_mode.is_decode_or_idle():
|
|
return super().init_forward_metadata_capture_cuda_graph(
|
|
bs,
|
|
num_tokens,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
encoder_lens,
|
|
forward_mode,
|
|
spec_info,
|
|
)
|
|
|
|
# Custom fast-path for decode/idle.
|
|
# Capture with full width so future longer sequences are safe during replay
|
|
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
|
block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_blocks_per_seq]
|
|
|
|
create_flashmla_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
None,
|
|
block_kv_indices,
|
|
self.req_to_token.stride(0),
|
|
max_blocks_per_seq,
|
|
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
|
PAGED_SIZE=self.page_size,
|
|
)
|
|
|
|
# Record the true maximum sequence length for this capture batch so that
|
|
# the kernel launch path (which requires an int not a tensor) can reuse
|
|
# it safely during both capture and replay.
|
|
max_seq_len_val = int(seq_lens.max().item())
|
|
|
|
metadata = TRTLLMMLADecodeMetadata(
|
|
self.decode_cuda_graph_workspace,
|
|
block_kv_indices,
|
|
max_seq_len_val,
|
|
)
|
|
self.decode_cuda_graph_metadata[bs] = metadata
|
|
self.forward_decode_metadata = metadata
|
|
|
|
def init_forward_metadata_replay_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[SpecInfo],
|
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
):
|
|
"""Replay CUDA graph with new inputs."""
|
|
# Delegate to parent for non-decode modes.
|
|
if not forward_mode.is_decode_or_idle():
|
|
return super().init_forward_metadata_replay_cuda_graph(
|
|
bs,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
seq_lens_sum,
|
|
encoder_lens,
|
|
forward_mode,
|
|
spec_info,
|
|
seq_lens_cpu,
|
|
)
|
|
|
|
metadata = self.decode_cuda_graph_metadata[bs]
|
|
|
|
# Update block indices for new sequences.
|
|
create_flashmla_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices[:bs],
|
|
seq_lens[:bs],
|
|
None,
|
|
metadata.block_kv_indices,
|
|
self.req_to_token.stride(0),
|
|
metadata.block_kv_indices.shape[1],
|
|
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
|
PAGED_SIZE=self.page_size,
|
|
)
|
|
|
|
# Update stored max_seq_len so subsequent kernel calls use the correct value
|
|
# Prefer CPU tensor to avoid GPU synchronization when available.
|
|
if seq_lens_cpu is not None:
|
|
metadata.max_seq_len = int(seq_lens_cpu.max().item())
|
|
else:
|
|
metadata.max_seq_len = int(seq_lens.max().item())
|
|
|
|
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
|
"""Get the fill value for sequence lengths in CUDA graph."""
|
|
return 1
|
|
|
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
"""Initialize the metadata for a forward pass."""
|
|
# Delegate to parent for non-decode modes.
|
|
if (
|
|
forward_batch.forward_mode.is_extend()
|
|
and not forward_batch.forward_mode.is_target_verify()
|
|
and not forward_batch.forward_mode.is_draft_extend()
|
|
):
|
|
seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
|
|
cum_seq_lens_q = torch.cat(
|
|
(
|
|
torch.tensor([0], device=forward_batch.seq_lens.device),
|
|
torch.cumsum(seq_lens, dim=0),
|
|
)
|
|
).int()
|
|
max_seq_len = max(forward_batch.extend_seq_lens_cpu)
|
|
self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata(
|
|
max_seq_len,
|
|
cum_seq_lens_q,
|
|
seq_lens,
|
|
)
|
|
elif forward_batch.forward_mode.is_decode_or_idle():
|
|
bs = forward_batch.batch_size
|
|
|
|
# Get maximum sequence length.
|
|
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
|
|
max_seq = forward_batch.seq_lens_cpu.max().item()
|
|
else:
|
|
max_seq = forward_batch.seq_lens.max().item()
|
|
|
|
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
|
block_kv_indices = self._create_block_kv_indices(
|
|
bs,
|
|
max_seqlen_pad,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
forward_batch.seq_lens.device,
|
|
)
|
|
|
|
max_seq_len_val = int(max_seq)
|
|
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
|
self.workspace_buffer, block_kv_indices, max_seq_len_val
|
|
)
|
|
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
|
|
else:
|
|
return super().init_forward_metadata(forward_batch)
|
|
|
|
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
|
super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)
|
|
|
|
def quantize_and_rope_for_fp8(
|
|
self,
|
|
q_nope: torch.Tensor,
|
|
q_rope: torch.Tensor,
|
|
k_nope: torch.Tensor,
|
|
k_rope: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
cos_sin_cache: torch.Tensor,
|
|
is_neox: bool,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Quantize and apply RoPE for FP8 attention path.
|
|
|
|
This function handles the FP8 quantization and RoPE application for MLA attention.
|
|
It takes separate query/key nope and rope components, applies RoPE to the rope parts,
|
|
quantizes all components to FP8, and merges the query components into a single tensor.
|
|
|
|
Args:
|
|
q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank]
|
|
- expected dtype: torch.bfloat16
|
|
q_rope: Query RoPE component [seq_len, num_heads, qk_rope_head_dim]
|
|
- expected dtype: torch.bfloat16
|
|
k_nope: Key no-position-encoding component [seq_len, num_heads, kv_lora_rank]
|
|
- expected dtype: torch.bfloat16
|
|
k_rope: Key RoPE component [seq_len, num_heads, qk_rope_head_dim]
|
|
- expected dtype: torch.bfloat16
|
|
forward_batch: Forward batch containing position information
|
|
cos_sin_cache: Precomputed cosine/sine cache for RoPE
|
|
- expected dtype: matches q_/k_ input dtype (torch.bfloat16)
|
|
is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation)
|
|
|
|
Returns:
|
|
tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8
|
|
- merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn
|
|
- k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn
|
|
- k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn
|
|
"""
|
|
attn_dtype = torch.float8_e4m3fn
|
|
q_len, num_heads = q_rope.shape[0], q_rope.shape[1]
|
|
|
|
# Allocate output tensors with FP8 dtype
|
|
# Query output will contain merged nope + rope components
|
|
q_out = q_rope.new_empty(
|
|
q_len,
|
|
num_heads,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
dtype=attn_dtype,
|
|
)
|
|
|
|
# Key outputs maintain original shapes but with FP8 dtype
|
|
k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype)
|
|
k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype)
|
|
|
|
# Apply RoPE and quantize all components in a single fused kernel call
|
|
# This kernel handles:
|
|
# 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions
|
|
# 2. Quantization of all components to FP8 format
|
|
# 3. Output placement into pre-allocated tensors
|
|
flashinfer.rope.mla_rope_quantize_fp8(
|
|
q_rope=q_rope,
|
|
k_rope=k_rope,
|
|
q_nope=q_nope,
|
|
k_nope=k_nope,
|
|
cos_sin_cache=cos_sin_cache,
|
|
pos_ids=forward_batch.positions,
|
|
is_neox=is_neox,
|
|
quantize_dtype=attn_dtype,
|
|
# Output tensor slicing: q_out contains [nope_part, rope_part]
|
|
q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end
|
|
k_rope_out=k_rope_out,
|
|
q_nope_out=q_out[..., : self.kv_lora_rank], # Nope part goes to beginning
|
|
k_nope_out=k_nope_out,
|
|
# Quantization scales (set to 1.0 for no additional scaling)
|
|
quant_scale_q=1.0,
|
|
quant_scale_kv=1.0,
|
|
)
|
|
|
|
return q_out, k_nope_out, k_rope_out
|
|
|
|
def forward_decode(
|
|
self,
|
|
q: torch.Tensor, # q_nope
|
|
k: torch.Tensor, # k_nope
|
|
v: torch.Tensor, # not used in this backend
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache: bool = True,
|
|
q_rope: Optional[torch.Tensor] = None,
|
|
k_rope: Optional[torch.Tensor] = None,
|
|
cos_sin_cache: Optional[torch.Tensor] = None,
|
|
is_neox: Optional[bool] = False,
|
|
) -> torch.Tensor:
|
|
"""Run forward for decode using TRTLLM MLA kernel."""
|
|
merge_query = q_rope is not None
|
|
if self.data_type == torch.float8_e4m3fn:
|
|
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
|
|
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
|
|
assert all(
|
|
x is not None for x in [q_rope, k_rope, cos_sin_cache]
|
|
), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
|
|
q, k, k_rope = self.quantize_and_rope_for_fp8(
|
|
q,
|
|
q_rope,
|
|
k.squeeze(1),
|
|
k_rope.squeeze(1),
|
|
forward_batch,
|
|
cos_sin_cache,
|
|
is_neox,
|
|
)
|
|
merge_query = False
|
|
|
|
# Save KV cache if requested
|
|
if save_kv_cache:
|
|
assert (
|
|
k is not None and k_rope is not None
|
|
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
|
|
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
|
layer, forward_batch.out_cache_loc, k, k_rope
|
|
)
|
|
|
|
# Prepare query tensor inline
|
|
if merge_query:
|
|
# For FP16 path, we merge the query and rope parts into a single tensor
|
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
|
q_rope_reshaped = q_rope.view(
|
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
|
)
|
|
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
|
else:
|
|
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
|
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
|
|
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
|
|
if query.dim() == 3:
|
|
query = query.unsqueeze(1)
|
|
|
|
# Prepare KV cache inline
|
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
|
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
|
|
|
# Get metadata
|
|
metadata = (
|
|
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
|
or self.forward_decode_metadata
|
|
)
|
|
|
|
# Scale computation for TRTLLM MLA kernel BMM1 operation:
|
|
# The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
|
|
# Scale components:
|
|
# - q_scale: Query scaling factor (set to 1.0 for both FP16/FP8 paths)
|
|
# - k_scale: Key scaling factor from model checkpoint (defaults to 1.0 if not available)
|
|
# - softmax_scale: Attention softmax scaling = 1/sqrt(head_dim), pre-computed as layer.scaling
|
|
# This unified approach works for both FP16 and FP8 quantized attention paths.
|
|
q_scale = 1.0
|
|
k_scale = (
|
|
layer.k_scale_float
|
|
if getattr(layer, "k_scale_float", None) is not None
|
|
else 1.0
|
|
)
|
|
|
|
bmm1_scale = q_scale * k_scale * layer.scaling
|
|
|
|
# Call TRT-LLM kernel
|
|
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
|
query=query,
|
|
kv_cache=kv_cache,
|
|
workspace_buffer=metadata.workspace,
|
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
|
kv_lora_rank=self.kv_lora_rank,
|
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
|
block_tables=metadata.block_kv_indices,
|
|
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
|
max_seq_len=metadata.max_seq_len,
|
|
bmm1_scale=bmm1_scale,
|
|
)
|
|
|
|
# Reshape output directly without slicing
|
|
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
|
return output
|
|
|
|
def forward_extend(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache: bool = True,
|
|
q_rope: Optional[torch.Tensor] = None,
|
|
k_rope: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if (
|
|
forward_batch.forward_mode.is_target_verify()
|
|
or forward_batch.forward_mode.is_draft_extend()
|
|
):
|
|
return super().forward_extend(
|
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
|
)
|
|
|
|
if not forward_batch.attn_attend_prefix_cache:
|
|
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
|
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
|
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
|
query=q,
|
|
key=k,
|
|
value=v,
|
|
workspace_buffer=self.workspace_buffer,
|
|
seq_lens=self.forward_prefill_metadata.seq_lens,
|
|
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
|
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
|
bmm1_scale=layer.scaling,
|
|
bmm2_scale=1.0,
|
|
o_sf_scale=1.0,
|
|
batch_size=forward_batch.batch_size,
|
|
window_left=-1,
|
|
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
|
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
|
enable_pdl=False,
|
|
is_causal=True,
|
|
return_lse=forward_batch.mha_return_lse,
|
|
)
|
|
else:
|
|
# replace with trtllm ragged attention once accuracy is resolved.
|
|
output = super().forward_extend(
|
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
|
)
|
|
return output
|
|
|
|
|
|
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
|
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
|
|
|
|
def __init__(
|
|
self, model_runner: "ModelRunner", topk: int, speculative_num_steps: int
|
|
):
|
|
super().__init__(model_runner, topk, speculative_num_steps)
|
|
|
|
for i in range(self.speculative_num_steps):
|
|
self.attn_backends[i] = TRTLLMMLABackend(
|
|
model_runner,
|
|
skip_prefill=True,
|
|
kv_indptr_buf=self.kv_indptr[i],
|
|
q_indptr_decode_buf=self.q_indptr_decode,
|
|
)
|