[logprobs] Enable local deterministic logrprobs testing with strict threshold (#10994)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -86,7 +86,6 @@ suites = {
|
|||||||
TestFile("test_input_embeddings.py", 38),
|
TestFile("test_input_embeddings.py", 38),
|
||||||
TestFile("test_io_struct.py", 8),
|
TestFile("test_io_struct.py", 8),
|
||||||
TestFile("test_jinja_template_utils.py", 1),
|
TestFile("test_jinja_template_utils.py", 1),
|
||||||
TestFile("test_logprobs.py", 55),
|
|
||||||
TestFile("test_mamba_unittest.py", 4),
|
TestFile("test_mamba_unittest.py", 4),
|
||||||
TestFile("test_metrics.py", 32),
|
TestFile("test_metrics.py", 32),
|
||||||
TestFile("test_metrics_utils.py", 1),
|
TestFile("test_metrics_utils.py", 1),
|
||||||
|
|||||||
@@ -1,42 +1,210 @@
|
|||||||
import io
|
"""
|
||||||
|
Logprobs Accuracy Test for SGLang
|
||||||
|
|
||||||
|
======================
|
||||||
|
With deterministic/batch invariant kernels, we can ensure that SGLang produces exactly the same
|
||||||
|
logprobs results for identical inputs. However, logprobs are highly sensitive to GPU hardware,
|
||||||
|
kernels, torch versions, and other factors, so we cannot maintain a unified logprobs baseline
|
||||||
|
across different machines.
|
||||||
|
|
||||||
|
This test is designed to be run locally by contributors to verify logprobs accuracy
|
||||||
|
before making changes to related code.
|
||||||
|
When submitting changes that affect logprobs computation, please:
|
||||||
|
1. Generate baseline
|
||||||
|
2. Run test
|
||||||
|
3. Submit results
|
||||||
|
|
||||||
|
We really appreciate your effort and contribution to SGLang!
|
||||||
|
|
||||||
|
======================
|
||||||
|
What does this test do?
|
||||||
|
This test fetches 1000 samples from the ShareGPT dataset, generates logprobs for each sample,
|
||||||
|
and saves them as a baseline. Then, by running the test mode, it validates the accuracy of
|
||||||
|
logprobs by comparing them against the baseline.
|
||||||
|
|
||||||
|
This test ensures that:
|
||||||
|
- the boundary of log probs requests are correct, eg, the index for tokens that required log probs are strictly followed
|
||||||
|
- logprobs remain invariant between test runs, and also before and after your code changes;
|
||||||
|
|
||||||
|
======================
|
||||||
|
Usage
|
||||||
|
|
||||||
|
Step 1: Generate Baseline (Before Code Changes)
|
||||||
|
```bash
|
||||||
|
python test/srt/test_logprobs.py gen
|
||||||
|
```
|
||||||
|
|
||||||
|
Step 2: Test Against Baseline (After Code Changes)
|
||||||
|
```bash
|
||||||
|
python test/srt/test_logprobs.py test
|
||||||
|
```
|
||||||
|
This tests your changes against the locally generated baseline from Step 1.
|
||||||
|
The test passes if the maximum and mean differences are within the tolerance thresholds.
|
||||||
|
======================
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
import time
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
||||||
write_github_step_summary,
|
# Configuration
|
||||||
|
DENSE_MODEL_NAME = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
SHAREGPT_URL = (
|
||||||
|
"https://huggingface.co/datasets/anon8231489123/"
|
||||||
|
"ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Dense model configuration
|
# Hardware-specific configuration
|
||||||
DENSE_MODEL_NAME = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
if torch.version.cuda is not None:
|
||||||
if torch.version.hip is not None:
|
|
||||||
print("Running on AMD ROCm GPU")
|
|
||||||
DENSE_INPUT_PKL_URL = "https://huggingface.co/datasets/yushengsu/logprobs/resolve/main/sglang_baseline_2000_amd.pkl"
|
|
||||||
DENSE_TOLERANCE_MAX_DIFF = 1.4
|
|
||||||
DENSE_TOLERANCE_MEAN_DIFF = 0.1
|
|
||||||
elif torch.version.cuda is not None:
|
|
||||||
print("Running on NVIDIA CUDA GPU")
|
print("Running on NVIDIA CUDA GPU")
|
||||||
DENSE_INPUT_PKL_URL = "https://huggingface.co/datasets/font-info/logprobs/resolve/main/sglang_baseline_2000.pkl"
|
DENSE_TOLERANCE_MAX_DIFF = 1e-5
|
||||||
DENSE_TOLERANCE_MAX_DIFF = 1.5
|
DENSE_TOLERANCE_MEAN_DIFF = 1e-5
|
||||||
DENSE_TOLERANCE_MEAN_DIFF = 0.1
|
|
||||||
else:
|
else:
|
||||||
print("No GPU backend (CPU only)")
|
print("No GPU backend (CPU only)")
|
||||||
|
raise ValueError("No GPU backend (CPU only)")
|
||||||
|
|
||||||
# Common configuration
|
# Common configuration
|
||||||
TOP_K = 20
|
TOP_K = 20
|
||||||
MAX_RETRIES = 3
|
|
||||||
RETRY_DELAY = 2
|
|
||||||
NUM_SAMPLES = 1000
|
NUM_SAMPLES = 1000
|
||||||
LOGPROB_SAMPLE_RATIO = 0.5
|
LOGPROB_SAMPLE_RATIO = 0.5
|
||||||
TEMPERATURE = 1.0
|
TEMPERATURE = 1.0
|
||||||
|
MAX_LEN = 20000
|
||||||
|
|
||||||
|
# Default output files
|
||||||
|
DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl"
|
||||||
|
DEFAULT_META_JSON = "baseline_meta_preview.json"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_baseline(
|
||||||
|
baseline_file=DEFAULT_BASELINE_PKL,
|
||||||
|
meta_file=DEFAULT_META_JSON,
|
||||||
|
num_samples=NUM_SAMPLES,
|
||||||
|
):
|
||||||
|
"""Generate a local baseline for logprobs testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
baseline_file: Path to save the baseline pickle file
|
||||||
|
meta_file: Path to save the metadata preview JSON file
|
||||||
|
num_samples: Number of samples to generate
|
||||||
|
"""
|
||||||
|
print(f"SGLang version: {sgl.__version__}")
|
||||||
|
print("Downloading ShareGPT dataset...")
|
||||||
|
|
||||||
|
# Download ShareGPT dataset
|
||||||
|
try:
|
||||||
|
response = requests.get(SHAREGPT_URL, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
print(f"Dataset size: {len(data)}")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise Exception(f"Failed to download ShareGPT dataset: {e}") from e
|
||||||
|
|
||||||
|
# Filter and prepare texts
|
||||||
|
texts = []
|
||||||
|
for s in data:
|
||||||
|
if "conversations" in s and len(s["conversations"]) > 0:
|
||||||
|
try:
|
||||||
|
text = s["conversations"][0]["value"]
|
||||||
|
if isinstance(text, str) and len(text) <= MAX_LEN and len(text) >= 5500:
|
||||||
|
texts.append(text)
|
||||||
|
if len(texts) >= num_samples * 40: # Get more samples for filtering
|
||||||
|
break
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
print(f"Warning: Skipping invalid conversation data: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
raise ValueError("No valid texts found in the dataset")
|
||||||
|
|
||||||
|
print(f"Loading tokenizer for {DENSE_MODEL_NAME}...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(DENSE_MODEL_NAME, use_fast=True)
|
||||||
|
|
||||||
|
rng = np.random.default_rng(42)
|
||||||
|
|
||||||
|
print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...")
|
||||||
|
engine = sgl.Engine(
|
||||||
|
model_path=DENSE_MODEL_NAME,
|
||||||
|
attention_backend="flashinfer",
|
||||||
|
enable_deterministic_inference=True,
|
||||||
|
random_seed=42,
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
mem_fraction_static=0.8,
|
||||||
|
max_running_requests=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
records = []
|
||||||
|
prompt_lengths = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
if len(records) >= num_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
ids = tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
if len(ids) < 5:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_pos = int(rng.integers(0, max(1, len(ids) - 3)))
|
||||||
|
|
||||||
|
outputs = engine.generate(
|
||||||
|
input_ids=[ids],
|
||||||
|
sampling_params={
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"top_k": TOP_K,
|
||||||
|
"max_new_tokens": 1,
|
||||||
|
},
|
||||||
|
return_logprob=True,
|
||||||
|
logprob_start_len=start_pos,
|
||||||
|
top_logprobs_num=TOP_K,
|
||||||
|
)
|
||||||
|
meta = outputs[0]["meta_info"]
|
||||||
|
|
||||||
|
records.append(
|
||||||
|
dict(id=i, text=text, ids=ids, start_pos=start_pos, meta=meta)
|
||||||
|
)
|
||||||
|
prompt_lengths.append(len(ids))
|
||||||
|
|
||||||
|
if (i + 1) % 50 == 0:
|
||||||
|
print(f"Processed {len(records)}/{num_samples} samples")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to process sample {i}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to generate any baseline records. Please check the warnings above for errors."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save baseline files
|
||||||
|
with open(baseline_file, "wb") as f:
|
||||||
|
pickle.dump(records, f)
|
||||||
|
with open(meta_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(records[:2], f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f"✅ Saved {len(records)} samples to {baseline_file}")
|
||||||
|
print(f"✅ Meta preview saved to {meta_file}")
|
||||||
|
|
||||||
|
if prompt_lengths:
|
||||||
|
avg_prompt_length = sum(prompt_lengths) / len(prompt_lengths)
|
||||||
|
print(f"📊 Average prompt length: {avg_prompt_length:.2f} tokens")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
engine.shutdown()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
class TestLogprobsDense(unittest.TestCase):
|
class TestLogprobsDense(unittest.TestCase):
|
||||||
@@ -48,6 +216,8 @@ class TestLogprobsDense(unittest.TestCase):
|
|||||||
cls.engine = sgl.Engine(
|
cls.engine = sgl.Engine(
|
||||||
model_path=DENSE_MODEL_NAME,
|
model_path=DENSE_MODEL_NAME,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
|
attention_backend="flashinfer",
|
||||||
|
enable_deterministic_inference=True,
|
||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
mem_fraction_static=0.80,
|
mem_fraction_static=0.80,
|
||||||
)
|
)
|
||||||
@@ -58,31 +228,24 @@ class TestLogprobsDense(unittest.TestCase):
|
|||||||
cls.engine.shutdown()
|
cls.engine.shutdown()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def load_test_data(self):
|
def load_test_data(self, baseline_file=None):
|
||||||
"""Load test data from Hugging Face dataset with retry mechanism."""
|
"""Load test data from local baseline file. In test mode, only local baseline is supported."""
|
||||||
print(f"Loading data from {DENSE_INPUT_PKL_URL}...")
|
if not baseline_file:
|
||||||
|
raise ValueError("baseline_file is required in test mode")
|
||||||
|
|
||||||
for attempt in range(MAX_RETRIES):
|
if not os.path.exists(baseline_file):
|
||||||
try:
|
raise FileNotFoundError(
|
||||||
response = requests.get(DENSE_INPUT_PKL_URL, timeout=30)
|
f"Baseline file not found: {baseline_file}. Please run 'gen' mode first to generate the baseline."
|
||||||
response.raise_for_status()
|
)
|
||||||
|
|
||||||
with io.BytesIO(response.content) as f:
|
print(f"Loading local baseline from {baseline_file}...")
|
||||||
records = pickle.load(f)
|
try:
|
||||||
|
with open(baseline_file, "rb") as f:
|
||||||
if not records:
|
records = pickle.load(f)
|
||||||
raise ValueError("Empty dataset")
|
print(f"Successfully loaded {len(records)} records from local baseline")
|
||||||
|
return records
|
||||||
print(f"Successfully loaded {len(records)} records")
|
except (IOError, pickle.PickleError) as e:
|
||||||
return records
|
raise Exception(f"Failed to load local baseline: {e}") from e
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Attempt {attempt + 1}/{MAX_RETRIES} failed: {e}")
|
|
||||||
if attempt == MAX_RETRIES - 1:
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to load data after {MAX_RETRIES} attempts: {e}"
|
|
||||||
)
|
|
||||||
time.sleep(RETRY_DELAY)
|
|
||||||
|
|
||||||
def compare_meta(self, baseline_meta, sglang_meta):
|
def compare_meta(self, baseline_meta, sglang_meta):
|
||||||
"""Compare metadata between two outputs and return max and mean differences."""
|
"""Compare metadata between two outputs and return max and mean differences."""
|
||||||
@@ -102,19 +265,21 @@ class TestLogprobsDense(unittest.TestCase):
|
|||||||
common_tokens = baseline_token_map.keys() & sglang_token_map.keys()
|
common_tokens = baseline_token_map.keys() & sglang_token_map.keys()
|
||||||
self.assertGreaterEqual(
|
self.assertGreaterEqual(
|
||||||
len(common_tokens),
|
len(common_tokens),
|
||||||
TOP_K / 2,
|
TOP_K,
|
||||||
f"there are only {len(common_tokens)} common topk tokens that matches",
|
f"there are only {len(common_tokens)} common topk tokens that matches",
|
||||||
)
|
)
|
||||||
for token_id in common_tokens:
|
for token_id in common_tokens:
|
||||||
diffs.append(
|
diffs.append(
|
||||||
abs(baseline_token_map[token_id] - sglang_token_map[token_id])
|
abs(baseline_token_map[token_id] - sglang_token_map[token_id])
|
||||||
)
|
)
|
||||||
|
if not diffs:
|
||||||
|
return 0.0, 0.0
|
||||||
return max(diffs), float(np.mean(diffs))
|
return max(diffs), float(np.mean(diffs))
|
||||||
|
|
||||||
def test_logprobs_comparison(self):
|
def test_logprobs_comparison(self, baseline_file=None):
|
||||||
"""Test the logprobs comparison functionality with different parameter combinations."""
|
"""Test the logprobs comparison functionality with different parameter combinations."""
|
||||||
# Load test data with retry mechanism
|
# Load test data with retry mechanism
|
||||||
records = self.load_test_data()
|
records = self.load_test_data(baseline_file)
|
||||||
|
|
||||||
with self.subTest(
|
with self.subTest(
|
||||||
config={
|
config={
|
||||||
@@ -224,15 +389,6 @@ class TestLogprobsDense(unittest.TestCase):
|
|||||||
f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}",
|
f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Write results to GitHub summary
|
|
||||||
summary_content = f"""
|
|
||||||
- **Configuration**: {{"num_samples": {NUM_SAMPLES}, "logprob_sample_ratio": {LOGPROB_SAMPLE_RATIO}, "temperature": {TEMPERATURE}}}
|
|
||||||
- **Max of max Δ**: {max_of_max:.6g}
|
|
||||||
- **Mean of mean Δ**: {mean_of_mean:.6g}
|
|
||||||
- **Status**: {'✅ Passed' if max_of_max <= DENSE_TOLERANCE_MAX_DIFF and mean_of_mean <= DENSE_TOLERANCE_MEAN_DIFF else '❌ Failed'}
|
|
||||||
"""
|
|
||||||
write_github_step_summary(summary_content)
|
|
||||||
|
|
||||||
# Basic validation
|
# Basic validation
|
||||||
self.assertIsInstance(all_max, list)
|
self.assertIsInstance(all_max, list)
|
||||||
self.assertIsInstance(all_mean, list)
|
self.assertIsInstance(all_mean, list)
|
||||||
@@ -261,5 +417,52 @@ class TestLogprobsDense(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to handle command line arguments and run either generation or testing."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="SGLang Logprobs Test and Baseline Generation"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"mode",
|
||||||
|
choices=["gen", "test"],
|
||||||
|
help="Mode to run: 'gen' to generate baseline, 'test' to run tests",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.mode == "gen":
|
||||||
|
print("🚀 Generating baseline...")
|
||||||
|
generate_baseline()
|
||||||
|
print(f"\n✅ Baseline generation complete!")
|
||||||
|
print(f"📁 Baseline saved to: {DEFAULT_BASELINE_PKL}")
|
||||||
|
print(f"📁 Metadata preview saved to: {DEFAULT_META_JSON}")
|
||||||
|
print(f"\n💡 Next steps:")
|
||||||
|
print(f" 1. Make your code changes")
|
||||||
|
print(f" 2. Run: python {__file__} test")
|
||||||
|
|
||||||
|
elif args.mode == "test":
|
||||||
|
print("🧪 Running logprobs test...")
|
||||||
|
if not os.path.exists(DEFAULT_BASELINE_PKL):
|
||||||
|
print(f"❌ Baseline file not found: {DEFAULT_BASELINE_PKL}")
|
||||||
|
print(f"💡 Generate baseline first by running:")
|
||||||
|
print(f" python {__file__} gen")
|
||||||
|
print(f" This will download ShareGPT data and generate a local baseline.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Set environment variable for testing
|
||||||
|
os.environ["RETURN_ORIGINAL_LOGPROB"] = "True"
|
||||||
|
|
||||||
|
# Create test instance and run
|
||||||
|
test_instance = TestLogprobsDense()
|
||||||
|
test_instance.setUpClass()
|
||||||
|
try:
|
||||||
|
test_instance.test_logprobs_comparison(baseline_file=DEFAULT_BASELINE_PKL)
|
||||||
|
print("\n✅ Test completed successfully!")
|
||||||
|
finally:
|
||||||
|
test_instance.tearDownClass()
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
exit(main())
|
||||||
|
|||||||
Reference in New Issue
Block a user