Handle with_prefill_across_dp for multistream mla (#1322)

### What this PR does / why we need it?
After #1094, decode might be executed with non-compiled mode, despite of
`torchair_graph_config.enabled`, causing multistream mla to fail, which
assumes torchair compiled mode for decode when
`torchair_graph_config.enabled == True`.
Augment that assumption to fix this.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Tested both offline, and by graph mode mla e2e testcase.

---------

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
sdmyzlp
2025-06-26 09:32:07 +08:00
committed by GitHub
parent 2690697caa
commit 53c2d58ae1
3 changed files with 82 additions and 60 deletions

View File

@@ -20,6 +20,7 @@
Run `pytest tests/multicard/test_torchair_graph_mode.py`. Run `pytest tests/multicard/test_torchair_graph_mode.py`.
""" """
import os import os
from typing import Dict
import pytest import pytest
@@ -28,41 +29,38 @@ from tests.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", def _deepseek_torchair_test_fixture(
reason="torchair graph is not supported on v0") additional_config: Dict,
def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch): *,
with monkeypatch.context() as m: tensor_parallel_size=4,
m.setenv("VLLM_USE_MODELSCOPE", "True") ):
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
example_prompts = [ example_prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
] ]
dtype = "half"
max_tokens = 5
# torchair is only work without chunked-prefill now # torchair is only work without chunked-prefill now
with VllmRunner( kwargs = {
"vllm-ascend/DeepSeek-V3-Pruning",
dtype=dtype,
tensor_parallel_size=4,
distributed_executor_backend="mp",
additional_config={
"torchair_graph_config": {
"enabled": True,
},
"ascend_scheduler_config": { "ascend_scheduler_config": {
"enabled": True, "enabled": True,
}, },
"refresh": True, "refresh": True,
}, }
additional_config.update(**kwargs)
with VllmRunner(
"vllm-ascend/DeepSeek-V3-Pruning",
dtype="half",
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="mp",
enforce_eager=False, enforce_eager=False,
additional_config=additional_config,
) as vllm_model: ) as vllm_model:
# use greedy sampler to make sure the generated results are fix # use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts, vllm_output = vllm_model.generate_greedy(example_prompts, 5)
max_tokens)
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of # NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems # DeepSeek-V3 with 2 hidden layers, thus the golden results seems
# inaccurate. This will only change if accuracy improves with the # inaccurate. This will only change if accuracy improves with the
@@ -78,3 +76,26 @@ def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch):
for i in range(len(vllm_output)): for i in range(len(vllm_output)):
assert golden_results[i] == vllm_output[i][1] assert golden_results[i] == vllm_output[i][1]
print(f"Generated text: {vllm_output[i][1]!r}") print(f"Generated text: {vllm_output[i][1]!r}")
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="torchair graph is not supported on v0")
def test_e2e_deepseekv3_with_torchair():
additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
_deepseek_torchair_test_fixture(additional_config)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="torchair graph is not supported on v0")
def test_e2e_deepseekv3_with_torchair_ms_mla():
additional_config = {
"torchair_graph_config": {
"enabled": True,
"enable_multistream_mla": True,
},
}
_deepseek_torchair_test_fixture(additional_config)

View File

@@ -563,8 +563,6 @@ class AscendMLAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
# Adapt torch air graph mode with spec decoding. # Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config speculative_config = get_current_vllm_config().speculative_config
@@ -863,6 +861,7 @@ class AscendMLAImpl(MLAAttentionImpl):
sin: torch.Tensor, sin: torch.Tensor,
kv_cache: Tuple, kv_cache: Tuple,
slots: torch.Tensor, slots: torch.Tensor,
enable_multistream_mla: bool = False,
): ):
B = hidden_states.shape[0] B = hidden_states.shape[0]
@@ -874,7 +873,7 @@ class AscendMLAImpl(MLAAttentionImpl):
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
with npu_stream_switch("mla_secondary", with npu_stream_switch("mla_secondary",
0, 0,
enabled=self.enable_multistream_mla): enabled=enable_multistream_mla):
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv, kv,
self.kv_a_layernorm.weight, self.kv_a_layernorm.weight,
@@ -1034,6 +1033,7 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: M, attn_metadata: M,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
enable_multistream_mla: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if attn_metadata is None: if attn_metadata is None:
@@ -1093,22 +1093,22 @@ class AscendMLAImpl(MLAAttentionImpl):
# KvRmsNormRopeCache and SingleRope. # KvRmsNormRopeCache and SingleRope.
npu_wait_tensor(decode_hs_or_q_c, npu_wait_tensor(decode_hs_or_q_c,
cos, cos,
enabled=self.enable_multistream_mla) enabled=enable_multistream_mla)
npu_wait_tensor(decode_hs_or_q_c, npu_wait_tensor(decode_hs_or_q_c,
sin, sin,
enabled=self.enable_multistream_mla) enabled=enable_multistream_mla)
decode_ql_nope, decode_q_pe = \ decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c) self._q_proj_and_k_up_proj(decode_hs_or_q_c)
if self.running_in_graph: if self.running_in_graph:
decode_k_pe, decode_k_nope = self.exec_kv( decode_k_pe, decode_k_nope = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache, hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping) attn_metadata.slot_mapping, enable_multistream_mla)
with npu_stream_switch("mla_secondary", with npu_stream_switch("mla_secondary",
0, 0,
enabled=self.enable_multistream_mla): enabled=enable_multistream_mla):
npu_wait_tensor(decode_q_pe, npu_wait_tensor(decode_q_pe,
decode_k_pe, decode_k_pe,
enabled=self.enable_multistream_mla) enabled=enable_multistream_mla)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin) decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
else: else:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(

View File

@@ -555,20 +555,21 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None, kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
enable_multistream_mla = (self.enable_multistream_mla
and attn_metadata is not None
and not attn_metadata.with_prefill_across_dp
and attn_metadata.num_decodes > 0)
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0] ckq = self.q_a_proj(hidden_states)[0]
use_multistream_mla = (self.enable_multistream_mla npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla)
and attn_metadata is not None
and attn_metadata.num_decodes > 0)
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
with npu_stream_switch("mla_secondary", with npu_stream_switch("mla_secondary",
0, 0,
enabled=use_multistream_mla): enabled=enable_multistream_mla):
hidden_states_or_q_c = self.q_a_layernorm(ckq) hidden_states_or_q_c = self.q_a_layernorm(ckq)
else: else:
hidden_states_or_q_c = hidden_states hidden_states_or_q_c = hidden_states
if self.torchair_graph_enabled: if self.torchair_graph_enabled:
forward_kwargs = {}
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
output_shape = hidden_states.shape output_shape = hidden_states.shape
output = torch.empty(output_shape, output = torch.empty(output_shape,