Use trtllm_mla decode kernel for draft extend in speculative decoding (#11664)
This commit is contained in:
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAAttnBackend,
|
||||
@@ -48,6 +49,151 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
||||
# compute the LCM with other padding constraints.
|
||||
TRTLLM_BLOCK_CONSTRAINT = 128
|
||||
|
||||
|
||||
@triton.jit
|
||||
def pad_draft_extend_query_kernel(
|
||||
q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim]
|
||||
padded_q_ptr, # Output padded query tensor [batch_size, max_seq_len, num_heads, head_dim]
|
||||
seq_lens_q_ptr, # Sequence lengths for each sequence [batch_size]
|
||||
cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1]
|
||||
batch_size,
|
||||
max_seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Triton kernel for padding draft extended query tensor with parallelized head and dim processing."""
|
||||
# Use 3D program IDs: (batch_seq, head_block, dim_block)
|
||||
batch_seq_pid = tl.program_id(0)
|
||||
head_pid = tl.program_id(1)
|
||||
dim_pid = tl.program_id(2)
|
||||
|
||||
batch_id = batch_seq_pid // max_seq_len
|
||||
seq_pos = batch_seq_pid % max_seq_len
|
||||
|
||||
if batch_id >= batch_size:
|
||||
return
|
||||
|
||||
# Load accept length for this batch
|
||||
seq_len = tl.load(seq_lens_q_ptr + batch_id)
|
||||
|
||||
if seq_pos >= seq_len:
|
||||
return
|
||||
|
||||
# Load cumulative sum to get start position in input tensor
|
||||
input_start = tl.load(cumsum_ptr + batch_id)
|
||||
input_pos = input_start + seq_pos
|
||||
|
||||
# Calculate head and dim block ranges
|
||||
head_start = head_pid * BLOCK_SIZE
|
||||
head_end = tl.minimum(head_start + BLOCK_SIZE, num_heads)
|
||||
head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start)
|
||||
|
||||
dim_start = dim_pid * BLOCK_SIZE
|
||||
dim_end = tl.minimum(dim_start + BLOCK_SIZE, head_dim)
|
||||
dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start)
|
||||
|
||||
# Calculate input offset
|
||||
input_offset = (
|
||||
input_pos * num_heads * head_dim
|
||||
+ (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim
|
||||
+ (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
|
||||
)
|
||||
|
||||
# Load data
|
||||
data = tl.load(
|
||||
q_ptr + input_offset,
|
||||
mask=head_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Calculate output offset
|
||||
output_offset = (
|
||||
batch_id * max_seq_len * num_heads * head_dim
|
||||
+ seq_pos * num_heads * head_dim
|
||||
+ (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim
|
||||
+ (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
|
||||
)
|
||||
|
||||
# Store data
|
||||
tl.store(
|
||||
padded_q_ptr + output_offset,
|
||||
data,
|
||||
mask=head_mask[:, None] & dim_mask[None, :],
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def unpad_draft_extend_output_kernel(
|
||||
raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
|
||||
output_ptr, # Output tensor (-1, tp_q_head_num, v_head_dim)
|
||||
accept_length_ptr, # Accept lengths for each sequence [batch_size]
|
||||
cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1]
|
||||
batch_size,
|
||||
token_per_batch,
|
||||
tp_q_head_num,
|
||||
v_head_dim,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Triton kernel for unpadding draft extended output tensor with parallelized head and dim processing."""
|
||||
batch_seq_pid = tl.program_id(0)
|
||||
head_pid = tl.program_id(1)
|
||||
dim_pid = tl.program_id(2)
|
||||
|
||||
batch_id = batch_seq_pid // token_per_batch
|
||||
seq_pos = batch_seq_pid % token_per_batch
|
||||
|
||||
if batch_id >= batch_size:
|
||||
return
|
||||
|
||||
# Load accept length for this batch
|
||||
accept_len = tl.load(accept_length_ptr + batch_id)
|
||||
|
||||
if seq_pos >= accept_len:
|
||||
return
|
||||
|
||||
# Load cumulative sum to get start position in output tensor
|
||||
output_start = tl.load(cumsum_ptr + batch_id)
|
||||
output_pos = output_start + seq_pos
|
||||
|
||||
# Calculate head and dim block ranges
|
||||
head_start = head_pid * BLOCK_SIZE
|
||||
head_end = tl.minimum(head_start + BLOCK_SIZE, tp_q_head_num)
|
||||
head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start)
|
||||
|
||||
dim_start = dim_pid * BLOCK_SIZE
|
||||
dim_end = tl.minimum(dim_start + BLOCK_SIZE, v_head_dim)
|
||||
dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start)
|
||||
|
||||
# Calculate input offset: (batch_id, seq_pos, head_id, dim_id)
|
||||
input_offset = (
|
||||
batch_id * token_per_batch * tp_q_head_num * v_head_dim
|
||||
+ seq_pos * tp_q_head_num * v_head_dim
|
||||
+ (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim
|
||||
+ (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
|
||||
)
|
||||
|
||||
# Load data
|
||||
data = tl.load(
|
||||
raw_out_ptr + input_offset,
|
||||
mask=head_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
output_offset = (
|
||||
output_pos * tp_q_head_num * v_head_dim
|
||||
+ (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim
|
||||
+ (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
|
||||
)
|
||||
|
||||
# Store data
|
||||
tl.store(
|
||||
output_ptr + output_offset,
|
||||
data,
|
||||
mask=head_mask[:, None] & dim_mask[None, :],
|
||||
)
|
||||
|
||||
|
||||
global_zero_init_workspace_buffer = None
|
||||
|
||||
|
||||
@@ -65,7 +211,11 @@ class TRTLLMMLADecodeMetadata:
|
||||
"""Metadata for TRTLLM MLA decode operations."""
|
||||
|
||||
block_kv_indices: Optional[torch.Tensor] = None
|
||||
max_seq_len: Optional[int] = None
|
||||
max_seq_len_k: Optional[int] = None
|
||||
max_seq_len_q: Optional[int] = None
|
||||
sum_seq_lens_q: Optional[int] = None
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None
|
||||
seq_lens_q: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
@@ -120,6 +270,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
# CUDA graph state
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.decode_cuda_graph_kv_indices = None
|
||||
self.padded_q_buffer = None
|
||||
self.unpad_output_buffer = None
|
||||
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
|
||||
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
||||
|
||||
@@ -203,6 +355,21 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
self.decode_cuda_graph_kv_indices = torch.full(
|
||||
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
num_tokens_per_bs = max_num_tokens // max_bs
|
||||
|
||||
# Buffer for padded query: (max_bs, max_draft_tokens, num_q_heads, v_head_dim)
|
||||
self.padded_q_buffer = torch.zeros(
|
||||
(max_bs, num_tokens_per_bs, self.num_q_heads, self.kv_cache_dim),
|
||||
dtype=self.data_type,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Buffer for unpadded output: (max_num_tokens, num_q_heads, v_head_dim)
|
||||
self.unpad_output_buffer = torch.zeros(
|
||||
(max_num_tokens, self.num_q_heads, 512),
|
||||
dtype=self.data_type,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
||||
|
||||
@@ -219,7 +386,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
"""Initialize metadata for CUDA graph capture."""
|
||||
|
||||
# Delegate to parent for non-decode modes.
|
||||
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
||||
if (
|
||||
not forward_mode.is_decode_or_idle()
|
||||
and not forward_mode.is_target_verify()
|
||||
and not forward_mode.is_draft_extend()
|
||||
):
|
||||
return super().init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
num_tokens,
|
||||
@@ -259,6 +430,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
block_kv_indices,
|
||||
max_seq_len_val,
|
||||
)
|
||||
if forward_mode.is_draft_extend():
|
||||
num_tokens_per_bs = num_tokens // bs
|
||||
metadata.max_seq_len_q = num_tokens_per_bs + 1
|
||||
metadata.sum_seq_lens_q = num_tokens_per_bs * bs
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
bs * num_tokens_per_bs + 1,
|
||||
num_tokens_per_bs,
|
||||
dtype=torch.int32,
|
||||
device=seq_lens.device,
|
||||
)
|
||||
metadata.seq_lens_q = torch.full(
|
||||
(bs,), num_tokens_per_bs, dtype=torch.int32, device=seq_lens.device
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
self.forward_decode_metadata = metadata
|
||||
|
||||
@@ -275,7 +460,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
):
|
||||
"""Replay CUDA graph with new inputs."""
|
||||
# Delegate to parent for non-decode modes.
|
||||
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
||||
if (
|
||||
not forward_mode.is_decode_or_idle()
|
||||
and not forward_mode.is_target_verify()
|
||||
and not forward_mode.is_draft_extend()
|
||||
):
|
||||
return super().init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
req_pool_indices,
|
||||
@@ -293,6 +482,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
|
||||
if forward_mode.is_draft_extend():
|
||||
accept_length = spec_info.accept_length[:bs]
|
||||
if spec_info.accept_length_cpu:
|
||||
metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs])
|
||||
metadata.sum_seq_lens_q = sum(spec_info.accept_length_cpu[:bs])
|
||||
else:
|
||||
metadata.max_seq_len_q = 1
|
||||
metadata.sum_seq_lens_q = bs
|
||||
metadata.cu_seqlens_q[1:].copy_(
|
||||
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
||||
)
|
||||
metadata.seq_lens_q.copy_(accept_length)
|
||||
|
||||
# Update block indices for new sequences.
|
||||
create_flashmla_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
@@ -344,6 +546,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
elif (
|
||||
forward_batch.forward_mode.is_decode_or_idle()
|
||||
or forward_batch.forward_mode.is_target_verify()
|
||||
or forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
@@ -372,6 +575,23 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
||||
block_kv_indices, max_seq_len_val
|
||||
)
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
max_seq = forward_batch.seq_lens_cpu.max().item()
|
||||
|
||||
sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
|
||||
max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
||||
cu_seqlens_q = torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
forward_batch.extend_seq_lens, dim=0, dtype=torch.int32
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
|
||||
self.forward_decode_metadata.max_seq_len_q = max_seq_len_q
|
||||
self.forward_decode_metadata.sum_seq_lens_q = sum_seq_lens_q
|
||||
self.forward_decode_metadata.cu_seqlens_q = cu_seqlens_q
|
||||
self.forward_decode_metadata.seq_lens_q = forward_batch.extend_seq_lens
|
||||
|
||||
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
|
||||
else:
|
||||
return super().init_forward_metadata(forward_batch)
|
||||
@@ -457,6 +677,86 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
return q_out, k_nope_out, k_rope_out
|
||||
|
||||
def pad_draft_extend_query(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
padded_q: torch.Tensor,
|
||||
seq_lens_q: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Pad draft extended query using Triton kernel."""
|
||||
batch_size = cu_seqlens_q.shape[0] - 1
|
||||
max_seq_len_q = padded_q.shape[1]
|
||||
num_heads = padded_q.shape[2]
|
||||
head_dim = padded_q.shape[3]
|
||||
|
||||
# Launch Triton kernel with 3D grid for parallelized head and dim processing
|
||||
BLOCK_SIZE = 64
|
||||
num_head_blocks = triton.cdiv(num_heads, BLOCK_SIZE)
|
||||
num_dim_blocks = triton.cdiv(head_dim, BLOCK_SIZE)
|
||||
grid = (batch_size * max_seq_len_q, num_head_blocks, num_dim_blocks)
|
||||
|
||||
pad_draft_extend_query_kernel[grid](
|
||||
q_ptr=q,
|
||||
padded_q_ptr=padded_q,
|
||||
seq_lens_q_ptr=seq_lens_q,
|
||||
cumsum_ptr=cu_seqlens_q,
|
||||
batch_size=batch_size,
|
||||
max_seq_len=max_seq_len_q,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return padded_q
|
||||
|
||||
def unpad_draft_extend_output(
|
||||
self,
|
||||
raw_out: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
seq_lens_q: torch.Tensor,
|
||||
sum_seq_lens_q: int,
|
||||
) -> torch.Tensor:
|
||||
"""Unpad draft extended output using Triton kernel."""
|
||||
# raw_out: (batch_size, token_per_batch, layer.tp_q_head_num, layer.v_head_dim)
|
||||
batch_size = seq_lens_q.shape[0]
|
||||
token_per_batch = raw_out.shape[1] # max_seq_len
|
||||
tp_q_head_num = raw_out.shape[2] # num_heads
|
||||
v_head_dim = raw_out.shape[3] # head_dim
|
||||
total_tokens = sum_seq_lens_q
|
||||
|
||||
# Check if we're in CUDA graph mode (buffers are pre-allocated)
|
||||
if self.unpad_output_buffer is not None:
|
||||
# Use pre-allocated buffer for CUDA graph compatibility
|
||||
output = self.unpad_output_buffer[:total_tokens, :, :].to(
|
||||
dtype=raw_out.dtype
|
||||
)
|
||||
else:
|
||||
# Dynamic allocation for non-CUDA graph mode
|
||||
output = torch.empty(
|
||||
(total_tokens, tp_q_head_num, v_head_dim),
|
||||
dtype=raw_out.dtype,
|
||||
device=raw_out.device,
|
||||
)
|
||||
|
||||
# Launch Triton kernel with 3D grid for parallelized head and dim processing
|
||||
BLOCK_SIZE = 64
|
||||
num_head_blocks = triton.cdiv(tp_q_head_num, BLOCK_SIZE)
|
||||
num_dim_blocks = triton.cdiv(v_head_dim, BLOCK_SIZE)
|
||||
grid = (batch_size * token_per_batch, num_head_blocks, num_dim_blocks)
|
||||
|
||||
unpad_draft_extend_output_kernel[grid](
|
||||
raw_out_ptr=raw_out,
|
||||
output_ptr=output,
|
||||
accept_length_ptr=seq_lens_q,
|
||||
cumsum_ptr=cu_seqlens_q,
|
||||
batch_size=batch_size,
|
||||
token_per_batch=token_per_batch,
|
||||
tp_q_head_num=tp_q_head_num,
|
||||
v_head_dim=v_head_dim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return output[:total_tokens, :, :]
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor, # q_nope
|
||||
@@ -550,7 +850,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=metadata.block_kv_indices,
|
||||
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
||||
max_seq_len=metadata.max_seq_len,
|
||||
max_seq_len=metadata.max_seq_len_k,
|
||||
bmm1_scale=bmm1_scale,
|
||||
)
|
||||
|
||||
@@ -571,11 +871,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
cos_sin_cache: Optional[torch.Tensor] = None,
|
||||
is_neox: Optional[bool] = False,
|
||||
) -> torch.Tensor:
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
return super().forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||
)
|
||||
|
||||
# TODO refactor to avoid code duplication
|
||||
merge_query = q_rope is not None
|
||||
if (
|
||||
@@ -627,7 +922,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
||||
|
||||
if forward_batch.forward_mode.is_target_verify():
|
||||
if (
|
||||
forward_batch.forward_mode.is_target_verify()
|
||||
or forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
metadata = (
|
||||
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
||||
or self.forward_decode_metadata
|
||||
@@ -635,7 +933,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
|
||||
bs = forward_batch.batch_size
|
||||
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
||||
@@ -646,17 +943,42 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
if getattr(layer, "k_scale_float", None) is not None
|
||||
else 1.0
|
||||
)
|
||||
q = q.to(self.data_type)
|
||||
|
||||
bmm1_scale = q_scale * k_scale * layer.scaling
|
||||
|
||||
seq_lens = (
|
||||
forward_batch.seq_lens.to(torch.int32)
|
||||
+ forward_batch.spec_info.draft_token_num
|
||||
)
|
||||
max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
|
||||
if forward_batch.forward_mode.is_target_verify():
|
||||
seq_lens = (
|
||||
forward_batch.seq_lens.to(torch.int32)
|
||||
+ forward_batch.spec_info.draft_token_num
|
||||
)
|
||||
max_seq_len = (
|
||||
metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num
|
||||
)
|
||||
else:
|
||||
seq_lens = forward_batch.seq_lens.to(torch.int32)
|
||||
max_seq_len = metadata.max_seq_len_k
|
||||
# Check if we're in CUDA graph mode (buffers are pre-allocated)
|
||||
if self.padded_q_buffer is not None:
|
||||
# Use pre-allocated buffer for CUDA graph compatibility
|
||||
padded_q = self.padded_q_buffer[
|
||||
:bs, : metadata.max_seq_len_q, :, :
|
||||
].to(dtype=q.dtype)
|
||||
else:
|
||||
# Dynamic allocation for non-CUDA graph mode
|
||||
padded_q = torch.zeros(
|
||||
bs,
|
||||
metadata.max_seq_len_q,
|
||||
layer.tp_q_head_num,
|
||||
layer.head_dim,
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
q = self.pad_draft_extend_query(
|
||||
q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q
|
||||
)
|
||||
|
||||
# TODO may use `mla_rope_quantize_fp8` fusion
|
||||
q = q.to(self.data_type)
|
||||
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
||||
assert kv_cache.dtype == self.data_type
|
||||
|
||||
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
||||
@@ -673,6 +995,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
)
|
||||
|
||||
# Reshape output directly without slicing
|
||||
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
raw_out = self.unpad_draft_extend_output(
|
||||
raw_out,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.seq_lens_q,
|
||||
metadata.sum_seq_lens_q,
|
||||
)
|
||||
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
return output
|
||||
|
||||
|
||||
@@ -1263,6 +1263,178 @@ class TestTRTLLMMLA(CustomTestCase):
|
||||
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
|
||||
)
|
||||
|
||||
def test_draft_extend_padding_unpadding_kernels(self):
|
||||
"""Test TRTLLM MLA Triton kernels: pad_draft_extend_query_kernel and unpad_draft_extend_output_kernel."""
|
||||
|
||||
# Import the kernels
|
||||
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
||||
pad_draft_extend_query_kernel,
|
||||
unpad_draft_extend_output_kernel,
|
||||
)
|
||||
|
||||
def _create_test_data(
|
||||
self, batch_size, max_seq_len, num_heads, head_dim, dtype=torch.float32
|
||||
):
|
||||
"""Create test data for kernel testing."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Create sequence lengths (varying lengths for each batch)
|
||||
seq_lens = torch.randint(
|
||||
1, max_seq_len + 1, (batch_size,), device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
# Create cumulative sequence lengths
|
||||
cum_seq_lens = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
|
||||
cum_seq_lens[1:] = torch.cumsum(seq_lens, dim=0)
|
||||
|
||||
# Create input query tensor (flattened format)
|
||||
total_tokens = cum_seq_lens[-1].item()
|
||||
q_input = torch.randn(
|
||||
total_tokens, num_heads, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# Create padded query tensor (batch format)
|
||||
padded_q = torch.zeros(
|
||||
batch_size, max_seq_len, num_heads, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
return q_input, padded_q, seq_lens, cum_seq_lens
|
||||
|
||||
def _create_test_output_data(
|
||||
self,
|
||||
batch_size,
|
||||
token_per_batch,
|
||||
tp_q_head_num,
|
||||
v_head_dim,
|
||||
dtype=torch.float32,
|
||||
):
|
||||
"""Create test data for unpad kernel testing."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Create accept lengths (varying lengths for each batch)
|
||||
accept_lengths = torch.randint(
|
||||
1, token_per_batch + 1, (batch_size,), device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
# Create cumulative accept lengths
|
||||
cum_accept_lengths = torch.zeros(
|
||||
batch_size + 1, device=device, dtype=torch.int32
|
||||
)
|
||||
cum_accept_lengths[1:] = torch.cumsum(accept_lengths, dim=0)
|
||||
|
||||
# Create raw output tensor (batch format)
|
||||
raw_out = torch.randn(
|
||||
batch_size,
|
||||
token_per_batch,
|
||||
tp_q_head_num,
|
||||
v_head_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Create output tensor (flattened format)
|
||||
total_tokens = cum_accept_lengths[-1].item()
|
||||
output = torch.empty(
|
||||
total_tokens, tp_q_head_num, v_head_dim, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
return raw_out, output, accept_lengths, cum_accept_lengths
|
||||
|
||||
# Test 1: pad_draft_extend_query_kernel basic functionality
|
||||
with self.subTest(test="pad_kernel_basic"):
|
||||
batch_size = 4
|
||||
max_seq_len = 8
|
||||
num_heads = 16
|
||||
head_dim = 64
|
||||
|
||||
q_input, padded_q, seq_lens, cum_seq_lens = _create_test_data(
|
||||
self, batch_size, max_seq_len, num_heads, head_dim
|
||||
)
|
||||
|
||||
# Launch kernel
|
||||
BLOCK_SIZE = 64
|
||||
grid = (batch_size * max_seq_len,)
|
||||
|
||||
pad_draft_extend_query_kernel[grid](
|
||||
q_ptr=q_input,
|
||||
padded_q_ptr=padded_q,
|
||||
seq_lens_q_ptr=seq_lens,
|
||||
cumsum_ptr=cum_seq_lens,
|
||||
batch_size=batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Verify the padding worked correctly
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i].item()
|
||||
|
||||
# Check that valid positions are copied correctly
|
||||
for pos in range(seq_len):
|
||||
input_start = cum_seq_lens[i].item()
|
||||
input_pos = input_start + pos
|
||||
|
||||
# Compare input and output for valid positions
|
||||
input_data = q_input[input_pos]
|
||||
output_data = padded_q[i, pos]
|
||||
|
||||
torch.testing.assert_close(
|
||||
input_data, output_data, rtol=1e-5, atol=1e-6
|
||||
)
|
||||
|
||||
# Check that invalid positions are zero
|
||||
for pos in range(seq_len, max_seq_len):
|
||||
output_data = padded_q[i, pos]
|
||||
self.assertTrue(
|
||||
torch.allclose(output_data, torch.zeros_like(output_data)),
|
||||
f"Position {pos} in batch {i} should be zero",
|
||||
)
|
||||
|
||||
# Test 2: unpad_draft_extend_output_kernel basic functionality
|
||||
with self.subTest(test="unpad_kernel_basic"):
|
||||
batch_size = 4
|
||||
token_per_batch = 8
|
||||
tp_q_head_num = 16
|
||||
v_head_dim = 64
|
||||
|
||||
raw_out, output, accept_lengths, cum_accept_lengths = (
|
||||
_create_test_output_data(
|
||||
self, batch_size, token_per_batch, tp_q_head_num, v_head_dim
|
||||
)
|
||||
)
|
||||
|
||||
# Launch kernel
|
||||
BLOCK_SIZE = 64
|
||||
grid = (batch_size * token_per_batch,)
|
||||
|
||||
unpad_draft_extend_output_kernel[grid](
|
||||
raw_out_ptr=raw_out,
|
||||
output_ptr=output,
|
||||
accept_length_ptr=accept_lengths,
|
||||
cumsum_ptr=cum_accept_lengths,
|
||||
batch_size=batch_size,
|
||||
token_per_batch=token_per_batch,
|
||||
tp_q_head_num=tp_q_head_num,
|
||||
v_head_dim=v_head_dim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Verify the unpadding worked correctly
|
||||
for i in range(batch_size):
|
||||
accept_len = accept_lengths[i].item()
|
||||
output_start = cum_accept_lengths[i].item()
|
||||
|
||||
# Check that valid positions are copied correctly
|
||||
for pos in range(accept_len):
|
||||
input_data = raw_out[i, pos]
|
||||
output_data = output[output_start + pos]
|
||||
|
||||
torch.testing.assert_close(
|
||||
input_data, output_data, rtol=1e-5, atol=1e-6
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user