[fix] PD disaggregation when enable mtp and tp!=dp (#7420)
This commit is contained in:
@@ -766,7 +766,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
if not delay_process:
|
||||
self.prepare_mlp_sync_batch(batch, result)
|
||||
self.process_batch_result(batch, result)
|
||||
return batch, result
|
||||
|
||||
def get_next_disagg_decode_batch_to_run(
|
||||
|
||||
@@ -310,4 +310,4 @@ def attn_tp_reduce_scatter(
|
||||
|
||||
|
||||
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
||||
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
|
||||
return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
|
||||
|
||||
@@ -1435,7 +1435,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||
layer_id=layer_id,
|
||||
num_layers=config.num_hidden_layers,
|
||||
num_layers=1 if is_nextn else config.num_hidden_layers,
|
||||
is_layer_sparse=self.is_layer_sparse,
|
||||
is_previous_layer_sparse=is_previous_layer_sparse,
|
||||
)
|
||||
@@ -1488,6 +1488,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user