[logprobs] Enable local deterministic logrprobs testing with strict threshold (#10994)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -86,7 +86,6 @@ suites = {
|
||||
TestFile("test_input_embeddings.py", 38),
|
||||
TestFile("test_io_struct.py", 8),
|
||||
TestFile("test_jinja_template_utils.py", 1),
|
||||
TestFile("test_logprobs.py", 55),
|
||||
TestFile("test_mamba_unittest.py", 4),
|
||||
TestFile("test_metrics.py", 32),
|
||||
TestFile("test_metrics_utils.py", 1),
|
||||
|
||||
@@ -1,42 +1,210 @@
|
||||
import io
|
||||
"""
|
||||
Logprobs Accuracy Test for SGLang
|
||||
|
||||
======================
|
||||
With deterministic/batch invariant kernels, we can ensure that SGLang produces exactly the same
|
||||
logprobs results for identical inputs. However, logprobs are highly sensitive to GPU hardware,
|
||||
kernels, torch versions, and other factors, so we cannot maintain a unified logprobs baseline
|
||||
across different machines.
|
||||
|
||||
This test is designed to be run locally by contributors to verify logprobs accuracy
|
||||
before making changes to related code.
|
||||
When submitting changes that affect logprobs computation, please:
|
||||
1. Generate baseline
|
||||
2. Run test
|
||||
3. Submit results
|
||||
|
||||
We really appreciate your effort and contribution to SGLang!
|
||||
|
||||
======================
|
||||
What does this test do?
|
||||
This test fetches 1000 samples from the ShareGPT dataset, generates logprobs for each sample,
|
||||
and saves them as a baseline. Then, by running the test mode, it validates the accuracy of
|
||||
logprobs by comparing them against the baseline.
|
||||
|
||||
This test ensures that:
|
||||
- the boundary of log probs requests are correct, eg, the index for tokens that required log probs are strictly followed
|
||||
- logprobs remain invariant between test runs, and also before and after your code changes;
|
||||
|
||||
======================
|
||||
Usage
|
||||
|
||||
Step 1: Generate Baseline (Before Code Changes)
|
||||
```bash
|
||||
python test/srt/test_logprobs.py gen
|
||||
```
|
||||
|
||||
Step 2: Test Against Baseline (After Code Changes)
|
||||
```bash
|
||||
python test/srt/test_logprobs.py test
|
||||
```
|
||||
This tests your changes against the locally generated baseline from Step 1.
|
||||
The test passes if the maximum and mean differences are within the tolerance thresholds.
|
||||
======================
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
write_github_step_summary,
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
# Configuration
|
||||
DENSE_MODEL_NAME = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
SHAREGPT_URL = (
|
||||
"https://huggingface.co/datasets/anon8231489123/"
|
||||
"ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
)
|
||||
|
||||
# Dense model configuration
|
||||
DENSE_MODEL_NAME = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
if torch.version.hip is not None:
|
||||
print("Running on AMD ROCm GPU")
|
||||
DENSE_INPUT_PKL_URL = "https://huggingface.co/datasets/yushengsu/logprobs/resolve/main/sglang_baseline_2000_amd.pkl"
|
||||
DENSE_TOLERANCE_MAX_DIFF = 1.4
|
||||
DENSE_TOLERANCE_MEAN_DIFF = 0.1
|
||||
elif torch.version.cuda is not None:
|
||||
# Hardware-specific configuration
|
||||
if torch.version.cuda is not None:
|
||||
print("Running on NVIDIA CUDA GPU")
|
||||
DENSE_INPUT_PKL_URL = "https://huggingface.co/datasets/font-info/logprobs/resolve/main/sglang_baseline_2000.pkl"
|
||||
DENSE_TOLERANCE_MAX_DIFF = 1.5
|
||||
DENSE_TOLERANCE_MEAN_DIFF = 0.1
|
||||
DENSE_TOLERANCE_MAX_DIFF = 1e-5
|
||||
DENSE_TOLERANCE_MEAN_DIFF = 1e-5
|
||||
else:
|
||||
print("No GPU backend (CPU only)")
|
||||
raise ValueError("No GPU backend (CPU only)")
|
||||
|
||||
# Common configuration
|
||||
TOP_K = 20
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 2
|
||||
NUM_SAMPLES = 1000
|
||||
LOGPROB_SAMPLE_RATIO = 0.5
|
||||
TEMPERATURE = 1.0
|
||||
MAX_LEN = 20000
|
||||
|
||||
# Default output files
|
||||
DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl"
|
||||
DEFAULT_META_JSON = "baseline_meta_preview.json"
|
||||
|
||||
|
||||
def generate_baseline(
|
||||
baseline_file=DEFAULT_BASELINE_PKL,
|
||||
meta_file=DEFAULT_META_JSON,
|
||||
num_samples=NUM_SAMPLES,
|
||||
):
|
||||
"""Generate a local baseline for logprobs testing.
|
||||
|
||||
Args:
|
||||
baseline_file: Path to save the baseline pickle file
|
||||
meta_file: Path to save the metadata preview JSON file
|
||||
num_samples: Number of samples to generate
|
||||
"""
|
||||
print(f"SGLang version: {sgl.__version__}")
|
||||
print("Downloading ShareGPT dataset...")
|
||||
|
||||
# Download ShareGPT dataset
|
||||
try:
|
||||
response = requests.get(SHAREGPT_URL, timeout=30)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
print(f"Dataset size: {len(data)}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise Exception(f"Failed to download ShareGPT dataset: {e}") from e
|
||||
|
||||
# Filter and prepare texts
|
||||
texts = []
|
||||
for s in data:
|
||||
if "conversations" in s and len(s["conversations"]) > 0:
|
||||
try:
|
||||
text = s["conversations"][0]["value"]
|
||||
if isinstance(text, str) and len(text) <= MAX_LEN and len(text) >= 5500:
|
||||
texts.append(text)
|
||||
if len(texts) >= num_samples * 40: # Get more samples for filtering
|
||||
break
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
print(f"Warning: Skipping invalid conversation data: {e}")
|
||||
continue
|
||||
|
||||
if not texts:
|
||||
raise ValueError("No valid texts found in the dataset")
|
||||
|
||||
print(f"Loading tokenizer for {DENSE_MODEL_NAME}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(DENSE_MODEL_NAME, use_fast=True)
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
|
||||
print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...")
|
||||
engine = sgl.Engine(
|
||||
model_path=DENSE_MODEL_NAME,
|
||||
attention_backend="flashinfer",
|
||||
enable_deterministic_inference=True,
|
||||
random_seed=42,
|
||||
skip_tokenizer_init=True,
|
||||
mem_fraction_static=0.8,
|
||||
max_running_requests=1,
|
||||
)
|
||||
|
||||
records = []
|
||||
prompt_lengths = []
|
||||
|
||||
try:
|
||||
for i, text in enumerate(texts):
|
||||
if len(records) >= num_samples:
|
||||
break
|
||||
|
||||
try:
|
||||
ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
if len(ids) < 5:
|
||||
continue
|
||||
|
||||
start_pos = int(rng.integers(0, max(1, len(ids) - 3)))
|
||||
|
||||
outputs = engine.generate(
|
||||
input_ids=[ids],
|
||||
sampling_params={
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": TOP_K,
|
||||
"max_new_tokens": 1,
|
||||
},
|
||||
return_logprob=True,
|
||||
logprob_start_len=start_pos,
|
||||
top_logprobs_num=TOP_K,
|
||||
)
|
||||
meta = outputs[0]["meta_info"]
|
||||
|
||||
records.append(
|
||||
dict(id=i, text=text, ids=ids, start_pos=start_pos, meta=meta)
|
||||
)
|
||||
prompt_lengths.append(len(ids))
|
||||
|
||||
if (i + 1) % 50 == 0:
|
||||
print(f"Processed {len(records)}/{num_samples} samples")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to process sample {i}: {e}")
|
||||
continue
|
||||
|
||||
if not records:
|
||||
raise RuntimeError(
|
||||
"Failed to generate any baseline records. Please check the warnings above for errors."
|
||||
)
|
||||
|
||||
# Save baseline files
|
||||
with open(baseline_file, "wb") as f:
|
||||
pickle.dump(records, f)
|
||||
with open(meta_file, "w", encoding="utf-8") as f:
|
||||
json.dump(records[:2], f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"✅ Saved {len(records)} samples to {baseline_file}")
|
||||
print(f"✅ Meta preview saved to {meta_file}")
|
||||
|
||||
if prompt_lengths:
|
||||
avg_prompt_length = sum(prompt_lengths) / len(prompt_lengths)
|
||||
print(f"📊 Average prompt length: {avg_prompt_length:.2f} tokens")
|
||||
|
||||
finally:
|
||||
engine.shutdown()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
class TestLogprobsDense(unittest.TestCase):
|
||||
@@ -48,6 +216,8 @@ class TestLogprobsDense(unittest.TestCase):
|
||||
cls.engine = sgl.Engine(
|
||||
model_path=DENSE_MODEL_NAME,
|
||||
random_seed=42,
|
||||
attention_backend="flashinfer",
|
||||
enable_deterministic_inference=True,
|
||||
skip_tokenizer_init=True,
|
||||
mem_fraction_static=0.80,
|
||||
)
|
||||
@@ -58,31 +228,24 @@ class TestLogprobsDense(unittest.TestCase):
|
||||
cls.engine.shutdown()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def load_test_data(self):
|
||||
"""Load test data from Hugging Face dataset with retry mechanism."""
|
||||
print(f"Loading data from {DENSE_INPUT_PKL_URL}...")
|
||||
def load_test_data(self, baseline_file=None):
|
||||
"""Load test data from local baseline file. In test mode, only local baseline is supported."""
|
||||
if not baseline_file:
|
||||
raise ValueError("baseline_file is required in test mode")
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
response = requests.get(DENSE_INPUT_PKL_URL, timeout=30)
|
||||
response.raise_for_status()
|
||||
if not os.path.exists(baseline_file):
|
||||
raise FileNotFoundError(
|
||||
f"Baseline file not found: {baseline_file}. Please run 'gen' mode first to generate the baseline."
|
||||
)
|
||||
|
||||
with io.BytesIO(response.content) as f:
|
||||
records = pickle.load(f)
|
||||
|
||||
if not records:
|
||||
raise ValueError("Empty dataset")
|
||||
|
||||
print(f"Successfully loaded {len(records)} records")
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
print(f"Attempt {attempt + 1}/{MAX_RETRIES} failed: {e}")
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise Exception(
|
||||
f"Failed to load data after {MAX_RETRIES} attempts: {e}"
|
||||
)
|
||||
time.sleep(RETRY_DELAY)
|
||||
print(f"Loading local baseline from {baseline_file}...")
|
||||
try:
|
||||
with open(baseline_file, "rb") as f:
|
||||
records = pickle.load(f)
|
||||
print(f"Successfully loaded {len(records)} records from local baseline")
|
||||
return records
|
||||
except (IOError, pickle.PickleError) as e:
|
||||
raise Exception(f"Failed to load local baseline: {e}") from e
|
||||
|
||||
def compare_meta(self, baseline_meta, sglang_meta):
|
||||
"""Compare metadata between two outputs and return max and mean differences."""
|
||||
@@ -102,19 +265,21 @@ class TestLogprobsDense(unittest.TestCase):
|
||||
common_tokens = baseline_token_map.keys() & sglang_token_map.keys()
|
||||
self.assertGreaterEqual(
|
||||
len(common_tokens),
|
||||
TOP_K / 2,
|
||||
TOP_K,
|
||||
f"there are only {len(common_tokens)} common topk tokens that matches",
|
||||
)
|
||||
for token_id in common_tokens:
|
||||
diffs.append(
|
||||
abs(baseline_token_map[token_id] - sglang_token_map[token_id])
|
||||
)
|
||||
if not diffs:
|
||||
return 0.0, 0.0
|
||||
return max(diffs), float(np.mean(diffs))
|
||||
|
||||
def test_logprobs_comparison(self):
|
||||
def test_logprobs_comparison(self, baseline_file=None):
|
||||
"""Test the logprobs comparison functionality with different parameter combinations."""
|
||||
# Load test data with retry mechanism
|
||||
records = self.load_test_data()
|
||||
records = self.load_test_data(baseline_file)
|
||||
|
||||
with self.subTest(
|
||||
config={
|
||||
@@ -224,15 +389,6 @@ class TestLogprobsDense(unittest.TestCase):
|
||||
f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}",
|
||||
)
|
||||
|
||||
# Write results to GitHub summary
|
||||
summary_content = f"""
|
||||
- **Configuration**: {{"num_samples": {NUM_SAMPLES}, "logprob_sample_ratio": {LOGPROB_SAMPLE_RATIO}, "temperature": {TEMPERATURE}}}
|
||||
- **Max of max Δ**: {max_of_max:.6g}
|
||||
- **Mean of mean Δ**: {mean_of_mean:.6g}
|
||||
- **Status**: {'✅ Passed' if max_of_max <= DENSE_TOLERANCE_MAX_DIFF and mean_of_mean <= DENSE_TOLERANCE_MEAN_DIFF else '❌ Failed'}
|
||||
"""
|
||||
write_github_step_summary(summary_content)
|
||||
|
||||
# Basic validation
|
||||
self.assertIsInstance(all_max, list)
|
||||
self.assertIsInstance(all_mean, list)
|
||||
@@ -261,5 +417,52 @@ class TestLogprobsDense(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to handle command line arguments and run either generation or testing."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="SGLang Logprobs Test and Baseline Generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"mode",
|
||||
choices=["gen", "test"],
|
||||
help="Mode to run: 'gen' to generate baseline, 'test' to run tests",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == "gen":
|
||||
print("🚀 Generating baseline...")
|
||||
generate_baseline()
|
||||
print(f"\n✅ Baseline generation complete!")
|
||||
print(f"📁 Baseline saved to: {DEFAULT_BASELINE_PKL}")
|
||||
print(f"📁 Metadata preview saved to: {DEFAULT_META_JSON}")
|
||||
print(f"\n💡 Next steps:")
|
||||
print(f" 1. Make your code changes")
|
||||
print(f" 2. Run: python {__file__} test")
|
||||
|
||||
elif args.mode == "test":
|
||||
print("🧪 Running logprobs test...")
|
||||
if not os.path.exists(DEFAULT_BASELINE_PKL):
|
||||
print(f"❌ Baseline file not found: {DEFAULT_BASELINE_PKL}")
|
||||
print(f"💡 Generate baseline first by running:")
|
||||
print(f" python {__file__} gen")
|
||||
print(f" This will download ShareGPT data and generate a local baseline.")
|
||||
return 1
|
||||
|
||||
# Set environment variable for testing
|
||||
os.environ["RETURN_ORIGINAL_LOGPROB"] = "True"
|
||||
|
||||
# Create test instance and run
|
||||
test_instance = TestLogprobsDense()
|
||||
test_instance.setUpClass()
|
||||
try:
|
||||
test_instance.test_logprobs_comparison(baseline_file=DEFAULT_BASELINE_PKL)
|
||||
print("\n✅ Test completed successfully!")
|
||||
finally:
|
||||
test_instance.tearDownClass()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
exit(main())
|
||||
|
||||
Reference in New Issue
Block a user