Optimize some rejectsampler functions to make npu op launch non-blocking (#4587)
### What this PR does / why we need it?
- Vetorize the loop (but change not output) in some rejectsampler
functions include: `expand_pytorch`, `sample_recovered_tokens_pytorch`,
`rejection_random_sample_pytorch`, `sample_recovered_tokens`.
- Remove synchronize-launch torchnpu operator in them to accelerate
sampling + MTP postprocess.
### Does this PR introduce _any_ user-facing change?
- No
### How was this patch tested?
- We tested this change with the serve&bench command:
```
===== serve =====
vllm serve $LOCAL_CKPT_DIR \
--host 0.0.0.0 \
--port 8000 \
--data-parallel-size 4 \
--data-parallel-size-local 2 \
--data-parallel-address $MASTER_NODE_IP \
--data-parallel-start-rank $((2*VC_TASK_INDEX)) \
--data-parallel-rpc-port 13387 \
--tensor-parallel-size 8 \
--seed 1024 \
--enable-expert-parallel \
--served-model-name $NAME \
--max-model-len 4096 \
--max-num-seqs 16 \
--trust-remote-code \
--gpu-memory-utilization 0.90 \
$headless \
--speculative_config '{"method": "deepseek_mtp", "num_speculative_tokens": 1}' \
--additional-config '{"ascend_scheduler_config":{"enabled":false, "enable_chunked_prefill":true, "chunked_prefill_enabled":true}}'
==== bench =====
vllm bench serve --model $LOCAL_CKPT_DIR --served-model-name DeepseekV3ForCausalLM \
--dataset-name spec_bench --spec-bench-output-len 2048 \
--dataset-path question.jsonl \
--top-p 1.0 --temperature 0.8 \
--ignore-eos \
--num-prompts 64 --trust-remote-code --base-url "http://0.0.0.0:8000" --request-rate 64
```
- In this case, our rj optimization can reduce TPOT from 84.94ms to
64.61ms, about 23% gain.
## before
<img width="1068" height="830" alt="image"
src="https://github.com/user-attachments/assets/278ac878-b49d-4588-b87c-316ca4d537f5"
/>
## after
<img width="781" height="756" alt="image"
src="https://github.com/user-attachments/assets/0c6d37ad-ed77-40b3-a1be-4933c468365c"
/>
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: ZongYuan Zhan <zhanzy178@gmail.com>
Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com>
This commit is contained in:
@@ -27,8 +27,22 @@ GREEDY_TEMPERATURE = 0.0
|
||||
MAX_SPEC_LEN = 8 # Used as MAX_NUM_TOKENS in expand_batch_to_tokens
|
||||
|
||||
|
||||
def mock_pin_memory(original_func):
|
||||
|
||||
def func_wo_pin_memory(*args, **kwargs):
|
||||
if kwargs.get('pin_memory', False):
|
||||
kwargs['pin_memory'] = False
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
return func_wo_pin_memory
|
||||
|
||||
|
||||
class TestAscendRejectionSampler(TestBase):
|
||||
|
||||
@patch('torch.arange', new=mock_pin_memory(torch.arange))
|
||||
@patch('torch.ones', new=mock_pin_memory(torch.ones))
|
||||
@patch('torch.full', new=mock_pin_memory(torch.full))
|
||||
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
|
||||
def test_rejection_greedy_sample_pytorch(self):
|
||||
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
|
||||
batch_size = 2
|
||||
@@ -60,6 +74,10 @@ class TestAscendRejectionSampler(TestBase):
|
||||
assert output_token_ids[1, 0].item() == 20
|
||||
assert output_token_ids[1, 2].item() == PLACEHOLDER_TOKEN_ID
|
||||
|
||||
@patch('torch.arange', new=mock_pin_memory(torch.arange))
|
||||
@patch('torch.ones', new=mock_pin_memory(torch.ones))
|
||||
@patch('torch.full', new=mock_pin_memory(torch.full))
|
||||
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
|
||||
def test_rejection_random_sample_pytorch(self):
|
||||
"""Test random rejection sampling: accept based on uniform probability"""
|
||||
batch_size = 2
|
||||
@@ -104,6 +122,10 @@ class TestAscendRejectionSampler(TestBase):
|
||||
assert output_token_ids[0, 1].item() == 0
|
||||
assert output_token_ids[0, 2].item() == 100
|
||||
|
||||
@patch('torch.arange', new=mock_pin_memory(torch.arange))
|
||||
@patch('torch.ones', new=mock_pin_memory(torch.ones))
|
||||
@patch('torch.full', new=mock_pin_memory(torch.full))
|
||||
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
|
||||
def test_expand_pytorch(self):
|
||||
"""Test expand_pytorch functionality"""
|
||||
input_ptr = torch.tensor([10, 20, 30], dtype=torch.int32)
|
||||
@@ -122,6 +144,10 @@ class TestAscendRejectionSampler(TestBase):
|
||||
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
|
||||
assert torch.equal(output_ptr, expected)
|
||||
|
||||
@patch('torch.arange', new=mock_pin_memory(torch.arange))
|
||||
@patch('torch.ones', new=mock_pin_memory(torch.ones))
|
||||
@patch('torch.full', new=mock_pin_memory(torch.full))
|
||||
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
|
||||
def test_expand_batch_to_tokens(self):
|
||||
"""Test expand_batch_to_tokens wrapper"""
|
||||
x = torch.tensor([10, 20, 30])
|
||||
@@ -154,6 +180,10 @@ class TestAscendRejectionSampler(TestBase):
|
||||
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
@patch('torch.arange', new=mock_pin_memory(torch.arange))
|
||||
@patch('torch.ones', new=mock_pin_memory(torch.ones))
|
||||
@patch('torch.full', new=mock_pin_memory(torch.full))
|
||||
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
|
||||
def test_sample_recovered_tokens_pytorch_ngram(self):
|
||||
"""Test recovered token sampling under n-gram mode"""
|
||||
output_token_ids = torch.empty(2, dtype=torch.int32)
|
||||
@@ -184,6 +214,10 @@ class TestAscendRejectionSampler(TestBase):
|
||||
assert output_token_ids[0].item() == 0
|
||||
assert output_token_ids[1].item() == 1
|
||||
|
||||
@patch('torch.arange', new=mock_pin_memory(torch.arange))
|
||||
@patch('torch.ones', new=mock_pin_memory(torch.ones))
|
||||
@patch('torch.full', new=mock_pin_memory(torch.full))
|
||||
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
|
||||
def test_sample_recovered_tokens_pytorch_autoregressive(self):
|
||||
"""Test recovered token sampling for autoregressive models"""
|
||||
output_token_ids = torch.empty(2, dtype=torch.int32)
|
||||
|
||||
Reference in New Issue
Block a user