[Feature] support deepseek v3/r1/v3.2 (#78)

* [Feature] support deepseek v3/r1/v3.2

* fix gpt_oss

* update readme

* update readme

---------

Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
baoqian426
2026-01-05 22:55:35 +08:00
committed by GitHub
parent 07bc24a555
commit ee0f50e68f
27 changed files with 5760 additions and 621 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import os
import torch
import torch.nn as nn
from packaging import version
@@ -24,6 +24,7 @@ class TopKTopPSampler(nn.Module):
def __init__(self, logprobs_mode):
super().__init__()
self.logprobs_mode = logprobs_mode
logger.info_once(
"Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_kunlun
@@ -40,9 +41,14 @@ class TopKTopPSampler(nn.Module):
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), None
return random_sample(probs, generators), logits_to_return
def forward_kunlun(
self,
@@ -52,16 +58,13 @@ class TopKTopPSampler(nn.Module):
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
if k is None and p is None:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), None
if generators:
logger.warning_once("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
if (k is None and p is None) or generators:
if generators:
logger.debug_once(
"FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation."
)
return self.forward_native(logits, generators, k, p)
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
@@ -196,6 +199,7 @@ def flashinfer_sample(
probs, top_k=k, deterministic=True)
else:
# Both top-k and top-p.
k = k.to(torch.int32)
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
probs, top_k=k, top_p=p, deterministic=True)