Fix triton kernel illegal memory issue for eagle (#4100)
This commit is contained in:
@@ -292,11 +292,7 @@ class ForwardBatch:
|
|||||||
ret.extend_prefix_lens = torch.tensor(
|
ret.extend_prefix_lens = torch.tensor(
|
||||||
batch.extend_prefix_lens, dtype=torch.int32
|
batch.extend_prefix_lens, dtype=torch.int32
|
||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
if (
|
if model_runner.server_args.attention_backend != "torch_native":
|
||||||
model_runner.server_args.attention_backend != "torch_native"
|
|
||||||
# TODO: Fix triton kernel illegal memory access for EAGLE
|
|
||||||
and model_runner.server_args.speculative_algorithm != "EAGLE"
|
|
||||||
):
|
|
||||||
ret.extend_num_tokens = batch.extend_num_tokens
|
ret.extend_num_tokens = batch.extend_num_tokens
|
||||||
positions, ret.extend_start_loc = compute_position_triton(
|
positions, ret.extend_start_loc = compute_position_triton(
|
||||||
ret.extend_prefix_lens,
|
ret.extend_prefix_lens,
|
||||||
@@ -386,6 +382,8 @@ def compute_position_triton(
|
|||||||
):
|
):
|
||||||
"""Compute positions. It is a fused version of `compute_position_torch`."""
|
"""Compute positions. It is a fused version of `compute_position_torch`."""
|
||||||
batch_size = extend_seq_lens.shape[0]
|
batch_size = extend_seq_lens.shape[0]
|
||||||
|
has_prefix = extend_prefix_lens.shape[0] == batch_size
|
||||||
|
|
||||||
positions = torch.empty(
|
positions = torch.empty(
|
||||||
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
|
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
|
||||||
)
|
)
|
||||||
@@ -399,6 +397,7 @@ def compute_position_triton(
|
|||||||
extend_start_loc,
|
extend_start_loc,
|
||||||
extend_prefix_lens,
|
extend_prefix_lens,
|
||||||
extend_seq_lens,
|
extend_seq_lens,
|
||||||
|
has_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
return positions, extend_start_loc
|
return positions, extend_start_loc
|
||||||
@@ -410,11 +409,12 @@ def compute_position_kernel(
|
|||||||
extend_start_loc,
|
extend_start_loc,
|
||||||
extend_prefix_lens,
|
extend_prefix_lens,
|
||||||
extend_seq_lens,
|
extend_seq_lens,
|
||||||
|
has_prefix: tl.constexpr,
|
||||||
):
|
):
|
||||||
BLOCK_SIZE: tl.constexpr = 512
|
BLOCK_SIZE: tl.constexpr = 512
|
||||||
pid = tl.program_id(0).to(tl.int64)
|
pid = tl.program_id(0).to(tl.int64)
|
||||||
|
|
||||||
prefix_len = tl.load(extend_prefix_lens + pid)
|
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
|
||||||
seq_len = tl.load(extend_seq_lens + pid)
|
seq_len = tl.load(extend_seq_lens + pid)
|
||||||
|
|
||||||
# TODO: optimize this?
|
# TODO: optimize this?
|
||||||
|
|||||||
Reference in New Issue
Block a user