[Feature] improve TBO: two chunk overlap (#8144)

This commit is contained in:
HouseWest
2025-08-06 12:11:01 +08:00
committed by GitHub
parent d26ca84f39
commit ca47e24f5d
6 changed files with 218 additions and 29 deletions

View File

@@ -420,16 +420,12 @@ class ForwardBatch:
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_num_tokens = batch.extend_num_tokens
if support_triton(model_runner.server_args.attention_backend):
positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens,
ret.extend_seq_lens,
ret.extend_num_tokens,
)
else:
positions, ret.extend_start_loc = compute_position_torch(
ret.extend_prefix_lens, ret.extend_seq_lens
)
positions, ret.extend_start_loc = compute_position(
model_runner.server_args.attention_backend,
ret.extend_prefix_lens,
ret.extend_seq_lens,
ret.extend_num_tokens,
)
if ret.positions is None:
ret.positions = positions
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
@@ -882,6 +878,25 @@ class PPProxyTensors:
return f"PPProxyTensors(tensors={self.tensors})"
def compute_position(
attn_backend: str,
extend_prefix_lens: torch.Tensor,
extend_seq_lens: torch.Tensor,
extend_seq_lens_sum: int,
):
if support_triton(attn_backend):
positions, extend_start_loc = compute_position_triton(
extend_prefix_lens,
extend_seq_lens,
extend_seq_lens_sum,
)
else:
positions, extend_start_loc = compute_position_torch(
extend_prefix_lens, extend_seq_lens
)
return positions, extend_start_loc
def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
):