diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index cdb0908..c67f340 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -43,6 +43,7 @@ The details of each config option are as follows: | Name | Type | Default | Description | | ---- | ---- | ------- | ----------- | | `enabled` | bool | `False` | Whether to enable torchair graph mode. Currently only DeepSeek series models and PanguProMoE are supported to use torchair graph mode | +| `mode` | str | `None` | When using reduce-overhead mode for torchair, mode needs to be set | | `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream. This option only takes effects on models using MLA (e.g., DeepSeek). | | `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on DeepSeek moe models. | | `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization | diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 622b751..ad252b4 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -46,6 +46,7 @@ class TestAscendConfig(TestBase): torchair_graph_config = ascend_config.torchair_graph_config self.assertFalse(torchair_graph_config.enabled) + self.assertEqual(torchair_graph_config.mode, '') self.assertFalse(torchair_graph_config.use_cached_graph) self.assertEqual(torchair_graph_config.graph_batch_sizes, []) self.assertFalse(torchair_graph_config.graph_batch_sizes_init) @@ -294,6 +295,17 @@ class TestAscendConfig(TestBase): } init_ascend_config(test_vllm_config) + # mode should not be configured without torchair graph mode + with self.assertRaises(RuntimeError): + test_vllm_config.additional_config = { + "torchair_graph_config": { + "enabled": False, + "mode": 'max-autotune', + }, + "refresh": True + } + init_ascend_config(test_vllm_config) + # enable_kv_nz should not be enabled without torchair graph mode with self.assertRaises(RuntimeError): test_vllm_config.additional_config = { diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 2a2ac7b..597ff5b 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -70,6 +70,7 @@ class TorchairGraphConfig: def __init__(self, torchair_graph_config): self.enabled = torchair_graph_config.get("enabled", False) + self.mode = torchair_graph_config.get("mode", '') self.use_cached_graph = torchair_graph_config.get( "use_cached_graph", False) self.graph_batch_sizes = torchair_graph_config.get( @@ -91,6 +92,9 @@ class TorchairGraphConfig: "graph_batch_sizes_init is only valid when graph_batch_sizes is empty" ) if not self.enabled: + if self.mode: + raise RuntimeError( + "mode is valid only when Torchair graph mode is enabled") if self.use_cached_graph: raise RuntimeError( "use_cached_graph is valid only when Torchair graph mode is enabled" diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 2a0dc15..7d2f605 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -324,6 +324,8 @@ class NPUTorchairModelRunner(NPUModelRunner): communication_adaptation_310p() config = torchair.CompilerConfig() + if get_ascend_config().torchair_graph_config.mode: + config.mode = get_ascend_config().torchair_graph_config.mode config.experimental_config.frozen_parameter = True # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to # disable it on 300I Duo platform now.