Enable overlap scheduler by default for the triton attention backend (#2105)

This commit is contained in:
Lianmin Zheng
2024-11-20 02:58:35 -08:00
committed by GitHub
parent 56a347f7d3
commit 722530fa01
6 changed files with 21 additions and 24 deletions

View File

@@ -170,7 +170,6 @@ class CudaGraphRunner:
self.encoder_lens = None
if self.enable_dp_attention:
self.global_num_tokens = [0] * self.tp_size
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.tp_size,
@@ -264,10 +263,10 @@ class CudaGraphRunner:
mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
global_num_tokens = [bs] * self.tp_size
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
else:
self.global_num_tokens = None
global_num_tokens = None
gathered_buffer = None
# Attention backend
@@ -296,7 +295,7 @@ class CudaGraphRunner:
top_logprobs_nums=[0] * bs,
positions=clamp_position(seq_lens),
mrope_positions=mrope_positions,
global_num_tokens=self.global_num_tokens,
global_num_tokens=global_num_tokens,
gathered_buffer=gathered_buffer,
)
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
@@ -348,8 +347,6 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(