From b69b04d3a91cdc312031c466556d42001df841f3 Mon Sep 17 00:00:00 2001 From: Ronald Date: Thu, 18 Dec 2025 15:51:54 +0800 Subject: [PATCH] implement model runner v2 basic framework (#5051) ### What this PR does / why we need it? This PR aim to implement model runner v2 basic framework in vllm-ascend, the e2e function is not guaranteed by this pr. ### Does this PR introduce _any_ user-facing change? use envs.VLLM_USE_V2_MODEL_RUNNER to decide if choose model_runenr_v2. ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: Ronald1995 --- .../developer_guide/contribution/testing.md | 2 +- .../developer_guide/feature_guide/patch.md | 2 +- .../developer_guide/contribution/testing.po | 8 +- .../developer_guide/feature_guide/patch.po | 4 +- tests/ut/test_platform.py | 2 +- tests/ut/worker/test_worker_v1.py | 170 ++++----- vllm_ascend/platform.py | 2 +- vllm_ascend/worker/v2/__init__.py | 0 vllm_ascend/worker/v2/aclgraph_utils.py | 71 ++++ vllm_ascend/worker/v2/attn_utils.py | 160 ++++++++ vllm_ascend/worker/v2/input_batch.py | 37 ++ vllm_ascend/worker/v2/model_runner.py | 346 ++++++++++++++++++ vllm_ascend/worker/v2/states.py | 88 +++++ vllm_ascend/worker/v2/utils.py | 33 ++ .../worker/{worker_v1.py => worker.py} | 12 +- vllm_ascend/xlite/xlite_worker.py | 4 +- 16 files changed, 843 insertions(+), 98 deletions(-) create mode 100644 vllm_ascend/worker/v2/__init__.py create mode 100644 vllm_ascend/worker/v2/aclgraph_utils.py create mode 100644 vllm_ascend/worker/v2/attn_utils.py create mode 100644 vllm_ascend/worker/v2/input_batch.py create mode 100644 vllm_ascend/worker/v2/model_runner.py create mode 100644 vllm_ascend/worker/v2/states.py create mode 100644 vllm_ascend/worker/v2/utils.py rename vllm_ascend/worker/{worker_v1.py => worker.py} (97%) diff --git a/docs/source/developer_guide/contribution/testing.md b/docs/source/developer_guide/contribution/testing.md index b4dea166..0c30c67c 100644 --- a/docs/source/developer_guide/contribution/testing.md +++ b/docs/source/developer_guide/contribution/testing.md @@ -142,7 +142,7 @@ pip install -r requirements-dev.txt There are several principles to follow when writing unit tests: -- The test file path should be consistent with the source file and start with the `test_` prefix, such as: `vllm_ascend/worker/worker_v1.py` --> `tests/ut/worker/test_worker_v1.py` +- The test file path should be consistent with the source file and start with the `test_` prefix, such as: `vllm_ascend/worker/worker.py` --> `tests/ut/worker/test_worker.py` - The vLLM Ascend test uses unittest framework. See [here](https://docs.python.org/3/library/unittest.html#module-unittest) to understand how to write unit tests. - All unit tests can be run on CPUs, so you must mock the device-related function to host. - Example: [tests/ut/test_ascend_config.py](https://github.com/vllm-project/vllm-ascend/blob/main/tests/ut/test_ascend_config.py). diff --git a/docs/source/developer_guide/feature_guide/patch.md b/docs/source/developer_guide/feature_guide/patch.md index cd360017..56d5f0ee 100644 --- a/docs/source/developer_guide/feature_guide/patch.md +++ b/docs/source/developer_guide/feature_guide/patch.md @@ -29,7 +29,7 @@ vllm_ascend - **platform**: The patch code in this directory is for patching the code in vLLM main process. It's called by `vllm_ascend/platform::NPUPlatform::pre_register_and_update` very early when vLLM is initialized. - For online mode, vLLM process calls the platform patch in `vllm/vllm/engine/arg_utils.py::AsyncEngineArgs.add_cli_args` when parsing the cli args. - For offline mode, vLLM process calls the platform patch in `vllm/vllm/engine/arg_utils.py::EngineArgs.create_engine_config` when parsing the input parameters. -- **worker**: The patch code in this directory is for patching the code in vLLM worker process. It's called by `vllm_ascend/worker/worker_v1::NPUWorker::__init__` when the vLLM worker process is initialized. +- **worker**: The patch code in this directory is for patching the code in vLLM worker process. It's called by `vllm_ascend/worker/worker::NPUWorker::__init__` when the vLLM worker process is initialized. - For both online and offline mode, vLLM engine core process calls the worker patch in `vllm/vllm/worker/worker_base.py::WorkerWrapperBase.init_worker` when initializing the worker process. ## How to write a patch diff --git a/docs/source/locale/zh_CN/LC_MESSAGES/developer_guide/contribution/testing.po b/docs/source/locale/zh_CN/LC_MESSAGES/developer_guide/contribution/testing.po index 7f581029..be76daa7 100644 --- a/docs/source/locale/zh_CN/LC_MESSAGES/developer_guide/contribution/testing.po +++ b/docs/source/locale/zh_CN/LC_MESSAGES/developer_guide/contribution/testing.po @@ -77,11 +77,11 @@ msgstr "编写单元测试时需要遵循几个原则:" #: ../../developer_guide/contribution/testing.md:143 msgid "" "The test file path should be consistent with source file and start with " -"`test_` prefix, such as: `vllm_ascend/worker/worker_v1.py` --> " -"`tests/ut/worker/test_worker_v1.py`" +"`test_` prefix, such as: `vllm_ascend/worker/worker.py` --> " +"`tests/ut/worker/test_worker.py`" msgstr "" -"测试文件的路径应与源文件保持一致,并以 `test_` 前缀开头,例如:`vllm_ascend/worker/worker_v1.py` --> " -"`tests/ut/worker/test_worker_v1.py`" +"测试文件的路径应与源文件保持一致,并以 `test_` 前缀开头,例如:`vllm_ascend/worker/worker.py` --> " +"`tests/ut/worker/test_worker.py`" #: ../../developer_guide/contribution/testing.md:144 msgid "" diff --git a/docs/source/locale/zh_CN/LC_MESSAGES/developer_guide/feature_guide/patch.po b/docs/source/locale/zh_CN/LC_MESSAGES/developer_guide/feature_guide/patch.po index 1e7daa4c..f7feb2ee 100644 --- a/docs/source/locale/zh_CN/LC_MESSAGES/developer_guide/feature_guide/patch.po +++ b/docs/source/locale/zh_CN/LC_MESSAGES/developer_guide/feature_guide/patch.po @@ -107,11 +107,11 @@ msgstr "" msgid "" "**worker**: The patch code in this directory is for patching the code in " "vLLM worker process. It's called by " -"`vllm_ascend/worker/worker_v1::NPUWorker::__init__` when the vLLM worker " +"`vllm_ascend/worker/worker::NPUWorker::__init__` when the vLLM worker " "process is initialized." msgstr "" "**worker**:此目录中的补丁代码用于修补 vLLM worker 进程中的代码。在初始化 vLLM worker 进程时,会被 " -"`vllm_ascend/worker/worker_v1::NPUWorker::__init__` 调用。" +"`vllm_ascend/worker/worker::NPUWorker::__init__` 调用。" #: ../../developer_guide/feature_guide/patch.md:37 msgid "" diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index a5257ce9..1ad608b9 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -445,7 +445,7 @@ class TestNPUPlatform(TestBase): self.assertEqual( vllm_config.parallel_config.worker_cls, - "vllm_ascend.worker.worker_v1.NPUWorker", + "vllm_ascend.worker.worker.NPUWorker", ) test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index 5aa37049..765a3aa0 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -47,13 +47,13 @@ class TestNPUWorker(TestBase): @patch("vllm_ascend.utils.adapt_patch") @patch("vllm_ascend.ops") - @patch("vllm_ascend.worker.worker_v1._register_atb_extensions") - @patch("vllm_ascend.worker.worker_v1.register_ascend_customop") - @patch("vllm_ascend.worker.worker_v1.get_ascend_config") - @patch("vllm_ascend.worker.worker_v1.init_ascend_config") - @patch("vllm_ascend.worker.worker_v1.check_ascend_device_type") + @patch("vllm_ascend.worker.worker._register_atb_extensions") + @patch("vllm_ascend.worker.worker.register_ascend_customop") + @patch("vllm_ascend.worker.worker.get_ascend_config") + @patch("vllm_ascend.worker.worker.init_ascend_config") + @patch("vllm_ascend.worker.worker.check_ascend_device_type") @patch(init_cached_hf_modules_path) - @patch("vllm_ascend.worker.worker_v1.NPUWorker._init_profiler") + @patch("vllm_ascend.worker.worker.NPUWorker._init_profiler") def test_init_npu_worker_normal_case( self, mock_init_profiler, @@ -74,7 +74,7 @@ class TestNPUWorker(TestBase): mock_get_ascend_config.return_value = mock_ascend_config # Import and create NPUWorker instance - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker worker = NPUWorker( vllm_config=self.vllm_config_mock, @@ -101,13 +101,13 @@ class TestNPUWorker(TestBase): @patch("vllm_ascend.utils.adapt_patch") @patch("vllm_ascend.ops") - @patch("vllm_ascend.worker.worker_v1._register_atb_extensions") - @patch("vllm_ascend.worker.worker_v1.register_ascend_customop") - @patch("vllm_ascend.worker.worker_v1.get_ascend_config") - @patch("vllm_ascend.worker.worker_v1.init_ascend_config") - @patch("vllm_ascend.worker.worker_v1.check_ascend_device_type") + @patch("vllm_ascend.worker.worker._register_atb_extensions") + @patch("vllm_ascend.worker.worker.register_ascend_customop") + @patch("vllm_ascend.worker.worker.get_ascend_config") + @patch("vllm_ascend.worker.worker.init_ascend_config") + @patch("vllm_ascend.worker.worker.check_ascend_device_type") @patch(init_cached_hf_modules_path) - @patch("vllm_ascend.worker.worker_v1.NPUWorker._init_profiler") + @patch("vllm_ascend.worker.worker.NPUWorker._init_profiler") def test_init_npu_worker_with_trust_remote_code( self, mock_init_profiler, @@ -129,7 +129,7 @@ class TestNPUWorker(TestBase): mock_get_ascend_config.return_value = mock_ascend_config # Create NPUWorker instance - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker _ = NPUWorker( vllm_config=self.vllm_config_mock, @@ -144,13 +144,13 @@ class TestNPUWorker(TestBase): @patch("vllm_ascend.utils.adapt_patch") @patch("vllm_ascend.ops") - @patch("vllm_ascend.worker.worker_v1._register_atb_extensions") - @patch("vllm_ascend.worker.worker_v1.register_ascend_customop") - @patch("vllm_ascend.worker.worker_v1.get_ascend_config") - @patch("vllm_ascend.worker.worker_v1.init_ascend_config") - @patch("vllm_ascend.worker.worker_v1.check_ascend_device_type") + @patch("vllm_ascend.worker.worker._register_atb_extensions") + @patch("vllm_ascend.worker.worker.register_ascend_customop") + @patch("vllm_ascend.worker.worker.get_ascend_config") + @patch("vllm_ascend.worker.worker.init_ascend_config") + @patch("vllm_ascend.worker.worker.check_ascend_device_type") @patch(init_cached_hf_modules_path) - @patch("vllm_ascend.worker.worker_v1.NPUWorker._init_profiler") + @patch("vllm_ascend.worker.worker.NPUWorker._init_profiler") def test_init_npu_worker_with_custom_cache_dtype( self, mock_init_profiler, @@ -172,7 +172,7 @@ class TestNPUWorker(TestBase): mock_get_ascend_config.return_value = mock_ascend_config # Create NPUWorker instance - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker with patch("vllm.utils.torch_utils.STR_DTYPE_TO_TORCH_DTYPE", {"float32": torch.float32}): @@ -189,7 +189,7 @@ class TestNPUWorker(TestBase): def test_initialize_cache(self): """Test initialize_cache method""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create a simple worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -203,11 +203,11 @@ class TestNPUWorker(TestBase): self.assertEqual(worker.cache_config.num_gpu_blocks, 100) self.assertEqual(worker.cache_config.num_cpu_blocks, 50) - @patch("vllm_ascend.worker.worker_v1.CaMemAllocator") + @patch("vllm_ascend.worker.worker.CaMemAllocator") @patch.dict("os.environ", {"VLLM_ASCEND_ENABLE_NZ": "0"}) def test_wake_up_mode_enabled(self, mock_allocator_class): """Test wake_up method when sleep mode is enabled""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Setup mock mock_allocator = MagicMock() @@ -236,12 +236,12 @@ class TestNPUWorker(TestBase): mock_allocator.wake_up.assert_called_once_with(tags=["test_tag"]) @patch( - "vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment" + "vllm_ascend.worker.worker.NPUWorker._init_worker_distributed_environment" ) - @patch("vllm_ascend.worker.worker_v1.NPUPlatform") + @patch("vllm_ascend.worker.worker.NPUPlatform") def test_init_device(self, mock_platform, mock_init_dist_env): """Test _init_device method""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Setup mock mock_platform.mem_get_info.return_value = (1000, 2000) @@ -279,7 +279,7 @@ class TestNPUWorker(TestBase): def test_profile_start_stop(self): """Test profile method start and stop""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -297,7 +297,7 @@ class TestNPUWorker(TestBase): def test_profile_no_profiler_raises_error(self): """Test profile method raises exception when profiler is not available""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -310,12 +310,12 @@ class TestNPUWorker(TestBase): self.assertIn("Profiler is not enabled", str(cm.exception)) - @patch("vllm_ascend.worker.worker_v1.envs_vllm") - @patch("vllm_ascend.worker.worker_v1.envs_ascend") + @patch("vllm_ascend.worker.worker.envs_vllm") + @patch("vllm_ascend.worker.worker.envs_ascend") def test_profile_and_msmonitor_both_enabled_raises_error( self, mock_envs_vllm, mock_envs_ascend): """Test profile method raises exception when both profiler and msmonitor are enabled""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker mock_envs_vllm.VLLM_TORCH_PROFILER_DIR = "/path/to/traces" mock_envs_ascend.MSMONITOR_USE_DAEMON = 1 @@ -334,7 +334,7 @@ class TestNPUWorker(TestBase): def test_lora_methods(self): """Test LoRA related methods""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -364,7 +364,7 @@ class TestNPUWorker(TestBase): def test_get_methods(self): """Test various get methods""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -394,7 +394,7 @@ class TestNPUWorker(TestBase): def test_execute_dummy_batch(self): """Test execute_dummy_batch method""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -412,8 +412,8 @@ class TestNPUWorker(TestBase): mock_model_runner._dummy_run.assert_called_once_with( num_tokens=mock_decode_token_per_req, uniform_decode=True) - @patch("vllm_ascend.worker.worker_v1.envs_vllm") - @patch("vllm_ascend.worker.worker_v1.logger") + @patch("vllm_ascend.worker.worker.envs_vllm") + @patch("vllm_ascend.worker.worker.logger") @patch("torch_npu.profiler._ExperimentalConfig") @patch("torch_npu.profiler.profile") @patch("torch_npu.profiler.tensorboard_trace_handler") @@ -434,7 +434,7 @@ class TestNPUWorker(TestBase): mock_envs_vllm, ): """Test _init_profiler method - profiler enabled case with stack and memory profiling enabled""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Set environment variables to enable profiler mock_envs_vllm.VLLM_TORCH_PROFILER_DIR = "/path/to/traces" @@ -510,10 +510,10 @@ class TestNPUWorker(TestBase): # Verify return value self.assertEqual(result, mock_profiler_instance) - @patch("vllm_ascend.worker.worker_v1.envs_vllm") + @patch("vllm_ascend.worker.worker.envs_vllm") def test_init_profiler_disabled(self, mock_envs_vllm): """Test _init_profiler method - profiler disabled case""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Set environment variable to disable profiler mock_envs_vllm.VLLM_TORCH_PROFILER_DIR = None @@ -528,10 +528,10 @@ class TestNPUWorker(TestBase): # Verify returns None self.assertIsNone(result) - @patch("vllm_ascend.worker.worker_v1.envs_vllm") + @patch("vllm_ascend.worker.worker.envs_vllm") def test_init_profiler_empty_dir(self, mock_envs_vllm): """Test _init_profiler method - empty directory string case""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Set environment variable to empty string mock_envs_vllm.VLLM_TORCH_PROFILER_DIR = "" @@ -546,12 +546,12 @@ class TestNPUWorker(TestBase): # Verify returns None (empty string is considered false) self.assertIsNone(result) - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.clear_npu_memory") - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.empty_cache") - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.mem_get_info") + @patch("vllm_ascend.worker.worker.NPUPlatform.clear_npu_memory") + @patch("vllm_ascend.worker.worker.NPUPlatform.empty_cache") + @patch("vllm_ascend.worker.worker.NPUPlatform.mem_get_info") @patch("torch_npu.npu.memory_stats") @patch("torch_npu.npu.mem_get_info") - @patch("vllm_ascend.worker.worker_v1.logger") + @patch("vllm_ascend.worker.worker.logger") def test_determine_available_memory_normal_case( self, mock_logger, @@ -562,7 +562,7 @@ class TestNPUWorker(TestBase): mock_platform_clear_npu_memory, ): """Test determine_available_memory normal case (no non-torch memory allocation)""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Setup mock - test case without non-torch memory allocation mock_platform_mem_get_info.side_effect = [ @@ -627,9 +627,9 @@ class TestNPUWorker(TestBase): # Verify log output mock_logger.info.assert_called_once() - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.clear_npu_memory") - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.empty_cache") - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.mem_get_info") + @patch("vllm_ascend.worker.worker.NPUPlatform.clear_npu_memory") + @patch("vllm_ascend.worker.worker.NPUPlatform.empty_cache") + @patch("vllm_ascend.worker.worker.NPUPlatform.mem_get_info") @patch("torch_npu.npu.memory_stats") @patch("torch_npu.npu.mem_get_info") def test_determine_available_memory_with_non_torch_allocations( @@ -641,7 +641,7 @@ class TestNPUWorker(TestBase): mock_platform_clear_npu_memory, ): """Test determine_available_memory with significant non-torch memory allocation""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Setup mock - test case with large non-torch memory allocation mock_platform_mem_get_info.side_effect = [ @@ -693,12 +693,12 @@ class TestNPUWorker(TestBase): expected_result = max(0, int(10000 * 0.9 - 5500)) self.assertEqual(result, expected_result) - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.clear_npu_memory") - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.mem_get_info") + @patch("vllm_ascend.worker.worker.NPUPlatform.clear_npu_memory") + @patch("vllm_ascend.worker.worker.NPUPlatform.mem_get_info") def test_determine_available_memory_memory_profiling_error( self, mock_platform_mem_get_info, mock_platform_clear_npu_memory): """Test determine_available_memory throws exception on memory profiling error""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Setup mock: initial memory less than current free memory (error case) mock_platform_mem_get_info.side_effect = [ @@ -720,9 +720,9 @@ class TestNPUWorker(TestBase): self.assertIn("Error in memory profiling", str(cm.exception)) - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.clear_npu_memory") - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.empty_cache") - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.mem_get_info") + @patch("vllm_ascend.worker.worker.NPUPlatform.clear_npu_memory") + @patch("vllm_ascend.worker.worker.NPUPlatform.empty_cache") + @patch("vllm_ascend.worker.worker.NPUPlatform.mem_get_info") @patch("torch_npu.npu.memory_stats") @patch("torch_npu.npu.mem_get_info") def test_determine_available_memory_negative_result( @@ -734,7 +734,7 @@ class TestNPUWorker(TestBase): mock_platform_clear_npu_memory, ): """Test determine_available_memory returns 0 when result is negative""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Setup mock: high peak memory causes negative available memory mock_platform_mem_get_info.side_effect = [ @@ -787,12 +787,12 @@ class TestNPUWorker(TestBase): """Test execute_model method - first rank case""" from vllm.v1.outputs import ModelRunnerOutput - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with ( patch.object(NPUWorker, "__init__", lambda x, **kwargs: None), - patch("vllm_ascend.worker.worker_v1.get_pp_group") as + patch("vllm_ascend.worker.worker.get_pp_group") as mock_get_pp_group, ): worker = NPUWorker() @@ -822,14 +822,14 @@ class TestNPUWorker(TestBase): mock_scheduler_output, None) self.assertEqual(result, mock_model_output) - @patch("vllm_ascend.worker.worker_v1.get_pp_group") - @patch("vllm_ascend.worker.worker_v1.get_tp_group") + @patch("vllm_ascend.worker.worker.get_pp_group") + @patch("vllm_ascend.worker.worker.get_tp_group") def test_execute_model_middle_rank(self, mock_get_tp_group, mock_get_pp_group): """Test execute_model method - middle rank case""" from vllm.sequence import IntermediateTensors - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -882,12 +882,12 @@ class TestNPUWorker(TestBase): """Test execute_model method - external_launcher mode""" from vllm.v1.outputs import ModelRunnerOutput - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with ( patch.object(NPUWorker, "__init__", lambda x, **kwargs: None), - patch("vllm_ascend.worker.worker_v1.get_pp_group") as + patch("vllm_ascend.worker.worker.get_pp_group") as mock_get_pp_group, ): worker = NPUWorker() @@ -915,10 +915,10 @@ class TestNPUWorker(TestBase): # In external_launcher mode, it doesn't enter middle processing logic, returns result directly self.assertEqual(result, mock_model_output) - @patch("vllm_ascend.worker.worker_v1.CaMemAllocator") + @patch("vllm_ascend.worker.worker.CaMemAllocator") def test_load_model_with_sleep_mode(self, mock_allocator_class): """Test load_model method - with sleep mode enabled""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -947,7 +947,7 @@ class TestNPUWorker(TestBase): def test_load_model_without_sleep_mode(self): """Test load_model method - without sleep mode enabled""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -963,10 +963,10 @@ class TestNPUWorker(TestBase): # Verify calls worker.model_runner.load_model.assert_called_once() - @patch("vllm_ascend.worker.worker_v1.CaMemAllocator") + @patch("vllm_ascend.worker.worker.CaMemAllocator") def test_load_model_sleep_mode_assertion_error(self, mock_allocator_class): """Test load_model method - assertion error in sleep mode""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -987,14 +987,14 @@ class TestNPUWorker(TestBase): self.assertIn("Sleep mode can only be", str(cm.exception)) - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.seed_everything") - @patch("vllm_ascend.worker.worker_v1.logger") - @patch("vllm_ascend.worker.worker_v1.NPUWorker._warm_up_atb") + @patch("vllm_ascend.worker.worker.NPUPlatform.seed_everything") + @patch("vllm_ascend.worker.worker.logger") + @patch("vllm_ascend.worker.worker.NPUWorker._warm_up_atb") def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb, mock_logger, mock_seed_everything): """Test compile_or_warm_up_model method - eager mode""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -1036,13 +1036,13 @@ class TestNPUWorker(TestBase): # Verify atb warm up mock_warm_up_atb.assert_called_once() - @patch("vllm_ascend.worker.worker_v1.NPUPlatform.seed_everything") - @patch("vllm_ascend.worker.worker_v1.logger") - @patch("vllm_ascend.worker.worker_v1.NPUWorker._warm_up_atb") + @patch("vllm_ascend.worker.worker.NPUPlatform.seed_everything") + @patch("vllm_ascend.worker.worker.logger") + @patch("vllm_ascend.worker.worker.NPUWorker._warm_up_atb") def test_compile_or_warm_up_model_with_graph_capture( self, mock_warm_up_atb, mock_logger, mock_seed_everything): """Test compile_or_warm_up_model method - with graph capture enabled""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -1076,11 +1076,11 @@ class TestNPUWorker(TestBase): # Verify atb warm up mock_warm_up_atb.assert_called_once() - @patch("vllm_ascend.worker.worker_v1.CaMemAllocator") + @patch("vllm_ascend.worker.worker.CaMemAllocator") def test_initialize_from_config_with_sleep_mode(self, mock_allocator_class): """Test initialize_from_config method - with sleep mode enabled""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -1111,7 +1111,7 @@ class TestNPUWorker(TestBase): def test_initialize_from_config_without_sleep_mode(self): """Test initialize_from_config method - without sleep mode enabled""" - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -1131,16 +1131,16 @@ class TestNPUWorker(TestBase): worker.model_runner.initialize_kv_cache.assert_called_once_with( mock_kv_cache_config) - @patch("vllm_ascend.worker.worker_v1.get_pp_group") - @patch("vllm_ascend.worker.worker_v1.get_tp_group") - @patch("vllm_ascend.worker.worker_v1.EMPTY_MODEL_RUNNER_OUTPUT") + @patch("vllm_ascend.worker.worker.get_pp_group") + @patch("vllm_ascend.worker.worker.get_tp_group") + @patch("vllm_ascend.worker.worker.EMPTY_MODEL_RUNNER_OUTPUT") def test_execute_model_kv_connector_not_finished(self, mock_empty_output, mock_get_tp_group, mock_get_pp_group): """Test execute_model method - kv_connector_output not finished sending/recving case""" from vllm.sequence import IntermediateTensors - from vllm_ascend.worker.worker_v1 import NPUWorker + from vllm_ascend.worker.worker import NPUWorker # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 5f227172..92a3f253 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -289,7 +289,7 @@ class NPUPlatform(Platform): ) parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker" else: - parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker" + parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker" refresh_block_size(vllm_config) diff --git a/vllm_ascend/worker/v2/__init__.py b/vllm_ascend/worker/v2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/worker/v2/aclgraph_utils.py b/vllm_ascend/worker/v2/aclgraph_utils.py new file mode 100644 index 00000000..b6460fd0 --- /dev/null +++ b/vllm_ascend/worker/v2/aclgraph_utils.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from contextlib import contextmanager +from typing import Any + +import torch +import torch.nn as nn +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager +from vllm.v1.worker.gpu.cudagraph_utils import \ + prepare_inputs_to_capture as prepare_inputs_to_capture_gpu +from vllm.v1.worker.gpu.input_batch import InputBuffers + +from vllm_ascend.worker.v2.utils import torch_cuda_wrapper + + +class AclGraphManager(CudaGraphManager): + """ACL Graph Manager for Ascend NPUs.""" + + def __init__(self, vllm_config: VllmConfig, device: torch.device): + with torch_cuda_wrapper(): + super().__init__(vllm_config, device) + + def capture_graph( + self, + num_tokens: int, + model: nn.Module, + input_buffers: InputBuffers, + block_tables: BlockTables, + attn_metadata_builders: list[AttentionMetadataBuilder], + kv_cache_config: KVCacheConfig, + ) -> None: + with (torch_cuda_wrapper(), prepare_capture_inputs_wrapper()): + super().capture_graph( + num_tokens, + model, + input_buffers, + block_tables, + attn_metadata_builders, + kv_cache_config, + ) + + +@contextmanager +def prepare_capture_inputs_wrapper(): + """Context manager to override input preparation for NPU graph capture.""" + # TODO(Ronald1995): make prepare_inputs_to_capture as static method + # in CudaGraphManager. + global prepare_inputs_to_capture_gpu + try: + ori_func = prepare_inputs_to_capture_gpu + prepare_inputs_to_capture_gpu = prepare_inputs_to_capture + yield + finally: + prepare_inputs_to_capture_gpu = ori_func + + +def prepare_inputs_to_capture( + num_reqs: int, + num_tokens: int, + input_buffers: InputBuffers, + block_tables: BlockTables, + attn_metadata_builders: list[AttentionMetadataBuilder], + max_model_len: int, + kv_cache_config: KVCacheConfig, +) -> dict[str, Any]: + # TODO(Ronald1995): Implement NPU specific input preparation. + return {} diff --git a/vllm_ascend/worker/v2/attn_utils.py b/vllm_ascend/worker/v2/attn_utils.py new file mode 100644 index 00000000..ad2caa55 --- /dev/null +++ b/vllm_ascend/worker/v2/attn_utils.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence +from typing import Any + +import numpy as np +import torch +from vllm.config import VllmConfig +from vllm.config.model import ModelDType +from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig + +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + AscendPrefillContextParallelMetadata) + +_ATTENTION_MASK_BUILDER = None + + +def get_attn_mask_builder(device: torch.device): + """Get attention mask builder which only have one instance.""" + global _ATTENTION_MASK_BUILDER + if _ATTENTION_MASK_BUILDER is None: + _ATTENTION_MASK_BUILDER = AttentionMaskBuilder(device) + return _ATTENTION_MASK_BUILDER + + +def build_attn_metadata( + attn_metadata_builders: list[AttentionMetadataBuilder], + num_reqs: int, + num_tokens: int, + query_start_loc_gpu: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + num_computed_tokens_cpu: torch.Tensor, + block_tables: Sequence[torch.Tensor], + slot_mappings: torch.Tensor, + kv_cache_config: KVCacheConfig, + decode_token_per_req: int, + actual_seq_lengths_q: list[int], + positions: torch.Tensor | None = None, + attn_mask: torch.Tensor + | None = None, + spec_attn_mask: torch.Tensor | None = None, + attn_state: Any | None = None, + is_only_prefill: bool = False, + graph_pad_size: int = -1, + num_input_tokens: int = 0, + prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata + | None = None, +) -> dict[str, Any]: + """Build attention metadata for Ascend NPUs.""" + # TODO(Ronald1995): optimize AscendCommonAttentionMetadata. + max_query_len = int(query_start_loc_cpu.max()) + + attn_metadata: dict[str, Any] = {} + kv_cache_groups = kv_cache_config.kv_cache_groups + for i, kv_cache_spec in enumerate(kv_cache_groups): + block_table = block_tables[i] + slot_mapping = slot_mappings[i] + + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=query_start_loc_gpu, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens_cpu=seq_lens_cpu[:num_reqs], + seq_lens=seq_lens[:num_reqs], + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + decode_token_per_req=decode_token_per_req, + block_table_tensor=block_table, + slot_mapping=slot_mapping, + actual_seq_lengths_q=actual_seq_lengths_q, + positions=positions, + attn_mask=attn_mask, + spec_attn_mask=spec_attn_mask, + attn_state=attn_state, + is_only_prefill=is_only_prefill, + graph_pad_size=graph_pad_size, + num_input_tokens=num_input_tokens, + prefill_context_parallel_metadata=prefill_context_parallel_metadata, + ) + + attn_metadata_builder = attn_metadata_builders[i] + metadata = attn_metadata_builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, # type: ignore + ) + for layer_name in kv_cache_spec.layer_names: + attn_metadata[layer_name] = metadata + return attn_metadata + + +def build_attn_state( + vllm_config: VllmConfig, + seq_lens_np: np.ndarray, + num_reqs, + num_scheduled_tokens, + num_valid_tokens, +): + """Build attention state for npu's attention backend.""" + if vllm_config.model_config.runner_type == "pooling": + if isinstance( + vllm_config.kv_cache_config.kv_cache_groups[0].kv_cache_spec, + EncoderOnlyAttentionSpec, + ): + attn_state = AscendAttentionState.PrefillNoCache + else: + attn_state = AscendAttentionState.PrefillCacheHit + elif np.array_equal(seq_lens_np[:num_reqs], num_scheduled_tokens): + attn_state = AscendAttentionState.PrefillNoCache + # We assume it is the decode stage, where prefill occurs + # but only one token is not hit in cache. + elif np.all(num_scheduled_tokens == 1): + attn_state = AscendAttentionState.DecodeOnly + if (vllm_config.speculative_config + and vllm_config.speculative_config.method == 'mtp'): + # SpecDecoding now supports seq_len=1 and seq_len=2 + # In Prefilling Decoding Disaggregation scenario, SpecDecoding + # need to supports seq_len=1 + attn_state = AscendAttentionState.SpecDecoding + # Speculative decoding. + elif np.all(num_valid_tokens == 1): + if (vllm_config.speculative_config + and vllm_config.speculative_config.method == 'mtp'): + attn_state = AscendAttentionState.SpecDecoding + else: + attn_state = AscendAttentionState.ChunkedPrefill + # splitfuse + elif vllm_config.scheduler_config.enable_chunked_prefill: + attn_state = AscendAttentionState.ChunkedPrefill + else: + attn_state = AscendAttentionState.PrefillCacheHit + return attn_state + + +def make_attention_mask( + vllm_config: VllmConfig, + attn_state: AscendAttentionState, + dtype: ModelDType | torch.dtype, + device: torch.device, +) -> torch.Tensor: + """make attention mask for npu's attention backend.""" + attn_mask_builder = get_attn_mask_builder(device) + # pcp situation. + if attn_mask_builder is None: + raise ValueError("Attn mask builder is None") + # Pooling situation. + if vllm_config.model_config.runner_type == "pooling": + return attn_mask_builder.get_attn_mask(2048, torch.bool) + + # TODO(Ronald1995) cosidering pcp. + if vllm_config.model_config.use_mla: + # mla prefill + if attn_state != AscendAttentionState.DecodeOnly: + return attn_mask_builder.get_mla_mask(dtype) + return attn_mask_builder.get_splitfuse_attn_mask() diff --git a/vllm_ascend/worker/v2/input_batch.py b/vllm_ascend/worker/v2/input_batch.py new file mode 100644 index 00000000..843ec658 --- /dev/null +++ b/vllm_ascend/worker/v2/input_batch.py @@ -0,0 +1,37 @@ +import numpy as np +import torch +from vllm.v1.worker.gpu.input_batch import InputBuffers + + +class AscendInputBuffers(InputBuffers): + """Input buffers for Ascend NPUs.""" + + def __init__( + self, + max_num_reqs: int, + max_num_tokens: int, + inputs_embeds_size: int, + vocab_size: int, + dtype: torch.dtype, + device: torch.device, + pin_memory: bool, + ): + super().__init__( + max_num_reqs, + max_num_tokens, + inputs_embeds_size, + vocab_size, + dtype, + device, + pin_memory, + ) + # Create seq_lens_cpu and seq_lens_np. + # npu's attention backend still needs seq_lens on CPU side. + self.seq_lens_cpu: torch.Tensor = torch.zeros( + max_num_reqs, + dtype=torch.int32, + device="cpu", + ) + # seq_len_np and seq_lens_cpu share the same memory. + # define seq_lens_np for easier calculation with numpy. + self.seq_lens_np: np.ndarray = self.seq_lens_cpu.numpy() diff --git a/vllm_ascend/worker/v2/model_runner.py b/vllm_ascend/worker/v2/model_runner.py new file mode 100644 index 00000000..f85304d6 --- /dev/null +++ b/vllm_ascend/worker/v2/model_runner.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import torch +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput +from vllm.v1.worker.gpu.input_batch import (InputBatch, + combine_sampled_and_draft_tokens, + prepare_pos_seq_lens, + prepare_prefill_inputs) +from vllm.v1.worker.gpu.model_runner import GPUModelRunner +from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu.sample.output import SamplerOutput + +from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager +from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata, + build_attn_state, + make_attention_mask) +from vllm_ascend.worker.v2.input_batch import AscendInputBuffers +from vllm_ascend.worker.v2.states import AscendRequestState +from vllm_ascend.worker.v2.utils import torch_cuda_wrapper + +logger = init_logger(__name__) + + +class NPUModelRunner(GPUModelRunner): + """Model runner for Ascend NPUs.""" + + def __init__(self, vllm_config: VllmConfig, device: torch.device): + with torch_cuda_wrapper(): + super().__init__(vllm_config, device) + + # because we will override these attribute, delete these attribute to + # make sure it's collected by python gc immediately. + del self.cudagraph_manager + del self.req_states + del self.input_buffers + + # NPU specific initializations can be added below. + self.cudagraph_manager: AclGraphManager = AclGraphManager( + vllm_config, + device, + ) + # AscendRequestState has extra `num_computed_tokens_cpu` attribute. + # so reinitialize req_states here. + self.req_states: AscendRequestState = AscendRequestState( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + num_speculative_steps=self.num_speculative_steps, + vocab_size=self.vocab_size, + device=self.device, + pin_memory=self.pin_memory, + ) + # AscendInputBuffers has extra `seq_lens_cpu` attribute. + # so reinitialize input_buffers here. + self.input_buffers: AscendInputBuffers = AscendInputBuffers( + max_num_reqs=self.max_num_reqs, + max_num_tokens=self.max_num_tokens, + inputs_embeds_size=self.inputs_embeds_size, + vocab_size=self.vocab_size, + dtype=self.dtype, + device=self.device, + pin_memory=self.pin_memory, + ) + + # actual seq lengths for query (used in attention backends). + self.actual_seq_lengths_q: list[int] = [] + # decode token per request (used in attention backends). + self.decode_token_per_req = 1 + + # there attributes are for async scheduling with speculative decoding. + # because npu attention backend still need to use seq_lens_cpu, + # we need to copy num_rejected_tokens back to cpu to help + # update actual seq_lens_cpu. gpu attention backend do not need these + # attributes, cause their attention backends do not use seq_lens_cpu. + # and seq_lens_cpu is deprecated in gpu_model_runner_v2. + self.num_rejected_tokens_event = None + self.num_rejectd_tokens_cpu = None + self.num_rejected_token_stream = None + if self.use_async_scheduling and self.do_spec_decode: + self.num_rejected_tokens_event = torch.npu.Event() + self.num_rejected_token_stream = torch.npu.Stream() + self.num_rejectd_tokens_cpu = torch.empty( + self.max_num_reqs, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + + def prepare_inputs( + self, + scheduler_output: SchedulerOutput, + num_tokens_after_padding: int, + ) -> InputBatch: + """Override GPUModelRunner.prepare_inputs for Ascend NPUs. + npu attention backends need seq_lens_cpu to work. + so we need to prepare seq_lens_cpu here. + """ + num_tokens = scheduler_output.total_num_scheduled_tokens + assert num_tokens > 0 + num_reqs = len(scheduler_output.num_scheduled_tokens) + + # Decode first, then prefill. + # batch_idx -> req_id + req_ids = sorted( + scheduler_output.num_scheduled_tokens.keys(), + key=lambda k: scheduler_output.num_scheduled_tokens[k], + ) + + self._update_seq_lens_cpu(scheduler_output, req_ids) + + num_scheduled_tokens = np.array( + [scheduler_output.num_scheduled_tokens[i] for i in req_ids], + dtype=np.int32) + num_valid_tokens = num_scheduled_tokens + if scheduler_output.scheduled_spec_decode_tokens: + num_valid_tokens = np.array( + [ + num_tokens - len( + scheduler_output.scheduled_spec_decode_tokens.get( + i, [])) + for num_tokens, i in zip(num_scheduled_tokens, req_ids) + ], + dtype=np.int32, + ) + attn_state = build_attn_state( + self.vllm_config, + self.input_buffers.seq_lens_np, + num_reqs, + num_scheduled_tokens, + num_valid_tokens, + ) + attn_mask = make_attention_mask( + self.vllm_config, + attn_state, + self.dtype, + self.device, + ) + + idx_mapping_list = [ + self.req_states.req_id_to_index[req_id] for req_id in req_ids + ] + idx_mapping = self.input_buffers.idx_mapping + idx_mapping.np[:num_reqs] = idx_mapping_list + idx_mapping_np = idx_mapping.np[:num_reqs] + # add `idx_mapping_cpu` here, because vllm-ascend's self.req_states. + # num_computed_tokens_cpu is actually cpu's tensor, while it's a gpu's + # tensor in vllm gpu_model_runner_v2. + idx_mapping_cpu = idx_mapping.cpu[:num_reqs] + idx_mapping_npu = idx_mapping.copy_to_gpu(num_reqs) + + # Get the number of draft tokens for each request. + if not scheduler_output.scheduled_spec_decode_tokens: + # No draft token scheduled (common case). + total_num_draft_tokens = 0 + total_num_logits = num_reqs + cu_num_logits = torch.arange(num_reqs + 1, + device=self.device, + dtype=torch.int32) + else: + draft_tokens = scheduler_output.scheduled_spec_decode_tokens + num_draft_tokens = np.array( + [ + len(draft_tokens[req_id]) if req_id in draft_tokens else 0 + for req_id in req_ids + ], + dtype=np.int32, + ) + total_num_draft_tokens = int(num_draft_tokens.sum()) + total_num_logits = num_reqs + total_num_draft_tokens + + np.cumsum( + num_draft_tokens + 1, + out=self.input_buffers.cu_num_logits.np[1:num_reqs + 1], + ) + cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu( + num_reqs + 1) + + # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] + block_tables = self.block_tables.gather_block_tables(idx_mapping_npu) + + # Get query_start_loc. + np.cumsum( + num_scheduled_tokens, + out=self.input_buffers.query_start_loc.np[1:num_reqs + 1], + ) + # Pad for full CUDA graph mode. + # Some attention backends like FA3 require query_start_loc to be non-decreasing. + self.input_buffers.query_start_loc.np[num_reqs + 1:] = num_tokens + self.input_buffers.query_start_loc.copy_to_gpu() + query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: + num_reqs + + 1] + query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: + num_reqs + + 1] + query_start_loc_np = self.input_buffers.query_start_loc.np[:num_reqs + + 1] + + # Get prefill tokens. + prepare_prefill_inputs( + self.input_buffers.input_ids, + self.req_states.next_prefill_tokens, + idx_mapping_npu, + query_start_loc_gpu, + self.req_states.prefill_token_ids.gpu, + self.req_states.prefill_len.gpu, + self.req_states.num_computed_tokens, + ) + + # Prepare positions and seq_lens. + prepare_pos_seq_lens( + idx_mapping_npu, + query_start_loc_gpu, + self.req_states.num_computed_tokens, + self.input_buffers.positions, + self.input_buffers.seq_lens, + ) + seq_lens = self.input_buffers.seq_lens[:num_reqs] + + # Some input token ids are directly read from the last sampled tokens + # and draft tokens. Also, get the logits indices to sample tokens from. + logits_indices = combine_sampled_and_draft_tokens( + self.input_buffers.input_ids, + idx_mapping_npu, + self.req_states.last_sampled_tokens, + query_start_loc_gpu, + seq_lens, + self.req_states.prefill_len.gpu, + self.req_states.draft_tokens, + cu_num_logits, + total_num_logits, + ) + + # Compute slot mappings: [num_kv_cache_groups, num_tokens] + slot_mappings = self.block_tables.compute_slot_mappings( + query_start_loc_gpu, self.input_buffers.positions[:num_tokens]) + + # Layer name -> attention metadata. + # TODO(Ronald1995): try to add a new method `build_attn_metadata` in + # vllm gpu_model_runner_v2, maybe we don't overwrite `prepare_inputs` + # method like this. + attn_metadata = build_attn_metadata( + attn_metadata_builders=self.attn_metadata_builders, + num_reqs=num_reqs, + num_tokens=num_tokens, + query_start_loc_gpu=query_start_loc_gpu, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=self.input_buffers.seq_lens, + seq_lens_cpu=self.input_buffers.seq_lens_cpu, + actual_seq_lengths_q=self.actual_seq_lengths_q, + num_computed_tokens_cpu=self.req_states. + num_computed_tokens_cpu[idx_mapping_cpu], + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=self.kv_cache_config, + decode_token_per_req=self.decode_token_per_req, + attn_mask=attn_mask, + attn_state=attn_state, + ) + + input_ids = self.input_buffers.input_ids[:num_tokens_after_padding] + positions = self.input_buffers.positions[:num_tokens_after_padding] + return InputBatch( + req_ids=req_ids, + num_reqs=num_reqs, + idx_mapping=idx_mapping_npu, + idx_mapping_np=idx_mapping_np, + num_scheduled_tokens=num_scheduled_tokens, + num_tokens=num_tokens, + num_tokens_after_padding=num_tokens_after_padding, + num_draft_tokens=total_num_draft_tokens, + query_start_loc=query_start_loc_gpu, + query_start_loc_np=query_start_loc_np, + seq_lens=seq_lens, + seq_lens_np=self.input_buffers.seq_lens_np, + input_ids=input_ids, + positions=positions, + attn_metadata=attn_metadata, + logits_indices=logits_indices, + cu_num_logits=cu_num_logits, + ) + + def sample( + self, + hidden_states: torch.Tensor, + input_batch: InputBatch, + sampling_metadata: SamplingMetadata, + grammar_output: GrammarOutput | None, + ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]: + """Override GPUModelRunner.sample for Ascend NPUs. + when using async scheduling with speculative decoding, + we need to copy mpu's num_rejected tensor to cpu. + these operations aren't needed in gpu_model_runner_v2, + because gpu attention backends do not use seq_lens_cpu anymore. + """ + sampler_output, num_sampled, num_rejected = super().sample( + hidden_states, + input_batch, + sampling_metadata, + grammar_output, + ) + if self.num_rejected_tokens_event is not None: + # npu attention backend still need to use seq_lens_cpu, + # when doing speculative decoding with async_scheduling, + # we need to copy num_rejected_tokens back to cpu. + default_stream = torch.cuda.current_stream() + assert self.num_rejected_token_stream is not None + assert self.num_rejectd_tokens_cpu is not None + with torch.npu.stream(self.num_rejected_token_stream): + self.num_rejected_token_stream.wait_stream(default_stream) + self.num_rejectd_tokens_cpu.copy_( + num_rejected, + non_blocking=True, + ) + self.num_rejected_tokens_event.record() + return sampler_output, num_sampled, num_rejected + + def _update_seq_lens_cpu( + self, + scheduler_output: SchedulerOutput, + req_ids: list[str], + ): + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + + # update num_computed_tokens_cpu + # TODO(Ronald1995): update num_computed_tokens_cpu by considering + # num_rejectd_tokens. + for req_id, num_computed_token in zip( + scheduler_output.scheduled_cached_reqs.req_ids, + scheduler_output.scheduled_cached_reqs.num_computed_tokens, + ): + req_index = self.req_states.req_id_to_index[req_id] + self.req_states.num_computed_tokens_cpu[ + req_index] = num_computed_token + + # update seq_lens_cpu + for i, req_id in enumerate(req_ids): + req_index = self.req_states.req_id_to_index[req_id] + num_computed_tokens = self.req_states.num_computed_tokens_cpu[ + req_index] + self.input_buffers.seq_lens_cpu[ + i] = num_computed_tokens + num_scheduled_tokens[req_id] diff --git a/vllm_ascend/worker/v2/states.py b/vllm_ascend/worker/v2/states.py new file mode 100644 index 00000000..1364c869 --- /dev/null +++ b/vllm_ascend/worker/v2/states.py @@ -0,0 +1,88 @@ +from contextlib import contextmanager + +import torch +from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.gpu.states import RequestState, UvaBuffer + + +class AscendRequestState(RequestState): + """Request state for Ascend NPUs.""" + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + num_speculative_steps: int, + vocab_size: int, + device: torch.device, + pin_memory: bool, + ): + with uva_wrapper(): + super().__init__( + max_num_reqs, + max_model_len, + max_num_batched_tokens, + num_speculative_steps, + vocab_size, + device, + pin_memory, + ) + # because we will override these attribute, delete these attribute to + # make sure it's collected by python gc immediately. + del self.prefill_token_ids + # vllm gpu_model_runner_v2 deprecate the seqs_lens_cpu attribute, + # because they think most attention backends do not need it. + # However, Ascend attention backend muse uses seqs_lens_cpu, + # so we keep num_computed_tokens_cpu here, seq_lens_cpu need to be + # calculated by num_computed_tokens_cpu + decode_token_per_req outside. + self.num_computed_tokens_cpu: torch.Tensor = torch.zeros( + self.max_num_reqs, + dtype=torch.int32, + device="cpu", + ) + # NOTE(Ronald1995): Ascend NPUs do not support UVA yet, + # so we use CpuGpuBuffer to allocate prefill_token_ids buffer. + self.prefill_token_ids: CpuGpuBuffer = self._make_buffer( # type: ignore + (self.max_num_reqs, self.max_model_len), + dtype=torch.int32) + + def add_request( + self, + req_id, + prompt_len, + prefill_token_ids, + num_computed_tokens, + sampling_params, + lora_request, + ): + + super().add_request( + req_id, + prompt_len, + prefill_token_ids, + num_computed_tokens, + sampling_params, + lora_request, + ) + req_idx = self.req_id_to_index[req_id] + self.num_computed_tokens_cpu[req_idx] = num_computed_tokens + + +@contextmanager +def uva_wrapper(): + """Context manager to disable UVA for Ascend NPUs.""" + + class UvaBufferWrapper: + + def __init__(self, *args, **kwargs): + pass + + # TODO(Ronald1995): rectify this when NPU support uva. + global UvaBuffer + ori_class = UvaBuffer + try: + UvaBuffer = UvaBufferWrapper + yield + finally: + UvaBuffer = ori_class diff --git a/vllm_ascend/worker/v2/utils.py b/vllm_ascend/worker/v2/utils.py new file mode 100644 index 00000000..20efb75f --- /dev/null +++ b/vllm_ascend/worker/v2/utils.py @@ -0,0 +1,33 @@ +from contextlib import contextmanager + +import torch + + +@contextmanager +def torch_cuda_wrapper(): + ori_event = torch.cuda.Event + ori_stream = torch.cuda.Stream + ori_default_stream = torch.cuda.default_stream + ori_current_stream = torch.cuda.current_stream + ori_graph_pool_handle = torch.cuda.graph_pool_handle + ori_cuda_graph_cls = torch.cuda.CUDAGraph + ori_cuda_graph_func = torch.cuda.graph + try: + torch.cuda.Event = torch.npu.Event + torch.cuda.Stream = torch.npu.Stream + torch.cuda.default_stream = torch.npu.default_stream + torch.cuda.current_stream = torch.npu.current_stream + torch.cuda.graph_pool_handle = torch.npu.graph_pool_handle + torch.cuda.CUDAGraph = torch.npu.NpuGraph + torch.cuda.graph = torch.npu.graph + yield + finally: + # revert back torch cuda properties, so it will still raise error + # to call cuda ops in npu environment. + torch.cuda.Event = ori_event + torch.cuda.Stream = ori_stream + torch.cuda.default_stream = ori_default_stream + torch.cuda.current_stream = ori_current_stream + torch.cuda.graph_pool_handle = ori_graph_pool_handle + torch.cuda.CUDAGraph = ori_cuda_graph_cls + torch.cuda.graph = ori_cuda_graph_func diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker.py similarity index 97% rename from vllm_ascend/worker/worker_v1.py rename to vllm_ascend/worker/worker.py index f05ef69a..de23efa7 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker.py @@ -137,6 +137,8 @@ class NPUWorker(WorkerBase): if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED: WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") + self.use_v2_model_runner = envs_vllm.VLLM_USE_V2_MODEL_RUNNER + def sleep(self, level: int = 1) -> None: free_bytes_before_sleep = NPUPlatform.mem_get_info()[0] # Save the buffers before level 2 sleep @@ -230,7 +232,15 @@ class NPUWorker(WorkerBase): # for more details self.device = self._init_device() # Init ModelRunner here, so that we have access to self.device. - self.model_runner = NPUModelRunner(self.vllm_config, self.device) + if self.use_v2_model_runner: + logger.error( + "npu model runner v2 is in developing, it can't work well for now." + ) + from vllm_ascend.worker.v2.model_runner import \ + NPUModelRunner as NPUModelRunnerV2 + self.model_runner = NPUModelRunnerV2(self.vllm_config, self.device) + else: + self.model_runner = NPUModelRunner(self.vllm_config, self.device) @torch.inference_mode() def determine_available_memory(self) -> int: diff --git a/vllm_ascend/xlite/xlite_worker.py b/vllm_ascend/xlite/xlite_worker.py index 54537c92..83fa1571 100644 --- a/vllm_ascend/xlite/xlite_worker.py +++ b/vllm_ascend/xlite/xlite_worker.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from vllm_ascend.worker.worker_v1 import NPUWorker +from vllm_ascend.worker.worker import NPUWorker from vllm_ascend.xlite.xlite_model_runner import XliteModelRunner @@ -23,4 +23,4 @@ class XliteWorker(NPUWorker): def init_device(self): """Override init_device to init xlite model runner""" self.device = self._init_device() - self.model_runner = XliteModelRunner(self.vllm_config, self.device) \ No newline at end of file + self.model_runner = XliteModelRunner(self.vllm_config, self.device)