diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index f36dc5ca2..5bc9cc1d8 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -5,7 +5,7 @@ import torch import torch.distributed as dist from torch import nn -from sglang.srt.distributed import get_tensor_model_parallel_group +from sglang.srt.distributed import get_tp_group from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -30,7 +30,7 @@ class Sampler(nn.Module): def __init__(self): super().__init__() self.use_nan_detection = global_server_args_dict["enable_nan_detection"] - self.tp_sync_group = get_tensor_model_parallel_group().device_group + self.tp_sync_group = get_tp_group().device_group if global_server_args_dict["enable_dp_attention"]: self.tp_sync_group = get_attention_tp_group().device_group @@ -59,7 +59,7 @@ class Sampler(nn.Module): # Apply the custom logit processors if registered in the sampling info. if sampling_info.has_custom_logit_processor: - self._apply_custom_logit_processor(logits, sampling_info) + apply_custom_logit_processor(logits, sampling_info) if self.use_nan_detection and torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") @@ -81,49 +81,39 @@ class Sampler(nn.Module): probs = logits del logits - if global_server_args_dict["sampling_backend"] == "flashinfer": - if return_logprob: - # NOTE: the top_p_renorm_prob from flashinfer has numerical problems, - # https://github.com/flashinfer-ai/flashinfer/issues/708 - # so we use the torch implementation. - # NOTE: OpenAI's logprobs is independent of top-p, we use the - # same rule. - logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) - - max_top_k_round, batch_size = 32, probs.shape[0] - if sampling_info.need_min_p_sampling: - probs = top_k_renorm_prob(probs, sampling_info.top_ks) - probs = top_p_renorm_prob(probs, sampling_info.top_ps) - batch_next_token_ids = min_p_sampling_from_probs( - probs, sampling_info.min_ps - ) - else: - # Check Nan will throw exception, only check when crash_on_warnings is True - check_nan = self.use_nan_detection and crash_on_warnings() - batch_next_token_ids = top_k_top_p_sampling_from_probs( - probs.contiguous(), + if True: # Keep this redundant check to simplify some internal code sync + if global_server_args_dict["sampling_backend"] == "flashinfer": + if sampling_info.need_min_p_sampling: + probs = top_k_renorm_prob(probs, sampling_info.top_ks) + probs = top_p_renorm_prob(probs, sampling_info.top_ps) + batch_next_token_ids = min_p_sampling_from_probs( + probs, sampling_info.min_ps + ) + else: + batch_next_token_ids = top_k_top_p_sampling_from_probs( + probs, + sampling_info.top_ks, + sampling_info.top_ps, + filter_apply_order="joint", + check_nan=self.use_nan_detection, + ) + elif global_server_args_dict["sampling_backend"] == "pytorch": + # A slower fallback implementation with torch native operations. + batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( + probs, sampling_info.top_ks, sampling_info.top_ps, - filter_apply_order="joint", - check_nan=check_nan, + sampling_info.min_ps, + sampling_info.need_min_p_sampling, + ) + else: + raise ValueError( + f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" ) - elif global_server_args_dict["sampling_backend"] == "pytorch": - # A slower fallback implementation with torch native operations. - batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( - probs, - sampling_info.top_ks, - sampling_info.top_ps, - sampling_info.min_ps, - sampling_info.need_min_p_sampling, - ) - - if return_logprob: - logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) - else: - raise ValueError( - f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" - ) + if return_logprob: + # clamp to avoid -inf + logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) # Attach logprobs to logits_output (in-place modification) if return_logprob: @@ -160,39 +150,6 @@ class Sampler(nn.Module): return batch_next_token_ids - def _apply_custom_logit_processor( - self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo - ): - """Apply custom logit processors to the logits. - This function will modify the logits in-place.""" - - assert logits.shape[0] == len(sampling_batch_info), ( - f"The batch size of logits ({logits.shape[0]}) does not match the batch size of " - f"sampling_batch_info ({len(sampling_batch_info)})" - ) - - for _, ( - processor, - batch_mask, - ) in sampling_batch_info.custom_logit_processor.items(): - # Get the batch indices that need to be processed - batch_indices = batch_mask.nonzero(as_tuple=True)[0] - - assert batch_mask.shape[0] == len(sampling_batch_info), ( - f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of " - f"sampling_batch_info ({len(sampling_batch_info)})" - ) - - # Apply the processor to the logits - logits[batch_mask] = processor( - logits[batch_mask], - [sampling_batch_info.custom_params[i] for i in batch_indices], - ) - - logger.debug( - f"Custom logit processor {processor.__class__.__name__} is applied." - ) - def top_k_top_p_min_p_sampling_from_probs_torch( probs: torch.Tensor, @@ -221,6 +178,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch( return batch_next_token_ids +def sampling_from_probs_torch(probs: torch.Tensor): + """A sampling implementation with native pytorch operations, without + top-k, top-p, or min-p filtering.""" + sampled_index = torch.multinomial(probs, num_samples=1) + batch_next_token_ids = sampled_index.view(-1).to(torch.int32) + return batch_next_token_ids + + def top_p_normalize_probs_torch( probs: torch.Tensor, top_ps: torch.Tensor, @@ -259,3 +224,44 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List output_token_ids_logprobs_idx.append([]) return output_token_ids_logprobs_val, output_token_ids_logprobs_idx + + +def apply_custom_logit_processor( + logits: torch.Tensor, + sampling_batch_info: SamplingBatchInfo, + num_tokens_in_batch: int = 1, +): + """Apply custom logit processors to the logits. + This function will modify the logits in-place. + num_tokens_in_batch is needed to support spec decoding, where each batch can contain multiple + tokens. By default, we assume each batch contains only 1 token. + """ + + assert logits.shape[0] == len(sampling_batch_info) * num_tokens_in_batch, ( + f"The batch size of logits ({logits.shape[0]}) does not match the batch size of " + f"sampling_batch_info ({len(sampling_batch_info)}) x num_tokens_in_batch " + f"({num_tokens_in_batch})" + ) + + for _, ( + processor, + batch_mask, + ) in sampling_batch_info.custom_logit_processor.items(): + # Get the batch indices that need to be processed + batch_indices = batch_mask.nonzero(as_tuple=True)[0] + + assert batch_mask.shape[0] == len(sampling_batch_info), ( + f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of " + f"sampling_batch_info ({len(sampling_batch_info)})" + ) + batch_mask = torch.repeat_interleave(batch_mask, num_tokens_in_batch) + + # Apply the processor to the logits + logits[batch_mask] = processor( + logits[batch_mask], + [sampling_batch_info.custom_params[i] for i in batch_indices], + ) + + logger.debug( + f"Custom logit processor {processor.__class__.__name__} is applied." + ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index f4b451dfb..24821007b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -852,7 +852,7 @@ class TokenizerManager: obj.load_format = self.server_args.load_format logger.info("Start update_weights. Load format=%s", obj.load_format) - if True: + if True: # Keep this redundant check to simplify some internal code sync # Hold the lock if it is not async. This means that weight sync # cannot run while requests are in progress. async with self.model_update_lock.writer_lock: diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index dc4d8f9df..f08a50611 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -17,7 +17,7 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" import logging -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 7f169ef04..238f3a7cd 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -9,10 +9,12 @@ import torch import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor +from sglang.srt.sampling.sampling_params import TOP_K_ALL if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch + logger = logging.getLogger(__name__) @@ -27,6 +29,12 @@ class SamplingBatchInfo: # Whether all requests use greedy sampling is_all_greedy: bool + # Whether any requests use top_p sampling + need_top_p_sampling: bool + + # Whether any requests use top_k sampling + need_top_k_sampling: bool + # Whether any request needs min_p sampling need_min_p_sampling: bool @@ -133,6 +141,8 @@ class SamplingBatchInfo: top_ks=top_ks, min_ps=min_ps, is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), + need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs), + need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs), need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), vocab_size=vocab_size, penalizer_orchestrator=penalizer_orchestrator, @@ -167,7 +177,7 @@ class SamplingBatchInfo: # Apply the mask for i, grammar in enumerate(self.grammars): - if grammar and not grammar.finished: + if grammar and not grammar.finished and not grammar.is_terminated(): grammar.fill_vocab_mask(self.vocab_mask, i) # Move the mask to the device if needed @@ -308,4 +318,6 @@ class SamplingBatchInfo: setattr(self, item, torch.cat([self_val, other_val])) self.is_all_greedy &= other.is_all_greedy + self.need_top_p_sampling |= other.need_top_p_sampling + self.need_top_k_sampling |= other.need_top_k_sampling self.need_min_p_sampling |= other.need_min_p_sampling diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 4c505fe7a..87436f86d 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Union _SAMPLING_EPS = 1e-6 +TOP_K_ALL = 1 << 30 class SamplingParams: @@ -84,7 +85,7 @@ class SamplingParams: self.temperature = 1.0 self.top_k = 1 if self.top_k == -1: - self.top_k = 1 << 30 # whole vocabulary + self.top_k = TOP_K_ALL # whole vocabulary def verify(self): if self.temperature < 0.0: diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index b417de7ce..814a7f95d 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -1,17 +1,15 @@ import dataclasses import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence +from typing import Dict, List, Optional, Sequence import torch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.communicator import ( CommunicateContext, - CommunicateSimpleFn, CommunicateSummableTensorPairFn, ScatterMode, ) -from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -20,9 +18,6 @@ from sglang.srt.operations import execute_operations, execute_overlapped_operati from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var -if TYPE_CHECKING: - from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner - _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG") logger = logging.getLogger(__name__) @@ -46,7 +41,7 @@ def compute_split_seq_index( assert num_tokens == 0 return 0 else: - raise NotImplementedError + raise NotImplementedError() def _split_array_by_half_sum(arr: Sequence[int]) -> int: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 32c9f865a..a669b4302 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1928,16 +1928,18 @@ def next_power_of_2(n: int): setattr(triton, "next_power_of_2", next_power_of_2) -@contextmanager -def empty_context(*args, **kwargs): - try: - # Setup code goes here - yield - finally: - # Cleanup code goes here +class EmptyContextManager: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): pass +def empty_context(*args, **kwargs): + return EmptyContextManager() + + def add_prefix(name: str, prefix: str) -> str: """Add a weight path prefix to a module name.