[Refactor]Refactor sampler (#2050)
Refactor Sampler implementation from patch way to inherit from vLLM
Sampler interface.
Next step: Make the op `TopKTopPSampler` in vLLM support custom ops
register mechanism
- vLLM version: v0.10.0
- vLLM main:
61a6905ab0
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -1,46 +0,0 @@
|
|||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import os
|
|
||||||
from unittest import mock
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from vllm.v1.sample.ops import topk_topp_sampler
|
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
|
||||||
|
|
||||||
|
|
||||||
class TestTopKTopPSamplerOptimize(TestBase):
|
|
||||||
|
|
||||||
@mock.patch.dict(os.environ,
|
|
||||||
{"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION": "1"})
|
|
||||||
@mock.patch("torch_npu.npu_top_k_top_p")
|
|
||||||
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
|
|
||||||
# We have to patch and reload because the patch will take effect
|
|
||||||
# only after VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE is set.
|
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_sampler
|
|
||||||
importlib.reload(vllm_ascend.patch.worker.patch_common.patch_sampler)
|
|
||||||
|
|
||||||
mock_npu_op.return_value = (torch.randn(1, 3))
|
|
||||||
sampler = topk_topp_sampler.TopKTopPSampler()
|
|
||||||
|
|
||||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
|
||||||
k = torch.tensor([2])
|
|
||||||
p = torch.tensor([0.9])
|
|
||||||
generators = {0: torch.Generator()}
|
|
||||||
generators[0].manual_seed(42)
|
|
||||||
|
|
||||||
sampler.forward_native(logits, generators, k, p)
|
|
||||||
mock_npu_op.assert_called_once_with(logits, p, k)
|
|
||||||
32
tests/ut/sample/test_sampler.py
Normal file
32
tests/ut/sample/test_sampler.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendSampler(TestBase):
|
||||||
|
|
||||||
|
def test_init_with_raw_logprobs(self):
|
||||||
|
sampler = AscendSampler(logprobs_mode="raw_logprobs")
|
||||||
|
self.assertEqual(sampler.logprobs_mode, "raw_logprobs")
|
||||||
|
self.assertTrue(hasattr(sampler, 'topk_topp_sampler'))
|
||||||
|
self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendTopKTopPSampler(TestBase):
|
||||||
|
|
||||||
|
@mock.patch("torch_npu.npu_top_k_top_p")
|
||||||
|
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
|
||||||
|
mock_npu_op.return_value = (torch.randn(1, 3))
|
||||||
|
sampler = AscendTopKTopPSampler()
|
||||||
|
|
||||||
|
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||||
|
k = torch.tensor([2])
|
||||||
|
p = torch.tensor([0.9])
|
||||||
|
generators = {0: torch.Generator()}
|
||||||
|
generators[0].manual_seed(42)
|
||||||
|
|
||||||
|
sampler.forward_native(logits, generators, k, p)
|
||||||
|
mock_npu_op.assert_called_once_with(logits, p, k)
|
||||||
@@ -128,11 +128,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
|
"VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
|
||||||
lambda: int(
|
lambda: int(
|
||||||
os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
|
os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
|
||||||
# Whether to enable the topk optimization. It's disabled by default for experimental support
|
# Whether to enable the topk optimization. It's enabled by default. Please set to False if you hit any issue.
|
||||||
# We'll make it enabled by default in the future.
|
# We'll remove this flag in the future once it's stable enough.
|
||||||
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
|
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
|
||||||
lambda: bool(
|
lambda: bool(
|
||||||
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
|
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '1'))),
|
||||||
|
|
||||||
# `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is
|
# `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is
|
||||||
# used for llmdatadist to build the communication topology for kv cache transfer, it is
|
# used for llmdatadist to build the communication topology for kv cache transfer, it is
|
||||||
|
|||||||
@@ -88,21 +88,7 @@
|
|||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Remove this patch once pytorch 2.7.0 is supported for vllm ascend.
|
# Remove this patch once pytorch 2.7.0 is supported for vllm ascend.
|
||||||
#
|
#
|
||||||
# ** File: worker/patch_common/patch_sampler.py **
|
# ** File: worker/patch_0_10_0/patch_sampler_gather_logprobs.py **
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p`
|
|
||||||
# Why:
|
|
||||||
# We need to use the patched `apply_top_k_top_p` in `sample`.
|
|
||||||
# The mainly reason to overwrite `apply_top_k_top_p` is
|
|
||||||
# to improve performance.
|
|
||||||
# How:
|
|
||||||
# Re-implementation the `apply_top_k_top_p` function by pytorch
|
|
||||||
# Related PR (if no, explain why):
|
|
||||||
# - https://github.com/vllm-project/vllm-ascend/pull/1732
|
|
||||||
# Future Plan:
|
|
||||||
# Revert it when the ascend scatter performance improves.
|
|
||||||
#
|
|
||||||
# ** File: worker/patch_common/patch_sampler.py **
|
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# 1. `vllm.v1.sample.sampler.Sampler.gather_logprobs`
|
# 1. `vllm.v1.sample.sampler.Sampler.gather_logprobs`
|
||||||
# Why:
|
# Why:
|
||||||
|
|||||||
@@ -21,4 +21,3 @@ import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
|
|||||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
|
|
||||||
|
|||||||
@@ -1,83 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch_npu
|
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
|
|
||||||
|
|
||||||
from vllm_ascend import envs
|
|
||||||
|
|
||||||
|
|
||||||
def apply_top_k_top_p(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
p: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if p is not None and k is not None:
|
|
||||||
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
|
|
||||||
return torch_npu.npu_top_k_top_p(logits, p, k)
|
|
||||||
|
|
||||||
probs = logits.softmax(dim=-1)
|
|
||||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
|
||||||
|
|
||||||
if k is not None:
|
|
||||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
|
||||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
|
||||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
|
||||||
|
|
||||||
# Make sure the no top-k rows are no-op.
|
|
||||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
|
||||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
|
||||||
|
|
||||||
elements_to_discard = probs < top_k_cutoff
|
|
||||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
|
||||||
|
|
||||||
if p is not None:
|
|
||||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
|
||||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
|
||||||
top_p_mask[:, -1] = False # at least one
|
|
||||||
|
|
||||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
|
||||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
|
||||||
elements_to_discard = probs < top_p_cutoff
|
|
||||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def topk_topp_forward_native(
|
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
|
||||||
generators: dict[int, torch.Generator],
|
|
||||||
k: Optional[torch.Tensor],
|
|
||||||
p: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
PyTorch-native implementation of top-k and top-p sampling.
|
|
||||||
|
|
||||||
The logits tensor may be updated in-place.
|
|
||||||
"""
|
|
||||||
logits = apply_top_k_top_p(logits, k, p)
|
|
||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
|
||||||
return random_sample(probs, generators)
|
|
||||||
|
|
||||||
|
|
||||||
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
|
|
||||||
TopKTopPSampler.forward_native = topk_topp_forward_native
|
|
||||||
62
vllm_ascend/sample/sampler.py
Normal file
62
vllm_ascend/sample/sampler.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
|
||||||
|
from vllm.v1.sample.sampler import Sampler
|
||||||
|
|
||||||
|
|
||||||
|
class AscendSampler(Sampler):
|
||||||
|
|
||||||
|
def __init__(self, logprobs_mode="raw_logprobs"):
|
||||||
|
# TODO: support logprobs_mode in vllm-ascend
|
||||||
|
super().__init__(logprobs_mode=logprobs_mode)
|
||||||
|
self.topk_topp_sampler = AscendTopKTopPSampler()
|
||||||
|
|
||||||
|
|
||||||
|
class AscendTopKTopPSampler(TopKTopPSampler):
|
||||||
|
|
||||||
|
def _apply_top_k_top_p(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
p: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if p is not None and k is not None:
|
||||||
|
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
|
||||||
|
return torch_npu.npu_top_k_top_p(logits, p, k)
|
||||||
|
|
||||||
|
if p is None and k is None:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
probs = logits.softmax(dim=-1)
|
||||||
|
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||||
|
|
||||||
|
if k is not None:
|
||||||
|
top_k_count = probs_sort.size(1) - k.to(
|
||||||
|
torch.long) # shape: (batch, )
|
||||||
|
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||||
|
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||||
|
|
||||||
|
# Make sure the no top-k rows are no-op.
|
||||||
|
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||||
|
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||||
|
|
||||||
|
elements_to_discard = probs < top_k_cutoff
|
||||||
|
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||||
|
|
||||||
|
if p is not None:
|
||||||
|
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||||
|
top_p_mask[:, -1] = False # at least one
|
||||||
|
|
||||||
|
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||||
|
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||||
|
elements_to_discard = probs < top_p_cutoff
|
||||||
|
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def forward_native(self, logits, generators, k, p):
|
||||||
|
"""Override pytorch native implementation to torch_npu"""
|
||||||
|
logits = self._apply_top_k_top_p(logits, k, p)
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
|
return random_sample(probs, generators)
|
||||||
@@ -64,7 +64,6 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
|||||||
ModelRunnerOutput)
|
ModelRunnerOutput)
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.sampler import Sampler
|
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
@@ -72,6 +71,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
|
|||||||
sanity_check_mm_encoder_outputs,
|
sanity_check_mm_encoder_outputs,
|
||||||
scatter_mm_placeholders)
|
scatter_mm_placeholders)
|
||||||
|
|
||||||
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
@@ -165,7 +165,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
self.device = device
|
self.device = device
|
||||||
self.dtype = self.model_config.dtype
|
self.dtype = self.model_config.dtype
|
||||||
self.sampler = Sampler()
|
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
|
||||||
|
# TODO: drop the env config to use ascend sampler by default
|
||||||
|
from vllm_ascend.sample.sampler import AscendSampler
|
||||||
|
|
||||||
|
self.sampler = AscendSampler()
|
||||||
|
else:
|
||||||
|
from vllm.v1.sample.sampler import Sampler
|
||||||
|
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
# Lazy initialization, these will be set after __init__
|
# Lazy initialization, these will be set after __init__
|
||||||
self.kv_caches: List[torch.Tensor] = []
|
self.kv_caches: List[torch.Tensor] = []
|
||||||
|
|||||||
Reference in New Issue
Block a user