From ef9d3b3c2c8355993cd86508d03cc35da5e34a91 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Thu, 6 Mar 2025 03:23:53 +0800 Subject: [PATCH] Fix triton kernel illegal memory issue for eagle (#4100) --- .../sglang/srt/model_executor/forward_batch_info.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 826fd306b..70b8c6f46 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -292,11 +292,7 @@ class ForwardBatch: ret.extend_prefix_lens = torch.tensor( batch.extend_prefix_lens, dtype=torch.int32 ).to(device, non_blocking=True) - if ( - model_runner.server_args.attention_backend != "torch_native" - # TODO: Fix triton kernel illegal memory access for EAGLE - and model_runner.server_args.speculative_algorithm != "EAGLE" - ): + if model_runner.server_args.attention_backend != "torch_native": ret.extend_num_tokens = batch.extend_num_tokens positions, ret.extend_start_loc = compute_position_triton( ret.extend_prefix_lens, @@ -386,6 +382,8 @@ def compute_position_triton( ): """Compute positions. It is a fused version of `compute_position_torch`.""" batch_size = extend_seq_lens.shape[0] + has_prefix = extend_prefix_lens.shape[0] == batch_size + positions = torch.empty( extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device ) @@ -399,6 +397,7 @@ def compute_position_triton( extend_start_loc, extend_prefix_lens, extend_seq_lens, + has_prefix, ) return positions, extend_start_loc @@ -410,11 +409,12 @@ def compute_position_kernel( extend_start_loc, extend_prefix_lens, extend_seq_lens, + has_prefix: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 512 pid = tl.program_id(0).to(tl.int64) - prefix_len = tl.load(extend_prefix_lens + pid) + prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0 seq_len = tl.load(extend_seq_lens + pid) # TODO: optimize this?