Sync changes on io_struct.py and deterministic ops (#11498)

This commit is contained in:
Lianmin Zheng
2025-10-12 16:03:10 -07:00
committed by GitHub
parent 0aa65f94f1
commit 2ac46e94ef
11 changed files with 73 additions and 25 deletions

View File

@@ -321,7 +321,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--debug-tensor-dump-output-folder` | The output folder for debug tensor dumps. | None | | `--debug-tensor-dump-output-folder` | The output folder for debug tensor dumps. | None |
| `--debug-tensor-dump-input-file` | The input file for debug tensor dumps. | None | | `--debug-tensor-dump-input-file` | The input file for debug tensor dumps. | None |
| `--debug-tensor-dump-inject` | Enable injection of debug tensor dumps. | False | | `--debug-tensor-dump-inject` | Enable injection of debug tensor dumps. | False |
| `--debug-tensor-dump-prefill-only` | Enable prefill-only mode for debug tensor dumps. | False |
## PD disaggregation ## PD disaggregation

View File

@@ -240,6 +240,7 @@ class GroupCoordinator:
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None, group_name: Optional[str] = None,
torch_compile: Optional[bool] = None, torch_compile: Optional[bool] = None,
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
): ):
# Set group info # Set group info
group_name = group_name or "anonymous" group_name = group_name or "anonymous"
@@ -259,7 +260,9 @@ class GroupCoordinator:
) )
# a group with `gloo` backend, to allow direct coordination between # a group with `gloo` backend, to allow direct coordination between
# processes through the CPU. # processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo") cpu_group = torch.distributed.new_group(
ranks, backend="gloo", timeout=gloo_timeout
)
if self.rank in ranks: if self.rank in ranks:
self.ranks = ranks self.ranks = ranks
self.world_size = len(ranks) self.world_size = len(ranks)

View File

@@ -91,7 +91,6 @@ class Sampler(nn.Module):
batch_next_token_ids = torch.argmax(logits, -1) batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob: if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1) logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else: else:
# If requested, cache probabilities from original logits before temperature scaling. # If requested, cache probabilities from original logits before temperature scaling.
if return_logprob and RETURN_ORIGINAL_LOGPROB: if return_logprob and RETURN_ORIGINAL_LOGPROB:
@@ -288,21 +287,29 @@ def multinomial_with_seed(
""" """
n, m = inputs.shape n, m = inputs.shape
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0) col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
step_seed = seed * 19349663 ^ positions * 73856093 step_seed = (seed * 19349663) ^ (positions * 73856093)
seed_expanded = step_seed.unsqueeze(-1) seed_expanded = step_seed.unsqueeze(-1)
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599 hashed = (seed_expanded * 8589934591) ^ (col_indices * 479001599)
uniform_samples = (hashed % (2**24)).float() / (2**24) uniform_samples = (hashed % (2**24)).float() / (2**24)
epsilon = 1e-9 epsilon = 1e-10
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon) uniform_samples = uniform_samples.clamp(epsilon, 1.0 - epsilon)
gumbel_noise = -torch.log(-torch.log(uniform_samples))
log_probs = torch.log(inputs + epsilon) log_probs = torch.log(inputs + epsilon)
perturbed_log_probs = log_probs + gumbel_noise perturbed_log_probs = log_probs + gumbel_noise
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True) return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
def sampling_from_probs_torch(probs: torch.Tensor): def sampling_from_probs_torch(
probs: torch.Tensor,
sampling_seed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
):
"""A sampling implementation with native pytorch operations, without """A sampling implementation with native pytorch operations, without
top-k, top-p, or min-p filtering.""" top-k, top-p, or min-p filtering."""
sampled_index = torch.multinomial(probs, num_samples=1) if sampling_seed is not None:
sampled_index = multinomial_with_seed(probs, sampling_seed, positions)
else:
sampled_index = torch.multinomial(probs, num_samples=1)
batch_next_token_ids = sampled_index.view(-1).to(torch.int32) batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
return batch_next_token_ids return batch_next_token_ids

View File

@@ -245,9 +245,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx, input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val, output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
output_token_entropy_val=recv_obj.output_token_entropy_val,
output_hidden_states=recv_obj.output_hidden_states, output_hidden_states=recv_obj.output_hidden_states,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
token_steps=recv_obj.token_steps,
) )
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):

View File

@@ -170,6 +170,9 @@ class GenerateReqInput(BaseReq):
# (Internal) Whether to return bytes for image generation # (Internal) Whether to return bytes for image generation
return_bytes: bool = False return_bytes: bool = False
# Whether to return entropy
return_entropy: bool = False
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
return ( return (
has_valid_data(self.image_data) has_valid_data(self.image_data)
@@ -568,6 +571,7 @@ class GenerateReqInput(BaseReq):
no_logs=self.no_logs, no_logs=self.no_logs,
custom_labels=self.custom_labels, custom_labels=self.custom_labels,
return_bytes=self.return_bytes, return_bytes=self.return_bytes,
return_entropy=self.return_entropy,
) )
@@ -633,6 +637,9 @@ class TokenizedGenerateReqInput(BaseReq):
# (Internal) Whether to return bytes for image generation # (Internal) Whether to return bytes for image generation
return_bytes: bool = False return_bytes: bool = False
# Whether to return entropy
return_entropy: bool = False
@dataclass @dataclass
class BatchTokenizedGenerateReqInput(BaseBatchReq): class BatchTokenizedGenerateReqInput(BaseBatchReq):
@@ -830,6 +837,7 @@ class BatchTokenIDOutput(BaseBatchReq):
input_token_ids_logprobs_idx: List[List] input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List] output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List] output_token_ids_logprobs_idx: List[List]
output_token_entropy_val: List[float]
# Hidden states # Hidden states
output_hidden_states: List[List[float]] output_hidden_states: List[List[float]]
@@ -840,6 +848,9 @@ class BatchTokenIDOutput(BaseBatchReq):
placeholder_tokens_idx: List[Optional[List[int]]] placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]] placeholder_tokens_val: List[Optional[List[int]]]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None
@dataclass @dataclass
class BatchMultimodalDecodeReq(BaseBatchReq): class BatchMultimodalDecodeReq(BaseBatchReq):
@@ -861,11 +872,14 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
completion_tokens: List[int] completion_tokens: List[int]
cached_tokens: List[int] cached_tokens: List[int]
# Placeholder token info # The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
placeholder_tokens_idx: List[Optional[List[int]]] placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]] placeholder_tokens_val: List[Optional[List[int]]]
return_bytes: bool = False # The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None
@dataclass @dataclass
@@ -896,13 +910,20 @@ class BatchStrOutput(BaseBatchReq):
input_token_ids_logprobs_idx: List[List] input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List] output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List] output_token_ids_logprobs_idx: List[List]
output_token_entropy_val: List[float]
# Hidden states # Hidden states
output_hidden_states: List[List[float]] output_hidden_states: List[List[float]]
# The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
placeholder_tokens_idx: List[Optional[List[int]]] placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]] placeholder_tokens_val: List[Optional[List[int]]]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None
@dataclass @dataclass
class BatchMultimodalOutput(BaseBatchReq): class BatchMultimodalOutput(BaseBatchReq):
@@ -979,6 +1000,8 @@ class UpdateWeightFromDiskReqInput(BaseReq):
torch_empty_cache: bool = False torch_empty_cache: bool = False
# Whether to keep the scheduler paused after weight update # Whether to keep the scheduler paused after weight update
keep_pause: bool = False keep_pause: bool = False
# The trainer step id. Used to know which step's weights are used for sampling.
token_step: int = 0
@dataclass @dataclass
@@ -1416,6 +1439,16 @@ class WatchLoadUpdateReq(BaseReq):
loads: List[GetLoadReqOutput] loads: List[GetLoadReqOutput]
@dataclass
class LazyDumpTensorsReqInput(BaseReq):
pass
@dataclass
class LazyDumpTensorsReqOutput(BaseReq):
success: bool
def _check_all_req_types(): def _check_all_req_types():
"""A helper function to check all request types are defined in this file.""" """A helper function to check all request types are defined in this file."""
import inspect import inspect

View File

@@ -190,6 +190,11 @@ def _handle_output_by_index(output, i):
if output.output_token_ids_logprobs_idx if output.output_token_ids_logprobs_idx
else None else None
), ),
output_token_entropy_val=(
[output.output_token_entropy_val[i]]
if output.output_token_entropy_val
else None
),
output_hidden_states=( output_hidden_states=(
[output.output_hidden_states[i]] [output.output_hidden_states[i]]
if output.output_hidden_states if output.output_hidden_states
@@ -197,6 +202,7 @@ def _handle_output_by_index(output, i):
), ),
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
token_steps=([output.token_steps[i]] if output.token_steps else None),
) )
elif isinstance(output, BatchEmbeddingOutput): elif isinstance(output, BatchEmbeddingOutput):
new_output = BatchEmbeddingOutput( new_output = BatchEmbeddingOutput(
@@ -306,6 +312,11 @@ def _handle_output_by_index(output, i):
if output.output_token_ids_logprobs_idx if output.output_token_ids_logprobs_idx
else None else None
), ),
output_token_entropy_val=(
[output.output_token_entropy_val[i]]
if output.output_token_entropy_val
else None
),
output_hidden_states=( output_hidden_states=(
[output.output_hidden_states[i]] [output.output_hidden_states[i]]
if output.output_hidden_states if output.output_hidden_states
@@ -313,6 +324,7 @@ def _handle_output_by_index(output, i):
), ),
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
token_steps=([output.token_steps[i]] if output.token_steps else None),
) )
elif isinstance(output, BatchMultimodalOutput): elif isinstance(output, BatchMultimodalOutput):
new_output = BatchMultimodalOutput( new_output = BatchMultimodalOutput(

View File

@@ -920,7 +920,8 @@ class SchedulerOutputProcessorMixin:
input_token_ids_logprobs_idx, input_token_ids_logprobs_idx,
output_token_ids_logprobs_val, output_token_ids_logprobs_val,
output_token_ids_logprobs_idx, output_token_ids_logprobs_idx,
output_hidden_states, output_token_entropy_val=None,
output_hidden_states=output_hidden_states,
rids=rids, rids=rids,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,

View File

@@ -73,9 +73,6 @@ logger = logging.getLogger(__name__)
# Dump tensors for debugging # Dump tensors for debugging
debug_tensor_dump_output_folder = None debug_tensor_dump_output_folder = None
debug_tensor_dump_prefill_only = False
# Skip all the other tensor dumps, only dump the target logits
debug_tensor_dump_only_target_logprobs = False
debug_tensor_dump_inject = False debug_tensor_dump_inject = False
debug_tensor_dump_layers = None debug_tensor_dump_layers = None
debug_tensor_dump_test = False debug_tensor_dump_test = False

View File

@@ -455,7 +455,6 @@ class ServerArgs:
debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_output_folder: Optional[str] = None
debug_tensor_dump_input_file: Optional[str] = None debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False debug_tensor_dump_inject: bool = False
debug_tensor_dump_prefill_only: bool = False
# PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only) # PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode: Literal["null", "prefill", "decode"] = "null" disaggregation_mode: Literal["null", "prefill", "decode"] = "null"
@@ -2831,11 +2830,6 @@ class ServerArgs:
default=ServerArgs.debug_tensor_dump_inject, default=ServerArgs.debug_tensor_dump_inject,
help="Inject the outputs from jax as the input of every layer.", help="Inject the outputs from jax as the input of every layer.",
) )
parser.add_argument(
"--debug-tensor-dump-prefill-only",
action="store_true",
help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
)
parser.add_argument( parser.add_argument(
"--enable-dynamic-batch-tokenizer", "--enable-dynamic-batch-tokenizer",
action="store_true", action="store_true",

View File

@@ -34,7 +34,7 @@ def get_model_config(tp_size: int):
"topk": topk, "topk": topk,
"hidden_size": config.hidden_size, "hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size, "shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype, "dtype": config.dtype,
"block_shape": config.quantization_config["weight_block_size"], "block_shape": config.quantization_config["weight_block_size"],
} }

View File

@@ -1,8 +1,5 @@
import time
import unittest import unittest
import requests
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_deterministic import BenchArgs, test_deterministic from sglang.test.test_deterministic import BenchArgs, test_deterministic
from sglang.test.test_utils import ( from sglang.test.test_utils import (
@@ -55,6 +52,7 @@ class TestDeterministicBase(CustomTestCase):
args.n_start = 10 args.n_start = 10
args.n_trials = 20 args.n_trials = 20
results = test_deterministic(args) results = test_deterministic(args)
args.temperature = 0.5 # test for deterministic sampling
for result in results: for result in results:
assert result == 1 assert result == 1
@@ -65,6 +63,7 @@ class TestDeterministicBase(CustomTestCase):
args.test_mode = "mixed" args.test_mode = "mixed"
args.n_start = 10 args.n_start = 10
args.n_trials = 20 args.n_trials = 20
args.temperature = 0.5 # test for deterministic sampling
results = test_deterministic(args) results = test_deterministic(args)
for result in results: for result in results:
assert result == 1 assert result == 1
@@ -76,6 +75,7 @@ class TestDeterministicBase(CustomTestCase):
args.test_mode = "prefix" args.test_mode = "prefix"
args.n_start = 10 args.n_start = 10
args.n_trials = 10 args.n_trials = 10
args.temperature = 0.5 # test for deterministic sampling
results = test_deterministic(args) results = test_deterministic(args)
for result in results: for result in results:
assert result == 1 assert result == 1