From d13fb0766e1805089499ee3dfb0329949d1905ac Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee <132029610+Pr0Wh1teGivee@users.noreply.github.com> Date: Fri, 11 Jul 2025 15:32:02 +0800 Subject: [PATCH] [Perf] add patch to optimize apply_topk_topp (#1732) ### What this PR does / why we need it? Performance optimization for apply_top_k_top_p ### Does this PR introduce _any_ user-facing change? Use VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION to enable this feature ### How was this patch tested? e2e & ut - vLLM version: v0.9.2 - vLLM main: https://github.com/vllm-project/vllm/commit/6a9e6b2abf88181f93a1959fe16291c3f1696329 Signed-off-by: Pr0Wh1teGivee --- .../test_offline_inference_distributed.py | 22 ++++ .../e2e/singlecard/test_offline_inference.py | 24 ++++ tests/e2e/singlecard/test_sampler.py | 109 ++++++++++++++++++ .../worker/patch_common/test_patch_sampler.py | 46 ++++++++ vllm_ascend/envs.py | 5 + vllm_ascend/patch/__init__.py | 14 +++ .../patch/worker/patch_common/__init__.py | 1 + .../worker/patch_common/patch_sampler.py | 83 +++++++++++++ 8 files changed, 304 insertions(+) create mode 100644 tests/e2e/singlecard/test_sampler.py create mode 100644 tests/ut/patch/worker/patch_common/test_patch_sampler.py create mode 100644 vllm_ascend/patch/worker/patch_common/patch_sampler.py diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 6ccdbd7..58d0bf0 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -145,3 +145,25 @@ def test_models_distributed_pangu(): distributed_executor_backend="mp", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + + +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION": "1"}) +def test_models_distributed_topk() -> None: + example_prompts = [ + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", + "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", + "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", + ] + dtype = "half" + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner( + "deepseek-ai/DeepSeek-V2-Lite", + dtype=dtype, + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) diff --git a/tests/e2e/singlecard/test_offline_inference.py b/tests/e2e/singlecard/test_offline_inference.py index f6f9b04..26acb94 100644 --- a/tests/e2e/singlecard/test_offline_inference.py +++ b/tests/e2e/singlecard/test_offline_inference.py @@ -21,9 +21,12 @@ Run `pytest tests/test_offline_inference.py`. """ import os +from unittest.mock import patch import pytest +import vllm # noqa: F401 from modelscope import snapshot_download # type: ignore[import-untyped] +from vllm import SamplingParams from vllm.assets.image import ImageAsset import vllm_ascend # noqa: F401 @@ -103,3 +106,24 @@ def test_multimodal(model, prompt_template, vllm_runner): vllm_model.generate_greedy(prompts=prompts, images=images, max_tokens=64) + + +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION": "1"}) +def test_models_topk() -> None: + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner("Qwen/Qwen2.5-0.5B-Instruct", + max_model_len=8192, + dtype="float16", + enforce_eager=True, + gpu_memory_utilization=0.7) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) diff --git a/tests/e2e/singlecard/test_sampler.py b/tests/e2e/singlecard/test_sampler.py new file mode 100644 index 0000000..93b999d --- /dev/null +++ b/tests/e2e/singlecard/test_sampler.py @@ -0,0 +1,109 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Optional + +import torch + +# Set tolerance to 1 for quant ops +DEFAULT_ATOL = 1e-3 +DEFAULT_RTOL = 1e-3 + + +def apply_top_k_top_p( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """Apply top-k and top-p masks to the logits. + + If a top-p is used, this function will sort the logits tensor, + which can be slow for large batches. + + The logits tensor may be updated in-place. + """ + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + if k is not None: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if p is not None: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits + + +def apply_top_k_top_p_new( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + batch_size, vocab_size = logits.shape + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + # Apply top-k. + boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1)) + top_k_mask = logits_sort < boundary + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if p is not None: + # Apply top-p. + cutoff = top_k_mask.sum(dim=-1).min() + probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:] + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = True + strides = torch.arange(0, + batch_size * vocab_size, + vocab_size, + device=logits.device) + flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1) + valid_idx = torch.masked_select(flatten_idx, top_p_mask) + logits_flatten = logits.flatten() + valid_logits = torch.index_select(logits_flatten, 0, valid_idx) + logits = torch.empty_like(logits_flatten).fill_(-float("inf")) + logits[valid_idx] = valid_logits + return logits.reshape(batch_size, vocab_size) + + +# test with leading dimension and merge seqlen and batch_size as num_tokens +@torch.inference_mode() +def test_apply_top_k_top_p() -> None: + logits = torch.randn((128, 7168)).npu() + k = torch.Tensor([-1]).int().npu() + p = torch.Tensor([1]).int().npu() + logits_new = apply_top_k_top_p_new(logits, k, p) + logits_old = apply_top_k_top_p(logits, k, p) + # Compare the results. + torch.testing.assert_close(logits_new, + logits_old, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) diff --git a/tests/ut/patch/worker/patch_common/test_patch_sampler.py b/tests/ut/patch/worker/patch_common/test_patch_sampler.py new file mode 100644 index 0000000..fc9fbd1 --- /dev/null +++ b/tests/ut/patch/worker/patch_common/test_patch_sampler.py @@ -0,0 +1,46 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +import importlib +import os +from unittest import mock + +import torch +from vllm.v1.sample.ops import topk_topp_sampler + +from tests.ut.base import TestBase + + +class TestTopKTopPSamplerOptimize(TestBase): + + @mock.patch.dict(os.environ, + {"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION": "1"}) + @mock.patch("torch_npu.npu_top_k_top_p") + def test_npu_topk_topp_called_when_optimized(self, mock_npu_op): + # We have to patch and reload because the patch will take effect + # only after VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE is set. + import vllm_ascend.patch.worker.patch_common.patch_sampler + importlib.reload(vllm_ascend.patch.worker.patch_common.patch_sampler) + + mock_npu_op.return_value = (torch.randn(1, 3)) + sampler = topk_topp_sampler.TopKTopPSampler() + + logits = torch.tensor([[1.0, 2.0, 3.0]]) + k = torch.tensor([2]) + p = torch.tensor([0.9]) + generators = {0: torch.Generator()} + generators[0].manual_seed(42) + + sampler.forward_native(logits, generators, k, p) + mock_npu_op.assert_called_once_with(logits, p, k) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index df25840..5ea6aa9 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -128,6 +128,11 @@ env_variables: Dict[str, Callable[[], Any]] = { "VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE": lambda: int( os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)), + # Whether to enable the topk optimization. It's disabled by default for experimental support + # We'll make it enabled by default in the future. + "VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION": + lambda: bool( + int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))), } # end-env-vars-definition diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index b054fc6..391e41d 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -129,3 +129,17 @@ # This is the problem in vllm-ascend # Future Plan: # Remove this patch once pytorch 2.7.0 is supported for vllm ascend. +# +# ** File: worker/patch_common/patch_sampler.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p` +# Why: +# We need to use the patched `apply_top_k_top_p` in `sample`. +# The mainly reason to overwrite `apply_top_k_top_p` is +# to improve performance. +# How: +# Re-implementation the `apply_top_k_top_p` function by pytorch +# Related PR (if no, explain why): +# - https://github.com/vllm-project/vllm-ascend/pull/1732 +# Future Plan: +# Revert it when the ascend scatter performance improves. diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index 7617809..d78b6dc 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -21,4 +21,5 @@ import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa +import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_sampler.py b/vllm_ascend/patch/worker/patch_common/patch_sampler.py new file mode 100644 index 0000000..e745bf0 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_sampler.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +import torch +import torch_npu +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample + +from vllm_ascend import envs + + +def apply_top_k_top_p( + logits: torch.Tensor, + k: torch.Tensor, + p: torch.Tensor, +) -> torch.Tensor: + if p is not None and k is not None: + # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) + return torch_npu.npu_top_k_top_p(logits, p, k) + + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + + if k is not None: + top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + if p is not None: + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + return logits + + +def topk_topp_forward_native( + self, + logits: torch.Tensor, + generators: dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """ + PyTorch-native implementation of top-k and top-p sampling. + + The logits tensor may be updated in-place. + """ + logits = apply_top_k_top_p(logits, k, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators) + + +if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION: + TopKTopPSampler.forward_native = topk_topp_forward_native