From d9a20fd28ae67c0b7666ee40dc86f73239a95c3d Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Mon, 20 Oct 2025 20:42:09 -0700 Subject: [PATCH] Use trtllm_mla decode kernel for draft extend in speculative decoding (#11664) --- .../layers/attention/trtllm_mla_backend.py | 366 +++++++++++++++++- .../test/attention/test_trtllm_mla_backend.py | 172 ++++++++ 2 files changed, 520 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 4943eed90..12eb04b09 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Optional, Union import torch import triton +import triton.language as tl from sglang.srt.layers.attention.flashinfer_mla_backend import ( FlashInferMLAAttnBackend, @@ -48,6 +49,151 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB # compute the LCM with other padding constraints. TRTLLM_BLOCK_CONSTRAINT = 128 + +@triton.jit +def pad_draft_extend_query_kernel( + q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim] + padded_q_ptr, # Output padded query tensor [batch_size, max_seq_len, num_heads, head_dim] + seq_lens_q_ptr, # Sequence lengths for each sequence [batch_size] + cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1] + batch_size, + max_seq_len, + num_heads, + head_dim, + BLOCK_SIZE: tl.constexpr, +): + """Triton kernel for padding draft extended query tensor with parallelized head and dim processing.""" + # Use 3D program IDs: (batch_seq, head_block, dim_block) + batch_seq_pid = tl.program_id(0) + head_pid = tl.program_id(1) + dim_pid = tl.program_id(2) + + batch_id = batch_seq_pid // max_seq_len + seq_pos = batch_seq_pid % max_seq_len + + if batch_id >= batch_size: + return + + # Load accept length for this batch + seq_len = tl.load(seq_lens_q_ptr + batch_id) + + if seq_pos >= seq_len: + return + + # Load cumulative sum to get start position in input tensor + input_start = tl.load(cumsum_ptr + batch_id) + input_pos = input_start + seq_pos + + # Calculate head and dim block ranges + head_start = head_pid * BLOCK_SIZE + head_end = tl.minimum(head_start + BLOCK_SIZE, num_heads) + head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start) + + dim_start = dim_pid * BLOCK_SIZE + dim_end = tl.minimum(dim_start + BLOCK_SIZE, head_dim) + dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start) + + # Calculate input offset + input_offset = ( + input_pos * num_heads * head_dim + + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim + + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :] + ) + + # Load data + data = tl.load( + q_ptr + input_offset, + mask=head_mask[:, None] & dim_mask[None, :], + other=0.0, + ) + + # Calculate output offset + output_offset = ( + batch_id * max_seq_len * num_heads * head_dim + + seq_pos * num_heads * head_dim + + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim + + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :] + ) + + # Store data + tl.store( + padded_q_ptr + output_offset, + data, + mask=head_mask[:, None] & dim_mask[None, :], + ) + + +@triton.jit +def unpad_draft_extend_output_kernel( + raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim) + output_ptr, # Output tensor (-1, tp_q_head_num, v_head_dim) + accept_length_ptr, # Accept lengths for each sequence [batch_size] + cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1] + batch_size, + token_per_batch, + tp_q_head_num, + v_head_dim, + BLOCK_SIZE: tl.constexpr, +): + """Triton kernel for unpadding draft extended output tensor with parallelized head and dim processing.""" + batch_seq_pid = tl.program_id(0) + head_pid = tl.program_id(1) + dim_pid = tl.program_id(2) + + batch_id = batch_seq_pid // token_per_batch + seq_pos = batch_seq_pid % token_per_batch + + if batch_id >= batch_size: + return + + # Load accept length for this batch + accept_len = tl.load(accept_length_ptr + batch_id) + + if seq_pos >= accept_len: + return + + # Load cumulative sum to get start position in output tensor + output_start = tl.load(cumsum_ptr + batch_id) + output_pos = output_start + seq_pos + + # Calculate head and dim block ranges + head_start = head_pid * BLOCK_SIZE + head_end = tl.minimum(head_start + BLOCK_SIZE, tp_q_head_num) + head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start) + + dim_start = dim_pid * BLOCK_SIZE + dim_end = tl.minimum(dim_start + BLOCK_SIZE, v_head_dim) + dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start) + + # Calculate input offset: (batch_id, seq_pos, head_id, dim_id) + input_offset = ( + batch_id * token_per_batch * tp_q_head_num * v_head_dim + + seq_pos * tp_q_head_num * v_head_dim + + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim + + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :] + ) + + # Load data + data = tl.load( + raw_out_ptr + input_offset, + mask=head_mask[:, None] & dim_mask[None, :], + other=0.0, + ) + + output_offset = ( + output_pos * tp_q_head_num * v_head_dim + + (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim + + (dim_start + tl.arange(0, BLOCK_SIZE))[None, :] + ) + + # Store data + tl.store( + output_ptr + output_offset, + data, + mask=head_mask[:, None] & dim_mask[None, :], + ) + + global_zero_init_workspace_buffer = None @@ -65,7 +211,11 @@ class TRTLLMMLADecodeMetadata: """Metadata for TRTLLM MLA decode operations.""" block_kv_indices: Optional[torch.Tensor] = None - max_seq_len: Optional[int] = None + max_seq_len_k: Optional[int] = None + max_seq_len_q: Optional[int] = None + sum_seq_lens_q: Optional[int] = None + cu_seqlens_q: Optional[torch.Tensor] = None + seq_lens_q: Optional[torch.Tensor] = None class TRTLLMMLABackend(FlashInferMLAAttnBackend): @@ -120,6 +270,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): # CUDA graph state self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_kv_indices = None + self.padded_q_buffer = None + self.unpad_output_buffer = None self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None @@ -203,6 +355,21 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): self.decode_cuda_graph_kv_indices = torch.full( (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device ) + num_tokens_per_bs = max_num_tokens // max_bs + + # Buffer for padded query: (max_bs, max_draft_tokens, num_q_heads, v_head_dim) + self.padded_q_buffer = torch.zeros( + (max_bs, num_tokens_per_bs, self.num_q_heads, self.kv_cache_dim), + dtype=self.data_type, + device=self.device, + ) + + # Buffer for unpadded output: (max_num_tokens, num_q_heads, v_head_dim) + self.unpad_output_buffer = torch.zeros( + (max_num_tokens, self.num_q_heads, 512), + dtype=self.data_type, + device=self.device, + ) super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf) @@ -219,7 +386,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): """Initialize metadata for CUDA graph capture.""" # Delegate to parent for non-decode modes. - if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify(): + if ( + not forward_mode.is_decode_or_idle() + and not forward_mode.is_target_verify() + and not forward_mode.is_draft_extend() + ): return super().init_forward_metadata_capture_cuda_graph( bs, num_tokens, @@ -259,6 +430,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): block_kv_indices, max_seq_len_val, ) + if forward_mode.is_draft_extend(): + num_tokens_per_bs = num_tokens // bs + metadata.max_seq_len_q = num_tokens_per_bs + 1 + metadata.sum_seq_lens_q = num_tokens_per_bs * bs + metadata.cu_seqlens_q = torch.arange( + 0, + bs * num_tokens_per_bs + 1, + num_tokens_per_bs, + dtype=torch.int32, + device=seq_lens.device, + ) + metadata.seq_lens_q = torch.full( + (bs,), num_tokens_per_bs, dtype=torch.int32, device=seq_lens.device + ) self.decode_cuda_graph_metadata[bs] = metadata self.forward_decode_metadata = metadata @@ -275,7 +460,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ): """Replay CUDA graph with new inputs.""" # Delegate to parent for non-decode modes. - if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify(): + if ( + not forward_mode.is_decode_or_idle() + and not forward_mode.is_target_verify() + and not forward_mode.is_draft_extend() + ): return super().init_forward_metadata_replay_cuda_graph( bs, req_pool_indices, @@ -293,6 +482,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): metadata = self.decode_cuda_graph_metadata[bs] + if forward_mode.is_draft_extend(): + accept_length = spec_info.accept_length[:bs] + if spec_info.accept_length_cpu: + metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs]) + metadata.sum_seq_lens_q = sum(spec_info.accept_length_cpu[:bs]) + else: + metadata.max_seq_len_q = 1 + metadata.sum_seq_lens_q = bs + metadata.cu_seqlens_q[1:].copy_( + torch.cumsum(accept_length, dim=0, dtype=torch.int32) + ) + metadata.seq_lens_q.copy_(accept_length) + # Update block indices for new sequences. create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, @@ -344,6 +546,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): elif ( forward_batch.forward_mode.is_decode_or_idle() or forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend() ): bs = forward_batch.batch_size @@ -372,6 +575,23 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): self.forward_decode_metadata = TRTLLMMLADecodeMetadata( block_kv_indices, max_seq_len_val ) + if forward_batch.forward_mode.is_draft_extend(): + max_seq = forward_batch.seq_lens_cpu.max().item() + + sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu) + max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) + cu_seqlens_q = torch.nn.functional.pad( + torch.cumsum( + forward_batch.extend_seq_lens, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + + self.forward_decode_metadata.max_seq_len_q = max_seq_len_q + self.forward_decode_metadata.sum_seq_lens_q = sum_seq_lens_q + self.forward_decode_metadata.cu_seqlens_q = cu_seqlens_q + self.forward_decode_metadata.seq_lens_q = forward_batch.extend_seq_lens + forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata else: return super().init_forward_metadata(forward_batch) @@ -457,6 +677,86 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): return q_out, k_nope_out, k_rope_out + def pad_draft_extend_query( + self, + q: torch.Tensor, + padded_q: torch.Tensor, + seq_lens_q: torch.Tensor, + cu_seqlens_q: torch.Tensor, + ) -> torch.Tensor: + """Pad draft extended query using Triton kernel.""" + batch_size = cu_seqlens_q.shape[0] - 1 + max_seq_len_q = padded_q.shape[1] + num_heads = padded_q.shape[2] + head_dim = padded_q.shape[3] + + # Launch Triton kernel with 3D grid for parallelized head and dim processing + BLOCK_SIZE = 64 + num_head_blocks = triton.cdiv(num_heads, BLOCK_SIZE) + num_dim_blocks = triton.cdiv(head_dim, BLOCK_SIZE) + grid = (batch_size * max_seq_len_q, num_head_blocks, num_dim_blocks) + + pad_draft_extend_query_kernel[grid]( + q_ptr=q, + padded_q_ptr=padded_q, + seq_lens_q_ptr=seq_lens_q, + cumsum_ptr=cu_seqlens_q, + batch_size=batch_size, + max_seq_len=max_seq_len_q, + num_heads=num_heads, + head_dim=head_dim, + BLOCK_SIZE=BLOCK_SIZE, + ) + return padded_q + + def unpad_draft_extend_output( + self, + raw_out: torch.Tensor, + cu_seqlens_q: torch.Tensor, + seq_lens_q: torch.Tensor, + sum_seq_lens_q: int, + ) -> torch.Tensor: + """Unpad draft extended output using Triton kernel.""" + # raw_out: (batch_size, token_per_batch, layer.tp_q_head_num, layer.v_head_dim) + batch_size = seq_lens_q.shape[0] + token_per_batch = raw_out.shape[1] # max_seq_len + tp_q_head_num = raw_out.shape[2] # num_heads + v_head_dim = raw_out.shape[3] # head_dim + total_tokens = sum_seq_lens_q + + # Check if we're in CUDA graph mode (buffers are pre-allocated) + if self.unpad_output_buffer is not None: + # Use pre-allocated buffer for CUDA graph compatibility + output = self.unpad_output_buffer[:total_tokens, :, :].to( + dtype=raw_out.dtype + ) + else: + # Dynamic allocation for non-CUDA graph mode + output = torch.empty( + (total_tokens, tp_q_head_num, v_head_dim), + dtype=raw_out.dtype, + device=raw_out.device, + ) + + # Launch Triton kernel with 3D grid for parallelized head and dim processing + BLOCK_SIZE = 64 + num_head_blocks = triton.cdiv(tp_q_head_num, BLOCK_SIZE) + num_dim_blocks = triton.cdiv(v_head_dim, BLOCK_SIZE) + grid = (batch_size * token_per_batch, num_head_blocks, num_dim_blocks) + + unpad_draft_extend_output_kernel[grid]( + raw_out_ptr=raw_out, + output_ptr=output, + accept_length_ptr=seq_lens_q, + cumsum_ptr=cu_seqlens_q, + batch_size=batch_size, + token_per_batch=token_per_batch, + tp_q_head_num=tp_q_head_num, + v_head_dim=v_head_dim, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output[:total_tokens, :, :] + def forward_decode( self, q: torch.Tensor, # q_nope @@ -550,7 +850,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): qk_rope_head_dim=self.qk_rope_head_dim, block_tables=metadata.block_kv_indices, seq_lens=forward_batch.seq_lens.to(torch.int32), - max_seq_len=metadata.max_seq_len, + max_seq_len=metadata.max_seq_len_k, bmm1_scale=bmm1_scale, ) @@ -571,11 +871,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): cos_sin_cache: Optional[torch.Tensor] = None, is_neox: Optional[bool] = False, ) -> torch.Tensor: - if forward_batch.forward_mode.is_draft_extend(): - return super().forward_extend( - q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope - ) - # TODO refactor to avoid code duplication merge_query = q_rope is not None if ( @@ -627,7 +922,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim) - if forward_batch.forward_mode.is_target_verify(): + if ( + forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend() + ): metadata = ( getattr(forward_batch, "decode_trtllm_mla_metadata", None) or self.forward_decode_metadata @@ -635,7 +933,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): # Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim] bs = forward_batch.batch_size - q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1) @@ -646,17 +943,42 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): if getattr(layer, "k_scale_float", None) is not None else 1.0 ) + q = q.to(self.data_type) bmm1_scale = q_scale * k_scale * layer.scaling - - seq_lens = ( - forward_batch.seq_lens.to(torch.int32) - + forward_batch.spec_info.draft_token_num - ) - max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num + if forward_batch.forward_mode.is_target_verify(): + seq_lens = ( + forward_batch.seq_lens.to(torch.int32) + + forward_batch.spec_info.draft_token_num + ) + max_seq_len = ( + metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num + ) + else: + seq_lens = forward_batch.seq_lens.to(torch.int32) + max_seq_len = metadata.max_seq_len_k + # Check if we're in CUDA graph mode (buffers are pre-allocated) + if self.padded_q_buffer is not None: + # Use pre-allocated buffer for CUDA graph compatibility + padded_q = self.padded_q_buffer[ + :bs, : metadata.max_seq_len_q, :, : + ].to(dtype=q.dtype) + else: + # Dynamic allocation for non-CUDA graph mode + padded_q = torch.zeros( + bs, + metadata.max_seq_len_q, + layer.tp_q_head_num, + layer.head_dim, + dtype=q.dtype, + device=q.device, + ) + q = self.pad_draft_extend_query( + q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q + ) # TODO may use `mla_rope_quantize_fp8` fusion - q = q.to(self.data_type) + q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) assert kv_cache.dtype == self.data_type raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( @@ -673,6 +995,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ) # Reshape output directly without slicing + + if forward_batch.forward_mode.is_draft_extend(): + raw_out = self.unpad_draft_extend_output( + raw_out, + metadata.cu_seqlens_q, + metadata.seq_lens_q, + metadata.sum_seq_lens_q, + ) output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) return output diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py index 6f610baf0..3fabb83bc 100755 --- a/python/sglang/test/attention/test_trtllm_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -1263,6 +1263,178 @@ class TestTRTLLMMLA(CustomTestCase): f"Max diff: {(out_trtllm - out_reference).abs().max().item()}", ) + def test_draft_extend_padding_unpadding_kernels(self): + """Test TRTLLM MLA Triton kernels: pad_draft_extend_query_kernel and unpad_draft_extend_output_kernel.""" + + # Import the kernels + from sglang.srt.layers.attention.trtllm_mla_backend import ( + pad_draft_extend_query_kernel, + unpad_draft_extend_output_kernel, + ) + + def _create_test_data( + self, batch_size, max_seq_len, num_heads, head_dim, dtype=torch.float32 + ): + """Create test data for kernel testing.""" + device = torch.device("cuda") + + # Create sequence lengths (varying lengths for each batch) + seq_lens = torch.randint( + 1, max_seq_len + 1, (batch_size,), device=device, dtype=torch.int32 + ) + + # Create cumulative sequence lengths + cum_seq_lens = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) + cum_seq_lens[1:] = torch.cumsum(seq_lens, dim=0) + + # Create input query tensor (flattened format) + total_tokens = cum_seq_lens[-1].item() + q_input = torch.randn( + total_tokens, num_heads, head_dim, device=device, dtype=dtype + ) + + # Create padded query tensor (batch format) + padded_q = torch.zeros( + batch_size, max_seq_len, num_heads, head_dim, device=device, dtype=dtype + ) + + return q_input, padded_q, seq_lens, cum_seq_lens + + def _create_test_output_data( + self, + batch_size, + token_per_batch, + tp_q_head_num, + v_head_dim, + dtype=torch.float32, + ): + """Create test data for unpad kernel testing.""" + device = torch.device("cuda") + + # Create accept lengths (varying lengths for each batch) + accept_lengths = torch.randint( + 1, token_per_batch + 1, (batch_size,), device=device, dtype=torch.int32 + ) + + # Create cumulative accept lengths + cum_accept_lengths = torch.zeros( + batch_size + 1, device=device, dtype=torch.int32 + ) + cum_accept_lengths[1:] = torch.cumsum(accept_lengths, dim=0) + + # Create raw output tensor (batch format) + raw_out = torch.randn( + batch_size, + token_per_batch, + tp_q_head_num, + v_head_dim, + device=device, + dtype=dtype, + ) + + # Create output tensor (flattened format) + total_tokens = cum_accept_lengths[-1].item() + output = torch.empty( + total_tokens, tp_q_head_num, v_head_dim, device=device, dtype=dtype + ) + + return raw_out, output, accept_lengths, cum_accept_lengths + + # Test 1: pad_draft_extend_query_kernel basic functionality + with self.subTest(test="pad_kernel_basic"): + batch_size = 4 + max_seq_len = 8 + num_heads = 16 + head_dim = 64 + + q_input, padded_q, seq_lens, cum_seq_lens = _create_test_data( + self, batch_size, max_seq_len, num_heads, head_dim + ) + + # Launch kernel + BLOCK_SIZE = 64 + grid = (batch_size * max_seq_len,) + + pad_draft_extend_query_kernel[grid]( + q_ptr=q_input, + padded_q_ptr=padded_q, + seq_lens_q_ptr=seq_lens, + cumsum_ptr=cum_seq_lens, + batch_size=batch_size, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Verify the padding worked correctly + for i in range(batch_size): + seq_len = seq_lens[i].item() + + # Check that valid positions are copied correctly + for pos in range(seq_len): + input_start = cum_seq_lens[i].item() + input_pos = input_start + pos + + # Compare input and output for valid positions + input_data = q_input[input_pos] + output_data = padded_q[i, pos] + + torch.testing.assert_close( + input_data, output_data, rtol=1e-5, atol=1e-6 + ) + + # Check that invalid positions are zero + for pos in range(seq_len, max_seq_len): + output_data = padded_q[i, pos] + self.assertTrue( + torch.allclose(output_data, torch.zeros_like(output_data)), + f"Position {pos} in batch {i} should be zero", + ) + + # Test 2: unpad_draft_extend_output_kernel basic functionality + with self.subTest(test="unpad_kernel_basic"): + batch_size = 4 + token_per_batch = 8 + tp_q_head_num = 16 + v_head_dim = 64 + + raw_out, output, accept_lengths, cum_accept_lengths = ( + _create_test_output_data( + self, batch_size, token_per_batch, tp_q_head_num, v_head_dim + ) + ) + + # Launch kernel + BLOCK_SIZE = 64 + grid = (batch_size * token_per_batch,) + + unpad_draft_extend_output_kernel[grid]( + raw_out_ptr=raw_out, + output_ptr=output, + accept_length_ptr=accept_lengths, + cumsum_ptr=cum_accept_lengths, + batch_size=batch_size, + token_per_batch=token_per_batch, + tp_q_head_num=tp_q_head_num, + v_head_dim=v_head_dim, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Verify the unpadding worked correctly + for i in range(batch_size): + accept_len = accept_lengths[i].item() + output_start = cum_accept_lengths[i].item() + + # Check that valid positions are copied correctly + for pos in range(accept_len): + input_data = raw_out[i, pos] + output_data = output[output_start + pos] + + torch.testing.assert_close( + input_data, output_data, rtol=1e-5, atol=1e-6 + ) + if __name__ == "__main__": unittest.main()