[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #7) (#6023)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|` vllm_ascend/quantization/compressed_tensors/compressed_tensors.py`|
|` vllm_ascend/quantization/quant_config.py`|
|` vllm_ascend/quantization/utils.py`|
|` vllm_ascend/quantization/w4a16.py`|
|` vllm_ascend/quantization/w4a4_flatquant_dynamic.py`|
|` vllm_ascend/quantization/w4a8_dynamic.py`|
|` vllm_ascend/quantization/w8a16.py`|
|` vllm_ascend/quantization/w8a8.py`|
|` vllm_ascend/quantization/w8a8_dynamic.py`|
|` vllm_ascend/quantization/w8a8_pdmix.py`|
|` vllm_ascend/quantization/w8a8mxfp8.py`|
|` vllm_ascend/sample/rejection_sampler.py`|
|` vllm_ascend/sample/sampler.py`|
|` vllm_ascend/worker/block_table.py`|

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-02-06 14:56:53 +08:00
committed by GitHub
parent d0bc16859c
commit 99aedaff63
20 changed files with 997 additions and 1307 deletions

View File

@@ -1,18 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from vllm.triton_utils import HAS_TRITON, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE, MAX_SPEC_LEN,
PLACEHOLDER_TOKEN_ID,
generate_uniform_probs)
from vllm.v1.sample.rejection_sampler import (
GREEDY_TEMPERATURE,
MAX_SPEC_LEN,
PLACEHOLDER_TOKEN_ID,
generate_uniform_probs,
)
from vllm_ascend.ops.triton.reject_sample import (
cal_grid_and_block_size, expand_triton,
cal_grid_and_block_size,
expand_triton,
rejection_greedy_sample_with_triton,
rejection_random_sample_block_verify_kernel,
rejection_random_sample_kernel, sample_recovered_tokens_kernel)
rejection_random_sample_kernel,
sample_recovered_tokens_kernel,
)
from vllm_ascend.sample.sampler import apply_top_k_top_p
@@ -83,7 +88,7 @@ def rejection_sample(
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
draft_probs: torch.Tensor | None,
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
# [batch_size, 1]
@@ -126,15 +131,20 @@ def rejection_sample(
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
if HAS_TRITON:
rejection_greedy_sample_with_triton(output_token_ids,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids, target_argmax,
bonus_token_ids, is_greedy,
max_spec_len, grid, block_size)
rejection_greedy_sample_with_triton(
output_token_ids,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
grid,
block_size,
)
else:
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
if min(num_draft_tokens) == 1 and max(num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids,
draft_token_ids,
@@ -179,7 +189,7 @@ def rejection_sample(
if not using_block_verify:
# Rejection sampling for random sampling requests.
if HAS_TRITON:
rejection_random_sample_kernel[(grid, )](
rejection_random_sample_kernel[(grid,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -214,7 +224,7 @@ def rejection_sample(
else:
# MagicMTP: Improving acceptance rate with Block Verify.
if HAS_TRITON:
rejection_random_sample_block_verify_kernel[(grid, )](
rejection_random_sample_block_verify_kernel[(grid,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -231,19 +241,20 @@ def rejection_sample(
BLOCK_SIZE=block_size,
)
else:
rejection_random_sample_block_verify_pytorch(output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs
is None)
rejection_random_sample_block_verify_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
)
return output_token_ids
@@ -277,13 +288,7 @@ def expand_batch_to_tokens(
assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens)
if HAS_TRITON:
expand_triton(batch_size,
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
max_num_tokens=MAX_SPEC_LEN)
expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, replace_to, max_num_tokens=MAX_SPEC_LEN)
else:
expand_pytorch(
expanded_x,
@@ -301,7 +306,7 @@ def sample_recovered_tokens(
num_draft_tokens: list[int],
cu_num_draft_tokens: torch.Tensor,
draft_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor],
draft_probs: torch.Tensor | None,
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
device: torch.device,
@@ -316,9 +321,7 @@ def sample_recovered_tokens(
)
q.exponential_()
num_draft_tensor = torch.tensor(num_draft_tokens,
pin_memory=True).to(device,
non_blocking=True)
num_draft_tensor = torch.tensor(num_draft_tokens, pin_memory=True).to(device, non_blocking=True)
has_draft_mask = num_draft_tensor > 0
for i, generator in sampling_metadata.generators.items():
@@ -357,10 +360,10 @@ def sample_recovered_tokens(
def rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
):
batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0)
@@ -368,73 +371,56 @@ def rejection_greedy_sample_spec_len_1_pytorch(
accept_req_mask = draft_token_ids == target_argmax
output_token_ids[:, 0] = target_argmax
bonus_token_ids = bonus_token_ids.squeeze(1)
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids,
output_token_ids[:, 1])
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids, output_token_ids[:, 1])
def rejection_greedy_sample_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list
max_spec_len,
is_greedy=None, # [batch_size] or None
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list
max_spec_len,
is_greedy=None, # [batch_size] or None
):
batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0)
device = output_token_ids.device
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
device, non_blocking=True)
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(device, non_blocking=True)
if is_greedy is None:
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
start_indices = cu_num_draft_tokens - draft_tokens_per_req
req_ids = torch.arange(batch_size, device=device)
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
token_positions = torch.arange(
num_tokens, device=device) - start_indices[token_req_ids]
token_positions = torch.arange(num_tokens, device=device) - start_indices[token_req_ids]
# Find the first mismatch position of each request.
mismatch_global = (draft_token_ids != target_argmax)
mismatch_global = draft_token_ids != target_argmax
if max_spec_len == 0:
first_mismatch_pos_per_req = torch.zeros(batch_size,
dtype=torch.long,
device=device)
first_mismatch_pos_per_req = torch.zeros(batch_size, dtype=torch.long, device=device)
else:
# [bs, max_spec_len]
pos_matrix = torch.full((batch_size, max_spec_len),
-1,
dtype=torch.long,
device=device)
pos_matrix = torch.full((batch_size, max_spec_len), -1, dtype=torch.long, device=device)
pos_matrix[token_req_ids, token_positions] = token_positions
mismatch_matrix = torch.full((batch_size, max_spec_len),
False,
dtype=torch.bool,
device=device)
mismatch_matrix = torch.full((batch_size, max_spec_len), False, dtype=torch.bool, device=device)
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
max_spec_len * 2)
mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2)
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
no_mismatch_mask]
no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[no_mismatch_mask]
# Copy matched target tokens into output.
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
draft_tokens_per_req)
copy_indices = torch.arange(max_spec_len + 1,
device=device).expand(batch_size, -1)
copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req)
copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1)
copy_mask = copy_indices < copy_len.unsqueeze(1)
greedy_mask = is_greedy.unsqueeze(1)
final_copy_mask = copy_mask & greedy_mask
global_idx = start_indices.unsqueeze(1) + copy_indices
output_token_ids[final_copy_mask] = target_argmax[
global_idx[final_copy_mask]].to(output_token_ids.dtype)
output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(output_token_ids.dtype)
# Fill bonus token.
needs_bonus = is_greedy & (first_mismatch_pos_per_req
>= draft_tokens_per_req)
needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req)
if torch.any(needs_bonus):
bonus_rows = torch.where(needs_bonus)[0]
bonus_cols = draft_tokens_per_req[bonus_rows]
@@ -458,24 +444,24 @@ def rejection_random_sample_pytorch(
):
"""
This function implements the Speculative Decoding rejection sampling step.
Instead of looping through each request and each token (which causes high
Instead of looping through each request and each token (which causes high
overhead), it uses a fully vectorized approach:
1. **Index Mapping**: Converts the flattened 1D token arrays into a 2D
[batch_size, max_draft_len] grid using 'cu_num_draft_tokens' to handle
1. **Index Mapping**: Converts the flattened 1D token arrays into a 2D
[batch_size, max_draft_len] grid using 'cu_num_draft_tokens' to handle
variable-length sequences in the batch.
2. **Parallel Validation**: Calculates the acceptance condition
(target_prob / draft_prob >= uniform_sample) for ALL draft tokens
2. **Parallel Validation**: Calculates the acceptance condition
(target_prob / draft_prob >= uniform_sample) for ALL draft tokens
simultaneously across the entire batch.
3. **Short-circuit Simulation**: In the loop version, once a token is rejected,
subsequent tokens are ignored. Here, we simulate this by finding the
'first_reject_pos' using argmax on the rejection mask and creating a
3. **Short-circuit Simulation**: In the loop version, once a token is rejected,
subsequent tokens are ignored. Here, we simulate this by finding the
'first_reject_pos' using argmax on the rejection mask and creating a
'should_skip' mask for all indices after the first failure.
4. **Token Selection**: Uses 'torch.where' to select:
- Draft tokens (if accepted)
- Recovered tokens (at the point of first rejection)
- Bonus tokens (if all tokens in a sequence were accepted)
5. **Masking**: Ensures operations only apply to non-greedy requests and
5. **Masking**: Ensures operations only apply to non-greedy requests and
within valid sequence lengths.
"""
@@ -495,15 +481,12 @@ def rejection_random_sample_pytorch(
valid_mask = pos_indices < num_draft_per_batch[:, None]
global_token_indices = cu_start[:, None] + pos_indices
global_token_indices = global_token_indices.clamp(
0, draft_token_ids.shape[0] - 1)
draft_tokens = draft_token_ids[
global_token_indices] # [batch_size, max_draft_len]
global_token_indices = global_token_indices.clamp(0, draft_token_ids.shape[0] - 1)
draft_tokens = draft_token_ids[global_token_indices] # [batch_size, max_draft_len]
if IS_NGRAM:
ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32)
draft_token_probs = ones_cpu.to(
device, non_blocking=True).expand_as(draft_tokens)
draft_token_probs = ones_cpu.to(device, non_blocking=True).expand_as(draft_tokens)
else:
flat_indices = global_token_indices.flatten()
flat_draft_tokens = draft_tokens.flatten()
@@ -518,24 +501,21 @@ def rejection_random_sample_pytorch(
uniform_token_probs = uniform_probs[global_token_indices]
recovered_tokens = recovered_token_ids[global_token_indices]
zero_threshold_cpu = torch.tensor([0.0],
pin_memory=True,
dtype=torch.float32)
zero_threshold_cpu = torch.tensor([0.0], pin_memory=True, dtype=torch.float32)
zero_threshold = zero_threshold_cpu.to(device, non_blocking=True)
acceptance_condition = (draft_token_probs > zero_threshold) & (
target_token_probs / draft_token_probs >= uniform_token_probs)
target_token_probs / draft_token_probs >= uniform_token_probs
)
first_rejection = (~acceptance_condition) & valid_mask
default_pos_cpu = torch.full([batch_size, 1],
max_draft_len,
pin_memory=True)
default_pos_cpu = torch.full([batch_size, 1], max_draft_len, pin_memory=True)
default_pos = default_pos_cpu.to(device, non_blocking=True)
first_reject_pos = torch.where(
first_rejection.any(dim=1, keepdim=True),
first_rejection.float().argmax(dim=1, keepdim=True), default_pos)
first_rejection.any(dim=1, keepdim=True), first_rejection.float().argmax(dim=1, keepdim=True), default_pos
)
pos_mask = pos_indices >= first_reject_pos
should_skip = pos_mask & valid_mask
@@ -543,16 +523,17 @@ def rejection_random_sample_pytorch(
non_greedy_mask = ~is_greedy
update_mask = non_greedy_mask[:, None] & valid_mask & (~should_skip)
first_reject_mask = (pos_indices == first_reject_pos
) & valid_mask & non_greedy_mask[:, None]
first_reject_mask = (pos_indices == first_reject_pos) & valid_mask & non_greedy_mask[:, None]
final_update_mask = update_mask | first_reject_mask
final_tokens = torch.where(
first_reject_mask, recovered_tokens,
torch.where(final_acceptance, draft_tokens,
output_token_ids[:, :max_draft_len]))
first_reject_mask,
recovered_tokens,
torch.where(final_acceptance, draft_tokens, output_token_ids[:, :max_draft_len]),
)
output_token_ids[:, :max_draft_len] = torch.where(
final_update_mask, final_tokens, output_token_ids[:, :max_draft_len])
final_update_mask, final_tokens, output_token_ids[:, :max_draft_len]
)
no_rejection = first_reject_pos.squeeze(1) >= num_draft_per_batch
should_add_bonus = non_greedy_mask & no_rejection
@@ -561,8 +542,7 @@ def rejection_random_sample_pytorch(
seq_len = output_token_ids.shape[1]
all_positions_cpu = torch.arange(seq_len, pin_memory=True)
all_positions = all_positions_cpu.to(
device, non_blocking=True)[None, :] # [1, seq_len]
all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :] # [1, seq_len]
batch_bonus_positions = bonus_positions[:, None] # [batch_size, 1]
@@ -572,12 +552,11 @@ def rejection_random_sample_pytorch(
valid_bonus_pos = bonus_positions < (max_spec_len_device + 1)
final_bonus_mask = should_add_bonus & valid_bonus_pos
bonus_pos_match = (all_positions == batch_bonus_positions)
bonus_pos_match = all_positions == batch_bonus_positions
bonus_pos_mask = bonus_pos_match & final_bonus_mask[:, None]
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, seq_len)
output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded,
output_token_ids)
output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded, output_token_ids)
def expand_pytorch(
@@ -589,17 +568,17 @@ def expand_pytorch(
MAX_NUM_TOKENS,
):
"""
This function broadcasts batch-level values (input_ptr) to token-level
positions (output_ptr) based on cumulative token offsets. It acts like
This function broadcasts batch-level values (input_ptr) to token-level
positions (output_ptr) based on cumulative token offsets. It acts like
a "scatter" or "repeat_interleave" operation but with custom logic:
1. **Range Broadcasting**: It creates a boolean matrix 'in_range' of size
[num_tokens, batch_size] that identifies which batch index each token
1. **Range Broadcasting**: It creates a boolean matrix 'in_range' of size
[num_tokens, batch_size] that identifies which batch index each token
belongs to by checking if the token index falls between cu_start and cu_end.
2. **Conditional Replacement**: Before expansion, it replaces specific values
2. **Conditional Replacement**: Before expansion, it replaces specific values
(e.g., padding or special markers) in the input to prepare the data.
3. **Matrix-based Mapping**: It uses 'torch.einsum' to perform a weighted
sum that effectively "picks" the correct batch value for every token position
3. **Matrix-based Mapping**: It uses 'torch.einsum' to perform a weighted
sum that effectively "picks" the correct batch value for every token position
simultaneously, avoiding a Python loop over the batch.
"""
device = cu_num_tokens_ptr.device
@@ -609,21 +588,16 @@ def expand_pytorch(
if batch_size == 0 or num_tokens == 0:
return
cu_start = torch.cat([
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
cu_num_tokens_ptr[:-1]
])
cu_start = torch.cat([torch.tensor([0], pin_memory=True).to(device, non_blocking=True), cu_num_tokens_ptr[:-1]])
cu_end = cu_num_tokens_ptr
token_indices = torch.arange(num_tokens,
device=device)[:, None] # [num_tokens, 1]
token_indices = torch.arange(num_tokens, device=device)[:, None] # [num_tokens, 1]
cu_start_exp = cu_start[None, :] # [1, batch_size]
cu_end_exp = cu_end[None, :] # [1, batch_size]
in_range = (token_indices >= cu_start_exp) & (token_indices < cu_end_exp)
replaced_input = torch.where(input_ptr == replace_from, replace_to,
input_ptr).float()
replaced_input = torch.where(input_ptr == replace_from, replace_to, input_ptr).float()
token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input)
@@ -643,21 +617,21 @@ def sample_recovered_tokens_pytorch(
IS_NGRAM=False,
):
"""
When a draft token is rejected, we must sample a "recovered" token from
a modified distribution. This function calculates that distribution across
When a draft token is rejected, we must sample a "recovered" token from
a modified distribution. This function calculates that distribution across
the entire flattened batch.
1. **Token-to-Batch Mapping**: Using the cumulative draft token counts, it
determines which request in the batch each token belongs to. This is
1. **Token-to-Batch Mapping**: Using the cumulative draft token counts, it
determines which request in the batch each token belongs to. This is
necessary because 'q' (normalization factor) is stored per-request.
2. **Probability Adjustment**:
2. **Probability Adjustment**:
- If N-GRAM: It zeroes out the draft token's probability in the target.
- If Probabilistic: It calculates max(0, target_probs - draft_probs)
- If Probabilistic: It calculates max(0, target_probs - draft_probs)
as per the standard speculative decoding algorithm.
3. **Normalization & Sampling**: It divides the adjusted probabilities
by the normalization distribution 'q'. To remain vectorized, it
3. **Normalization & Sampling**: It divides the adjusted probabilities
by the normalization distribution 'q'. To remain vectorized, it
broadcasts 'q' from [batch_size, vocab] to [num_tokens, vocab].
4. **Argmax Selection**: It selects the best recovery token for every
4. **Argmax Selection**: It selects the best recovery token for every
position in one pass using torch.argmax.
"""
device = output_token_ids.device
@@ -666,10 +640,12 @@ def sample_recovered_tokens_pytorch(
if num_tokens == 0:
return
cu_start = torch.cat([
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
cu_num_draft_tokens[:-1],
])
cu_start = torch.cat(
[
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
cu_num_draft_tokens[:-1],
]
)
cu_end = cu_num_draft_tokens
token_indices = torch.arange(num_tokens, device=device) # [num_tokens]
@@ -678,8 +654,7 @@ def sample_recovered_tokens_pytorch(
cu_start_expanded = cu_start[None, :] # [1, batch_size]
cu_end_expanded = cu_end[None, :] # [1, batch_size]
in_range_mask = (token_indices_expanded >= cu_start_expanded) & (
token_indices_expanded < cu_end_expanded)
in_range_mask = (token_indices_expanded >= cu_start_expanded) & (token_indices_expanded < cu_end_expanded)
token_to_batch = torch.argmax(in_range_mask.int(), dim=1)
@@ -707,8 +682,7 @@ def sample_recovered_tokens_pytorch(
prob_over_q = prob / q_values_safe
prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10,
prob_over_q)
prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10, prob_over_q)
recovered_ids = torch.argmax(prob_over_q, dim=1)
@@ -742,14 +716,12 @@ def rejection_random_sample_block_verify_pytorch(
pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :]
valid_mask = pos_indices < num_draft_per_batch
global_token_indices = cu_start[:, None] + pos_indices
global_token_indices = global_token_indices.clamp(
0, draft_token_ids.shape[0] - 1)
global_token_indices = global_token_indices.clamp(0, draft_token_ids.shape[0] - 1)
draft_tokens = draft_token_ids[global_token_indices]
if IS_NGRAM:
ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32)
draft_token_probs = ones_cpu.to(
device, non_blocking=True).expand_as(draft_tokens)
draft_token_probs = ones_cpu.to(device, non_blocking=True).expand_as(draft_tokens)
else:
flat_indices = global_token_indices.flatten()
flat_draft_tokens = draft_tokens.flatten()
@@ -772,27 +744,21 @@ def rejection_random_sample_block_verify_pytorch(
last_accept_pos = torch.where(
legal_mask.any(dim=-1, keepdim=True),
(max_spec_len -
legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1),
-1)
(max_spec_len - legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1),
-1,
)
non_greedy_mask = (~is_greedy)[:, None]
accept_mask = (pos_indices
<= last_accept_pos) & valid_mask & non_greedy_mask
output_token_ids[:, :max_spec_len] = torch.where(
accept_mask, draft_tokens, output_token_ids[:, :max_spec_len])
accept_mask = (pos_indices <= last_accept_pos) & valid_mask & non_greedy_mask
output_token_ids[:, :max_spec_len] = torch.where(accept_mask, draft_tokens, output_token_ids[:, :max_spec_len])
reject_mask = (pos_indices
== last_accept_pos + 1) & valid_mask & non_greedy_mask
output_token_ids[:, :max_spec_len] = torch.where(
reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len])
reject_mask = (pos_indices == last_accept_pos + 1) & valid_mask & non_greedy_mask
output_token_ids[:, :max_spec_len] = torch.where(reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len])
bonus_mask = (last_accept_pos + 1 >= num_draft_per_batch) & non_greedy_mask
all_positions_cpu = torch.arange(max_spec_len + 1, pin_memory=True)
all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :]
bonus_pos_match = (all_positions == num_draft_per_batch)
bonus_pos_match = all_positions == num_draft_per_batch
bonus_mask = bonus_mask & bonus_pos_match
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(
-1, max_spec_len + 1)
output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded,
output_token_ids)
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, max_spec_len + 1)
output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded, output_token_ids)

View File

@@ -35,7 +35,6 @@ def random_sample(
class AscendSampler(Sampler):
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):
# TODO: support logprobs_mode in vllm-ascend
super().__init__(logprobs_mode=logprobs_mode)
@@ -62,7 +61,6 @@ class AscendSampler(Sampler):
class AscendTopKTopPSampler(TopKTopPSampler):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_top_k_top_p = apply_top_k_top_p
@@ -135,4 +133,9 @@ def _apply_top_k_top_p_ascendc(
return logits
return torch.ops._C_ascend.npu_apply_top_k_top_p(logits, k=k, p=p)
apply_top_k_top_p = _apply_top_k_top_p_ascendc if get_ascend_device_type() in [AscendDeviceType.A2, AscendDeviceType.A3] else _apply_top_k_top_p_pytorch
apply_top_k_top_p = (
_apply_top_k_top_p_ascendc
if get_ascend_device_type() in [AscendDeviceType.A2, AscendDeviceType.A3]
else _apply_top_k_top_p_pytorch
)