[Feature] improve TBO: two chunk overlap (#8144)
This commit is contained in:
@@ -420,16 +420,12 @@ class ForwardBatch:
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
if support_triton(model_runner.server_args.attention_backend):
|
||||
positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens,
|
||||
ret.extend_seq_lens,
|
||||
ret.extend_num_tokens,
|
||||
)
|
||||
else:
|
||||
positions, ret.extend_start_loc = compute_position_torch(
|
||||
ret.extend_prefix_lens, ret.extend_seq_lens
|
||||
)
|
||||
positions, ret.extend_start_loc = compute_position(
|
||||
model_runner.server_args.attention_backend,
|
||||
ret.extend_prefix_lens,
|
||||
ret.extend_seq_lens,
|
||||
ret.extend_num_tokens,
|
||||
)
|
||||
if ret.positions is None:
|
||||
ret.positions = positions
|
||||
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
||||
@@ -882,6 +878,25 @@ class PPProxyTensors:
|
||||
return f"PPProxyTensors(tensors={self.tensors})"
|
||||
|
||||
|
||||
def compute_position(
|
||||
attn_backend: str,
|
||||
extend_prefix_lens: torch.Tensor,
|
||||
extend_seq_lens: torch.Tensor,
|
||||
extend_seq_lens_sum: int,
|
||||
):
|
||||
if support_triton(attn_backend):
|
||||
positions, extend_start_loc = compute_position_triton(
|
||||
extend_prefix_lens,
|
||||
extend_seq_lens,
|
||||
extend_seq_lens_sum,
|
||||
)
|
||||
else:
|
||||
positions, extend_start_loc = compute_position_torch(
|
||||
extend_prefix_lens, extend_seq_lens
|
||||
)
|
||||
return positions, extend_start_loc
|
||||
|
||||
|
||||
def compute_position_triton(
|
||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user