[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)
Signed-off-by: xyDong0223 <dongxinyu03@baidu.com> Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import kunlun_ops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
"""
|
||||
Args:
|
||||
metadata:
|
||||
Metadata for spec decoding.
|
||||
@@ -81,7 +80,7 @@ class RejectionSampler(nn.Module):
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
'''
|
||||
"""
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
@@ -124,11 +123,11 @@ class RejectionSampler(nn.Module):
|
||||
"""
|
||||
output_token_ids_np = output_token_ids.cpu().numpy()
|
||||
# Create mask for valid tokens.
|
||||
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
|
||||
(output_token_ids_np < vocab_size))
|
||||
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
|
||||
output_token_ids_np < vocab_size
|
||||
)
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist()
|
||||
for i, row in enumerate(output_token_ids_np)
|
||||
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs
|
||||
|
||||
@@ -179,25 +178,15 @@ def rejection_sample(
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
if min(num_draft_tokens) == 1 and max(
|
||||
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
||||
rejection_greedy_sample_spec_len_1_pytorch(
|
||||
output_token_ids,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
)
|
||||
else:
|
||||
rejection_greedy_sample_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
num_draft_tokens,
|
||||
max_spec_len,
|
||||
is_greedy,
|
||||
)
|
||||
kunlun_ops.rejection_greedy_sample(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
|
||||
@@ -222,8 +211,9 @@ def rejection_sample(
|
||||
sampling_metadata,
|
||||
device,
|
||||
)
|
||||
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||
|
||||
rejection_random_sample_pytorch(
|
||||
kunlun_ops.rejection_random_sample(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -235,8 +225,7 @@ def rejection_sample(
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM=draft_probs is None,
|
||||
# num_warps=1,
|
||||
no_draft_probs=draft_probs is None,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
@@ -374,7 +363,7 @@ def generate_uniform_probs(
|
||||
random values in the range [0, 1).
|
||||
"""
|
||||
uniform_probs = torch.rand(
|
||||
(num_tokens, ),
|
||||
(num_tokens,),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
@@ -422,7 +411,7 @@ def sample_recovered_tokens(
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
recovered_token_ids = torch.empty_like(draft_token_ids)
|
||||
sample_recovered_tokens_pytorch(
|
||||
kunlun_ops.sample_recovered_tokens(
|
||||
recovered_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -430,16 +419,16 @@ def sample_recovered_tokens(
|
||||
target_probs,
|
||||
q,
|
||||
vocab_size,
|
||||
IS_NGRAM=draft_probs is None,
|
||||
no_draft_probs=draft_probs is None,
|
||||
)
|
||||
return recovered_token_ids
|
||||
|
||||
|
||||
def rejection_greedy_sample_spec_len_1_pytorch(
|
||||
output_token_ids, # [batch_size, 2]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
output_token_ids, # [batch_size, 2]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
):
|
||||
batch_size = output_token_ids.size(0)
|
||||
num_tokens = draft_token_ids.size(0)
|
||||
@@ -447,73 +436,72 @@ def rejection_greedy_sample_spec_len_1_pytorch(
|
||||
accept_req_mask = draft_token_ids == target_argmax
|
||||
output_token_ids[:, 0] = target_argmax
|
||||
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids,
|
||||
output_token_ids[:, 1])
|
||||
output_token_ids[:, 1] = torch.where(
|
||||
accept_req_mask, bonus_token_ids, output_token_ids[:, 1]
|
||||
)
|
||||
|
||||
|
||||
def rejection_greedy_sample_pytorch(
|
||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens, # [batch_size]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
draft_tokens_per_req, # [batch_size], list
|
||||
max_spec_len,
|
||||
is_greedy=None, # [batch_size] or None
|
||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens, # [batch_size]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
draft_tokens_per_req, # [batch_size], list
|
||||
max_spec_len,
|
||||
is_greedy=None, # [batch_size] or None
|
||||
):
|
||||
batch_size = output_token_ids.size(0)
|
||||
num_tokens = draft_token_ids.size(0)
|
||||
device = output_token_ids.device
|
||||
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
|
||||
device, non_blocking=True)
|
||||
device, non_blocking=True
|
||||
)
|
||||
if is_greedy is None:
|
||||
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
start_indices = cu_num_draft_tokens - draft_tokens_per_req
|
||||
req_ids = torch.arange(batch_size, device=device)
|
||||
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
|
||||
token_positions = torch.arange(
|
||||
num_tokens, device=device) - start_indices[token_req_ids]
|
||||
token_positions = (
|
||||
torch.arange(num_tokens, device=device) - start_indices[token_req_ids]
|
||||
)
|
||||
|
||||
# Find the first mismatch position of each request.
|
||||
mismatch_global = (draft_token_ids != target_argmax)
|
||||
mismatch_global = draft_token_ids != target_argmax
|
||||
if max_spec_len == 0:
|
||||
first_mismatch_pos_per_req = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
first_mismatch_pos_per_req = torch.zeros(
|
||||
batch_size, dtype=torch.long, device=device
|
||||
)
|
||||
else:
|
||||
# [bs, max_spec_len]
|
||||
pos_matrix = torch.full((batch_size, max_spec_len),
|
||||
-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
pos_matrix = torch.full(
|
||||
(batch_size, max_spec_len), -1, dtype=torch.long, device=device
|
||||
)
|
||||
pos_matrix[token_req_ids, token_positions] = token_positions
|
||||
mismatch_matrix = torch.full((batch_size, max_spec_len),
|
||||
False,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
mismatch_matrix = torch.full(
|
||||
(batch_size, max_spec_len), False, dtype=torch.bool, device=device
|
||||
)
|
||||
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
|
||||
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
|
||||
max_spec_len * 2)
|
||||
mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2)
|
||||
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
|
||||
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
|
||||
no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2
|
||||
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
|
||||
no_mismatch_mask]
|
||||
no_mismatch_mask
|
||||
]
|
||||
|
||||
# Copy matched target tokens into output.
|
||||
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
|
||||
draft_tokens_per_req)
|
||||
copy_indices = torch.arange(max_spec_len + 1,
|
||||
device=device).expand(batch_size, -1)
|
||||
copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req)
|
||||
copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1)
|
||||
copy_mask = copy_indices < copy_len.unsqueeze(1)
|
||||
greedy_mask = is_greedy.unsqueeze(1)
|
||||
final_copy_mask = copy_mask & greedy_mask
|
||||
global_idx = start_indices.unsqueeze(1) + copy_indices
|
||||
output_token_ids[final_copy_mask] = target_argmax[
|
||||
global_idx[final_copy_mask]].to(output_token_ids.dtype)
|
||||
output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(
|
||||
output_token_ids.dtype
|
||||
)
|
||||
# Fill bonus token.
|
||||
needs_bonus = is_greedy & (first_mismatch_pos_per_req
|
||||
>= draft_tokens_per_req)
|
||||
needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req)
|
||||
if torch.any(needs_bonus):
|
||||
bonus_rows = torch.where(needs_bonus)[0]
|
||||
bonus_cols = draft_tokens_per_req[bonus_rows]
|
||||
@@ -556,11 +544,9 @@ def rejection_random_sample_pytorch(
|
||||
if IS_NGRAM:
|
||||
draft_prob = 1.0
|
||||
else:
|
||||
draft_prob = draft_probs[start_idx + pos,
|
||||
draft_token_id].item()
|
||||
draft_prob = draft_probs[start_idx + pos, draft_token_id].item()
|
||||
|
||||
target_prob = target_probs[start_idx + pos,
|
||||
draft_token_id].item()
|
||||
target_prob = target_probs[start_idx + pos, draft_token_id].item()
|
||||
uniform_prob = uniform_probs[start_idx + pos].item()
|
||||
|
||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||
@@ -629,12 +615,11 @@ def sample_recovered_tokens_pytorch(
|
||||
else:
|
||||
draft_p = draft_probs[token_idx].clone()
|
||||
target_p = target_probs[token_idx].clone()
|
||||
prob = torch.maximum(target_p - draft_p,
|
||||
torch.tensor(0.0, device=target_p.device))
|
||||
prob = torch.maximum(
|
||||
target_p - draft_p, torch.tensor(0.0, device=target_p.device)
|
||||
)
|
||||
|
||||
q_values = torch.full((vocab_size, ),
|
||||
float('-inf'),
|
||||
device=q.device)
|
||||
q_values = torch.full((vocab_size,), float("-inf"), device=q.device)
|
||||
q_values[:vocab_size] = q[req_idx, :vocab_size]
|
||||
|
||||
recovered_id = torch.argmax(prob / q_values).item()
|
||||
@@ -642,4 +627,3 @@ def sample_recovered_tokens_pytorch(
|
||||
|
||||
if IS_NGRAM:
|
||||
target_probs[token_idx, draft_token_id] = orig_prob
|
||||
|
||||
|
||||
Reference in New Issue
Block a user