[minor] sync code on python/sglang/test/test_deterministic.py and improve ci tests (#11777)

Co-authored-by: Stefan He <hebiaobuaa@gmail.com>
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-10-17 14:25:22 -07:00
committed by GitHub
parent 20b8d2306c
commit b9a54e0968
9 changed files with 264 additions and 39 deletions

View File

@@ -879,6 +879,8 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]]
return_bytes: List[bool]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None

View File

@@ -150,6 +150,9 @@ class SchedulerStats:
engine_startup_time: float = 0.0
engine_load_weights_time: float = 0.0
# CUDA graph
is_cuda_graph: float = 0.0
class SchedulerMetricsCollector:
@@ -499,6 +502,13 @@ class SchedulerMetricsCollector:
labelnames=list(labels.keys()) + ["stage"],
)
self.is_cuda_graph = Gauge(
name="sglang:is_cuda_graph",
documentation="Whether the batch is using CUDA graph.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
@@ -574,6 +584,9 @@ class SchedulerMetricsCollector:
self.engine_load_weights_time, stats.engine_load_weights_time
)
# CUDA graph
self._log_gauge(self.is_cuda_graph, stats.is_cuda_graph)
self.last_log_time = time.perf_counter()
def log_grammar_stats(self, grammar_stats) -> None:

View File

@@ -509,6 +509,11 @@ class ServerArgs:
"""
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
"""
if self.model_path.lower() in ["none", "dummy"]:
# Skip for dummy models
return
# Handle deprecated arguments.
self._handle_deprecated_args()

View File

@@ -66,7 +66,7 @@ class MockModelRunner:
enable_memory_saver=False,
)
# Required by torch native backend
self.server_args = ServerArgs(model_path="fake_model_path")
self.server_args = ServerArgs(model_path="dummy")
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")

View File

