From 9b67c87b1475fe7dc79442efc8f2fcbce7b728cf Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 30 Jul 2025 08:47:22 +0800 Subject: [PATCH] [Refactor]Refactor sampler (#2050) Refactor Sampler implementation from patch way to inherit from vLLM Sampler interface. Next step: Make the op `TopKTopPSampler` in vLLM support custom ops register mechanism - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/61a6905ab036fd00eafdb1b0ca130d5feccfe686 Signed-off-by: wangxiyuan --- .../worker/patch_common/test_patch_sampler.py | 46 ---------- tests/ut/sample/test_sampler.py | 32 +++++++ vllm_ascend/envs.py | 6 +- vllm_ascend/patch/__init__.py | 16 +--- .../patch/worker/patch_common/__init__.py | 1 - .../worker/patch_common/patch_sampler.py | 83 ------------------- vllm_ascend/sample/sampler.py | 62 ++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 12 ++- 8 files changed, 108 insertions(+), 150 deletions(-) delete mode 100644 tests/ut/patch/worker/patch_common/test_patch_sampler.py create mode 100644 tests/ut/sample/test_sampler.py delete mode 100644 vllm_ascend/patch/worker/patch_common/patch_sampler.py create mode 100644 vllm_ascend/sample/sampler.py diff --git a/tests/ut/patch/worker/patch_common/test_patch_sampler.py b/tests/ut/patch/worker/patch_common/test_patch_sampler.py deleted file mode 100644 index fc9fbd1..0000000 --- a/tests/ut/patch/worker/patch_common/test_patch_sampler.py +++ /dev/null @@ -1,46 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. - -import importlib -import os -from unittest import mock - -import torch -from vllm.v1.sample.ops import topk_topp_sampler - -from tests.ut.base import TestBase - - -class TestTopKTopPSamplerOptimize(TestBase): - - @mock.patch.dict(os.environ, - {"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION": "1"}) - @mock.patch("torch_npu.npu_top_k_top_p") - def test_npu_topk_topp_called_when_optimized(self, mock_npu_op): - # We have to patch and reload because the patch will take effect - # only after VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE is set. - import vllm_ascend.patch.worker.patch_common.patch_sampler - importlib.reload(vllm_ascend.patch.worker.patch_common.patch_sampler) - - mock_npu_op.return_value = (torch.randn(1, 3)) - sampler = topk_topp_sampler.TopKTopPSampler() - - logits = torch.tensor([[1.0, 2.0, 3.0]]) - k = torch.tensor([2]) - p = torch.tensor([0.9]) - generators = {0: torch.Generator()} - generators[0].manual_seed(42) - - sampler.forward_native(logits, generators, k, p) - mock_npu_op.assert_called_once_with(logits, p, k) diff --git a/tests/ut/sample/test_sampler.py b/tests/ut/sample/test_sampler.py new file mode 100644 index 0000000..98a83e6 --- /dev/null +++ b/tests/ut/sample/test_sampler.py @@ -0,0 +1,32 @@ +from unittest import mock + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler + + +class TestAscendSampler(TestBase): + + def test_init_with_raw_logprobs(self): + sampler = AscendSampler(logprobs_mode="raw_logprobs") + self.assertEqual(sampler.logprobs_mode, "raw_logprobs") + self.assertTrue(hasattr(sampler, 'topk_topp_sampler')) + self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler) + + +class TestAscendTopKTopPSampler(TestBase): + + @mock.patch("torch_npu.npu_top_k_top_p") + def test_npu_topk_topp_called_when_optimized(self, mock_npu_op): + mock_npu_op.return_value = (torch.randn(1, 3)) + sampler = AscendTopKTopPSampler() + + logits = torch.tensor([[1.0, 2.0, 3.0]]) + k = torch.tensor([2]) + p = torch.tensor([0.9]) + generators = {0: torch.Generator()} + generators[0].manual_seed(42) + + sampler.forward_native(logits, generators, k, p) + mock_npu_op.assert_called_once_with(logits, p, k) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index eb0223c..a7bb9fa 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -128,11 +128,11 @@ env_variables: Dict[str, Callable[[], Any]] = { "VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE": lambda: int( os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)), - # Whether to enable the topk optimization. It's disabled by default for experimental support - # We'll make it enabled by default in the future. + # Whether to enable the topk optimization. It's enabled by default. Please set to False if you hit any issue. + # We'll remove this flag in the future once it's stable enough. "VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION": lambda: bool( - int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))), + int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '1'))), # `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is # used for llmdatadist to build the communication topology for kv cache transfer, it is diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 3446c45..f22d948 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -88,21 +88,7 @@ # Future Plan: # Remove this patch once pytorch 2.7.0 is supported for vllm ascend. # -# ** File: worker/patch_common/patch_sampler.py ** -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p` -# Why: -# We need to use the patched `apply_top_k_top_p` in `sample`. -# The mainly reason to overwrite `apply_top_k_top_p` is -# to improve performance. -# How: -# Re-implementation the `apply_top_k_top_p` function by pytorch -# Related PR (if no, explain why): -# - https://github.com/vllm-project/vllm-ascend/pull/1732 -# Future Plan: -# Revert it when the ascend scatter performance improves. -# -# ** File: worker/patch_common/patch_sampler.py ** +# ** File: worker/patch_0_10_0/patch_sampler_gather_logprobs.py ** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.v1.sample.sampler.Sampler.gather_logprobs` # Why: diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index 8eebcdf..2533d13 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -21,4 +21,3 @@ import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa import vllm_ascend.patch.worker.patch_common.patch_linear # noqa import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa -import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_sampler.py b/vllm_ascend/patch/worker/patch_common/patch_sampler.py deleted file mode 100644 index e745bf0..0000000 --- a/vllm_ascend/patch/worker/patch_common/patch_sampler.py +++ /dev/null @@ -1,83 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Optional - -import torch -import torch_npu -from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample - -from vllm_ascend import envs - - -def apply_top_k_top_p( - logits: torch.Tensor, - k: torch.Tensor, - p: torch.Tensor, -) -> torch.Tensor: - if p is not None and k is not None: - # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) - return torch_npu.npu_top_k_top_p(logits, p, k) - - probs = logits.softmax(dim=-1) - probs_sort, _ = probs.sort(dim=-1, descending=False) - - if k is not None: - top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) - top_k_count = top_k_count.unsqueeze(dim=1) - top_k_cutoff = probs_sort.gather(-1, top_k_count) - - # Make sure the no top-k rows are no-op. - no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) - top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) - - elements_to_discard = probs < top_k_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - if p is not None: - cumprob = torch.cumsum(probs_sort, dim=-1) - top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) - top_p_mask[:, -1] = False # at least one - - top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) - top_p_cutoff = probs_sort.gather(-1, top_p_count) - elements_to_discard = probs < top_p_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - return logits - - -def topk_topp_forward_native( - self, - logits: torch.Tensor, - generators: dict[int, torch.Generator], - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], -) -> torch.Tensor: - """ - PyTorch-native implementation of top-k and top-p sampling. - - The logits tensor may be updated in-place. - """ - logits = apply_top_k_top_p(logits, k, p) - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) - - -if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION: - TopKTopPSampler.forward_native = topk_topp_forward_native diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py new file mode 100644 index 0000000..862bd03 --- /dev/null +++ b/vllm_ascend/sample/sampler.py @@ -0,0 +1,62 @@ +import torch +import torch_npu +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample +from vllm.v1.sample.sampler import Sampler + + +class AscendSampler(Sampler): + + def __init__(self, logprobs_mode="raw_logprobs"): + # TODO: support logprobs_mode in vllm-ascend + super().__init__(logprobs_mode=logprobs_mode) + self.topk_topp_sampler = AscendTopKTopPSampler() + + +class AscendTopKTopPSampler(TopKTopPSampler): + + def _apply_top_k_top_p( + self, + logits: torch.Tensor, + k: torch.Tensor, + p: torch.Tensor, + ) -> torch.Tensor: + if p is not None and k is not None: + # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) + return torch_npu.npu_top_k_top_p(logits, p, k) + + if p is None and k is None: + return logits + + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + + if k is not None: + top_k_count = probs_sort.size(1) - k.to( + torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + if p is not None: + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + return logits + + def forward_native(self, logits, generators, k, p): + """Override pytorch native implementation to torch_npu""" + logits = self._apply_top_k_top_p(logits, k, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 2bee8dd..886ccb8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -64,7 +64,6 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -72,6 +71,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder @@ -165,7 +165,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.device = device self.dtype = self.model_config.dtype - self.sampler = Sampler() + if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION: + # TODO: drop the env config to use ascend sampler by default + from vllm_ascend.sample.sampler import AscendSampler + + self.sampler = AscendSampler() + else: + from vllm.v1.sample.sampler import Sampler + + self.sampler = Sampler() # Lazy initialization, these will be set after __init__ self.kv_caches: List[torch.Tensor] = []