[bugfix]fix multistream moe in torchair (#3164)
### What this PR does / why we need it? the multistream moe in tochari only validate in decode, but can't be applied to chunked prefill, So add some judgments to isolate the scenario ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
@@ -176,7 +176,7 @@ def test_torchair_deepseek_v2_merged_replicated_linear(mock_distributed):
|
|||||||
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
|
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||||
TorchairDeepseekV2RowParallelLinear
|
TorchairDeepseekV2RowParallelLinear
|
||||||
])
|
])
|
||||||
def test_row_parallel_linear(cls, mock_distributed):
|
def test_row_parallel_linear(cls, mock_distributed, mock_forward_context):
|
||||||
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
|
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
|
||||||
linear.quant_method = Mock()
|
linear.quant_method = Mock()
|
||||||
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
|
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
|
||||||
@@ -282,7 +282,7 @@ def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
|
|||||||
mock_maybe_wait_prefetch_done,
|
mock_maybe_wait_prefetch_done,
|
||||||
mock_rms_norm, mock_add_norm,
|
mock_rms_norm, mock_add_norm,
|
||||||
mock_distributed, base_config,
|
mock_distributed, base_config,
|
||||||
vllm_config):
|
vllm_config, mock_forward_context):
|
||||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||||
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
|
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
|
||||||
torch.randn(2, 128))
|
torch.randn(2, 128))
|
||||||
|
|||||||
@@ -161,12 +161,13 @@ class TorchairDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
|
|||||||
output_parallel = self.quant_method.apply(self,
|
output_parallel = self.quant_method.apply(self,
|
||||||
input_parallel,
|
input_parallel,
|
||||||
bias=bias_)
|
bias=bias_)
|
||||||
|
forward_context = get_forward_context()
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and self.tp_size > 1:
|
||||||
num_tokens = output_parallel.shape[0]
|
num_tokens = output_parallel.shape[0]
|
||||||
if is_force_scatter and num_tokens % self.tp_size:
|
if is_force_scatter and num_tokens % self.tp_size:
|
||||||
output_parallel = nn.functional.pad(
|
output_parallel = nn.functional.pad(
|
||||||
output_parallel, (0, 0, 0, -num_tokens % self.tp_size))
|
output_parallel, (0, 0, 0, -num_tokens % self.tp_size))
|
||||||
if is_force_scatter or (not is_prefill
|
if is_force_scatter or (not forward_context.with_prefill
|
||||||
and output_parallel.shape[0] % self.tp_size
|
and output_parallel.shape[0] % self.tp_size
|
||||||
== 0):
|
== 0):
|
||||||
output = tensor_model_parallel_reduce_scatter(output_parallel,
|
output = tensor_model_parallel_reduce_scatter(output_parallel,
|
||||||
@@ -945,15 +946,15 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
forward_context = get_forward_context()
|
||||||
if attn_metadata is not None:
|
if attn_metadata is not None:
|
||||||
decoding_condition_met = (
|
decoding_condition_met = (
|
||||||
not attn_metadata.is_prefill if self.use_sfa else
|
not attn_metadata.is_prefill if self.use_sfa else
|
||||||
attn_metadata.num_decodes > 0 if self.use_mla else False)
|
not forward_context.with_prefill if self.use_mla else False)
|
||||||
mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce
|
mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce
|
||||||
else:
|
else:
|
||||||
mla_moe_communication = False
|
mla_moe_communication = False
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
|
||||||
if (envs.VLLM_ASCEND_ENABLE_MLAPO
|
if (envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||||
and isinstance(self.self_attn, TorchairDeepseekV2SFAAttention)
|
and isinstance(self.self_attn, TorchairDeepseekV2SFAAttention)
|
||||||
and attn_metadata is not None
|
and attn_metadata is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user