### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|` vllm_ascend/quantization/compressed_tensors/compressed_tensors.py`|
|` vllm_ascend/quantization/quant_config.py`|
|` vllm_ascend/quantization/utils.py`|
|` vllm_ascend/quantization/w4a16.py`|
|` vllm_ascend/quantization/w4a4_flatquant_dynamic.py`|
|` vllm_ascend/quantization/w4a8_dynamic.py`|
|` vllm_ascend/quantization/w8a16.py`|
|` vllm_ascend/quantization/w8a8.py`|
|` vllm_ascend/quantization/w8a8_dynamic.py`|
|` vllm_ascend/quantization/w8a8_pdmix.py`|
|` vllm_ascend/quantization/w8a8mxfp8.py`|
|` vllm_ascend/sample/rejection_sampler.py`|
|` vllm_ascend/sample/sampler.py`|
|` vllm_ascend/worker/block_table.py`|
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -1,18 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE, MAX_SPEC_LEN,
|
||||
PLACEHOLDER_TOKEN_ID,
|
||||
generate_uniform_probs)
|
||||
from vllm.v1.sample.rejection_sampler import (
|
||||
GREEDY_TEMPERATURE,
|
||||
MAX_SPEC_LEN,
|
||||
PLACEHOLDER_TOKEN_ID,
|
||||
generate_uniform_probs,
|
||||
)
|
||||
|
||||
from vllm_ascend.ops.triton.reject_sample import (
|
||||
cal_grid_and_block_size, expand_triton,
|
||||
cal_grid_and_block_size,
|
||||
expand_triton,
|
||||
rejection_greedy_sample_with_triton,
|
||||
rejection_random_sample_block_verify_kernel,
|
||||
rejection_random_sample_kernel, sample_recovered_tokens_kernel)
|
||||
rejection_random_sample_kernel,
|
||||
sample_recovered_tokens_kernel,
|
||||
)
|
||||
from vllm_ascend.sample.sampler import apply_top_k_top_p
|
||||
|
||||
|
||||
@@ -83,7 +88,7 @@ def rejection_sample(
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
draft_probs: torch.Tensor | None,
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
@@ -126,15 +131,20 @@ def rejection_sample(
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
if HAS_TRITON:
|
||||
rejection_greedy_sample_with_triton(output_token_ids,
|
||||
num_draft_tokens,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids, target_argmax,
|
||||
bonus_token_ids, is_greedy,
|
||||
max_spec_len, grid, block_size)
|
||||
rejection_greedy_sample_with_triton(
|
||||
output_token_ids,
|
||||
num_draft_tokens,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
grid,
|
||||
block_size,
|
||||
)
|
||||
else:
|
||||
if min(num_draft_tokens) == 1 and max(
|
||||
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
||||
if min(num_draft_tokens) == 1 and max(num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
||||
rejection_greedy_sample_spec_len_1_pytorch(
|
||||
output_token_ids,
|
||||
draft_token_ids,
|
||||
@@ -179,7 +189,7 @@ def rejection_sample(
|
||||
if not using_block_verify:
|
||||
# Rejection sampling for random sampling requests.
|
||||
if HAS_TRITON:
|
||||
rejection_random_sample_kernel[(grid, )](
|
||||
rejection_random_sample_kernel[(grid,)](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -214,7 +224,7 @@ def rejection_sample(
|
||||
else:
|
||||
# MagicMTP: Improving acceptance rate with Block Verify.
|
||||
if HAS_TRITON:
|
||||
rejection_random_sample_block_verify_kernel[(grid, )](
|
||||
rejection_random_sample_block_verify_kernel[(grid,)](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -231,19 +241,20 @@ def rejection_sample(
|
||||
BLOCK_SIZE=block_size,
|
||||
)
|
||||
else:
|
||||
rejection_random_sample_block_verify_pytorch(output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM=draft_probs
|
||||
is None)
|
||||
rejection_random_sample_block_verify_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM=draft_probs is None,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
|
||||
@@ -277,13 +288,7 @@ def expand_batch_to_tokens(
|
||||
assert cu_num_tokens.shape[0] == batch_size
|
||||
expanded_x = x.new_empty(num_tokens)
|
||||
if HAS_TRITON:
|
||||
expand_triton(batch_size,
|
||||
expanded_x,
|
||||
x,
|
||||
cu_num_tokens,
|
||||
replace_from,
|
||||
replace_to,
|
||||
max_num_tokens=MAX_SPEC_LEN)
|
||||
expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, replace_to, max_num_tokens=MAX_SPEC_LEN)
|
||||
else:
|
||||
expand_pytorch(
|
||||
expanded_x,
|
||||
@@ -301,7 +306,7 @@ def sample_recovered_tokens(
|
||||
num_draft_tokens: list[int],
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
draft_probs: torch.Tensor | None,
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
@@ -316,9 +321,7 @@ def sample_recovered_tokens(
|
||||
)
|
||||
q.exponential_()
|
||||
|
||||
num_draft_tensor = torch.tensor(num_draft_tokens,
|
||||
pin_memory=True).to(device,
|
||||
non_blocking=True)
|
||||
num_draft_tensor = torch.tensor(num_draft_tokens, pin_memory=True).to(device, non_blocking=True)
|
||||
has_draft_mask = num_draft_tensor > 0
|
||||
|
||||
for i, generator in sampling_metadata.generators.items():
|
||||
@@ -357,10 +360,10 @@ def sample_recovered_tokens(
|
||||
|
||||
|
||||
def rejection_greedy_sample_spec_len_1_pytorch(
|
||||
output_token_ids, # [batch_size, 2]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
output_token_ids, # [batch_size, 2]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
):
|
||||
batch_size = output_token_ids.size(0)
|
||||
num_tokens = draft_token_ids.size(0)
|
||||
@@ -368,73 +371,56 @@ def rejection_greedy_sample_spec_len_1_pytorch(
|
||||
accept_req_mask = draft_token_ids == target_argmax
|
||||
output_token_ids[:, 0] = target_argmax
|
||||
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids,
|
||||
output_token_ids[:, 1])
|
||||
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids, output_token_ids[:, 1])
|
||||
|
||||
|
||||
def rejection_greedy_sample_pytorch(
|
||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens, # [batch_size]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
draft_tokens_per_req, # [batch_size], list
|
||||
max_spec_len,
|
||||
is_greedy=None, # [batch_size] or None
|
||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens, # [batch_size]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
draft_tokens_per_req, # [batch_size], list
|
||||
max_spec_len,
|
||||
is_greedy=None, # [batch_size] or None
|
||||
):
|
||||
batch_size = output_token_ids.size(0)
|
||||
num_tokens = draft_token_ids.size(0)
|
||||
device = output_token_ids.device
|
||||
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
|
||||
device, non_blocking=True)
|
||||
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(device, non_blocking=True)
|
||||
if is_greedy is None:
|
||||
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
start_indices = cu_num_draft_tokens - draft_tokens_per_req
|
||||
req_ids = torch.arange(batch_size, device=device)
|
||||
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
|
||||
token_positions = torch.arange(
|
||||
num_tokens, device=device) - start_indices[token_req_ids]
|
||||
token_positions = torch.arange(num_tokens, device=device) - start_indices[token_req_ids]
|
||||
|
||||
# Find the first mismatch position of each request.
|
||||
mismatch_global = (draft_token_ids != target_argmax)
|
||||
mismatch_global = draft_token_ids != target_argmax
|
||||
if max_spec_len == 0:
|
||||
first_mismatch_pos_per_req = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
first_mismatch_pos_per_req = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
else:
|
||||
# [bs, max_spec_len]
|
||||
pos_matrix = torch.full((batch_size, max_spec_len),
|
||||
-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
pos_matrix = torch.full((batch_size, max_spec_len), -1, dtype=torch.long, device=device)
|
||||
pos_matrix[token_req_ids, token_positions] = token_positions
|
||||
mismatch_matrix = torch.full((batch_size, max_spec_len),
|
||||
False,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
mismatch_matrix = torch.full((batch_size, max_spec_len), False, dtype=torch.bool, device=device)
|
||||
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
|
||||
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
|
||||
max_spec_len * 2)
|
||||
mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2)
|
||||
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
|
||||
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
|
||||
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
|
||||
no_mismatch_mask]
|
||||
no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2
|
||||
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[no_mismatch_mask]
|
||||
|
||||
# Copy matched target tokens into output.
|
||||
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
|
||||
draft_tokens_per_req)
|
||||
copy_indices = torch.arange(max_spec_len + 1,
|
||||
device=device).expand(batch_size, -1)
|
||||
copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req)
|
||||
copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1)
|
||||
copy_mask = copy_indices < copy_len.unsqueeze(1)
|
||||
greedy_mask = is_greedy.unsqueeze(1)
|
||||
final_copy_mask = copy_mask & greedy_mask
|
||||
global_idx = start_indices.unsqueeze(1) + copy_indices
|
||||
output_token_ids[final_copy_mask] = target_argmax[
|
||||
global_idx[final_copy_mask]].to(output_token_ids.dtype)
|
||||
output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(output_token_ids.dtype)
|
||||
# Fill bonus token.
|
||||
needs_bonus = is_greedy & (first_mismatch_pos_per_req
|
||||
>= draft_tokens_per_req)
|
||||
needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req)
|
||||
if torch.any(needs_bonus):
|
||||
bonus_rows = torch.where(needs_bonus)[0]
|
||||
bonus_cols = draft_tokens_per_req[bonus_rows]
|
||||
@@ -458,24 +444,24 @@ def rejection_random_sample_pytorch(
|
||||
):
|
||||
"""
|
||||
This function implements the Speculative Decoding rejection sampling step.
|
||||
Instead of looping through each request and each token (which causes high
|
||||
Instead of looping through each request and each token (which causes high
|
||||
overhead), it uses a fully vectorized approach:
|
||||
|
||||
1. **Index Mapping**: Converts the flattened 1D token arrays into a 2D
|
||||
[batch_size, max_draft_len] grid using 'cu_num_draft_tokens' to handle
|
||||
|
||||
1. **Index Mapping**: Converts the flattened 1D token arrays into a 2D
|
||||
[batch_size, max_draft_len] grid using 'cu_num_draft_tokens' to handle
|
||||
variable-length sequences in the batch.
|
||||
2. **Parallel Validation**: Calculates the acceptance condition
|
||||
(target_prob / draft_prob >= uniform_sample) for ALL draft tokens
|
||||
2. **Parallel Validation**: Calculates the acceptance condition
|
||||
(target_prob / draft_prob >= uniform_sample) for ALL draft tokens
|
||||
simultaneously across the entire batch.
|
||||
3. **Short-circuit Simulation**: In the loop version, once a token is rejected,
|
||||
subsequent tokens are ignored. Here, we simulate this by finding the
|
||||
'first_reject_pos' using argmax on the rejection mask and creating a
|
||||
3. **Short-circuit Simulation**: In the loop version, once a token is rejected,
|
||||
subsequent tokens are ignored. Here, we simulate this by finding the
|
||||
'first_reject_pos' using argmax on the rejection mask and creating a
|
||||
'should_skip' mask for all indices after the first failure.
|
||||
4. **Token Selection**: Uses 'torch.where' to select:
|
||||
- Draft tokens (if accepted)
|
||||
- Recovered tokens (at the point of first rejection)
|
||||
- Bonus tokens (if all tokens in a sequence were accepted)
|
||||
5. **Masking**: Ensures operations only apply to non-greedy requests and
|
||||
5. **Masking**: Ensures operations only apply to non-greedy requests and
|
||||
within valid sequence lengths.
|
||||
"""
|
||||
|
||||
@@ -495,15 +481,12 @@ def rejection_random_sample_pytorch(
|
||||
|
||||
valid_mask = pos_indices < num_draft_per_batch[:, None]
|
||||
global_token_indices = cu_start[:, None] + pos_indices
|
||||
global_token_indices = global_token_indices.clamp(
|
||||
0, draft_token_ids.shape[0] - 1)
|
||||
draft_tokens = draft_token_ids[
|
||||
global_token_indices] # [batch_size, max_draft_len]
|
||||
global_token_indices = global_token_indices.clamp(0, draft_token_ids.shape[0] - 1)
|
||||
draft_tokens = draft_token_ids[global_token_indices] # [batch_size, max_draft_len]
|
||||
|
||||
if IS_NGRAM:
|
||||
ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32)
|
||||
draft_token_probs = ones_cpu.to(
|
||||
device, non_blocking=True).expand_as(draft_tokens)
|
||||
draft_token_probs = ones_cpu.to(device, non_blocking=True).expand_as(draft_tokens)
|
||||
else:
|
||||
flat_indices = global_token_indices.flatten()
|
||||
flat_draft_tokens = draft_tokens.flatten()
|
||||
@@ -518,24 +501,21 @@ def rejection_random_sample_pytorch(
|
||||
uniform_token_probs = uniform_probs[global_token_indices]
|
||||
recovered_tokens = recovered_token_ids[global_token_indices]
|
||||
|
||||
zero_threshold_cpu = torch.tensor([0.0],
|
||||
pin_memory=True,
|
||||
dtype=torch.float32)
|
||||
zero_threshold_cpu = torch.tensor([0.0], pin_memory=True, dtype=torch.float32)
|
||||
zero_threshold = zero_threshold_cpu.to(device, non_blocking=True)
|
||||
|
||||
acceptance_condition = (draft_token_probs > zero_threshold) & (
|
||||
target_token_probs / draft_token_probs >= uniform_token_probs)
|
||||
target_token_probs / draft_token_probs >= uniform_token_probs
|
||||
)
|
||||
|
||||
first_rejection = (~acceptance_condition) & valid_mask
|
||||
|
||||
default_pos_cpu = torch.full([batch_size, 1],
|
||||
max_draft_len,
|
||||
pin_memory=True)
|
||||
default_pos_cpu = torch.full([batch_size, 1], max_draft_len, pin_memory=True)
|
||||
default_pos = default_pos_cpu.to(device, non_blocking=True)
|
||||
|
||||
first_reject_pos = torch.where(
|
||||
first_rejection.any(dim=1, keepdim=True),
|
||||
first_rejection.float().argmax(dim=1, keepdim=True), default_pos)
|
||||
first_rejection.any(dim=1, keepdim=True), first_rejection.float().argmax(dim=1, keepdim=True), default_pos
|
||||
)
|
||||
pos_mask = pos_indices >= first_reject_pos
|
||||
should_skip = pos_mask & valid_mask
|
||||
|
||||
@@ -543,16 +523,17 @@ def rejection_random_sample_pytorch(
|
||||
non_greedy_mask = ~is_greedy
|
||||
update_mask = non_greedy_mask[:, None] & valid_mask & (~should_skip)
|
||||
|
||||
first_reject_mask = (pos_indices == first_reject_pos
|
||||
) & valid_mask & non_greedy_mask[:, None]
|
||||
first_reject_mask = (pos_indices == first_reject_pos) & valid_mask & non_greedy_mask[:, None]
|
||||
final_update_mask = update_mask | first_reject_mask
|
||||
final_tokens = torch.where(
|
||||
first_reject_mask, recovered_tokens,
|
||||
torch.where(final_acceptance, draft_tokens,
|
||||
output_token_ids[:, :max_draft_len]))
|
||||
first_reject_mask,
|
||||
recovered_tokens,
|
||||
torch.where(final_acceptance, draft_tokens, output_token_ids[:, :max_draft_len]),
|
||||
)
|
||||
|
||||
output_token_ids[:, :max_draft_len] = torch.where(
|
||||
final_update_mask, final_tokens, output_token_ids[:, :max_draft_len])
|
||||
final_update_mask, final_tokens, output_token_ids[:, :max_draft_len]
|
||||
)
|
||||
|
||||
no_rejection = first_reject_pos.squeeze(1) >= num_draft_per_batch
|
||||
should_add_bonus = non_greedy_mask & no_rejection
|
||||
@@ -561,8 +542,7 @@ def rejection_random_sample_pytorch(
|
||||
|
||||
seq_len = output_token_ids.shape[1]
|
||||
all_positions_cpu = torch.arange(seq_len, pin_memory=True)
|
||||
all_positions = all_positions_cpu.to(
|
||||
device, non_blocking=True)[None, :] # [1, seq_len]
|
||||
all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :] # [1, seq_len]
|
||||
|
||||
batch_bonus_positions = bonus_positions[:, None] # [batch_size, 1]
|
||||
|
||||
@@ -572,12 +552,11 @@ def rejection_random_sample_pytorch(
|
||||
valid_bonus_pos = bonus_positions < (max_spec_len_device + 1)
|
||||
final_bonus_mask = should_add_bonus & valid_bonus_pos
|
||||
|
||||
bonus_pos_match = (all_positions == batch_bonus_positions)
|
||||
bonus_pos_match = all_positions == batch_bonus_positions
|
||||
bonus_pos_mask = bonus_pos_match & final_bonus_mask[:, None]
|
||||
|
||||
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, seq_len)
|
||||
output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded,
|
||||
output_token_ids)
|
||||
output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded, output_token_ids)
|
||||
|
||||
|
||||
def expand_pytorch(
|
||||
@@ -589,17 +568,17 @@ def expand_pytorch(
|
||||
MAX_NUM_TOKENS,
|
||||
):
|
||||
"""
|
||||
This function broadcasts batch-level values (input_ptr) to token-level
|
||||
positions (output_ptr) based on cumulative token offsets. It acts like
|
||||
This function broadcasts batch-level values (input_ptr) to token-level
|
||||
positions (output_ptr) based on cumulative token offsets. It acts like
|
||||
a "scatter" or "repeat_interleave" operation but with custom logic:
|
||||
|
||||
1. **Range Broadcasting**: It creates a boolean matrix 'in_range' of size
|
||||
[num_tokens, batch_size] that identifies which batch index each token
|
||||
|
||||
1. **Range Broadcasting**: It creates a boolean matrix 'in_range' of size
|
||||
[num_tokens, batch_size] that identifies which batch index each token
|
||||
belongs to by checking if the token index falls between cu_start and cu_end.
|
||||
2. **Conditional Replacement**: Before expansion, it replaces specific values
|
||||
2. **Conditional Replacement**: Before expansion, it replaces specific values
|
||||
(e.g., padding or special markers) in the input to prepare the data.
|
||||
3. **Matrix-based Mapping**: It uses 'torch.einsum' to perform a weighted
|
||||
sum that effectively "picks" the correct batch value for every token position
|
||||
3. **Matrix-based Mapping**: It uses 'torch.einsum' to perform a weighted
|
||||
sum that effectively "picks" the correct batch value for every token position
|
||||
simultaneously, avoiding a Python loop over the batch.
|
||||
"""
|
||||
device = cu_num_tokens_ptr.device
|
||||
@@ -609,21 +588,16 @@ def expand_pytorch(
|
||||
if batch_size == 0 or num_tokens == 0:
|
||||
return
|
||||
|
||||
cu_start = torch.cat([
|
||||
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
|
||||
cu_num_tokens_ptr[:-1]
|
||||
])
|
||||
cu_start = torch.cat([torch.tensor([0], pin_memory=True).to(device, non_blocking=True), cu_num_tokens_ptr[:-1]])
|
||||
cu_end = cu_num_tokens_ptr
|
||||
|
||||
token_indices = torch.arange(num_tokens,
|
||||
device=device)[:, None] # [num_tokens, 1]
|
||||
token_indices = torch.arange(num_tokens, device=device)[:, None] # [num_tokens, 1]
|
||||
cu_start_exp = cu_start[None, :] # [1, batch_size]
|
||||
cu_end_exp = cu_end[None, :] # [1, batch_size]
|
||||
|
||||
in_range = (token_indices >= cu_start_exp) & (token_indices < cu_end_exp)
|
||||
|
||||
replaced_input = torch.where(input_ptr == replace_from, replace_to,
|
||||
input_ptr).float()
|
||||
replaced_input = torch.where(input_ptr == replace_from, replace_to, input_ptr).float()
|
||||
|
||||
token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input)
|
||||
|
||||
@@ -643,21 +617,21 @@ def sample_recovered_tokens_pytorch(
|
||||
IS_NGRAM=False,
|
||||
):
|
||||
"""
|
||||
When a draft token is rejected, we must sample a "recovered" token from
|
||||
a modified distribution. This function calculates that distribution across
|
||||
When a draft token is rejected, we must sample a "recovered" token from
|
||||
a modified distribution. This function calculates that distribution across
|
||||
the entire flattened batch.
|
||||
|
||||
1. **Token-to-Batch Mapping**: Using the cumulative draft token counts, it
|
||||
determines which request in the batch each token belongs to. This is
|
||||
|
||||
1. **Token-to-Batch Mapping**: Using the cumulative draft token counts, it
|
||||
determines which request in the batch each token belongs to. This is
|
||||
necessary because 'q' (normalization factor) is stored per-request.
|
||||
2. **Probability Adjustment**:
|
||||
2. **Probability Adjustment**:
|
||||
- If N-GRAM: It zeroes out the draft token's probability in the target.
|
||||
- If Probabilistic: It calculates max(0, target_probs - draft_probs)
|
||||
- If Probabilistic: It calculates max(0, target_probs - draft_probs)
|
||||
as per the standard speculative decoding algorithm.
|
||||
3. **Normalization & Sampling**: It divides the adjusted probabilities
|
||||
by the normalization distribution 'q'. To remain vectorized, it
|
||||
3. **Normalization & Sampling**: It divides the adjusted probabilities
|
||||
by the normalization distribution 'q'. To remain vectorized, it
|
||||
broadcasts 'q' from [batch_size, vocab] to [num_tokens, vocab].
|
||||
4. **Argmax Selection**: It selects the best recovery token for every
|
||||
4. **Argmax Selection**: It selects the best recovery token for every
|
||||
position in one pass using torch.argmax.
|
||||
"""
|
||||
device = output_token_ids.device
|
||||
@@ -666,10 +640,12 @@ def sample_recovered_tokens_pytorch(
|
||||
if num_tokens == 0:
|
||||
return
|
||||
|
||||
cu_start = torch.cat([
|
||||
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
|
||||
cu_num_draft_tokens[:-1],
|
||||
])
|
||||
cu_start = torch.cat(
|
||||
[
|
||||
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
|
||||
cu_num_draft_tokens[:-1],
|
||||
]
|
||||
)
|
||||
cu_end = cu_num_draft_tokens
|
||||
|
||||
token_indices = torch.arange(num_tokens, device=device) # [num_tokens]
|
||||
@@ -678,8 +654,7 @@ def sample_recovered_tokens_pytorch(
|
||||
cu_start_expanded = cu_start[None, :] # [1, batch_size]
|
||||
cu_end_expanded = cu_end[None, :] # [1, batch_size]
|
||||
|
||||
in_range_mask = (token_indices_expanded >= cu_start_expanded) & (
|
||||
token_indices_expanded < cu_end_expanded)
|
||||
in_range_mask = (token_indices_expanded >= cu_start_expanded) & (token_indices_expanded < cu_end_expanded)
|
||||
|
||||
token_to_batch = torch.argmax(in_range_mask.int(), dim=1)
|
||||
|
||||
@@ -707,8 +682,7 @@ def sample_recovered_tokens_pytorch(
|
||||
|
||||
prob_over_q = prob / q_values_safe
|
||||
|
||||
prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10,
|
||||
prob_over_q)
|
||||
prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10, prob_over_q)
|
||||
|
||||
recovered_ids = torch.argmax(prob_over_q, dim=1)
|
||||
|
||||
@@ -742,14 +716,12 @@ def rejection_random_sample_block_verify_pytorch(
|
||||
pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :]
|
||||
valid_mask = pos_indices < num_draft_per_batch
|
||||
global_token_indices = cu_start[:, None] + pos_indices
|
||||
global_token_indices = global_token_indices.clamp(
|
||||
0, draft_token_ids.shape[0] - 1)
|
||||
global_token_indices = global_token_indices.clamp(0, draft_token_ids.shape[0] - 1)
|
||||
draft_tokens = draft_token_ids[global_token_indices]
|
||||
|
||||
if IS_NGRAM:
|
||||
ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32)
|
||||
draft_token_probs = ones_cpu.to(
|
||||
device, non_blocking=True).expand_as(draft_tokens)
|
||||
draft_token_probs = ones_cpu.to(device, non_blocking=True).expand_as(draft_tokens)
|
||||
else:
|
||||
flat_indices = global_token_indices.flatten()
|
||||
flat_draft_tokens = draft_tokens.flatten()
|
||||
@@ -772,27 +744,21 @@ def rejection_random_sample_block_verify_pytorch(
|
||||
|
||||
last_accept_pos = torch.where(
|
||||
legal_mask.any(dim=-1, keepdim=True),
|
||||
(max_spec_len -
|
||||
legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1),
|
||||
-1)
|
||||
(max_spec_len - legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1),
|
||||
-1,
|
||||
)
|
||||
non_greedy_mask = (~is_greedy)[:, None]
|
||||
|
||||
accept_mask = (pos_indices
|
||||
<= last_accept_pos) & valid_mask & non_greedy_mask
|
||||
output_token_ids[:, :max_spec_len] = torch.where(
|
||||
accept_mask, draft_tokens, output_token_ids[:, :max_spec_len])
|
||||
accept_mask = (pos_indices <= last_accept_pos) & valid_mask & non_greedy_mask
|
||||
output_token_ids[:, :max_spec_len] = torch.where(accept_mask, draft_tokens, output_token_ids[:, :max_spec_len])
|
||||
|
||||
reject_mask = (pos_indices
|
||||
== last_accept_pos + 1) & valid_mask & non_greedy_mask
|
||||
output_token_ids[:, :max_spec_len] = torch.where(
|
||||
reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len])
|
||||
reject_mask = (pos_indices == last_accept_pos + 1) & valid_mask & non_greedy_mask
|
||||
output_token_ids[:, :max_spec_len] = torch.where(reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len])
|
||||
|
||||
bonus_mask = (last_accept_pos + 1 >= num_draft_per_batch) & non_greedy_mask
|
||||
all_positions_cpu = torch.arange(max_spec_len + 1, pin_memory=True)
|
||||
all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :]
|
||||
bonus_pos_match = (all_positions == num_draft_per_batch)
|
||||
bonus_pos_match = all_positions == num_draft_per_batch
|
||||
bonus_mask = bonus_mask & bonus_pos_match
|
||||
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(
|
||||
-1, max_spec_len + 1)
|
||||
output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded,
|
||||
output_token_ids)
|
||||
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, max_spec_len + 1)
|
||||
output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded, output_token_ids)
|
||||
|
||||
@@ -35,7 +35,6 @@ def random_sample(
|
||||
|
||||
|
||||
class AscendSampler(Sampler):
|
||||
|
||||
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):
|
||||
# TODO: support logprobs_mode in vllm-ascend
|
||||
super().__init__(logprobs_mode=logprobs_mode)
|
||||
@@ -62,7 +61,6 @@ class AscendSampler(Sampler):
|
||||
|
||||
|
||||
class AscendTopKTopPSampler(TopKTopPSampler):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_top_k_top_p = apply_top_k_top_p
|
||||
@@ -135,4 +133,9 @@ def _apply_top_k_top_p_ascendc(
|
||||
return logits
|
||||
return torch.ops._C_ascend.npu_apply_top_k_top_p(logits, k=k, p=p)
|
||||
|
||||
apply_top_k_top_p = _apply_top_k_top_p_ascendc if get_ascend_device_type() in [AscendDeviceType.A2, AscendDeviceType.A3] else _apply_top_k_top_p_pytorch
|
||||
|
||||
apply_top_k_top_p = (
|
||||
_apply_top_k_top_p_ascendc
|
||||
if get_ascend_device_type() in [AscendDeviceType.A2, AscendDeviceType.A3]
|
||||
else _apply_top_k_top_p_pytorch
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user