Optimize some rejectsampler functions to make npu op launch non-blocking (#4587)

### What this PR does / why we need it?
- Vetorize the loop (but change not output) in some rejectsampler
functions include: `expand_pytorch`, `sample_recovered_tokens_pytorch`,
`rejection_random_sample_pytorch`, `sample_recovered_tokens`.
- Remove synchronize-launch torchnpu operator in them to accelerate
sampling + MTP postprocess.

### Does this PR introduce _any_ user-facing change?
- No

### How was this patch tested?
- We tested this change with the serve&bench command:
```
===== serve =====
vllm serve $LOCAL_CKPT_DIR \
        --host 0.0.0.0 \
        --port 8000 \
        --data-parallel-size 4 \
        --data-parallel-size-local 2 \
        --data-parallel-address $MASTER_NODE_IP \
        --data-parallel-start-rank $((2*VC_TASK_INDEX)) \
        --data-parallel-rpc-port 13387 \
        --tensor-parallel-size 8 \
        --seed 1024 \
        --enable-expert-parallel \
        --served-model-name $NAME \
        --max-model-len 4096 \
        --max-num-seqs 16 \
        --trust-remote-code \
        --gpu-memory-utilization 0.90 \
        $headless \
	    --speculative_config '{"method": "deepseek_mtp", "num_speculative_tokens": 1}' \
        --additional-config '{"ascend_scheduler_config":{"enabled":false, "enable_chunked_prefill":true, "chunked_prefill_enabled":true}}' 

==== bench =====
vllm bench serve --model $LOCAL_CKPT_DIR  --served-model-name DeepseekV3ForCausalLM \
--dataset-name spec_bench --spec-bench-output-len 2048 \
--dataset-path question.jsonl \
--top-p 1.0 --temperature 0.8 \
--ignore-eos \
--num-prompts 64  --trust-remote-code --base-url "http://0.0.0.0:8000" --request-rate 64
```
- In this case, our rj optimization can reduce TPOT from 84.94ms to
64.61ms, about 23% gain.

## before
<img width="1068" height="830" alt="image"
src="https://github.com/user-attachments/assets/278ac878-b49d-4588-b87c-316ca4d537f5"
/>

## after
<img width="781" height="756" alt="image"
src="https://github.com/user-attachments/assets/0c6d37ad-ed77-40b3-a1be-4933c468365c"
/>

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: ZongYuan Zhan <zhanzy178@gmail.com>
Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com>
This commit is contained in:
ZongYuan Zhan
2025-12-29 14:10:39 +08:00
committed by GitHub
parent 3e67e8276c
commit d8e15dae6c
2 changed files with 260 additions and 75 deletions

View File

@@ -27,8 +27,22 @@ GREEDY_TEMPERATURE = 0.0
MAX_SPEC_LEN = 8 # Used as MAX_NUM_TOKENS in expand_batch_to_tokens
def mock_pin_memory(original_func):
def func_wo_pin_memory(*args, **kwargs):
if kwargs.get('pin_memory', False):
kwargs['pin_memory'] = False
return original_func(*args, **kwargs)
return func_wo_pin_memory
class TestAscendRejectionSampler(TestBase):
@patch('torch.arange', new=mock_pin_memory(torch.arange))
@patch('torch.ones', new=mock_pin_memory(torch.ones))
@patch('torch.full', new=mock_pin_memory(torch.full))
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
def test_rejection_greedy_sample_pytorch(self):
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
batch_size = 2
@@ -60,6 +74,10 @@ class TestAscendRejectionSampler(TestBase):
assert output_token_ids[1, 0].item() == 20
assert output_token_ids[1, 2].item() == PLACEHOLDER_TOKEN_ID
@patch('torch.arange', new=mock_pin_memory(torch.arange))
@patch('torch.ones', new=mock_pin_memory(torch.ones))
@patch('torch.full', new=mock_pin_memory(torch.full))
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
def test_rejection_random_sample_pytorch(self):
"""Test random rejection sampling: accept based on uniform probability"""
batch_size = 2
@@ -104,6 +122,10 @@ class TestAscendRejectionSampler(TestBase):
assert output_token_ids[0, 1].item() == 0
assert output_token_ids[0, 2].item() == 100
@patch('torch.arange', new=mock_pin_memory(torch.arange))
@patch('torch.ones', new=mock_pin_memory(torch.ones))
@patch('torch.full', new=mock_pin_memory(torch.full))
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
def test_expand_pytorch(self):
"""Test expand_pytorch functionality"""
input_ptr = torch.tensor([10, 20, 30], dtype=torch.int32)
@@ -122,6 +144,10 @@ class TestAscendRejectionSampler(TestBase):
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
assert torch.equal(output_ptr, expected)
@patch('torch.arange', new=mock_pin_memory(torch.arange))
@patch('torch.ones', new=mock_pin_memory(torch.ones))
@patch('torch.full', new=mock_pin_memory(torch.full))
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
def test_expand_batch_to_tokens(self):
"""Test expand_batch_to_tokens wrapper"""
x = torch.tensor([10, 20, 30])
@@ -154,6 +180,10 @@ class TestAscendRejectionSampler(TestBase):
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
assert torch.equal(result, expected)
@patch('torch.arange', new=mock_pin_memory(torch.arange))
@patch('torch.ones', new=mock_pin_memory(torch.ones))
@patch('torch.full', new=mock_pin_memory(torch.full))
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
def test_sample_recovered_tokens_pytorch_ngram(self):
"""Test recovered token sampling under n-gram mode"""
output_token_ids = torch.empty(2, dtype=torch.int32)
@@ -184,6 +214,10 @@ class TestAscendRejectionSampler(TestBase):
assert output_token_ids[0].item() == 0
assert output_token_ids[1].item() == 1
@patch('torch.arange', new=mock_pin_memory(torch.arange))
@patch('torch.ones', new=mock_pin_memory(torch.ones))
@patch('torch.full', new=mock_pin_memory(torch.full))
@patch('torch.tensor', new=mock_pin_memory(torch.tensor))
def test_sample_recovered_tokens_pytorch_autoregressive(self):
"""Test recovered token sampling for autoregressive models"""
output_token_ids = torch.empty(2, dtype=torch.int32)

View File

@@ -303,31 +303,32 @@ def expand_batch_to_tokens(
def sample_recovered_tokens(
max_spec_len: int,
num_draft_tokens: list[int],
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens]
draft_token_ids: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
device: torch.device,
) -> torch.Tensor:
# NOTE(woosuk): Create only one distribution for each request.
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
device=device,
)
q.exponential_()
num_draft_tensor = torch.tensor(num_draft_tokens,
pin_memory=True).to(device,
non_blocking=True)
has_draft_mask = num_draft_tensor > 0
for i, generator in sampling_metadata.generators.items():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)
temp_q = torch.empty_like(q[i])
temp_q.exponential_(generator=generator)
q[i] = torch.where(has_draft_mask[i], temp_q, q[i])
recovered_token_ids = torch.empty_like(draft_token_ids)
if HAS_TRITON:
@@ -459,45 +460,128 @@ def rejection_random_sample_pytorch(
vocab_size,
IS_NGRAM=False,
):
"""
This function implements the Speculative Decoding rejection sampling step.
Instead of looping through each request and each token (which causes high
overhead), it uses a fully vectorized approach:
1. **Index Mapping**: Converts the flattened 1D token arrays into a 2D
[batch_size, max_draft_len] grid using 'cu_num_draft_tokens' to handle
variable-length sequences in the batch.
2. **Parallel Validation**: Calculates the acceptance condition
(target_prob / draft_prob >= uniform_sample) for ALL draft tokens
simultaneously across the entire batch.
3. **Short-circuit Simulation**: In the loop version, once a token is rejected,
subsequent tokens are ignored. Here, we simulate this by finding the
'first_reject_pos' using argmax on the rejection mask and creating a
'should_skip' mask for all indices after the first failure.
4. **Token Selection**: Uses 'torch.where' to select:
- Draft tokens (if accepted)
- Recovered tokens (at the point of first rejection)
- Bonus tokens (if all tokens in a sequence were accepted)
5. **Masking**: Ensures operations only apply to non-greedy requests and
within valid sequence lengths.
"""
batch_size = output_token_ids.shape[0]
device = output_token_ids.device
for req_idx in range(batch_size):
if is_greedy[req_idx]:
continue
zero_cpu = torch.tensor([0], pin_memory=True)
zero_device = zero_cpu.to(device, non_blocking=True)
if req_idx == 0:
start_idx = 0
else:
start_idx = cu_num_draft_tokens[req_idx - 1].item()
end_idx = cu_num_draft_tokens[req_idx].item()
num_draft_tokens = end_idx - start_idx
cu_start = torch.cat([zero_device, cu_num_draft_tokens[:-1]])
cu_end = cu_num_draft_tokens
num_draft_per_batch = cu_end - cu_start
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = draft_token_ids[start_idx + pos].item()
max_draft_len = max_spec_len
pos_indices_cpu = torch.arange(max_draft_len, pin_memory=True)
pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :]
if IS_NGRAM:
draft_prob = 1.0
else:
draft_prob = draft_probs[start_idx + pos,
draft_token_id].item()
valid_mask = pos_indices < num_draft_per_batch[:, None]
global_token_indices = cu_start[:, None] + pos_indices
global_token_indices = global_token_indices.clamp(
0, draft_token_ids.shape[0] - 1)
draft_tokens = draft_token_ids[
global_token_indices] # [batch_size, max_draft_len]
target_prob = target_probs[start_idx + pos,
draft_token_id].item()
uniform_prob = uniform_probs[start_idx + pos].item()
if IS_NGRAM:
ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32)
draft_token_probs = ones_cpu.to(
device, non_blocking=True).expand_as(draft_tokens)
else:
flat_indices = global_token_indices.flatten()
flat_draft_tokens = draft_tokens.flatten()
flat_draft_probs = draft_probs[flat_indices, flat_draft_tokens]
draft_token_probs = flat_draft_probs.view(batch_size, max_draft_len)
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
token_id = draft_token_id
else:
rejected = True
token_id = recovered_token_ids[start_idx + pos].item()
flat_indices = global_token_indices.flatten()
flat_draft_tokens = draft_tokens.flatten()
flat_target_probs = target_probs[flat_indices, flat_draft_tokens]
target_token_probs = flat_target_probs.view(batch_size, max_draft_len)
output_token_ids[req_idx, pos] = token_id
uniform_token_probs = uniform_probs[global_token_indices]
recovered_tokens = recovered_token_ids[global_token_indices]
if not rejected:
bonus_token_id = bonus_token_ids[req_idx].item()
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
zero_threshold_cpu = torch.tensor([0.0],
pin_memory=True,
dtype=torch.float32)
zero_threshold = zero_threshold_cpu.to(device, non_blocking=True)
acceptance_condition = (draft_token_probs > zero_threshold) & (
target_token_probs / draft_token_probs >= uniform_token_probs)
first_rejection = (~acceptance_condition) & valid_mask
default_pos_cpu = torch.full([batch_size, 1],
max_draft_len,
pin_memory=True)
default_pos = default_pos_cpu.to(device, non_blocking=True)
first_reject_pos = torch.where(
first_rejection.any(dim=1, keepdim=True),
first_rejection.float().argmax(dim=1, keepdim=True), default_pos)
pos_mask = pos_indices >= first_reject_pos
should_skip = pos_mask & valid_mask
final_acceptance = acceptance_condition & (~should_skip)
non_greedy_mask = ~is_greedy
update_mask = non_greedy_mask[:, None] & valid_mask & (~should_skip)
first_reject_mask = (pos_indices == first_reject_pos
) & valid_mask & non_greedy_mask[:, None]
final_update_mask = update_mask | first_reject_mask
final_tokens = torch.where(
first_reject_mask, recovered_tokens,
torch.where(final_acceptance, draft_tokens,
output_token_ids[:, :max_draft_len]))
output_token_ids[:, :max_draft_len] = torch.where(
final_update_mask, final_tokens, output_token_ids[:, :max_draft_len])
no_rejection = first_reject_pos.squeeze(1) >= num_draft_per_batch
should_add_bonus = non_greedy_mask & no_rejection
bonus_positions = num_draft_per_batch # [batch_size]
seq_len = output_token_ids.shape[1]
all_positions_cpu = torch.arange(seq_len, pin_memory=True)
all_positions = all_positions_cpu.to(
device, non_blocking=True)[None, :] # [1, seq_len]
batch_bonus_positions = bonus_positions[:, None] # [batch_size, 1]
max_spec_len_cpu = torch.tensor([max_spec_len], pin_memory=True)
max_spec_len_device = max_spec_len_cpu.to(device, non_blocking=True)
valid_bonus_pos = bonus_positions < (max_spec_len_device + 1)
final_bonus_mask = should_add_bonus & valid_bonus_pos
bonus_pos_match = (all_positions == batch_bonus_positions)
bonus_pos_mask = bonus_pos_match & final_bonus_mask[:, None]
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, seq_len)
output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded,
output_token_ids)
def expand_pytorch(
@@ -508,21 +592,48 @@ def expand_pytorch(
replace_to,
MAX_NUM_TOKENS,
):
batch_size = len(input_ptr)
"""
This function broadcasts batch-level values (input_ptr) to token-level
positions (output_ptr) based on cumulative token offsets. It acts like
a "scatter" or "repeat_interleave" operation but with custom logic:
1. **Range Broadcasting**: It creates a boolean matrix 'in_range' of size
[num_tokens, batch_size] that identifies which batch index each token
belongs to by checking if the token index falls between cu_start and cu_end.
2. **Conditional Replacement**: Before expansion, it replaces specific values
(e.g., padding or special markers) in the input to prepare the data.
3. **Matrix-based Mapping**: It uses 'torch.einsum' to perform a weighted
sum that effectively "picks" the correct batch value for every token position
simultaneously, avoiding a Python loop over the batch.
"""
device = cu_num_tokens_ptr.device
batch_size = input_ptr.shape[0]
num_tokens = output_ptr.shape[0]
for req_idx in range(batch_size):
start_idx = 0 if req_idx == 0 else cu_num_tokens_ptr[req_idx - 1]
end_idx = cu_num_tokens_ptr[req_idx]
num_tokens = end_idx - start_idx
if batch_size == 0 or num_tokens == 0:
return
src_val = input_ptr[req_idx]
src_val = replace_to if src_val == replace_from else src_val
cu_start = torch.cat([
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
cu_num_tokens_ptr[:-1]
])
cu_end = cu_num_tokens_ptr
offset = torch.arange(MAX_NUM_TOKENS, device=num_tokens.device)
mask = offset < num_tokens
token_indices = torch.arange(num_tokens,
device=device)[:, None] # [num_tokens, 1]
cu_start_exp = cu_start[None, :] # [1, batch_size]
cu_end_exp = cu_end[None, :] # [1, batch_size]
output_slice = start_idx + offset[mask]
output_ptr[output_slice] = src_val
in_range = (token_indices >= cu_start_exp) & (token_indices < cu_end_exp)
replaced_input = torch.where(input_ptr == replace_from, replace_to,
input_ptr).float()
token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input)
needs_update = in_range.any(dim=1)
output_ptr[:] = torch.where(needs_update, token_values, output_ptr)
def sample_recovered_tokens_pytorch(
@@ -535,37 +646,77 @@ def sample_recovered_tokens_pytorch(
vocab_size,
IS_NGRAM=False,
):
batch_size = len(cu_num_draft_tokens)
"""
When a draft token is rejected, we must sample a "recovered" token from
a modified distribution. This function calculates that distribution across
the entire flattened batch.
1. **Token-to-Batch Mapping**: Using the cumulative draft token counts, it
determines which request in the batch each token belongs to. This is
necessary because 'q' (normalization factor) is stored per-request.
2. **Probability Adjustment**:
- If N-GRAM: It zeroes out the draft token's probability in the target.
- If Probabilistic: It calculates max(0, target_probs - draft_probs)
as per the standard speculative decoding algorithm.
3. **Normalization & Sampling**: It divides the adjusted probabilities
by the normalization distribution 'q'. To remain vectorized, it
broadcasts 'q' from [batch_size, vocab] to [num_tokens, vocab].
4. **Argmax Selection**: It selects the best recovery token for every
position in one pass using torch.argmax.
"""
device = output_token_ids.device
num_tokens = output_token_ids.shape[0]
for req_idx in range(batch_size):
start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1]
end_idx = cu_num_draft_tokens[req_idx]
num_draft_tokens = end_idx - start_idx
if num_tokens == 0:
return
for pos in range(num_draft_tokens):
token_idx = start_idx + pos
cu_start = torch.cat([
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
cu_num_draft_tokens[:-1],
])
cu_end = cu_num_draft_tokens
if IS_NGRAM:
draft_token_id = draft_token_ids[token_idx]
orig_prob = target_probs[token_idx, draft_token_id].item()
target_probs[token_idx, draft_token_id] = 0
prob = target_probs[token_idx].clone()
else:
draft_p = draft_probs[token_idx].clone()
target_p = target_probs[token_idx].clone()
prob = torch.maximum(target_p - draft_p,
torch.tensor(0.0, device=target_p.device))
token_indices = torch.arange(num_tokens, device=device) # [num_tokens]
q_values = torch.full((vocab_size, ),
float('-inf'),
device=q.device)
q_values[:vocab_size] = q[req_idx, :vocab_size]
token_indices_expanded = token_indices[:, None] # [num_tokens, 1]
cu_start_expanded = cu_start[None, :] # [1, batch_size]
cu_end_expanded = cu_end[None, :] # [1, batch_size]
recovered_id = torch.argmax(prob / q_values).item()
output_token_ids[token_idx] = recovered_id
in_range_mask = (token_indices_expanded >= cu_start_expanded) & (
token_indices_expanded < cu_end_expanded)
if IS_NGRAM:
target_probs[token_idx, draft_token_id] = orig_prob
token_to_batch = torch.argmax(in_range_mask.int(), dim=1)
has_match = in_range_mask.any(dim=1)
token_to_batch = torch.where(has_match, token_to_batch, 0)
if IS_NGRAM:
token_indices = torch.arange(num_tokens, device=device)
modified_target_probs = target_probs.clone()
modified_target_probs[token_indices, draft_token_ids] = 0
prob = modified_target_probs
else:
prob = torch.maximum(
target_probs - draft_probs,
torch.tensor(0.0, pin_memory=True).to(device, non_blocking=True),
)
q_values = q[token_to_batch] # [num_tokens, vocab_size]
epsilon = 1e-10
q_values_safe = torch.where(q_values == 0, epsilon, q_values)
q_values_safe = torch.where(torch.isinf(q_values), epsilon, q_values_safe)
prob_over_q = prob / q_values_safe
prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10,
prob_over_q)
recovered_ids = torch.argmax(prob_over_q, dim=1)
output_token_ids[:] = recovered_ids
@triton.jit(do_not_specialize=["max_spec_len"])