diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/e2e/multicard/test_torchair_graph_mode.py index d06ec7d..3aca92c 100644 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ b/tests/e2e/multicard/test_torchair_graph_mode.py @@ -20,6 +20,7 @@ Run `pytest tests/multicard/test_torchair_graph_mode.py`. """ import os +from typing import Dict import pytest @@ -28,53 +29,73 @@ from tests.conftest import VllmRunner os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" +def _deepseek_torchair_test_fixture( + additional_config: Dict, + *, + tensor_parallel_size=4, +): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # torchair is only work without chunked-prefill now + kwargs = { + "ascend_scheduler_config": { + "enabled": True, + }, + "refresh": True, + } + additional_config.update(**kwargs) + + with VllmRunner( + "vllm-ascend/DeepSeek-V3-Pruning", + dtype="half", + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend="mp", + enforce_eager=False, + additional_config=additional_config, + ) as vllm_model: + # use greedy sampler to make sure the generated results are fix + vllm_output = vllm_model.generate_greedy(example_prompts, 5) + + # NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of + # DeepSeek-V3 with 2 hidden layers, thus the golden results seems + # inaccurate. This will only change if accuracy improves with the + # official weights of DeepSeek-V3. + golden_results = [ + 'Hello, my name is feasibility伸 spazio debtor添', + 'The president of the United States is begg"""\n杭州风和 bestimm', + 'The capital of France is frequentlyশามalinkAllowed', + 'The future of AI is deleting俯احت怎么样了حراف', + ] + + assert len(golden_results) == len(vllm_output) + for i in range(len(vllm_output)): + assert golden_results[i] == vllm_output[i][1] + print(f"Generated text: {vllm_output[i][1]!r}") + + @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", reason="torchair graph is not supported on v0") -def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_MODELSCOPE", "True") - m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") +def test_e2e_deepseekv3_with_torchair(): + additional_config = { + "torchair_graph_config": { + "enabled": True, + }, + } + _deepseek_torchair_test_fixture(additional_config) - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - dtype = "half" - max_tokens = 5 - # torchair is only work without chunked-prefill now - with VllmRunner( - "vllm-ascend/DeepSeek-V3-Pruning", - dtype=dtype, - tensor_parallel_size=4, - distributed_executor_backend="mp", - additional_config={ - "torchair_graph_config": { - "enabled": True, - }, - "ascend_scheduler_config": { - "enabled": True, - }, - "refresh": True, - }, - enforce_eager=False, - ) as vllm_model: - # use greedy sampler to make sure the generated results are fix - vllm_output = vllm_model.generate_greedy(example_prompts, - max_tokens) - # NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of - # DeepSeek-V3 with 2 hidden layers, thus the golden results seems - # inaccurate. This will only change if accuracy improves with the - # official weights of DeepSeek-V3. - golden_results = [ - 'Hello, my name is feasibility伸 spazio debtor添', - 'The president of the United States is begg"""\n杭州风和 bestimm', - 'The capital of France is frequentlyশามalinkAllowed', - 'The future of AI is deleting俯احت怎么样了حراف', - ] - assert len(golden_results) == len(vllm_output) - for i in range(len(vllm_output)): - assert golden_results[i] == vllm_output[i][1] - print(f"Generated text: {vllm_output[i][1]!r}") +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", + reason="torchair graph is not supported on v0") +def test_e2e_deepseekv3_with_torchair_ms_mla(): + additional_config = { + "torchair_graph_config": { + "enabled": True, + "enable_multistream_mla": True, + }, + } + _deepseek_torchair_test_fixture(additional_config) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index dbcf6ef..d9471b2 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -563,8 +563,6 @@ class AscendMLAImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz - self.enable_multistream_mla = \ - ascend_config.torchair_graph_config.enable_multistream_mla # Adapt torch air graph mode with spec decoding. speculative_config = get_current_vllm_config().speculative_config @@ -863,6 +861,7 @@ class AscendMLAImpl(MLAAttentionImpl): sin: torch.Tensor, kv_cache: Tuple, slots: torch.Tensor, + enable_multistream_mla: bool = False, ): B = hidden_states.shape[0] @@ -874,7 +873,7 @@ class AscendMLAImpl(MLAAttentionImpl): cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" with npu_stream_switch("mla_secondary", 0, - enabled=self.enable_multistream_mla): + enabled=enable_multistream_mla): k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( kv, self.kv_a_layernorm.weight, @@ -1034,6 +1033,7 @@ class AscendMLAImpl(MLAAttentionImpl): kv_cache: torch.Tensor, attn_metadata: M, output: Optional[torch.Tensor] = None, + enable_multistream_mla: bool = False, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." if attn_metadata is None: @@ -1093,22 +1093,22 @@ class AscendMLAImpl(MLAAttentionImpl): # KvRmsNormRopeCache and SingleRope. npu_wait_tensor(decode_hs_or_q_c, cos, - enabled=self.enable_multistream_mla) + enabled=enable_multistream_mla) npu_wait_tensor(decode_hs_or_q_c, sin, - enabled=self.enable_multistream_mla) + enabled=enable_multistream_mla) decode_ql_nope, decode_q_pe = \ self._q_proj_and_k_up_proj(decode_hs_or_q_c) if self.running_in_graph: decode_k_pe, decode_k_nope = self.exec_kv( hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping) + attn_metadata.slot_mapping, enable_multistream_mla) with npu_stream_switch("mla_secondary", 0, - enabled=self.enable_multistream_mla): + enabled=enable_multistream_mla): npu_wait_tensor(decode_q_pe, decode_k_pe, - enabled=self.enable_multistream_mla) + enabled=enable_multistream_mla) decode_q_pe = self.rope_single(decode_q_pe, cos, sin) else: decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 908c60f..807c0a2 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -555,20 +555,21 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + enable_multistream_mla = (self.enable_multistream_mla + and attn_metadata is not None + and not attn_metadata.with_prefill_across_dp + and attn_metadata.num_decodes > 0) + forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} if self.q_lora_rank is not None: ckq = self.q_a_proj(hidden_states)[0] - use_multistream_mla = (self.enable_multistream_mla - and attn_metadata is not None - and attn_metadata.num_decodes > 0) - npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla) + npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla) with npu_stream_switch("mla_secondary", 0, - enabled=use_multistream_mla): + enabled=enable_multistream_mla): hidden_states_or_q_c = self.q_a_layernorm(ckq) else: hidden_states_or_q_c = hidden_states if self.torchair_graph_enabled: - forward_kwargs = {} if envs.VLLM_USE_V1: output_shape = hidden_states.shape output = torch.empty(output_shape,