diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index ad8dc9405..0bc20b416 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -321,7 +321,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--debug-tensor-dump-output-folder` | The output folder for debug tensor dumps. | None | | `--debug-tensor-dump-input-file` | The input file for debug tensor dumps. | None | | `--debug-tensor-dump-inject` | Enable injection of debug tensor dumps. | False | -| `--debug-tensor-dump-prefill-only` | Enable prefill-only mode for debug tensor dumps. | False | ## PD disaggregation diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index daf18e68c..78e3f2b9a 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -240,6 +240,7 @@ class GroupCoordinator: use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, torch_compile: Optional[bool] = None, + gloo_timeout: timedelta = timedelta(seconds=120 * 60), ): # Set group info group_name = group_name or "anonymous" @@ -259,7 +260,9 @@ class GroupCoordinator: ) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. - cpu_group = torch.distributed.new_group(ranks, backend="gloo") + cpu_group = torch.distributed.new_group( + ranks, backend="gloo", timeout=gloo_timeout + ) if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index f7b1bc369..bf50d4b11 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -91,7 +91,6 @@ class Sampler(nn.Module): batch_next_token_ids = torch.argmax(logits, -1) if return_logprob: logprobs = torch.nn.functional.log_softmax(logits, dim=-1) - else: # If requested, cache probabilities from original logits before temperature scaling. if return_logprob and RETURN_ORIGINAL_LOGPROB: @@ -288,21 +287,29 @@ def multinomial_with_seed( """ n, m = inputs.shape col_indices = torch.arange(m, device=inputs.device).unsqueeze(0) - step_seed = seed * 19349663 ^ positions * 73856093 + step_seed = (seed * 19349663) ^ (positions * 73856093) seed_expanded = step_seed.unsqueeze(-1) - hashed = seed_expanded * 8589934591 ^ col_indices * 479001599 + hashed = (seed_expanded * 8589934591) ^ (col_indices * 479001599) uniform_samples = (hashed % (2**24)).float() / (2**24) - epsilon = 1e-9 - gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon) + epsilon = 1e-10 + uniform_samples = uniform_samples.clamp(epsilon, 1.0 - epsilon) + gumbel_noise = -torch.log(-torch.log(uniform_samples)) log_probs = torch.log(inputs + epsilon) perturbed_log_probs = log_probs + gumbel_noise return torch.argmax(perturbed_log_probs, dim=1, keepdim=True) -def sampling_from_probs_torch(probs: torch.Tensor): +def sampling_from_probs_torch( + probs: torch.Tensor, + sampling_seed: Optional[torch.Tensor] = None, + positions: Optional[torch.Tensor] = None, +): """A sampling implementation with native pytorch operations, without top-k, top-p, or min-p filtering.""" - sampled_index = torch.multinomial(probs, num_samples=1) + if sampling_seed is not None: + sampled_index = multinomial_with_seed(probs, sampling_seed, positions) + else: + sampled_index = torch.multinomial(probs, num_samples=1) batch_next_token_ids = sampled_index.view(-1).to(torch.int32) return batch_next_token_ids diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 68132991c..f8135767e 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -245,9 +245,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx, output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val, output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, + output_token_entropy_val=recv_obj.output_token_entropy_val, output_hidden_states=recv_obj.output_hidden_states, placeholder_tokens_idx=None, placeholder_tokens_val=None, + token_steps=recv_obj.token_steps, ) def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index e6dfa35c4..bb542b7bd 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -170,6 +170,9 @@ class GenerateReqInput(BaseReq): # (Internal) Whether to return bytes for image generation return_bytes: bool = False + # Whether to return entropy + return_entropy: bool = False + def contains_mm_input(self) -> bool: return ( has_valid_data(self.image_data) @@ -568,6 +571,7 @@ class GenerateReqInput(BaseReq): no_logs=self.no_logs, custom_labels=self.custom_labels, return_bytes=self.return_bytes, + return_entropy=self.return_entropy, ) @@ -633,6 +637,9 @@ class TokenizedGenerateReqInput(BaseReq): # (Internal) Whether to return bytes for image generation return_bytes: bool = False + # Whether to return entropy + return_entropy: bool = False + @dataclass class BatchTokenizedGenerateReqInput(BaseBatchReq): @@ -830,6 +837,7 @@ class BatchTokenIDOutput(BaseBatchReq): input_token_ids_logprobs_idx: List[List] output_token_ids_logprobs_val: List[List] output_token_ids_logprobs_idx: List[List] + output_token_entropy_val: List[float] # Hidden states output_hidden_states: List[List[float]] @@ -840,6 +848,9 @@ class BatchTokenIDOutput(BaseBatchReq): placeholder_tokens_idx: List[Optional[List[int]]] placeholder_tokens_val: List[Optional[List[int]]] + # The trainer step id. Used to know which step's weights are used for sampling. + token_steps: List[List[int]] = None + @dataclass class BatchMultimodalDecodeReq(BaseBatchReq): @@ -861,11 +872,14 @@ class BatchMultimodalDecodeReq(BaseBatchReq): completion_tokens: List[int] cached_tokens: List[int] - # Placeholder token info + # The information of placeholder tokens (e.g., image token) + # idx is the index of the token in the prompt after expansion. + # val is the length of padded tokens after expansion. placeholder_tokens_idx: List[Optional[List[int]]] placeholder_tokens_val: List[Optional[List[int]]] - return_bytes: bool = False + # The trainer step id. Used to know which step's weights are used for sampling. + token_steps: List[List[int]] = None @dataclass @@ -896,13 +910,20 @@ class BatchStrOutput(BaseBatchReq): input_token_ids_logprobs_idx: List[List] output_token_ids_logprobs_val: List[List] output_token_ids_logprobs_idx: List[List] + output_token_entropy_val: List[float] # Hidden states output_hidden_states: List[List[float]] + # The information of placeholder tokens (e.g., image token) + # idx is the index of the token in the prompt after expansion. + # val is the length of padded tokens after expansion. placeholder_tokens_idx: List[Optional[List[int]]] placeholder_tokens_val: List[Optional[List[int]]] + # The trainer step id. Used to know which step's weights are used for sampling. + token_steps: List[List[int]] = None + @dataclass class BatchMultimodalOutput(BaseBatchReq): @@ -979,6 +1000,8 @@ class UpdateWeightFromDiskReqInput(BaseReq): torch_empty_cache: bool = False # Whether to keep the scheduler paused after weight update keep_pause: bool = False + # The trainer step id. Used to know which step's weights are used for sampling. + token_step: int = 0 @dataclass @@ -1416,6 +1439,16 @@ class WatchLoadUpdateReq(BaseReq): loads: List[GetLoadReqOutput] +@dataclass +class LazyDumpTensorsReqInput(BaseReq): + pass + + +@dataclass +class LazyDumpTensorsReqOutput(BaseReq): + success: bool + + def _check_all_req_types(): """A helper function to check all request types are defined in this file.""" import inspect diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 83c966ec6..302546e5f 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -190,6 +190,11 @@ def _handle_output_by_index(output, i): if output.output_token_ids_logprobs_idx else None ), + output_token_entropy_val=( + [output.output_token_entropy_val[i]] + if output.output_token_entropy_val + else None + ), output_hidden_states=( [output.output_hidden_states[i]] if output.output_hidden_states @@ -197,6 +202,7 @@ def _handle_output_by_index(output, i): ), placeholder_tokens_idx=None, placeholder_tokens_val=None, + token_steps=([output.token_steps[i]] if output.token_steps else None), ) elif isinstance(output, BatchEmbeddingOutput): new_output = BatchEmbeddingOutput( @@ -306,6 +312,11 @@ def _handle_output_by_index(output, i): if output.output_token_ids_logprobs_idx else None ), + output_token_entropy_val=( + [output.output_token_entropy_val[i]] + if output.output_token_entropy_val + else None + ), output_hidden_states=( [output.output_hidden_states[i]] if output.output_hidden_states @@ -313,6 +324,7 @@ def _handle_output_by_index(output, i): ), placeholder_tokens_idx=None, placeholder_tokens_val=None, + token_steps=([output.token_steps[i]] if output.token_steps else None), ) elif isinstance(output, BatchMultimodalOutput): new_output = BatchMultimodalOutput( diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index ef7a3f54e..ba3b09e1a 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -920,7 +920,8 @@ class SchedulerOutputProcessorMixin: input_token_ids_logprobs_idx, output_token_ids_logprobs_val, output_token_ids_logprobs_idx, - output_hidden_states, + output_token_entropy_val=None, + output_hidden_states=output_hidden_states, rids=rids, placeholder_tokens_idx=None, placeholder_tokens_val=None, diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 330acef25..1f4a3b443 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -73,9 +73,6 @@ logger = logging.getLogger(__name__) # Dump tensors for debugging debug_tensor_dump_output_folder = None -debug_tensor_dump_prefill_only = False -# Skip all the other tensor dumps, only dump the target logits -debug_tensor_dump_only_target_logprobs = False debug_tensor_dump_inject = False debug_tensor_dump_layers = None debug_tensor_dump_test = False diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9f34517f5..e20e35129 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -455,7 +455,6 @@ class ServerArgs: debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_input_file: Optional[str] = None debug_tensor_dump_inject: bool = False - debug_tensor_dump_prefill_only: bool = False # PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only) disaggregation_mode: Literal["null", "prefill", "decode"] = "null" @@ -2831,11 +2830,6 @@ class ServerArgs: default=ServerArgs.debug_tensor_dump_inject, help="Inject the outputs from jax as the input of every layer.", ) - parser.add_argument( - "--debug-tensor-dump-prefill-only", - action="store_true", - help="Only dump the tensors for prefill requests (i.e. batch size > 1).", - ) parser.add_argument( "--enable-dynamic-batch-tokenizer", action="store_true", diff --git a/python/sglang/test/test_cutlass_moe.py b/python/sglang/test/test_cutlass_moe.py index 56f276c81..377534a49 100755 --- a/python/sglang/test/test_cutlass_moe.py +++ b/python/sglang/test/test_cutlass_moe.py @@ -34,7 +34,7 @@ def get_model_config(tp_size: int): "topk": topk, "hidden_size": config.hidden_size, "shard_intermediate_size": shard_intermediate_size, - "dtype": config.torch_dtype, + "dtype": config.dtype, "block_shape": config.quantization_config["weight_block_size"], } diff --git a/python/sglang/test/test_deterministic_utils.py b/python/sglang/test/test_deterministic_utils.py index 0c1607686..c665c8033 100644 --- a/python/sglang/test/test_deterministic_utils.py +++ b/python/sglang/test/test_deterministic_utils.py @@ -1,8 +1,5 @@ -import time import unittest -import requests - from sglang.srt.utils import kill_process_tree from sglang.test.test_deterministic import BenchArgs, test_deterministic from sglang.test.test_utils import ( @@ -55,6 +52,7 @@ class TestDeterministicBase(CustomTestCase): args.n_start = 10 args.n_trials = 20 results = test_deterministic(args) + args.temperature = 0.5 # test for deterministic sampling for result in results: assert result == 1 @@ -65,6 +63,7 @@ class TestDeterministicBase(CustomTestCase): args.test_mode = "mixed" args.n_start = 10 args.n_trials = 20 + args.temperature = 0.5 # test for deterministic sampling results = test_deterministic(args) for result in results: assert result == 1 @@ -76,6 +75,7 @@ class TestDeterministicBase(CustomTestCase): args.test_mode = "prefix" args.n_start = 10 args.n_trials = 10 + args.temperature = 0.5 # test for deterministic sampling results = test_deterministic(args) for result in results: assert result == 1