Slightly improve the sampler to skip unnecessary steps (#6956)
This commit is contained in:
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_group
|
||||
from sglang.srt.distributed import get_tp_group
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -30,7 +30,7 @@ class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
|
||||
self.tp_sync_group = get_tensor_model_parallel_group().device_group
|
||||
self.tp_sync_group = get_tp_group().device_group
|
||||
|
||||
if global_server_args_dict["enable_dp_attention"]:
|
||||
self.tp_sync_group = get_attention_tp_group().device_group
|
||||
@@ -59,7 +59,7 @@ class Sampler(nn.Module):
|
||||
|
||||
# Apply the custom logit processors if registered in the sampling info.
|
||||
if sampling_info.has_custom_logit_processor:
|
||||
self._apply_custom_logit_processor(logits, sampling_info)
|
||||
apply_custom_logit_processor(logits, sampling_info)
|
||||
|
||||
if self.use_nan_detection and torch.any(torch.isnan(logits)):
|
||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||
@@ -81,49 +81,39 @@ class Sampler(nn.Module):
|
||||
probs = logits
|
||||
del logits
|
||||
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
if return_logprob:
|
||||
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
||||
# so we use the torch implementation.
|
||||
# NOTE: OpenAI's logprobs is independent of top-p, we use the
|
||||
# same rule.
|
||||
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
||||
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
if sampling_info.need_min_p_sampling:
|
||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||
batch_next_token_ids = min_p_sampling_from_probs(
|
||||
probs, sampling_info.min_ps
|
||||
)
|
||||
else:
|
||||
# Check Nan will throw exception, only check when crash_on_warnings is True
|
||||
check_nan = self.use_nan_detection and crash_on_warnings()
|
||||
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
||||
probs.contiguous(),
|
||||
if True: # Keep this redundant check to simplify some internal code sync
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
if sampling_info.need_min_p_sampling:
|
||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||
batch_next_token_ids = min_p_sampling_from_probs(
|
||||
probs, sampling_info.min_ps
|
||||
)
|
||||
else:
|
||||
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
sampling_info.top_ks,
|
||||
sampling_info.top_ps,
|
||||
filter_apply_order="joint",
|
||||
check_nan=self.use_nan_detection,
|
||||
)
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# A slower fallback implementation with torch native operations.
|
||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs,
|
||||
sampling_info.top_ks,
|
||||
sampling_info.top_ps,
|
||||
filter_apply_order="joint",
|
||||
check_nan=check_nan,
|
||||
sampling_info.min_ps,
|
||||
sampling_info.need_min_p_sampling,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
)
|
||||
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# A slower fallback implementation with torch native operations.
|
||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs,
|
||||
sampling_info.top_ks,
|
||||
sampling_info.top_ps,
|
||||
sampling_info.min_ps,
|
||||
sampling_info.need_min_p_sampling,
|
||||
)
|
||||
|
||||
if return_logprob:
|
||||
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
)
|
||||
if return_logprob:
|
||||
# clamp to avoid -inf
|
||||
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
||||
|
||||
# Attach logprobs to logits_output (in-place modification)
|
||||
if return_logprob:
|
||||
@@ -160,39 +150,6 @@ class Sampler(nn.Module):
|
||||
|
||||
return batch_next_token_ids
|
||||
|
||||
def _apply_custom_logit_processor(
|
||||
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
||||
):
|
||||
"""Apply custom logit processors to the logits.
|
||||
This function will modify the logits in-place."""
|
||||
|
||||
assert logits.shape[0] == len(sampling_batch_info), (
|
||||
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
|
||||
f"sampling_batch_info ({len(sampling_batch_info)})"
|
||||
)
|
||||
|
||||
for _, (
|
||||
processor,
|
||||
batch_mask,
|
||||
) in sampling_batch_info.custom_logit_processor.items():
|
||||
# Get the batch indices that need to be processed
|
||||
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
||||
|
||||
assert batch_mask.shape[0] == len(sampling_batch_info), (
|
||||
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
|
||||
f"sampling_batch_info ({len(sampling_batch_info)})"
|
||||
)
|
||||
|
||||
# Apply the processor to the logits
|
||||
logits[batch_mask] = processor(
|
||||
logits[batch_mask],
|
||||
[sampling_batch_info.custom_params[i] for i in batch_indices],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Custom logit processor {processor.__class__.__name__} is applied."
|
||||
)
|
||||
|
||||
|
||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs: torch.Tensor,
|
||||
@@ -221,6 +178,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
def sampling_from_probs_torch(probs: torch.Tensor):
|
||||
"""A sampling implementation with native pytorch operations, without
|
||||
top-k, top-p, or min-p filtering."""
|
||||
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||
batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
def top_p_normalize_probs_torch(
|
||||
probs: torch.Tensor,
|
||||
top_ps: torch.Tensor,
|
||||
@@ -259,3 +224,44 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
|
||||
output_token_ids_logprobs_idx.append([])
|
||||
|
||||
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
|
||||
|
||||
|
||||
def apply_custom_logit_processor(
|
||||
logits: torch.Tensor,
|
||||
sampling_batch_info: SamplingBatchInfo,
|
||||
num_tokens_in_batch: int = 1,
|
||||
):
|
||||
"""Apply custom logit processors to the logits.
|
||||
This function will modify the logits in-place.
|
||||
num_tokens_in_batch is needed to support spec decoding, where each batch can contain multiple
|
||||
tokens. By default, we assume each batch contains only 1 token.
|
||||
"""
|
||||
|
||||
assert logits.shape[0] == len(sampling_batch_info) * num_tokens_in_batch, (
|
||||
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
|
||||
f"sampling_batch_info ({len(sampling_batch_info)}) x num_tokens_in_batch "
|
||||
f"({num_tokens_in_batch})"
|
||||
)
|
||||
|
||||
for _, (
|
||||
processor,
|
||||
batch_mask,
|
||||
) in sampling_batch_info.custom_logit_processor.items():
|
||||
# Get the batch indices that need to be processed
|
||||
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
||||
|
||||
assert batch_mask.shape[0] == len(sampling_batch_info), (
|
||||
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
|
||||
f"sampling_batch_info ({len(sampling_batch_info)})"
|
||||
)
|
||||
batch_mask = torch.repeat_interleave(batch_mask, num_tokens_in_batch)
|
||||
|
||||
# Apply the processor to the logits
|
||||
logits[batch_mask] = processor(
|
||||
logits[batch_mask],
|
||||
[sampling_batch_info.custom_params[i] for i in batch_indices],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Custom logit processor {processor.__class__.__name__} is applied."
|
||||
)
|
||||
|
||||
@@ -852,7 +852,7 @@ class TokenizerManager:
|
||||
obj.load_format = self.server_args.load_format
|
||||
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
||||
|
||||
if True:
|
||||
if True: # Keep this redundant check to simplify some internal code sync
|
||||
# Hold the lock if it is not async. This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -9,10 +9,12 @@ import torch
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.sampling.sampling_params import TOP_K_ALL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -27,6 +29,12 @@ class SamplingBatchInfo:
|
||||
# Whether all requests use greedy sampling
|
||||
is_all_greedy: bool
|
||||
|
||||
# Whether any requests use top_p sampling
|
||||
need_top_p_sampling: bool
|
||||
|
||||
# Whether any requests use top_k sampling
|
||||
need_top_k_sampling: bool
|
||||
|
||||
# Whether any request needs min_p sampling
|
||||
need_min_p_sampling: bool
|
||||
|
||||
@@ -133,6 +141,8 @@ class SamplingBatchInfo:
|
||||
top_ks=top_ks,
|
||||
min_ps=min_ps,
|
||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
|
||||
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
vocab_size=vocab_size,
|
||||
penalizer_orchestrator=penalizer_orchestrator,
|
||||
@@ -167,7 +177,7 @@ class SamplingBatchInfo:
|
||||
|
||||
# Apply the mask
|
||||
for i, grammar in enumerate(self.grammars):
|
||||
if grammar and not grammar.finished:
|
||||
if grammar and not grammar.finished and not grammar.is_terminated():
|
||||
grammar.fill_vocab_mask(self.vocab_mask, i)
|
||||
|
||||
# Move the mask to the device if needed
|
||||
@@ -308,4 +318,6 @@ class SamplingBatchInfo:
|
||||
setattr(self, item, torch.cat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy &= other.is_all_greedy
|
||||
self.need_top_p_sampling |= other.need_top_p_sampling
|
||||
self.need_top_k_sampling |= other.need_top_k_sampling
|
||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-6
|
||||
TOP_K_ALL = 1 << 30
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
@@ -84,7 +85,7 @@ class SamplingParams:
|
||||
self.temperature = 1.0
|
||||
self.top_k = 1
|
||||
if self.top_k == -1:
|
||||
self.top_k = 1 << 30 # whole vocabulary
|
||||
self.top_k = TOP_K_ALL # whole vocabulary
|
||||
|
||||
def verify(self):
|
||||
if self.temperature < 0.0:
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.communicator import (
|
||||
CommunicateContext,
|
||||
CommunicateSimpleFn,
|
||||
CommunicateSummableTensorPairFn,
|
||||
ScatterMode,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -20,9 +18,6 @@ from sglang.srt.operations import execute_operations, execute_overlapped_operati
|
||||
from sglang.srt.operations_strategy import OperationsStrategy
|
||||
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
|
||||
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -46,7 +41,7 @@ def compute_split_seq_index(
|
||||
assert num_tokens == 0
|
||||
return 0
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
||||
|
||||
@@ -1928,16 +1928,18 @@ def next_power_of_2(n: int):
|
||||
setattr(triton, "next_power_of_2", next_power_of_2)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def empty_context(*args, **kwargs):
|
||||
try:
|
||||
# Setup code goes here
|
||||
yield
|
||||
finally:
|
||||
# Cleanup code goes here
|
||||
class EmptyContextManager:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
|
||||
def empty_context(*args, **kwargs):
|
||||
return EmptyContextManager()
|
||||
|
||||
|
||||
def add_prefix(name: str, prefix: str) -> str:
|
||||
"""Add a weight path prefix to a module name.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user