[fix] PD disaggregation when enable mtp and tp!=dp (#7420)
This commit is contained in:
@@ -766,7 +766,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
if batch:
|
if batch:
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
if not delay_process:
|
if not delay_process:
|
||||||
self.prepare_mlp_sync_batch(batch, result)
|
self.process_batch_result(batch, result)
|
||||||
return batch, result
|
return batch, result
|
||||||
|
|
||||||
def get_next_disagg_decode_batch_to_run(
|
def get_next_disagg_decode_batch_to_run(
|
||||||
|
|||||||
@@ -310,4 +310,4 @@ def attn_tp_reduce_scatter(
|
|||||||
|
|
||||||
|
|
||||||
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
||||||
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
|
return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
|
||||||
|
|||||||
@@ -1435,7 +1435,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.layer_scatter_modes = LayerScatterModes.init_new(
|
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
num_layers=config.num_hidden_layers,
|
num_layers=1 if is_nextn else config.num_hidden_layers,
|
||||||
is_layer_sparse=self.is_layer_sparse,
|
is_layer_sparse=self.is_layer_sparse,
|
||||||
is_previous_layer_sparse=is_previous_layer_sparse,
|
is_previous_layer_sparse=is_previous_layer_sparse,
|
||||||
)
|
)
|
||||||
@@ -1488,6 +1488,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
zero_allocator: BumpAllocator,
|
zero_allocator: BumpAllocator,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
hidden_states, residual = self.layer_communicator.prepare_attn(
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user