[Feature] support deepseek v3/r1/v3.2 (#78)
* [Feature] support deepseek v3/r1/v3.2 * fix gpt_oss * update readme * update readme --------- Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
@@ -24,6 +24,7 @@ class TopKTopPSampler(nn.Module):
|
||||
|
||||
def __init__(self, logprobs_mode):
|
||||
super().__init__()
|
||||
self.logprobs_mode = logprobs_mode
|
||||
logger.info_once(
|
||||
"Using FlashInfer for top-p & top-k sampling.")
|
||||
self.forward = self.forward_kunlun
|
||||
@@ -40,9 +41,14 @@ class TopKTopPSampler(nn.Module):
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = apply_top_k_top_p(logits, k, p)
|
||||
logits = self.apply_top_k_top_p(logits, k, p)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators), None
|
||||
return random_sample(probs, generators), logits_to_return
|
||||
|
||||
def forward_kunlun(
|
||||
self,
|
||||
@@ -52,16 +58,13 @@ class TopKTopPSampler(nn.Module):
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""More optimized implementation for top-k and top-p sampling."""
|
||||
if k is None and p is None:
|
||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators), None
|
||||
if generators:
|
||||
logger.warning_once("FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation.")
|
||||
if (k is None and p is None) or generators:
|
||||
if generators:
|
||||
logger.debug_once(
|
||||
"FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation."
|
||||
)
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
# flashinfer sampling functions expect contiguous logits.
|
||||
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
|
||||
@@ -196,6 +199,7 @@ def flashinfer_sample(
|
||||
probs, top_k=k, deterministic=True)
|
||||
else:
|
||||
# Both top-k and top-p.
|
||||
k = k.to(torch.int32)
|
||||
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
|
||||
probs, top_k=k, top_p=p, deterministic=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user