Fix triton kernel illegal memory issue for eagle (#4100)
This commit is contained in:
@@ -292,11 +292,7 @@ class ForwardBatch:
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
if (
|
||||
model_runner.server_args.attention_backend != "torch_native"
|
||||
# TODO: Fix triton kernel illegal memory access for EAGLE
|
||||
and model_runner.server_args.speculative_algorithm != "EAGLE"
|
||||
):
|
||||
if model_runner.server_args.attention_backend != "torch_native":
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens,
|
||||
@@ -386,6 +382,8 @@ def compute_position_triton(
|
||||
):
|
||||
"""Compute positions. It is a fused version of `compute_position_torch`."""
|
||||
batch_size = extend_seq_lens.shape[0]
|
||||
has_prefix = extend_prefix_lens.shape[0] == batch_size
|
||||
|
||||
positions = torch.empty(
|
||||
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
|
||||
)
|
||||
@@ -399,6 +397,7 @@ def compute_position_triton(
|
||||
extend_start_loc,
|
||||
extend_prefix_lens,
|
||||
extend_seq_lens,
|
||||
has_prefix,
|
||||
)
|
||||
|
||||
return positions, extend_start_loc
|
||||
@@ -410,11 +409,12 @@ def compute_position_kernel(
|
||||
extend_start_loc,
|
||||
extend_prefix_lens,
|
||||
extend_seq_lens,
|
||||
has_prefix: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(0).to(tl.int64)
|
||||
|
||||
prefix_len = tl.load(extend_prefix_lens + pid)
|
||||
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
|
||||
seq_len = tl.load(extend_seq_lens + pid)
|
||||
|
||||
# TODO: optimize this?
|
||||
|
||||
Reference in New Issue
Block a user