@@ -2,7 +2,17 @@
Batch the same prompt in random batch sizes, and test if the results are consistent across different trials.
Usage:
python3 -m sglang.test.test_deterministic --n-trials <numer_of_trials> --test-mode <single|mixed|prefix> --profile
# Single mode: test determinism with varying batch sizes
python3 -m sglang.test.test_deterministic --n-trials 50 --test-mode single
# Mixed mode: test with mixed prompts
python3 -m sglang.test.test_deterministic --n-trials 50 --test-mode mixed
# Prefix mode: test with shared prefixes
python3 -m sglang.test.test_deterministic --n-start 1 --n-trials 50 --test-mode prefix
# Radix Cache Consistency mode: test radix cache determinism (cached vs uncached prefill)
python3 -m sglang.test.test_deterministic --test-mode radix_cache
"""
import argparse
@@ -67,7 +77,12 @@ class BenchArgs:
"--test-mode",
type=str,
default=BenchArgs.test_mode,
choices=["single", "mixed", "prefix"],
choices=[
"single",
"mixed",
"prefix",
"radix_cache",
],
)
parser.add_argument("--profile", action="store_true")
parser.add_argument(
@@ -83,26 +98,50 @@ class BenchArgs:
def send_single(
args,
batch_size: int,
batch_size: int = 1,
profile: bool = False,
profile_steps: int = 3,
profile_by_stage: bool = False,
return_full_response: bool = False,
input_ids: List[int] = None,
max_new_tokens: int = None,
):
base_url = f"http://{args.host}:{args.port}"
prompt = [PROMPT_1] * batch_size
json_data = {
"text": prompt,
"sampling_params": {
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
}
# Use input_ids if provided, otherwise use text prompts
if input_ids is not None:
json_data = {
"input_ids": input_ids,
"sampling_params": {
"temperature": args.temperature,
"max_new_tokens": (
max_new_tokens
if max_new_tokens is not None
else args.max_new_tokens
),
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
}
else:
prompt = [PROMPT_1] * batch_size
json_data = {
"text": prompt,
"sampling_params": {
"temperature": args.temperature,
"max_new_tokens": (
max_new_tokens
if max_new_tokens is not None
else args.max_new_tokens
),
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
}
if args.sampling_seed is not None:
# sglang server cannot parse None value for sampling_seed
@@ -119,6 +158,11 @@ def send_single(
stream=args.stream,
)
if response.status_code != 200:
ret = response.json()
print(f"Error: {ret}")
return None
if args.stream:
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
@@ -128,13 +172,13 @@ def send_single(
ret = json.loads(chunk[5:].strip("\n"))
else:
ret = response.json()
ret = ret[0]
if response.status_code != 200:
print(ret)
return -1
ret = ret[0] if isinstance(ret, list) else ret
return ret["text"]
if return_full_response:
return ret
else:
return ret["text"]
def send_mixed(args, batch_size: int):
@@ -235,7 +279,6 @@ def test_deterministic(args):
text = text.replace("\n", " ")
print(f"Trial {i} with batch size {batch_size}: {text}")
texts.append(text)
print(f"Total samples: {len(texts)}, Unique samples: {len(set(texts))}")
return [len(set(texts))]
@@ -297,6 +340,163 @@ def test_deterministic(args):
results.append(len(set(outputs[i])))
return results
elif args.test_mode == "radix_cache":
# Radix mode requires logprobs to compare results
args.return_logprob = True
print("\n=== Prefill Cache Consistency Test ===")
print(
"This test verifies prefill request produces consistent logprobs w/ and w/o cache.\n"
)
# We noticed that we cannot call flush cache before any request, otherwise it will hang.
warmup_response = send_single(
args, input_ids=[1] * 64, max_new_tokens=65, return_full_response=True
)
# Flush cache first to make sure there is no cache hit from previous tests
flush_response = requests.post(f"http://{args.host}:{args.port}/flush_cache")
print(f"Step 1: Generating random 64 token IDs...")
# Use a reasonable token ID range (e.g., 1-50000 for most tokenizers)
# Avoid special tokens like 0 (padding), 1 (BOS), 2 (EOS)
# set seed for random.randint
random.seed(42)
initial_token_ids = [random.randint(100, 50000) for _ in range(64)]
print(f"✓ Using {len(initial_token_ids)} initial tokens")
print(f" Initial token IDs: {initial_token_ids}")
print(
f"\nStep 2: Generating 2 tokens from {len(initial_token_ids)} token prefix..."
)
first_response = send_single(
args,
input_ids=initial_token_ids,
max_new_tokens=100,
return_full_response=True,
)
first_output_text = first_response["text"]
first_output_token_ids = first_response["output_ids"]
first_output_logprobs = first_response["meta_info"]["output_token_logprobs"]
expected_token_id = first_output_token_ids[-1]
expected_logprob = first_output_logprobs[-1][0]
print(f"✓ Generated {len(first_output_token_ids)} tokens")
print(f' Output text: "{first_output_text}"')
print(
f"\nStep 3: Generating with radix cache (164 tokens prefill, should hit > 128 tokens cache, based on page size)..."
)
prefix_token_ids = initial_token_ids + first_output_token_ids[:-1]
print(
f" Prefix: {len(initial_token_ids)} initial + 64 generated = {len(prefix_token_ids)} tokens"
)
print(f"Using Prompt: {prefix_token_ids}")
cached_response = send_single(
args,
input_ids=prefix_token_ids,
max_new_tokens=1,
return_full_response=True,
)
cached_logprobs = cached_response["meta_info"]["output_token_logprobs"]
cached_token_data = cached_logprobs[0]
cached_logprob = cached_token_data[0]
cached_token_id = cached_token_data[1]
print(f"✓ Generated with cache:")
print(f" Token ID: {cached_token_id}")
print(f" Logprob: {cached_logprob:.10f}")
print(f"\nStep 4: Flushing cache...")
flush_response = requests.post(f"http://{args.host}:{args.port}/flush_cache")
print(
f"\nStep 5: Generating without cache (same 164 tokens prefill, no cache)..."
)
print(f"Using Prompt: {prefix_token_ids}")
uncached_response = send_single(
args,
input_ids=prefix_token_ids,
max_new_tokens=1,
return_full_response=True,
)
uncached_logprobs = uncached_response["meta_info"]["output_token_logprobs"]
uncached_token_data = uncached_logprobs[0]
uncached_logprob = uncached_token_data[0]
uncached_token_id = uncached_token_data[1]
print(f"✓ Generated without cache:")
print(f" Token ID: {uncached_token_id}")
print(f" Logprob: {uncached_logprob:.10f}")
# Step 6: Compare results
print(f"\n{'='*60}")
print("Comparison 1: Decode (Request 1) vs Prefill with Cache (Request 2)")
print("=" * 60)
# Compare first request (decode) vs second request (prefill with cache)
# We expect them to be different (different kernels)
decode_vs_prefill_token_match = expected_token_id == cached_token_id
decode_vs_prefill_logprob_match = expected_logprob == cached_logprob
print(
f" Decode token (Request 1): ID={expected_token_id}, logprob={expected_logprob:.10f}"
)
print(
f" Prefill w/ cache token (Request 2): ID={cached_token_id}, logprob={cached_logprob:.10f}"
)
print(
f" Token ID match: {'✓ YES' if decode_vs_prefill_token_match else '✗ NO'}"
)
print(
f" Logprob match: {'✓ YES' if decode_vs_prefill_logprob_match else '✗ NO'}"
)
if not decode_vs_prefill_logprob_match:
diff = abs(expected_logprob - cached_logprob)
print(f" Logprob difference: {diff:.10e}")
print(f" Note: We expect these to be DIFFERENT (decode vs prefill kernels)")
print(f"\n{'='*60}")
print(
"Comparison 2: Cached Prefill (Request 2) vs Uncached Prefill (Request 3)"
)
print("=" * 60)
# Main test: compare cached vs uncached prefill (should be identical)
token_match = cached_token_id == uncached_token_id
logprob_match = cached_logprob == uncached_logprob
print(
f" Cached prefill token (Request 2): ID={cached_token_id}, logprob={cached_logprob:.10f}"
)
print(
f" Uncached prefill token (Request 3): ID={uncached_token_id}, logprob={uncached_logprob:.10f}"
)
print(f" Token ID match: {'✓ YES' if token_match else '✗ NO'}")
if not token_match:
print(f" Cached: {cached_token_id}")
print(f" Uncached: {uncached_token_id}")
print(f" Logprob match: {'✓ YES' if logprob_match else '✗ NO'}")
if not logprob_match:
print(f" Cached: {cached_logprob:.10f}")
print(f" Uncached: {uncached_logprob:.10f}")
diff = abs(cached_logprob - uncached_logprob)
print(f" Difference: {diff:.10e}")
print(f" Note: We expect these to be IDENTICAL (both prefill kernels)")
print(f"\n{'='*60}")
if token_match and logprob_match:
print("✓✓✓ TEST PASSED - Radix cache is consistent! ✓✓✓")
return [1]
else:
print("✗✗✗ TEST FAILED - Radix cache produces different results! ✗✗✗")
return [0]
else:
raise ValueError(f"Invalid test mode: {args.test_mode}")