diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py new file mode 100644 index 0000000..16825f0 --- /dev/null +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import os + +import pytest +from vllm import SamplingParams + +from tests.e2e.conftest import VllmRunner +from vllm_ascend.ascend_config import clear_ascend_config + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +@pytest.fixture +def sampling_config(): + return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False) + + +@pytest.fixture +def model_name(): + return "wemaster/deepseek_mtp_main_random_bf16" + + +def test_mtp_torchair_correctness( + sampling_config: SamplingParams, + model_name: str, +): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using mtp speculative decoding. + ''' + clear_ascend_config() + with VllmRunner(model_name, + tensor_parallel_size=1, + gpu_memory_utilization=0.7, + max_model_len=256, + enforce_eager=False, + additional_config={ + "torchair_graph_config": { + "enabled": True, + "use_cached_graph": False, + "graph_batch_sizes": [1, 2, 4], + }, + }) as ref_llm: + ref_outputs = ref_llm.generate(example_prompts, sampling_config) + clear_ascend_config() + with VllmRunner(model_name, + tensor_parallel_size=1, + max_num_seqs=256, + gpu_memory_utilization=0.7, + distributed_executor_backend="mp", + enable_expert_parallel=True, + speculative_config={ + "method": "deepseek_mtp", + "num_speculative_tokens": 1, + }, + enforce_eager=False, + max_model_len=2000, + additional_config={ + "torchair_graph_config": { + "enabled": True, + "use_cached_graph": False, + "graph_batch_sizes": [1, 2, 4], + } + }) as spec_llm: + spec_outputs = spec_llm.generate(example_prompts, sampling_config) + + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + ref_token_ids = ref_output[0][0] + spec_token_ids = spec_output[0][0] + if ref_token_ids == spec_token_ids[:len(ref_token_ids)]: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output[1][0]}") + print(f"spec_output: {spec_output[1][0]}") + + # Heuristic: expect at least 66% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(ref_outputs)) + del spec_llm + clear_ascend_config() diff --git a/tests/ut/torchair/ops/test_torchair_fused_moe.py b/tests/ut/torchair/ops/test_torchair_fused_moe.py index 7a6fc10..19df5dc 100644 --- a/tests/ut/torchair/ops/test_torchair_fused_moe.py +++ b/tests/ut/torchair/ops/test_torchair_fused_moe.py @@ -23,6 +23,8 @@ from pytest_mock import MockerFixture from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase from vllm_ascend.ascend_forward_context import _get_fused_moe_state +from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod +from vllm_ascend.quantization.quantizer import W8A8Quantizer from vllm_ascend.torchair.ops.torchair_fused_moe import ( TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod) from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402 @@ -233,12 +235,28 @@ class TestTorchairAscendFusedMoe: mock_quant_config = MagicMock() mock_quant_method = MockFusedMoEMethod() mock_quant_config.get_quant_method.return_value = mock_quant_method + mock_quant_config.is_layer_skipped_ascend.return_value = False + with patch( + 'vllm_ascend.quantization.quantizer.AscendQuantizer.get_quantizer', + return_value=W8A8Quantizer): + moe = TorchairAscendFusedMoE(**default_moe_config, + quant_config=mock_quant_config) + + assert moe.quant_method is not None + assert isinstance(moe.quant_method, AscendFusedMoEMethod) + + def test_init_with_mixed_quant(self, mock_dist_env, default_moe_config): + mock_quant_config = MagicMock() + mock_quant_method = MockFusedMoEMethod() + mock_quant_config.get_quant_method.return_value = mock_quant_method + mock_quant_config.is_layer_skipped_ascend.return_value = True moe = TorchairAscendFusedMoE(**default_moe_config, quant_config=mock_quant_config) assert moe.quant_method is not None - assert moe.quant_method == mock_quant_method + assert isinstance(moe.quant_method, + TorchairAscendUnquantizedFusedMoEMethod) @pytest.mark.parametrize( "others_param", diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 42e8659..bd2be21 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -45,6 +45,7 @@ from vllm_ascend.distributed.communication_op import \ from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.sequence_parallel import MetadataForPadding +from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_all_reduce_merge_state, @@ -1055,7 +1056,13 @@ class TorchairAscendFusedMoE(FusedMoE): self.quant_method = TorchairAscendUnquantizedFusedMoEMethod( self.moe) else: - self.quant_method = quant_config.get_quant_method(self, prefix) + if quant_config.is_layer_skipped_ascend( + prefix, quant_config.packed_modules_mapping): + self.quant_method = TorchairAscendUnquantizedFusedMoEMethod( + self.moe) + else: + self.quant_method = AscendFusedMoEMethod( + quant_config, prefix, quant_config.packed_modules_mapping) assert self.quant_method is not None diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 848da93..e5d555b 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -18,6 +18,8 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP +from vllm_ascend.torchair.models.torchair_deepseek_mtp import \ + TorchairDeepSeekMTP from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable @@ -266,8 +268,12 @@ class MtpProposer: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - self.model = CustomDeepSeekMTP( - vllm_config=self.vllm_config).to(target_device) + if self.torchair_graph_enabled: + self.model = TorchairDeepSeekMTP( + vllm_config=self.vllm_config).to(target_device) + else: + self.model = CustomDeepSeekMTP( + vllm_config=self.vllm_config).to(target_device) draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() -