From ad9b711f8965c864a7003d1b2787f58067fd78cc Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 <155598497+zhenwenqi2024@users.noreply.github.com> Date: Tue, 6 Jan 2026 22:48:21 +0800 Subject: [PATCH] [Bugfix] fix dcp_only bug and add e2e accuracy test for dcp only and pcp only (#5565) ### What this PR does / why we need it? [Bugfix] fix dcp_only bug and add e2e accuracy test for dcp only and pcp only this pr fix the bug of accuracy test when decode_parallel_size>1 and prefill_context_parallel_size=1. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/7157596103666ee7ccb7008acee8bff8a8ff1731 --------- Signed-off-by: zhenwenqi2024 --- .../multicard/long_sequence/test_accuracy.py | 114 ++++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 20 +-- vllm_ascend/xlite/xlite_worker.py | 4 + 3 files changed, 128 insertions(+), 10 deletions(-) diff --git a/tests/e2e/multicard/long_sequence/test_accuracy.py b/tests/e2e/multicard/long_sequence/test_accuracy.py index bcff9da3..61d99c97 100644 --- a/tests/e2e/multicard/long_sequence/test_accuracy.py +++ b/tests/e2e/multicard/long_sequence/test_accuracy.py @@ -96,3 +96,117 @@ def test_models_long_sequence_output_between_tp_and_cp( name_0="vllm_eager_outputs", name_1="vllm_context_parallel_outputs", ) + + +model = "vllm-ascend/DeepSeek-V2-Lite-W8A8" + + +@pytest.mark.parametrize("max_tokens", [10]) +def test_accuracy_dcp_only_graph(max_tokens: int, ) -> None: + prompts = [ + "The president of the United States is", "The capital of France is" + ] + cp_kwargs = { + "tensor_parallel_size": 2, + "decode_context_parallel_size": 2, + "prefill_context_parallel_size": 1, + "enable_expert_parallel": True, + "compilation_config": { + "cudagraph_mode": "FULL_DECODE_ONLY", + "cudagraph_capture_sizes": [4, 8, 24, 48, 60] + }, + "quantization": "ascend", + "max_model_len": 1024, + } + tp_kwargs = { + "tensor_parallel_size": 4, + "enable_expert_parallel": True, + "enforce_eager": True, + "quantization": "ascend", + "max_model_len": 1024, + } + with VllmRunner(model, **cp_kwargs) as runner: # type: ignore + vllm_context_parallel_outputs = runner.generate_greedy( + prompts, max_tokens) + + with VllmRunner(model, **tp_kwargs) as runner: # type: ignore + vllm_eager_outputs = runner.generate_greedy(prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs, + outputs_1_lst=vllm_context_parallel_outputs, + name_0="vllm_eager_outputs", + name_1="vllm_dcp_only_graph_outputs", + ) + + +@pytest.mark.parametrize("max_tokens", [10]) +def test_accuracy_dcp_only_eager(max_tokens: int, ) -> None: + prompts = [ + "The president of the United States is", "The capital of France is" + ] + cp_kwargs = { + "tensor_parallel_size": 2, + "decode_context_parallel_size": 2, + "prefill_context_parallel_size": 1, + "enable_expert_parallel": True, + "enforce_eager": True, + "quantization": "ascend", + "max_model_len": 1024, + } + tp_kwargs = { + "tensor_parallel_size": 4, + "enable_expert_parallel": True, + "enforce_eager": True, + "quantization": "ascend", + "max_model_len": 1024, + } + with VllmRunner(model, **cp_kwargs) as runner: # type: ignore + vllm_context_parallel_outputs = runner.generate_greedy( + prompts, max_tokens) + + with VllmRunner(model, **tp_kwargs) as runner: # type: ignore + vllm_eager_outputs = runner.generate_greedy(prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs, + outputs_1_lst=vllm_context_parallel_outputs, + name_0="vllm_eager_outputs", + name_1="vllm_dcp_only_eager_outputs", + ) + + +@pytest.mark.parametrize("max_tokens", [10]) +def test_accuracy_pcp_only(max_tokens: int, ) -> None: + prompts = [ + "The president of the United States is", "The capital of France is" + ] + cp_kwargs = { + "tensor_parallel_size": 2, + "decode_context_parallel_size": 1, + "prefill_context_parallel_size": 2, + "enable_expert_parallel": True, + "enforce_eager": True, + "quantization": "ascend", + "max_model_len": 1024, + } + tp_kwargs = { + "tensor_parallel_size": 4, + "enable_expert_parallel": True, + "enforce_eager": True, + "quantization": "ascend", + "max_model_len": 1024, + } + with VllmRunner(model, **cp_kwargs) as runner: # type: ignore + vllm_context_parallel_outputs = runner.generate_greedy( + prompts, max_tokens) + + with VllmRunner(model, **tp_kwargs) as runner: # type: ignore + vllm_eager_outputs = runner.generate_greedy(prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs, + outputs_1_lst=vllm_context_parallel_outputs, + name_0="vllm_eager_outputs", + name_1="vllm_pcp_only_outputs", + ) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0eab7345..d4d901b2 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -935,22 +935,21 @@ class NPUModelRunner(GPUModelRunner): blk_table_tensor = blk_table.get_device_tensor() slot_mapping = blk_table.slot_mapping.gpu[: maybe_pcp_full_tokens] - if self.pcp_size * self.dcp_size == 1: + if self.pcp_size == 1: slot_mapping[ total_num_scheduled_tokens:num_input_tokens].fill_(-1) - slot_mapping = blk_table.slot_mapping.gpu if self.pcp_size * self.dcp_size > 1: self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata( total_num_scheduled_tokens, self.query_lens, self.attn_mask, self.input_batch) blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1) - slot_mapping = slot_mapping[:maybe_pcp_full_tokens] - slot_mapping = self.pcp_manager.get_padded_slot_mapping( - total_num_scheduled_tokens, - slot_mapping, - ) - blk_table.slot_mapping.gpu[:self.pcp_manager. - num_actual_tokens_pcp_padded] = slot_mapping + if self.pcp_size > 1: + slot_mapping = self.pcp_manager.get_padded_slot_mapping( + total_num_scheduled_tokens, + slot_mapping, + ) + blk_table.slot_mapping.gpu[:self.pcp_manager. + num_actual_tokens_pcp_padded] = slot_mapping # NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs # has been split to multiple parts, and there are 3 parts that is related to this @@ -3034,7 +3033,7 @@ def _torch_cuda_wrapper(): torch.cuda.synchronize = torch.npu.synchronize torch.cuda.mem_get_info = torch.npu.mem_get_info yield - except Exception: + except Exception as e: torch.cuda.Event = _EventPlaceholder torch.cuda.Stream = _StreamPlaceholder torch.cuda.default_stream = _StreamPlaceholder @@ -3042,6 +3041,7 @@ def _torch_cuda_wrapper(): torch.cuda.stream = _StreamPlaceholder torch.cuda.synchronize = _StreamPlaceholder torch.cuda.mem_get_info = _StreamPlaceholder + raise RuntimeError(f"NPUModelRunner init failed, error is {e}") finally: # if anything goes wrong, just patch it with a placeholder torch.cuda.Event = _EventPlaceholder diff --git a/vllm_ascend/xlite/xlite_worker.py b/vllm_ascend/xlite/xlite_worker.py index 83fa1571..a9479a82 100644 --- a/vllm_ascend/xlite/xlite_worker.py +++ b/vllm_ascend/xlite/xlite_worker.py @@ -13,6 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from vllm.v1.worker.workspace import init_workspace_manager + from vllm_ascend.worker.worker import NPUWorker from vllm_ascend.xlite.xlite_model_runner import XliteModelRunner @@ -23,4 +25,6 @@ class XliteWorker(NPUWorker): def init_device(self): """Override init_device to init xlite model runner""" self.device = self._init_device() + num_ubatches = 1 + init_workspace_manager(self.device, num_ubatches) self.model_runner = XliteModelRunner(self.vllm_config, self.device)