提交vllm0.11.0开发分支
This commit is contained in:
@@ -22,7 +22,7 @@ class TopKTopPSampler(nn.Module):
|
||||
Implementations may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, logprobs_mode):
|
||||
super().__init__()
|
||||
logger.info_once(
|
||||
"Using FlashInfer for top-p & top-k sampling.")
|
||||
@@ -57,7 +57,7 @@ class TopKTopPSampler(nn.Module):
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
return random_sample(probs, generators), None
|
||||
if generators:
|
||||
logger.warning_once("FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
@@ -66,8 +66,7 @@ class TopKTopPSampler(nn.Module):
|
||||
# flashinfer sampling functions expect contiguous logits.
|
||||
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
|
||||
# because of slicing operation in logits_processor.
|
||||
return flashinfer_sample(logits.contiguous(), k, p, generators)
|
||||
|
||||
return flashinfer_sample(logits.contiguous(), k, p, generators), None
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
@@ -195,4 +194,4 @@ def flashinfer_sample(
|
||||
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
|
||||
probs, top_k=k, top_p=p, deterministic=True)
|
||||
|
||||
return next_token_ids.view(-1)
|
||||
return next_token_ids.view(-1)
|
||||
Reference in New Issue
Block a user