Sync changes on io_struct.py and deterministic ops (#11498)
This commit is contained in:
@@ -240,6 +240,7 @@ class GroupCoordinator:
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
torch_compile: Optional[bool] = None,
|
||||
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
|
||||
):
|
||||
# Set group info
|
||||
group_name = group_name or "anonymous"
|
||||
@@ -259,7 +260,9 @@ class GroupCoordinator:
|
||||
)
|
||||
# a group with `gloo` backend, to allow direct coordination between
|
||||
# processes through the CPU.
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
cpu_group = torch.distributed.new_group(
|
||||
ranks, backend="gloo", timeout=gloo_timeout
|
||||
)
|
||||
if self.rank in ranks:
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
|
||||
@@ -91,7 +91,6 @@ class Sampler(nn.Module):
|
||||
batch_next_token_ids = torch.argmax(logits, -1)
|
||||
if return_logprob:
|
||||
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
|
||||
else:
|
||||
# If requested, cache probabilities from original logits before temperature scaling.
|
||||
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
||||
@@ -288,21 +287,29 @@ def multinomial_with_seed(
|
||||
"""
|
||||
n, m = inputs.shape
|
||||
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
|
||||
step_seed = seed * 19349663 ^ positions * 73856093
|
||||
step_seed = (seed * 19349663) ^ (positions * 73856093)
|
||||
seed_expanded = step_seed.unsqueeze(-1)
|
||||
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
|
||||
hashed = (seed_expanded * 8589934591) ^ (col_indices * 479001599)
|
||||
uniform_samples = (hashed % (2**24)).float() / (2**24)
|
||||
epsilon = 1e-9
|
||||
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
|
||||
epsilon = 1e-10
|
||||
uniform_samples = uniform_samples.clamp(epsilon, 1.0 - epsilon)
|
||||
gumbel_noise = -torch.log(-torch.log(uniform_samples))
|
||||
log_probs = torch.log(inputs + epsilon)
|
||||
perturbed_log_probs = log_probs + gumbel_noise
|
||||
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
|
||||
|
||||
|
||||
def sampling_from_probs_torch(probs: torch.Tensor):
|
||||
def sampling_from_probs_torch(
|
||||
probs: torch.Tensor,
|
||||
sampling_seed: Optional[torch.Tensor] = None,
|
||||
positions: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""A sampling implementation with native pytorch operations, without
|
||||
top-k, top-p, or min-p filtering."""
|
||||
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||
if sampling_seed is not None:
|
||||
sampled_index = multinomial_with_seed(probs, sampling_seed, positions)
|
||||
else:
|
||||
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||
batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
@@ -245,9 +245,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
||||
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
|
||||
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
|
||||
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
|
||||
output_token_entropy_val=recv_obj.output_token_entropy_val,
|
||||
output_hidden_states=recv_obj.output_hidden_states,
|
||||
placeholder_tokens_idx=None,
|
||||
placeholder_tokens_val=None,
|
||||
token_steps=recv_obj.token_steps,
|
||||
)
|
||||
|
||||
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
||||
|
||||
@@ -170,6 +170,9 @@ class GenerateReqInput(BaseReq):
|
||||
# (Internal) Whether to return bytes for image generation
|
||||
return_bytes: bool = False
|
||||
|
||||
# Whether to return entropy
|
||||
return_entropy: bool = False
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return (
|
||||
has_valid_data(self.image_data)
|
||||
@@ -568,6 +571,7 @@ class GenerateReqInput(BaseReq):
|
||||
no_logs=self.no_logs,
|
||||
custom_labels=self.custom_labels,
|
||||
return_bytes=self.return_bytes,
|
||||
return_entropy=self.return_entropy,
|
||||
)
|
||||
|
||||
|
||||
@@ -633,6 +637,9 @@ class TokenizedGenerateReqInput(BaseReq):
|
||||
# (Internal) Whether to return bytes for image generation
|
||||
return_bytes: bool = False
|
||||
|
||||
# Whether to return entropy
|
||||
return_entropy: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchTokenizedGenerateReqInput(BaseBatchReq):
|
||||
@@ -830,6 +837,7 @@ class BatchTokenIDOutput(BaseBatchReq):
|
||||
input_token_ids_logprobs_idx: List[List]
|
||||
output_token_ids_logprobs_val: List[List]
|
||||
output_token_ids_logprobs_idx: List[List]
|
||||
output_token_entropy_val: List[float]
|
||||
|
||||
# Hidden states
|
||||
output_hidden_states: List[List[float]]
|
||||
@@ -840,6 +848,9 @@ class BatchTokenIDOutput(BaseBatchReq):
|
||||
placeholder_tokens_idx: List[Optional[List[int]]]
|
||||
placeholder_tokens_val: List[Optional[List[int]]]
|
||||
|
||||
# The trainer step id. Used to know which step's weights are used for sampling.
|
||||
token_steps: List[List[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchMultimodalDecodeReq(BaseBatchReq):
|
||||
@@ -861,11 +872,14 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
|
||||
completion_tokens: List[int]
|
||||
cached_tokens: List[int]
|
||||
|
||||
# Placeholder token info
|
||||
# The information of placeholder tokens (e.g., image token)
|
||||
# idx is the index of the token in the prompt after expansion.
|
||||
# val is the length of padded tokens after expansion.
|
||||
placeholder_tokens_idx: List[Optional[List[int]]]
|
||||
placeholder_tokens_val: List[Optional[List[int]]]
|
||||
|
||||
return_bytes: bool = False
|
||||
# The trainer step id. Used to know which step's weights are used for sampling.
|
||||
token_steps: List[List[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -896,13 +910,20 @@ class BatchStrOutput(BaseBatchReq):
|
||||
input_token_ids_logprobs_idx: List[List]
|
||||
output_token_ids_logprobs_val: List[List]
|
||||
output_token_ids_logprobs_idx: List[List]
|
||||
output_token_entropy_val: List[float]
|
||||
|
||||
# Hidden states
|
||||
output_hidden_states: List[List[float]]
|
||||
|
||||
# The information of placeholder tokens (e.g., image token)
|
||||
# idx is the index of the token in the prompt after expansion.
|
||||
# val is the length of padded tokens after expansion.
|
||||
placeholder_tokens_idx: List[Optional[List[int]]]
|
||||
placeholder_tokens_val: List[Optional[List[int]]]
|
||||
|
||||
# The trainer step id. Used to know which step's weights are used for sampling.
|
||||
token_steps: List[List[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchMultimodalOutput(BaseBatchReq):
|
||||
@@ -979,6 +1000,8 @@ class UpdateWeightFromDiskReqInput(BaseReq):
|
||||
torch_empty_cache: bool = False
|
||||
# Whether to keep the scheduler paused after weight update
|
||||
keep_pause: bool = False
|
||||
# The trainer step id. Used to know which step's weights are used for sampling.
|
||||
token_step: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1416,6 +1439,16 @@ class WatchLoadUpdateReq(BaseReq):
|
||||
loads: List[GetLoadReqOutput]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LazyDumpTensorsReqInput(BaseReq):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LazyDumpTensorsReqOutput(BaseReq):
|
||||
success: bool
|
||||
|
||||
|
||||
def _check_all_req_types():
|
||||
"""A helper function to check all request types are defined in this file."""
|
||||
import inspect
|
||||
|
||||
@@ -190,6 +190,11 @@ def _handle_output_by_index(output, i):
|
||||
if output.output_token_ids_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_token_entropy_val=(
|
||||
[output.output_token_entropy_val[i]]
|
||||
if output.output_token_entropy_val
|
||||
else None
|
||||
),
|
||||
output_hidden_states=(
|
||||
[output.output_hidden_states[i]]
|
||||
if output.output_hidden_states
|
||||
@@ -197,6 +202,7 @@ def _handle_output_by_index(output, i):
|
||||
),
|
||||
placeholder_tokens_idx=None,
|
||||
placeholder_tokens_val=None,
|
||||
token_steps=([output.token_steps[i]] if output.token_steps else None),
|
||||
)
|
||||
elif isinstance(output, BatchEmbeddingOutput):
|
||||
new_output = BatchEmbeddingOutput(
|
||||
@@ -306,6 +312,11 @@ def _handle_output_by_index(output, i):
|
||||
if output.output_token_ids_logprobs_idx
|
||||
else None
|
||||
),
|
||||
output_token_entropy_val=(
|
||||
[output.output_token_entropy_val[i]]
|
||||
if output.output_token_entropy_val
|
||||
else None
|
||||
),
|
||||
output_hidden_states=(
|
||||
[output.output_hidden_states[i]]
|
||||
if output.output_hidden_states
|
||||
@@ -313,6 +324,7 @@ def _handle_output_by_index(output, i):
|
||||
),
|
||||
placeholder_tokens_idx=None,
|
||||
placeholder_tokens_val=None,
|
||||
token_steps=([output.token_steps[i]] if output.token_steps else None),
|
||||
)
|
||||
elif isinstance(output, BatchMultimodalOutput):
|
||||
new_output = BatchMultimodalOutput(
|
||||
|
||||
@@ -920,7 +920,8 @@ class SchedulerOutputProcessorMixin:
|
||||
input_token_ids_logprobs_idx,
|
||||
output_token_ids_logprobs_val,
|
||||
output_token_ids_logprobs_idx,
|
||||
output_hidden_states,
|
||||
output_token_entropy_val=None,
|
||||
output_hidden_states=output_hidden_states,
|
||||
rids=rids,
|
||||
placeholder_tokens_idx=None,
|
||||
placeholder_tokens_val=None,
|
||||
|
||||
@@ -73,9 +73,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Dump tensors for debugging
|
||||
debug_tensor_dump_output_folder = None
|
||||
debug_tensor_dump_prefill_only = False
|
||||
# Skip all the other tensor dumps, only dump the target logits
|
||||
debug_tensor_dump_only_target_logprobs = False
|
||||
debug_tensor_dump_inject = False
|
||||
debug_tensor_dump_layers = None
|
||||
debug_tensor_dump_test = False
|
||||
|
||||
@@ -455,7 +455,6 @@ class ServerArgs:
|
||||
debug_tensor_dump_output_folder: Optional[str] = None
|
||||
debug_tensor_dump_input_file: Optional[str] = None
|
||||
debug_tensor_dump_inject: bool = False
|
||||
debug_tensor_dump_prefill_only: bool = False
|
||||
|
||||
# PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
||||
disaggregation_mode: Literal["null", "prefill", "decode"] = "null"
|
||||
@@ -2831,11 +2830,6 @@ class ServerArgs:
|
||||
default=ServerArgs.debug_tensor_dump_inject,
|
||||
help="Inject the outputs from jax as the input of every layer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug-tensor-dump-prefill-only",
|
||||
action="store_true",
|
||||
help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-dynamic-batch-tokenizer",
|
||||
action="store_true",
|
||||
|
||||
@@ -34,7 +34,7 @@ def get_model_config(tp_size: int):
|
||||
"topk": topk,
|
||||
"hidden_size": config.hidden_size,
|
||||
"shard_intermediate_size": shard_intermediate_size,
|
||||
"dtype": config.torch_dtype,
|
||||
"dtype": config.dtype,
|
||||
"block_shape": config.quantization_config["weight_block_size"],
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_deterministic import BenchArgs, test_deterministic
|
||||
from sglang.test.test_utils import (
|
||||
@@ -55,6 +52,7 @@ class TestDeterministicBase(CustomTestCase):
|
||||
args.n_start = 10
|
||||
args.n_trials = 20
|
||||
results = test_deterministic(args)
|
||||
args.temperature = 0.5 # test for deterministic sampling
|
||||
for result in results:
|
||||
assert result == 1
|
||||
|
||||
@@ -65,6 +63,7 @@ class TestDeterministicBase(CustomTestCase):
|
||||
args.test_mode = "mixed"
|
||||
args.n_start = 10
|
||||
args.n_trials = 20
|
||||
args.temperature = 0.5 # test for deterministic sampling
|
||||
results = test_deterministic(args)
|
||||
for result in results:
|
||||
assert result == 1
|
||||
@@ -76,6 +75,7 @@ class TestDeterministicBase(CustomTestCase):
|
||||
args.test_mode = "prefix"
|
||||
args.n_start = 10
|
||||
args.n_trials = 10
|
||||
args.temperature = 0.5 # test for deterministic sampling
|
||||
results = test_deterministic(args)
|
||||
for result in results:
|
||||
assert result == 1
|
||||
|
||||
Reference in New Issue
Block a user