[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)

Signed-off-by: xyDong0223 <dongxinyu03@baidu.com>
Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
chanzhennan
2026-02-28 11:15:50 +08:00
committed by GitHub
parent 153093d3b3
commit 82544aa0cc
17 changed files with 2668 additions and 1532 deletions

View File

@@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Union
import kunlun_ops
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
"""
Args:
metadata:
Metadata for spec decoding.
@@ -81,7 +80,7 @@ class RejectionSampler(nn.Module):
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
"""
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
@@ -124,11 +123,11 @@ class RejectionSampler(nn.Module):
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
outputs = [
row[valid_mask[i]].tolist()
for i, row in enumerate(output_token_ids_np)
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
return outputs
@@ -179,25 +178,15 @@ def rejection_sample(
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids,
draft_token_ids,
target_argmax,
bonus_token_ids,
)
else:
rejection_greedy_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
num_draft_tokens,
max_spec_len,
is_greedy,
)
kunlun_ops.rejection_greedy_sample(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
)
if sampling_metadata.all_greedy:
return output_token_ids
@@ -222,8 +211,9 @@ def rejection_sample(
sampling_metadata,
device,
)
bonus_token_ids = bonus_token_ids.squeeze(1)
rejection_random_sample_pytorch(
kunlun_ops.rejection_random_sample(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -235,8 +225,7 @@ def rejection_sample(
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
# num_warps=1,
no_draft_probs=draft_probs is None,
)
return output_token_ids
@@ -374,7 +363,7 @@ def generate_uniform_probs(
random values in the range [0, 1).
"""
uniform_probs = torch.rand(
(num_tokens, ),
(num_tokens,),
dtype=torch.float32,
device=device,
)
@@ -422,7 +411,7 @@ def sample_recovered_tokens(
q[i].exponential_(generator=generator)
recovered_token_ids = torch.empty_like(draft_token_ids)
sample_recovered_tokens_pytorch(
kunlun_ops.sample_recovered_tokens(
recovered_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -430,16 +419,16 @@ def sample_recovered_tokens(
target_probs,
q,
vocab_size,
IS_NGRAM=draft_probs is None,
no_draft_probs=draft_probs is None,
)
return recovered_token_ids
def rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
):
batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0)
@@ -447,73 +436,72 @@ def rejection_greedy_sample_spec_len_1_pytorch(
accept_req_mask = draft_token_ids == target_argmax
output_token_ids[:, 0] = target_argmax
bonus_token_ids = bonus_token_ids.squeeze(1)
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids,
output_token_ids[:, 1])
output_token_ids[:, 1] = torch.where(
accept_req_mask, bonus_token_ids, output_token_ids[:, 1]
)
def rejection_greedy_sample_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list
max_spec_len,
is_greedy=None, # [batch_size] or None
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list
max_spec_len,
is_greedy=None, # [batch_size] or None
):
batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0)
device = output_token_ids.device
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
device, non_blocking=True)
device, non_blocking=True
)
if is_greedy is None:
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
start_indices = cu_num_draft_tokens - draft_tokens_per_req
req_ids = torch.arange(batch_size, device=device)
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
token_positions = torch.arange(
num_tokens, device=device) - start_indices[token_req_ids]
token_positions = (
torch.arange(num_tokens, device=device) - start_indices[token_req_ids]
)
# Find the first mismatch position of each request.
mismatch_global = (draft_token_ids != target_argmax)
mismatch_global = draft_token_ids != target_argmax
if max_spec_len == 0:
first_mismatch_pos_per_req = torch.zeros(batch_size,
dtype=torch.long,
device=device)
first_mismatch_pos_per_req = torch.zeros(
batch_size, dtype=torch.long, device=device
)
else:
# [bs, max_spec_len]
pos_matrix = torch.full((batch_size, max_spec_len),
-1,
dtype=torch.long,
device=device)
pos_matrix = torch.full(
(batch_size, max_spec_len), -1, dtype=torch.long, device=device
)
pos_matrix[token_req_ids, token_positions] = token_positions
mismatch_matrix = torch.full((batch_size, max_spec_len),
False,
dtype=torch.bool,
device=device)
mismatch_matrix = torch.full(
(batch_size, max_spec_len), False, dtype=torch.bool, device=device
)
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
max_spec_len * 2)
mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2)
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
no_mismatch_mask]
no_mismatch_mask
]
# Copy matched target tokens into output.
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
draft_tokens_per_req)
copy_indices = torch.arange(max_spec_len + 1,
device=device).expand(batch_size, -1)
copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req)
copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1)
copy_mask = copy_indices < copy_len.unsqueeze(1)
greedy_mask = is_greedy.unsqueeze(1)
final_copy_mask = copy_mask & greedy_mask
global_idx = start_indices.unsqueeze(1) + copy_indices
output_token_ids[final_copy_mask] = target_argmax[
global_idx[final_copy_mask]].to(output_token_ids.dtype)
output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(
output_token_ids.dtype
)
# Fill bonus token.
needs_bonus = is_greedy & (first_mismatch_pos_per_req
>= draft_tokens_per_req)
needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req)
if torch.any(needs_bonus):
bonus_rows = torch.where(needs_bonus)[0]
bonus_cols = draft_tokens_per_req[bonus_rows]
@@ -556,11 +544,9 @@ def rejection_random_sample_pytorch(
if IS_NGRAM:
draft_prob = 1.0
else:
draft_prob = draft_probs[start_idx + pos,
draft_token_id].item()
draft_prob = draft_probs[start_idx + pos, draft_token_id].item()
target_prob = target_probs[start_idx + pos,
draft_token_id].item()
target_prob = target_probs[start_idx + pos, draft_token_id].item()
uniform_prob = uniform_probs[start_idx + pos].item()
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
@@ -629,12 +615,11 @@ def sample_recovered_tokens_pytorch(
else:
draft_p = draft_probs[token_idx].clone()
target_p = target_probs[token_idx].clone()
prob = torch.maximum(target_p - draft_p,
torch.tensor(0.0, device=target_p.device))
prob = torch.maximum(
target_p - draft_p, torch.tensor(0.0, device=target_p.device)
)
q_values = torch.full((vocab_size, ),
float('-inf'),
device=q.device)
q_values = torch.full((vocab_size,), float("-inf"), device=q.device)
q_values[:vocab_size] = q[req_idx, :vocab_size]
recovered_id = torch.argmax(prob / q_values).item()
@@ -642,4 +627,3 @@ def sample_recovered_tokens_pytorch(
if IS_NGRAM:
target_probs[token_idx, draft_token_id] = orig_prob