Sync changes on io_struct.py and deterministic ops (#11498)
This commit is contained in:
@@ -321,7 +321,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `--debug-tensor-dump-output-folder` | The output folder for debug tensor dumps. | None |
|
| `--debug-tensor-dump-output-folder` | The output folder for debug tensor dumps. | None |
|
||||||
| `--debug-tensor-dump-input-file` | The input file for debug tensor dumps. | None |
|
| `--debug-tensor-dump-input-file` | The input file for debug tensor dumps. | None |
|
||||||
| `--debug-tensor-dump-inject` | Enable injection of debug tensor dumps. | False |
|
| `--debug-tensor-dump-inject` | Enable injection of debug tensor dumps. | False |
|
||||||
| `--debug-tensor-dump-prefill-only` | Enable prefill-only mode for debug tensor dumps. | False |
|
|
||||||
|
|
||||||
## PD disaggregation
|
## PD disaggregation
|
||||||
|
|
||||||
|
|||||||
@@ -240,6 +240,7 @@ class GroupCoordinator:
|
|||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
group_name: Optional[str] = None,
|
group_name: Optional[str] = None,
|
||||||
torch_compile: Optional[bool] = None,
|
torch_compile: Optional[bool] = None,
|
||||||
|
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
|
||||||
):
|
):
|
||||||
# Set group info
|
# Set group info
|
||||||
group_name = group_name or "anonymous"
|
group_name = group_name or "anonymous"
|
||||||
@@ -259,7 +260,9 @@ class GroupCoordinator:
|
|||||||
)
|
)
|
||||||
# a group with `gloo` backend, to allow direct coordination between
|
# a group with `gloo` backend, to allow direct coordination between
|
||||||
# processes through the CPU.
|
# processes through the CPU.
|
||||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
cpu_group = torch.distributed.new_group(
|
||||||
|
ranks, backend="gloo", timeout=gloo_timeout
|
||||||
|
)
|
||||||
if self.rank in ranks:
|
if self.rank in ranks:
|
||||||
self.ranks = ranks
|
self.ranks = ranks
|
||||||
self.world_size = len(ranks)
|
self.world_size = len(ranks)
|
||||||
|
|||||||
@@ -91,7 +91,6 @@ class Sampler(nn.Module):
|
|||||||
batch_next_token_ids = torch.argmax(logits, -1)
|
batch_next_token_ids = torch.argmax(logits, -1)
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# If requested, cache probabilities from original logits before temperature scaling.
|
# If requested, cache probabilities from original logits before temperature scaling.
|
||||||
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
||||||
@@ -288,21 +287,29 @@ def multinomial_with_seed(
|
|||||||
"""
|
"""
|
||||||
n, m = inputs.shape
|
n, m = inputs.shape
|
||||||
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
|
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
|
||||||
step_seed = seed * 19349663 ^ positions * 73856093
|
step_seed = (seed * 19349663) ^ (positions * 73856093)
|
||||||
seed_expanded = step_seed.unsqueeze(-1)
|
seed_expanded = step_seed.unsqueeze(-1)
|
||||||
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
|
hashed = (seed_expanded * 8589934591) ^ (col_indices * 479001599)
|
||||||
uniform_samples = (hashed % (2**24)).float() / (2**24)
|
uniform_samples = (hashed % (2**24)).float() / (2**24)
|
||||||
epsilon = 1e-9
|
epsilon = 1e-10
|
||||||
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
|
uniform_samples = uniform_samples.clamp(epsilon, 1.0 - epsilon)
|
||||||
|
gumbel_noise = -torch.log(-torch.log(uniform_samples))
|
||||||
log_probs = torch.log(inputs + epsilon)
|
log_probs = torch.log(inputs + epsilon)
|
||||||
perturbed_log_probs = log_probs + gumbel_noise
|
perturbed_log_probs = log_probs + gumbel_noise
|
||||||
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
|
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
def sampling_from_probs_torch(probs: torch.Tensor):
|
def sampling_from_probs_torch(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
sampling_seed: Optional[torch.Tensor] = None,
|
||||||
|
positions: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
"""A sampling implementation with native pytorch operations, without
|
"""A sampling implementation with native pytorch operations, without
|
||||||
top-k, top-p, or min-p filtering."""
|
top-k, top-p, or min-p filtering."""
|
||||||
sampled_index = torch.multinomial(probs, num_samples=1)
|
if sampling_seed is not None:
|
||||||
|
sampled_index = multinomial_with_seed(probs, sampling_seed, positions)
|
||||||
|
else:
|
||||||
|
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||||
batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
|
batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
|
||||||
return batch_next_token_ids
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
|||||||
@@ -245,9 +245,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|||||||
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
|
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
|
||||||
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
|
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
|
||||||
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
|
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
|
||||||
|
output_token_entropy_val=recv_obj.output_token_entropy_val,
|
||||||
output_hidden_states=recv_obj.output_hidden_states,
|
output_hidden_states=recv_obj.output_hidden_states,
|
||||||
placeholder_tokens_idx=None,
|
placeholder_tokens_idx=None,
|
||||||
placeholder_tokens_val=None,
|
placeholder_tokens_val=None,
|
||||||
|
token_steps=recv_obj.token_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
||||||
|
|||||||
@@ -170,6 +170,9 @@ class GenerateReqInput(BaseReq):
|
|||||||
# (Internal) Whether to return bytes for image generation
|
# (Internal) Whether to return bytes for image generation
|
||||||
return_bytes: bool = False
|
return_bytes: bool = False
|
||||||
|
|
||||||
|
# Whether to return entropy
|
||||||
|
return_entropy: bool = False
|
||||||
|
|
||||||
def contains_mm_input(self) -> bool:
|
def contains_mm_input(self) -> bool:
|
||||||
return (
|
return (
|
||||||
has_valid_data(self.image_data)
|
has_valid_data(self.image_data)
|
||||||
@@ -568,6 +571,7 @@ class GenerateReqInput(BaseReq):
|
|||||||
no_logs=self.no_logs,
|
no_logs=self.no_logs,
|
||||||
custom_labels=self.custom_labels,
|
custom_labels=self.custom_labels,
|
||||||
return_bytes=self.return_bytes,
|
return_bytes=self.return_bytes,
|
||||||
|
return_entropy=self.return_entropy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -633,6 +637,9 @@ class TokenizedGenerateReqInput(BaseReq):
|
|||||||
# (Internal) Whether to return bytes for image generation
|
# (Internal) Whether to return bytes for image generation
|
||||||
return_bytes: bool = False
|
return_bytes: bool = False
|
||||||
|
|
||||||
|
# Whether to return entropy
|
||||||
|
return_entropy: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenizedGenerateReqInput(BaseBatchReq):
|
class BatchTokenizedGenerateReqInput(BaseBatchReq):
|
||||||
@@ -830,6 +837,7 @@ class BatchTokenIDOutput(BaseBatchReq):
|
|||||||
input_token_ids_logprobs_idx: List[List]
|
input_token_ids_logprobs_idx: List[List]
|
||||||
output_token_ids_logprobs_val: List[List]
|
output_token_ids_logprobs_val: List[List]
|
||||||
output_token_ids_logprobs_idx: List[List]
|
output_token_ids_logprobs_idx: List[List]
|
||||||
|
output_token_entropy_val: List[float]
|
||||||
|
|
||||||
# Hidden states
|
# Hidden states
|
||||||
output_hidden_states: List[List[float]]
|
output_hidden_states: List[List[float]]
|
||||||
@@ -840,6 +848,9 @@ class BatchTokenIDOutput(BaseBatchReq):
|
|||||||
placeholder_tokens_idx: List[Optional[List[int]]]
|
placeholder_tokens_idx: List[Optional[List[int]]]
|
||||||
placeholder_tokens_val: List[Optional[List[int]]]
|
placeholder_tokens_val: List[Optional[List[int]]]
|
||||||
|
|
||||||
|
# The trainer step id. Used to know which step's weights are used for sampling.
|
||||||
|
token_steps: List[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchMultimodalDecodeReq(BaseBatchReq):
|
class BatchMultimodalDecodeReq(BaseBatchReq):
|
||||||
@@ -861,11 +872,14 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
|
|||||||
completion_tokens: List[int]
|
completion_tokens: List[int]
|
||||||
cached_tokens: List[int]
|
cached_tokens: List[int]
|
||||||
|
|
||||||
# Placeholder token info
|
# The information of placeholder tokens (e.g., image token)
|
||||||
|
# idx is the index of the token in the prompt after expansion.
|
||||||
|
# val is the length of padded tokens after expansion.
|
||||||
placeholder_tokens_idx: List[Optional[List[int]]]
|
placeholder_tokens_idx: List[Optional[List[int]]]
|
||||||
placeholder_tokens_val: List[Optional[List[int]]]
|
placeholder_tokens_val: List[Optional[List[int]]]
|
||||||
|
|
||||||
return_bytes: bool = False
|
# The trainer step id. Used to know which step's weights are used for sampling.
|
||||||
|
token_steps: List[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -896,13 +910,20 @@ class BatchStrOutput(BaseBatchReq):
|
|||||||
input_token_ids_logprobs_idx: List[List]
|
input_token_ids_logprobs_idx: List[List]
|
||||||
output_token_ids_logprobs_val: List[List]
|
output_token_ids_logprobs_val: List[List]
|
||||||
output_token_ids_logprobs_idx: List[List]
|
output_token_ids_logprobs_idx: List[List]
|
||||||
|
output_token_entropy_val: List[float]
|
||||||
|
|
||||||
# Hidden states
|
# Hidden states
|
||||||
output_hidden_states: List[List[float]]
|
output_hidden_states: List[List[float]]
|
||||||
|
|
||||||
|
# The information of placeholder tokens (e.g., image token)
|
||||||
|
# idx is the index of the token in the prompt after expansion.
|
||||||
|
# val is the length of padded tokens after expansion.
|
||||||
placeholder_tokens_idx: List[Optional[List[int]]]
|
placeholder_tokens_idx: List[Optional[List[int]]]
|
||||||
placeholder_tokens_val: List[Optional[List[int]]]
|
placeholder_tokens_val: List[Optional[List[int]]]
|
||||||
|
|
||||||
|
# The trainer step id. Used to know which step's weights are used for sampling.
|
||||||
|
token_steps: List[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchMultimodalOutput(BaseBatchReq):
|
class BatchMultimodalOutput(BaseBatchReq):
|
||||||
@@ -979,6 +1000,8 @@ class UpdateWeightFromDiskReqInput(BaseReq):
|
|||||||
torch_empty_cache: bool = False
|
torch_empty_cache: bool = False
|
||||||
# Whether to keep the scheduler paused after weight update
|
# Whether to keep the scheduler paused after weight update
|
||||||
keep_pause: bool = False
|
keep_pause: bool = False
|
||||||
|
# The trainer step id. Used to know which step's weights are used for sampling.
|
||||||
|
token_step: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -1416,6 +1439,16 @@ class WatchLoadUpdateReq(BaseReq):
|
|||||||
loads: List[GetLoadReqOutput]
|
loads: List[GetLoadReqOutput]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LazyDumpTensorsReqInput(BaseReq):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LazyDumpTensorsReqOutput(BaseReq):
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
def _check_all_req_types():
|
def _check_all_req_types():
|
||||||
"""A helper function to check all request types are defined in this file."""
|
"""A helper function to check all request types are defined in this file."""
|
||||||
import inspect
|
import inspect
|
||||||
|
|||||||
@@ -190,6 +190,11 @@ def _handle_output_by_index(output, i):
|
|||||||
if output.output_token_ids_logprobs_idx
|
if output.output_token_ids_logprobs_idx
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
output_token_entropy_val=(
|
||||||
|
[output.output_token_entropy_val[i]]
|
||||||
|
if output.output_token_entropy_val
|
||||||
|
else None
|
||||||
|
),
|
||||||
output_hidden_states=(
|
output_hidden_states=(
|
||||||
[output.output_hidden_states[i]]
|
[output.output_hidden_states[i]]
|
||||||
if output.output_hidden_states
|
if output.output_hidden_states
|
||||||
@@ -197,6 +202,7 @@ def _handle_output_by_index(output, i):
|
|||||||
),
|
),
|
||||||
placeholder_tokens_idx=None,
|
placeholder_tokens_idx=None,
|
||||||
placeholder_tokens_val=None,
|
placeholder_tokens_val=None,
|
||||||
|
token_steps=([output.token_steps[i]] if output.token_steps else None),
|
||||||
)
|
)
|
||||||
elif isinstance(output, BatchEmbeddingOutput):
|
elif isinstance(output, BatchEmbeddingOutput):
|
||||||
new_output = BatchEmbeddingOutput(
|
new_output = BatchEmbeddingOutput(
|
||||||
@@ -306,6 +312,11 @@ def _handle_output_by_index(output, i):
|
|||||||
if output.output_token_ids_logprobs_idx
|
if output.output_token_ids_logprobs_idx
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
output_token_entropy_val=(
|
||||||
|
[output.output_token_entropy_val[i]]
|
||||||
|
if output.output_token_entropy_val
|
||||||
|
else None
|
||||||
|
),
|
||||||
output_hidden_states=(
|
output_hidden_states=(
|
||||||
[output.output_hidden_states[i]]
|
[output.output_hidden_states[i]]
|
||||||
if output.output_hidden_states
|
if output.output_hidden_states
|
||||||
@@ -313,6 +324,7 @@ def _handle_output_by_index(output, i):
|
|||||||
),
|
),
|
||||||
placeholder_tokens_idx=None,
|
placeholder_tokens_idx=None,
|
||||||
placeholder_tokens_val=None,
|
placeholder_tokens_val=None,
|
||||||
|
token_steps=([output.token_steps[i]] if output.token_steps else None),
|
||||||
)
|
)
|
||||||
elif isinstance(output, BatchMultimodalOutput):
|
elif isinstance(output, BatchMultimodalOutput):
|
||||||
new_output = BatchMultimodalOutput(
|
new_output = BatchMultimodalOutput(
|
||||||
|
|||||||
@@ -920,7 +920,8 @@ class SchedulerOutputProcessorMixin:
|
|||||||
input_token_ids_logprobs_idx,
|
input_token_ids_logprobs_idx,
|
||||||
output_token_ids_logprobs_val,
|
output_token_ids_logprobs_val,
|
||||||
output_token_ids_logprobs_idx,
|
output_token_ids_logprobs_idx,
|
||||||
output_hidden_states,
|
output_token_entropy_val=None,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
rids=rids,
|
rids=rids,
|
||||||
placeholder_tokens_idx=None,
|
placeholder_tokens_idx=None,
|
||||||
placeholder_tokens_val=None,
|
placeholder_tokens_val=None,
|
||||||
|
|||||||
@@ -73,9 +73,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Dump tensors for debugging
|
# Dump tensors for debugging
|
||||||
debug_tensor_dump_output_folder = None
|
debug_tensor_dump_output_folder = None
|
||||||
debug_tensor_dump_prefill_only = False
|
|
||||||
# Skip all the other tensor dumps, only dump the target logits
|
|
||||||
debug_tensor_dump_only_target_logprobs = False
|
|
||||||
debug_tensor_dump_inject = False
|
debug_tensor_dump_inject = False
|
||||||
debug_tensor_dump_layers = None
|
debug_tensor_dump_layers = None
|
||||||
debug_tensor_dump_test = False
|
debug_tensor_dump_test = False
|
||||||
|
|||||||
@@ -455,7 +455,6 @@ class ServerArgs:
|
|||||||
debug_tensor_dump_output_folder: Optional[str] = None
|
debug_tensor_dump_output_folder: Optional[str] = None
|
||||||
debug_tensor_dump_input_file: Optional[str] = None
|
debug_tensor_dump_input_file: Optional[str] = None
|
||||||
debug_tensor_dump_inject: bool = False
|
debug_tensor_dump_inject: bool = False
|
||||||
debug_tensor_dump_prefill_only: bool = False
|
|
||||||
|
|
||||||
# PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
# PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
||||||
disaggregation_mode: Literal["null", "prefill", "decode"] = "null"
|
disaggregation_mode: Literal["null", "prefill", "decode"] = "null"
|
||||||
@@ -2831,11 +2830,6 @@ class ServerArgs:
|
|||||||
default=ServerArgs.debug_tensor_dump_inject,
|
default=ServerArgs.debug_tensor_dump_inject,
|
||||||
help="Inject the outputs from jax as the input of every layer.",
|
help="Inject the outputs from jax as the input of every layer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--debug-tensor-dump-prefill-only",
|
|
||||||
action="store_true",
|
|
||||||
help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-dynamic-batch-tokenizer",
|
"--enable-dynamic-batch-tokenizer",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def get_model_config(tp_size: int):
|
|||||||
"topk": topk,
|
"topk": topk,
|
||||||
"hidden_size": config.hidden_size,
|
"hidden_size": config.hidden_size,
|
||||||
"shard_intermediate_size": shard_intermediate_size,
|
"shard_intermediate_size": shard_intermediate_size,
|
||||||
"dtype": config.torch_dtype,
|
"dtype": config.dtype,
|
||||||
"block_shape": config.quantization_config["weight_block_size"],
|
"block_shape": config.quantization_config["weight_block_size"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
import time
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.test_deterministic import BenchArgs, test_deterministic
|
from sglang.test.test_deterministic import BenchArgs, test_deterministic
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
@@ -55,6 +52,7 @@ class TestDeterministicBase(CustomTestCase):
|
|||||||
args.n_start = 10
|
args.n_start = 10
|
||||||
args.n_trials = 20
|
args.n_trials = 20
|
||||||
results = test_deterministic(args)
|
results = test_deterministic(args)
|
||||||
|
args.temperature = 0.5 # test for deterministic sampling
|
||||||
for result in results:
|
for result in results:
|
||||||
assert result == 1
|
assert result == 1
|
||||||
|
|
||||||
@@ -65,6 +63,7 @@ class TestDeterministicBase(CustomTestCase):
|
|||||||
args.test_mode = "mixed"
|
args.test_mode = "mixed"
|
||||||
args.n_start = 10
|
args.n_start = 10
|
||||||
args.n_trials = 20
|
args.n_trials = 20
|
||||||
|
args.temperature = 0.5 # test for deterministic sampling
|
||||||
results = test_deterministic(args)
|
results = test_deterministic(args)
|
||||||
for result in results:
|
for result in results:
|
||||||
assert result == 1
|
assert result == 1
|
||||||
@@ -76,6 +75,7 @@ class TestDeterministicBase(CustomTestCase):
|
|||||||
args.test_mode = "prefix"
|
args.test_mode = "prefix"
|
||||||
args.n_start = 10
|
args.n_start = 10
|
||||||
args.n_trials = 10
|
args.n_trials = 10
|
||||||
|
args.temperature = 0.5 # test for deterministic sampling
|
||||||
results = test_deterministic(args)
|
results = test_deterministic(args)
|
||||||
for result in results:
|
for result in results:
|
||||||
assert result == 1
|
assert result == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user