[Refactor]Refactor sampler (#2050)

Refactor Sampler implementation from patch way to inherit from vLLM
Sampler interface.

Next step: Make the op `TopKTopPSampler` in vLLM support custom ops
register mechanism

- vLLM version: v0.10.0
- vLLM main:
61a6905ab0

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-07-30 08:47:22 +08:00
committed by GitHub
parent b6a7f07c70
commit 9b67c87b14
8 changed files with 108 additions and 150 deletions

View File

@@ -1,46 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
import importlib
import os
from unittest import mock
import torch
from vllm.v1.sample.ops import topk_topp_sampler
from tests.ut.base import TestBase
class TestTopKTopPSamplerOptimize(TestBase):
@mock.patch.dict(os.environ,
{"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION": "1"})
@mock.patch("torch_npu.npu_top_k_top_p")
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
# We have to patch and reload because the patch will take effect
# only after VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE is set.
import vllm_ascend.patch.worker.patch_common.patch_sampler
importlib.reload(vllm_ascend.patch.worker.patch_common.patch_sampler)
mock_npu_op.return_value = (torch.randn(1, 3))
sampler = topk_topp_sampler.TopKTopPSampler()
logits = torch.tensor([[1.0, 2.0, 3.0]])
k = torch.tensor([2])
p = torch.tensor([0.9])
generators = {0: torch.Generator()}
generators[0].manual_seed(42)
sampler.forward_native(logits, generators, k, p)
mock_npu_op.assert_called_once_with(logits, p, k)

View File

@@ -0,0 +1,32 @@
from unittest import mock
import torch
from tests.ut.base import TestBase
from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler
class TestAscendSampler(TestBase):
def test_init_with_raw_logprobs(self):
sampler = AscendSampler(logprobs_mode="raw_logprobs")
self.assertEqual(sampler.logprobs_mode, "raw_logprobs")
self.assertTrue(hasattr(sampler, 'topk_topp_sampler'))
self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler)
class TestAscendTopKTopPSampler(TestBase):
@mock.patch("torch_npu.npu_top_k_top_p")
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
mock_npu_op.return_value = (torch.randn(1, 3))
sampler = AscendTopKTopPSampler()
logits = torch.tensor([[1.0, 2.0, 3.0]])
k = torch.tensor([2])
p = torch.tensor([0.9])
generators = {0: torch.Generator()}
generators[0].manual_seed(42)
sampler.forward_native(logits, generators, k, p)
mock_npu_op.assert_called_once_with(logits, p, k)

View File

@@ -128,11 +128,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
lambda: int(
os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
# Whether to enable the topk optimization. It's disabled by default for experimental support
# We'll make it enabled by default in the future.
# Whether to enable the topk optimization. It's enabled by default. Please set to False if you hit any issue.
# We'll remove this flag in the future once it's stable enough.
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '1'))),
# `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is
# used for llmdatadist to build the communication topology for kv cache transfer, it is

View File

@@ -88,21 +88,7 @@
# Future Plan:
# Remove this patch once pytorch 2.7.0 is supported for vllm ascend.
#
# ** File: worker/patch_common/patch_sampler.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p`
# Why:
# We need to use the patched `apply_top_k_top_p` in `sample`.
# The mainly reason to overwrite `apply_top_k_top_p` is
# to improve performance.
# How
# Re-implementation the `apply_top_k_top_p` function by pytorch
# Related PR (if no, explain why):
# - https://github.com/vllm-project/vllm-ascend/pull/1732
# Future Plan:
# Revert it when the ascend scatter performance improves.
#
# ** File: worker/patch_common/patch_sampler.py **
# ** File: worker/patch_0_10_0/patch_sampler_gather_logprobs.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.sample.sampler.Sampler.gather_logprobs`
# Why:

View File

@@ -21,4 +21,3 @@ import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa

View File

@@ -1,83 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
import torch
import torch_npu
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
from vllm_ascend import envs
def apply_top_k_top_p(
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
if p is not None and k is not None:
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
return torch_npu.npu_top_k_top_p(logits, p, k)
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
if k is not None:
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
top_k_count = top_k_count.unsqueeze(dim=1)
top_k_cutoff = probs_sort.gather(-1, top_k_count)
# Make sure the no top-k rows are no-op.
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
elements_to_discard = probs < top_k_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
if p is not None:
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
return logits
def topk_topp_forward_native(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
TopKTopPSampler.forward_native = topk_topp_forward_native

View File

@@ -0,0 +1,62 @@
import torch
import torch_npu
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
from vllm.v1.sample.sampler import Sampler
class AscendSampler(Sampler):
def __init__(self, logprobs_mode="raw_logprobs"):
# TODO: support logprobs_mode in vllm-ascend
super().__init__(logprobs_mode=logprobs_mode)
self.topk_topp_sampler = AscendTopKTopPSampler()
class AscendTopKTopPSampler(TopKTopPSampler):
def _apply_top_k_top_p(
self,
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
if p is not None and k is not None:
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
return torch_npu.npu_top_k_top_p(logits, p, k)
if p is None and k is None:
return logits
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
if k is not None:
top_k_count = probs_sort.size(1) - k.to(
torch.long) # shape: (batch, )
top_k_count = top_k_count.unsqueeze(dim=1)
top_k_cutoff = probs_sort.gather(-1, top_k_count)
# Make sure the no top-k rows are no-op.
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
elements_to_discard = probs < top_k_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
if p is not None:
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
return logits
def forward_native(self, logits, generators, k, p):
"""Override pytorch native implementation to torch_npu"""
logits = self._apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)

View File

@@ -64,7 +64,6 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -72,6 +71,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
@@ -165,7 +165,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.device = device
self.dtype = self.model_config.dtype
self.sampler = Sampler()
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
# TODO: drop the env config to use ascend sampler by default
from vllm_ascend.sample.sampler import AscendSampler
self.sampler = AscendSampler()
else:
from vllm.v1.sample.sampler import Sampler
self.sampler = Sampler()
# Lazy initialization, these will be set after __init__
self.kv_caches: List[torch.Tensor] = []