[perf][bugfix] improve performance of rejection sampler and eliminate HD synchronize in TopKTopPSampler (#4154)
### What this PR does / why we need it?
1. Use optimized apply_top_k_top_p for NPU platfrom in rejection
sampler; (avoid scatter elements which can reduce ~26ms TPOT with bs=24
per DP)
2. <del>Avoid D2H Synchronization before calling npu_top_k_top_p
introduced by parameter validation which improves inference speed with
`async_scheduling` enabled;</del> In order to elminate the D2H
synchronization introduced by parameter validation before calling
`npu_top_k_top_p`, we directly drop this fused operator since the
performance improvement is not significant compared to async_scheduling
and may bring potential accuracy problem.
3. Refactor the implementation of AscendTopKTopPSampler to align that of
vLLM.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
E2E serving test with combinations of `k=500` and `p=0.95` with
async_scheduling in single node and wide-EP scenarios.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -1,7 +1,3 @@
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler
|
||||
|
||||
@@ -13,23 +9,3 @@ class TestAscendSampler(TestBase):
|
||||
self.assertEqual(sampler.logprobs_mode, "raw_logprobs")
|
||||
self.assertTrue(hasattr(sampler, 'topk_topp_sampler'))
|
||||
self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler)
|
||||
|
||||
|
||||
class TestAscendTopKTopPSampler(TestBase):
|
||||
|
||||
@mock.patch("vllm_ascend.sample.sampler.random_sample")
|
||||
@mock.patch("torch_npu.npu_top_k_top_p")
|
||||
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op,
|
||||
mock_random_sample):
|
||||
mock_npu_op.return_value = (torch.randn(1, 3))
|
||||
mock_random_sample.return_value = torch.randn(3)
|
||||
sampler = AscendTopKTopPSampler()
|
||||
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
k = torch.tensor([2])
|
||||
p = torch.tensor([0.9])
|
||||
generators = {0: torch.Generator()}
|
||||
generators[0].manual_seed(42)
|
||||
|
||||
sampler.forward_native(logits, generators, k, p)
|
||||
mock_npu_op.assert_called_once_with(logits, p, k)
|
||||
|
||||
@@ -2,13 +2,11 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.rejection_sampler import generate_uniform_probs
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
from vllm_ascend.sample.sampler import apply_top_k_top_p
|
||||
|
||||
PLACEHOLDER_TOKEN_ID = -1
|
||||
GREEDY_TEMPERATURE = -1
|
||||
@@ -80,11 +78,6 @@ def apply_sampling_constraints(
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
if get_ascend_device_type(
|
||||
) != AscendDeviceType._310P and top_p is not None and top_k is not None and 1 <= int(
|
||||
top_k.max()) <= 1024:
|
||||
return torch_npu.npu_top_k_top_p(logits, top_p.to(logits.dtype), top_k)
|
||||
else:
|
||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||
# which is slow for large vocab sizes. This may cause performance issues.
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
global_stream, npu_stream_switch)
|
||||
from vllm_ascend.utils import global_stream, npu_stream_switch
|
||||
|
||||
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
|
||||
|
||||
@@ -65,25 +63,38 @@ class AscendSampler(Sampler):
|
||||
|
||||
class AscendTopKTopPSampler(TopKTopPSampler):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_top_k_top_p = apply_top_k_top_p
|
||||
|
||||
def set_q_event(self, q, event):
|
||||
# Pass in async exponential results.
|
||||
# Also pass in event to prevent synchronize errors.
|
||||
self.q = q
|
||||
self.async_event = event
|
||||
|
||||
def _apply_top_k_top_p(
|
||||
self,
|
||||
def forward_native(self, logits, generators, k, p):
|
||||
"""Override pytorch native implementation to torch_npu"""
|
||||
logits = self.apply_top_k_top_p(logits, k, p)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
if get_ascend_config().enable_async_exponential == 1:
|
||||
# Add synchronize to prevent synchronize error.
|
||||
self.async_event.synchronize()
|
||||
return probs.div_(self.q).argmax(dim=-1).view(-1), logits_to_return
|
||||
return random_sample(probs, generators), logits_to_return
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P
|
||||
if get_ascend_device_type(
|
||||
) != AscendDeviceType._310P and p is not None and k is not None and 1 <= int(
|
||||
k.max()) <= 1024:
|
||||
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
|
||||
return torch_npu.npu_top_k_top_p(logits, p, k)
|
||||
|
||||
if p is None and k is None:
|
||||
return logits
|
||||
|
||||
@@ -91,8 +102,7 @@ class AscendTopKTopPSampler(TopKTopPSampler):
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
top_k_count = probs_sort.size(1) - k.to(
|
||||
torch.long) # shape: (batch, )
|
||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||
|
||||
@@ -114,19 +124,3 @@ class AscendTopKTopPSampler(TopKTopPSampler):
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
return logits
|
||||
|
||||
def forward_native(self, logits, generators, k, p):
|
||||
"""Override pytorch native implementation to torch_npu"""
|
||||
logits = self._apply_top_k_top_p(logits, k, p)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
if get_ascend_config().enable_async_exponential == 1:
|
||||
# Add synchronize to prevent synchronize error.
|
||||
self.async_event.synchronize()
|
||||
return probs.div_(self.q).argmax(dim=-1).view(-1), logits_to_return
|
||||
return random_sample(probs, generators), logits_to_return
|
||||
|
||||
Reference in New Issue
Block a user