diff --git a/examples/offline_data_parallel.py b/examples/offline_data_parallel.py index e497a13..c5d0b3e 100644 --- a/examples/offline_data_parallel.py +++ b/examples/offline_data_parallel.py @@ -54,17 +54,16 @@ Multi-node: --master-port=13345 """ -import os -from time import sleep import contextlib import gc +import os +from time import sleep import torch - from vllm import LLM, SamplingParams -from vllm.utils import get_open_port from vllm.distributed.parallel_state import ( # noqa E402 destroy_distributed_environment, destroy_model_parallel) +from vllm.utils import get_open_port os.environ["VLLM_USE_MODELSCOPE"] = "True" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 34a5cca..ec00c0d 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -236,3 +236,71 @@ class TestAscendConfig(TestBase): for model_type, expected_output in test_cases: self.assertEqual(_check_torchair_supported(model_type), expected_output) + + @_clean_up_ascend_config + def test_ascend_config_load_error(self): + test_vllm_config = VllmConfig() + # graph_batch_sizes should be list. + with self.assertRaises(TypeError): + test_vllm_config.additional_config = { + "torchair_graph_config": { + "graph_batch_sizes": "fake_size", + }, + "refresh": True + } + init_ascend_config(test_vllm_config) + + # use_cached_graph should not be enabled without torchair graph mode + with self.assertRaises(RuntimeError): + test_vllm_config.additional_config = { + "torchair_graph_config": { + "enabled": False, + "use_cached_graph": True, + }, + "refresh": True + } + init_ascend_config(test_vllm_config) + + # graph_batch_sizes_init should not be enabled without torchair graph mode + with self.assertRaises(RuntimeError): + test_vllm_config.additional_config = { + "torchair_graph_config": { + "enabled": False, + "graph_batch_sizes_init": True, + }, + "refresh": True + } + init_ascend_config(test_vllm_config) + + # enable_multistream_mla should not be enabled without torchair graph mode + with self.assertRaises(RuntimeError): + test_vllm_config.additional_config = { + "torchair_graph_config": { + "enabled": False, + "enable_multistream_mla": True, + }, + "refresh": True + } + init_ascend_config(test_vllm_config) + + # enable_multistream_moe should not be enabled without torchair graph mode + with self.assertRaises(RuntimeError): + test_vllm_config.additional_config = { + "torchair_graph_config": { + "enabled": False, + "enable_multistream_moe": True, + }, + "refresh": True + } + init_ascend_config(test_vllm_config) + + # enable_kv_nz should not be enabled without torchair graph mode + with self.assertRaises(RuntimeError): + test_vllm_config.additional_config = { + "torchair_graph_config": { + "enabled": False, + "enable_kv_nz": True, + }, + "refresh": True + } + init_ascend_config(test_vllm_config) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 4bc6e88..b8fd24e 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -76,6 +76,31 @@ class TorchairGraphConfig: raise ValueError( "graph_batch_sizes_init is only valid when graph_batch_sizes is empty" ) + if not self.enabled: + if self.use_cached_graph: + raise RuntimeError( + "use_cached_graph is valid only when Torchair graph mode is enabled" + ) + if self.graph_batch_sizes: + raise RuntimeError( + "graph_batch_sizes is valid only when Torchair graph mode is enabled" + ) + if self.graph_batch_sizes_init: + raise RuntimeError( + "graph_batch_sizes_init is valid only when Torchair graph mode is enabled" + ) + if self.enable_multistream_mla: + raise RuntimeError( + "enable_multistream_mla is valid only when Torchair graph mode is enabled" + ) + if self.enable_multistream_moe: + raise RuntimeError( + "enable_multistream_moe is valid only when Torchair graph mode is enabled" + ) + if self.enable_kv_nz: + raise RuntimeError( + "enable_kv_nz is valid only when Torchair graph mode is enabled" + ) class AscendSchedulerConfig: diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 8886972..e1c2b1c 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -313,7 +313,8 @@ class CustomDeepseekV2MoE(nn.Module): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_multistream_moe = \ - ascend_config.torchair_graph_config.enable_multistream_moe + ascend_config.torchair_graph_config.enable_multistream_moe and \ + self.torchair_graph_enabled self.gate = ReplicatedLinear(config.hidden_size, config.n_routed_experts, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 712d3a1..80d2140 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1232,7 +1232,8 @@ class AscendFusedMoE(FusedMoE): self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_multistream_moe = \ - ascend_config.torchair_graph_config.enable_multistream_moe + ascend_config.torchair_graph_config.enable_multistream_moe and \ + self.torchair_graph_enabled if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for "