[DP] fix the compatibility issue between DP attention and --attention-backend triton (#8723)
This commit is contained in:
@@ -646,12 +646,17 @@ class ForwardBatch:
|
|||||||
device=model_runner.device,
|
device=model_runner.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
bs = self.batch_size
|
|
||||||
if len(global_num_tokens) > 1:
|
if len(global_num_tokens) > 1:
|
||||||
num_tokens = global_num_tokens[get_attention_dp_rank()]
|
num_tokens = global_num_tokens[get_attention_dp_rank()]
|
||||||
else:
|
else:
|
||||||
num_tokens = global_num_tokens[0]
|
num_tokens = global_num_tokens[0]
|
||||||
|
|
||||||
|
if self.forward_mode.is_decode():
|
||||||
|
setattr(self, "raw_bs", self.batch_size)
|
||||||
|
self.batch_size = num_tokens
|
||||||
|
|
||||||
|
bs = self.batch_size
|
||||||
|
|
||||||
# padding
|
# padding
|
||||||
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
|
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
|
||||||
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
|
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
|
||||||
@@ -659,6 +664,9 @@ class ForwardBatch:
|
|||||||
seq_len_fill_value = (
|
seq_len_fill_value = (
|
||||||
model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
)
|
)
|
||||||
|
self.seq_lens_sum = self.seq_lens_sum + seq_len_fill_value * (
|
||||||
|
bs - self.seq_lens.shape[0]
|
||||||
|
)
|
||||||
self.seq_lens = self._pad_tensor_to_size(
|
self.seq_lens = self._pad_tensor_to_size(
|
||||||
self.seq_lens, bs, value=seq_len_fill_value
|
self.seq_lens, bs, value=seq_len_fill_value
|
||||||
)
|
)
|
||||||
@@ -702,7 +710,7 @@ class ForwardBatch:
|
|||||||
|
|
||||||
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
|
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
|
||||||
|
|
||||||
bs = self.batch_size
|
bs = getattr(self, "raw_bs", self.batch_size)
|
||||||
|
|
||||||
if self.spec_info is not None:
|
if self.spec_info is not None:
|
||||||
if self.forward_mode.is_decode(): # draft
|
if self.forward_mode.is_decode(): # draft
|
||||||
|
|||||||
Reference in New Issue
Block a user