Add Logprobs unit test with a loose threshold (#10230)
Co-authored-by: Yusheng Su <yushengsu.thu@gmail.com> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: Ryan <ryan@ryanmini.mynetworksettings.com>
This commit is contained in:
@@ -81,6 +81,7 @@ suites = {
|
||||
TestFile("test_input_embeddings.py", 38),
|
||||
TestFile("test_io_struct.py", 8),
|
||||
TestFile("test_jinja_template_utils.py", 1),
|
||||
TestFile("test_logprobs.py", 55),
|
||||
TestFile("test_metrics.py", 32),
|
||||
TestFile("test_metrics_utils.py", 1),
|
||||
TestFile("test_mla.py", 167),
|
||||
|
||||
265
test/srt/test_logprobs.py
Normal file
265
test/srt/test_logprobs.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
write_github_step_summary,
|
||||
)
|
||||
|
||||
# Dense model configuration
|
||||
DENSE_MODEL_NAME = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
if torch.version.hip is not None:
|
||||
print("Running on AMD ROCm GPU")
|
||||
DENSE_INPUT_PKL_URL = "https://huggingface.co/datasets/yushengsu/logprobs/resolve/main/sglang_baseline_2000_amd.pkl"
|
||||
DENSE_TOLERANCE_MAX_DIFF = 1.4
|
||||
DENSE_TOLERANCE_MEAN_DIFF = 0.1
|
||||
elif torch.version.cuda is not None:
|
||||
print("Running on NVIDIA CUDA GPU")
|
||||
DENSE_INPUT_PKL_URL = "https://huggingface.co/datasets/font-info/logprobs/resolve/main/sglang_baseline_2000.pkl"
|
||||
DENSE_TOLERANCE_MAX_DIFF = 1.5
|
||||
DENSE_TOLERANCE_MEAN_DIFF = 0.1
|
||||
else:
|
||||
print("No GPU backend (CPU only)")
|
||||
|
||||
# Common configuration
|
||||
TOP_K = 20
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 2
|
||||
NUM_SAMPLES = 1000
|
||||
LOGPROB_SAMPLE_RATIO = 0.5
|
||||
TEMPERATURE = 1.0
|
||||
|
||||
|
||||
class TestLogprobsDense(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Set up the test class - initialize the engine once for all tests."""
|
||||
print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...")
|
||||
cls.engine = sgl.Engine(
|
||||
model_path=DENSE_MODEL_NAME,
|
||||
random_seed=42,
|
||||
skip_tokenizer_init=True,
|
||||
mem_fraction_static=0.85,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""Clean up after all tests - shutdown the engine."""
|
||||
cls.engine.shutdown()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def load_test_data(self):
|
||||
"""Load test data from Hugging Face dataset with retry mechanism."""
|
||||
print(f"Loading data from {DENSE_INPUT_PKL_URL}...")
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
response = requests.get(DENSE_INPUT_PKL_URL, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
with io.BytesIO(response.content) as f:
|
||||
records = pickle.load(f)
|
||||
|
||||
if not records:
|
||||
raise ValueError("Empty dataset")
|
||||
|
||||
print(f"Successfully loaded {len(records)} records")
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
print(f"Attempt {attempt + 1}/{MAX_RETRIES} failed: {e}")
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise Exception(
|
||||
f"Failed to load data after {MAX_RETRIES} attempts: {e}"
|
||||
)
|
||||
time.sleep(RETRY_DELAY)
|
||||
|
||||
def compare_meta(self, baseline_meta, sglang_meta):
|
||||
"""Compare metadata between two outputs and return max and mean differences."""
|
||||
diffs = []
|
||||
for key in ["input_top_logprobs", "output_top_logprobs"]:
|
||||
baseline_logprobs, sglang_logprobs = baseline_meta[key], sglang_meta[key]
|
||||
self.assertEqual(
|
||||
len(baseline_logprobs),
|
||||
len(sglang_logprobs),
|
||||
f"Length of {key} is not equal, sglang did not return the correct number of log probs(should be top 20)",
|
||||
)
|
||||
for baseline_entry, sglang_entry in zip(baseline_logprobs, sglang_logprobs):
|
||||
if not baseline_entry or not sglang_entry:
|
||||
continue
|
||||
baseline_token_map = {tid: lp for lp, tid, _ in baseline_entry}
|
||||
sglang_token_map = {tid: lp for lp, tid, _ in sglang_entry}
|
||||
common_tokens = baseline_token_map.keys() & sglang_token_map.keys()
|
||||
self.assertGreaterEqual(
|
||||
len(common_tokens),
|
||||
TOP_K / 2,
|
||||
f"there are only {len(common_tokens)} common topk tokens that matches",
|
||||
)
|
||||
for token_id in common_tokens:
|
||||
diffs.append(
|
||||
abs(baseline_token_map[token_id] - sglang_token_map[token_id])
|
||||
)
|
||||
return max(diffs), float(np.mean(diffs))
|
||||
|
||||
def test_logprobs_comparison(self):
|
||||
"""Test the logprobs comparison functionality with different parameter combinations."""
|
||||
# Load test data with retry mechanism
|
||||
records = self.load_test_data()
|
||||
|
||||
with self.subTest(
|
||||
config={
|
||||
"num_samples": NUM_SAMPLES,
|
||||
"logprob_sample_ratio": LOGPROB_SAMPLE_RATIO,
|
||||
"temperature": TEMPERATURE,
|
||||
}
|
||||
):
|
||||
|
||||
# Sample records for this config
|
||||
test_records = random.sample(records, k=min(NUM_SAMPLES, len(records)))
|
||||
random.shuffle(test_records)
|
||||
|
||||
# Calculate how many samples should return logprobs
|
||||
logprob_count = int(len(test_records) * LOGPROB_SAMPLE_RATIO)
|
||||
print(
|
||||
f"Testing with {len(test_records)} samples, temperature={TEMPERATURE}"
|
||||
)
|
||||
print(
|
||||
f"Will return logprobs for {logprob_count} samples (ratio: {LOGPROB_SAMPLE_RATIO})"
|
||||
)
|
||||
|
||||
all_max, all_mean = [], []
|
||||
logprob_returned_count = 0
|
||||
|
||||
# Process all records at once
|
||||
input_ids = [rec["ids"] for rec in test_records]
|
||||
logprob_start_lens = [rec["start_pos"] for rec in test_records]
|
||||
|
||||
# Determine which samples should return logprobs (randomly selected)
|
||||
logprob_indices = set(
|
||||
random.sample(range(len(test_records)), logprob_count)
|
||||
)
|
||||
return_logprob_array = [
|
||||
sample_idx in logprob_indices for sample_idx in range(len(test_records))
|
||||
]
|
||||
|
||||
# Sampling param per request
|
||||
sampling_params = [
|
||||
{
|
||||
"temperature": TEMPERATURE,
|
||||
"top_p": 1.0,
|
||||
"top_k": TOP_K,
|
||||
"max_new_tokens": 1,
|
||||
}
|
||||
for _ in test_records
|
||||
]
|
||||
|
||||
outputs = self.engine.generate(
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=return_logprob_array,
|
||||
logprob_start_len=logprob_start_lens,
|
||||
top_logprobs_num=TOP_K,
|
||||
)
|
||||
|
||||
for sample_idx, (rec, output) in enumerate(zip(test_records, outputs)):
|
||||
# Only compare logprobs for samples that should have them
|
||||
if sample_idx in logprob_indices:
|
||||
# Safe access to meta_info and input_top_logprobs
|
||||
meta_info = output.get("meta_info")
|
||||
input_top_logprobs = (
|
||||
meta_info.get("input_top_logprobs") if meta_info else None
|
||||
)
|
||||
|
||||
self.assertIsNotNone(
|
||||
input_top_logprobs,
|
||||
f"return_logprob enabled on this sample, but input_top_logprobs is None (length: {len(input_top_logprobs) if input_top_logprobs is not None else 'N/A'})",
|
||||
)
|
||||
baseline_meta = rec["meta"]
|
||||
sglang_meta = meta_info
|
||||
|
||||
max_diff, mean_diff = self.compare_meta(baseline_meta, sglang_meta)
|
||||
all_max.append(max_diff)
|
||||
all_mean.append(mean_diff)
|
||||
logprob_returned_count += 1
|
||||
else:
|
||||
# Verify that logprobs were not returned for this sample
|
||||
meta_info = output.get("meta_info")
|
||||
input_top_logprobs = (
|
||||
meta_info.get("input_top_logprobs") if meta_info else None
|
||||
)
|
||||
output_token_ids_logprobs = (
|
||||
meta_info.get("output_token_ids_logprobs")
|
||||
if meta_info
|
||||
else None
|
||||
)
|
||||
|
||||
self.assertFalse(
|
||||
input_top_logprobs,
|
||||
f"return_logprob is disabled on this sample, Sample {sample_idx} should not have logprobs, content: {output_token_ids_logprobs}",
|
||||
)
|
||||
|
||||
max_of_max = max(all_max) if all_max else 0.0
|
||||
mean_of_mean = np.mean(all_mean) if all_mean else 0.0
|
||||
|
||||
print(f"max Δ={max_of_max:.6g}")
|
||||
print(f"mean Δ={mean_of_mean:.6g}")
|
||||
print(
|
||||
f"logprobs returned for {logprob_returned_count} samples (expected: {logprob_count})"
|
||||
)
|
||||
|
||||
# Verify correct number of logprobs returned
|
||||
self.assertEqual(
|
||||
logprob_returned_count,
|
||||
logprob_count,
|
||||
f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}",
|
||||
)
|
||||
|
||||
# Write results to GitHub summary
|
||||
summary_content = f"""
|
||||
- **Configuration**: {{"num_samples": {NUM_SAMPLES}, "logprob_sample_ratio": {LOGPROB_SAMPLE_RATIO}, "temperature": {TEMPERATURE}}}
|
||||
- **Max of max Δ**: {max_of_max:.6g}
|
||||
- **Mean of mean Δ**: {mean_of_mean:.6g}
|
||||
- **Status**: {'✅ Passed' if max_of_max <= DENSE_TOLERANCE_MAX_DIFF and mean_of_mean <= DENSE_TOLERANCE_MEAN_DIFF else '❌ Failed'}
|
||||
"""
|
||||
write_github_step_summary(summary_content)
|
||||
|
||||
# Basic validation
|
||||
self.assertIsInstance(all_max, list)
|
||||
self.assertIsInstance(all_mean, list)
|
||||
self.assertGreater(
|
||||
len(all_max),
|
||||
0,
|
||||
f"No test samples processed for config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}}",
|
||||
)
|
||||
|
||||
# Tolerance checks with clear error messages
|
||||
failed_samples = []
|
||||
for sample_idx, (max_diff, mean_diff) in enumerate(zip(all_max, all_mean)):
|
||||
if max_diff > DENSE_TOLERANCE_MAX_DIFF:
|
||||
failed_samples.append(
|
||||
f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}"
|
||||
)
|
||||
if mean_diff > DENSE_TOLERANCE_MEAN_DIFF:
|
||||
failed_samples.append(
|
||||
f"Sample {sample_idx}: mean_diff={mean_diff:.6g} > {DENSE_TOLERANCE_MEAN_DIFF}"
|
||||
)
|
||||
|
||||
if failed_samples:
|
||||
self.fail(
|
||||
f"Config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}} - Tolerance exceeded in {len(failed_samples)} samples:\n"
|
||||
+ "\n".join(failed_samples[:5])
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user