Fix triton kernel illegal memory issue for eagle (#4100)

This commit is contained in:
Ke Bao
2025-03-06 03:23:53 +08:00
committed by GitHub
parent fc91d08a8f
commit ef9d3b3c2c

View File

@@ -292,11 +292,7 @@ class ForwardBatch:
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
if (
model_runner.server_args.attention_backend != "torch_native"
# TODO: Fix triton kernel illegal memory access for EAGLE
and model_runner.server_args.speculative_algorithm != "EAGLE"
):
if model_runner.server_args.attention_backend != "torch_native":
ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens,
@@ -386,6 +382,8 @@ def compute_position_triton(
):
"""Compute positions. It is a fused version of `compute_position_torch`."""
batch_size = extend_seq_lens.shape[0]
has_prefix = extend_prefix_lens.shape[0] == batch_size
positions = torch.empty(
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
)
@@ -399,6 +397,7 @@ def compute_position_triton(
extend_start_loc,
extend_prefix_lens,
extend_seq_lens,
has_prefix,
)
return positions, extend_start_loc
@@ -410,11 +409,12 @@ def compute_position_kernel(
extend_start_loc,
extend_prefix_lens,
extend_seq_lens,
has_prefix: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0).to(tl.int64)
prefix_len = tl.load(extend_prefix_lens + pid)
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
seq_len = tl.load(extend_seq_lens + pid)
# TODO: optimize this?