Enable overlap scheduler by default for the triton attention backend (#2105)
This commit is contained in:
@@ -170,7 +170,6 @@ class CudaGraphRunner:
|
||||
self.encoder_lens = None
|
||||
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens = [0] * self.tp_size
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.tp_size,
|
||||
@@ -264,10 +263,10 @@ class CudaGraphRunner:
|
||||
mrope_positions = self.mrope_positions[:, :bs]
|
||||
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens[:] = [bs] * self.tp_size
|
||||
global_num_tokens = [bs] * self.tp_size
|
||||
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
|
||||
else:
|
||||
self.global_num_tokens = None
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
|
||||
# Attention backend
|
||||
@@ -296,7 +295,7 @@ class CudaGraphRunner:
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=clamp_position(seq_lens),
|
||||
mrope_positions=mrope_positions,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens=global_num_tokens,
|
||||
gathered_buffer=gathered_buffer,
|
||||
)
|
||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||
@@ -348,8 +347,6 @@ class CudaGraphRunner:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
if forward_batch.mrope_positions is not None:
|
||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens[:] = [bs] * self.tp_size
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
|
||||
Reference in New Issue
Block a user