diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index df3c87838..51356c803 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -86,7 +86,6 @@ suites = { TestFile("test_input_embeddings.py", 38), TestFile("test_io_struct.py", 8), TestFile("test_jinja_template_utils.py", 1), - TestFile("test_logprobs.py", 55), TestFile("test_mamba_unittest.py", 4), TestFile("test_metrics.py", 32), TestFile("test_metrics_utils.py", 1), diff --git a/test/srt/test_logprobs.py b/test/srt/test_logprobs.py index 6af92e633..ba0663430 100644 --- a/test/srt/test_logprobs.py +++ b/test/srt/test_logprobs.py @@ -1,42 +1,210 @@ -import io +""" +Logprobs Accuracy Test for SGLang + +====================== +With deterministic/batch invariant kernels, we can ensure that SGLang produces exactly the same +logprobs results for identical inputs. However, logprobs are highly sensitive to GPU hardware, +kernels, torch versions, and other factors, so we cannot maintain a unified logprobs baseline +across different machines. + +This test is designed to be run locally by contributors to verify logprobs accuracy +before making changes to related code. +When submitting changes that affect logprobs computation, please: +1. Generate baseline +2. Run test +3. Submit results + +We really appreciate your effort and contribution to SGLang! + +====================== +What does this test do? +This test fetches 1000 samples from the ShareGPT dataset, generates logprobs for each sample, +and saves them as a baseline. Then, by running the test mode, it validates the accuracy of +logprobs by comparing them against the baseline. + +This test ensures that: +- the boundary of log probs requests are correct, eg, the index for tokens that required log probs are strictly followed +- logprobs remain invariant between test runs, and also before and after your code changes; + +====================== +Usage + +Step 1: Generate Baseline (Before Code Changes) +```bash +python test/srt/test_logprobs.py gen +``` + +Step 2: Test Against Baseline (After Code Changes) +```bash +python test/srt/test_logprobs.py test +``` +This tests your changes against the locally generated baseline from Step 1. +The test passes if the maximum and mean differences are within the tolerance thresholds. +====================== +""" + +import argparse +import json import os import pickle import random -import time import unittest import numpy as np import requests import torch +from transformers import AutoTokenizer import sglang as sgl -from sglang.test.test_utils import ( - DEFAULT_SMALL_MODEL_NAME_FOR_TEST, - write_github_step_summary, +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + +# Configuration +DENSE_MODEL_NAME = DEFAULT_SMALL_MODEL_NAME_FOR_TEST +SHAREGPT_URL = ( + "https://huggingface.co/datasets/anon8231489123/" + "ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" ) -# Dense model configuration -DENSE_MODEL_NAME = DEFAULT_SMALL_MODEL_NAME_FOR_TEST -if torch.version.hip is not None: - print("Running on AMD ROCm GPU") - DENSE_INPUT_PKL_URL = "https://huggingface.co/datasets/yushengsu/logprobs/resolve/main/sglang_baseline_2000_amd.pkl" - DENSE_TOLERANCE_MAX_DIFF = 1.4 - DENSE_TOLERANCE_MEAN_DIFF = 0.1 -elif torch.version.cuda is not None: +# Hardware-specific configuration +if torch.version.cuda is not None: print("Running on NVIDIA CUDA GPU") - DENSE_INPUT_PKL_URL = "https://huggingface.co/datasets/font-info/logprobs/resolve/main/sglang_baseline_2000.pkl" - DENSE_TOLERANCE_MAX_DIFF = 1.5 - DENSE_TOLERANCE_MEAN_DIFF = 0.1 + DENSE_TOLERANCE_MAX_DIFF = 1e-5 + DENSE_TOLERANCE_MEAN_DIFF = 1e-5 else: print("No GPU backend (CPU only)") + raise ValueError("No GPU backend (CPU only)") # Common configuration TOP_K = 20 -MAX_RETRIES = 3 -RETRY_DELAY = 2 NUM_SAMPLES = 1000 LOGPROB_SAMPLE_RATIO = 0.5 TEMPERATURE = 1.0 +MAX_LEN = 20000 + +# Default output files +DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl" +DEFAULT_META_JSON = "baseline_meta_preview.json" + + +def generate_baseline( + baseline_file=DEFAULT_BASELINE_PKL, + meta_file=DEFAULT_META_JSON, + num_samples=NUM_SAMPLES, +): + """Generate a local baseline for logprobs testing. + + Args: + baseline_file: Path to save the baseline pickle file + meta_file: Path to save the metadata preview JSON file + num_samples: Number of samples to generate + """ + print(f"SGLang version: {sgl.__version__}") + print("Downloading ShareGPT dataset...") + + # Download ShareGPT dataset + try: + response = requests.get(SHAREGPT_URL, timeout=30) + response.raise_for_status() + data = response.json() + print(f"Dataset size: {len(data)}") + except requests.exceptions.RequestException as e: + raise Exception(f"Failed to download ShareGPT dataset: {e}") from e + + # Filter and prepare texts + texts = [] + for s in data: + if "conversations" in s and len(s["conversations"]) > 0: + try: + text = s["conversations"][0]["value"] + if isinstance(text, str) and len(text) <= MAX_LEN and len(text) >= 5500: + texts.append(text) + if len(texts) >= num_samples * 40: # Get more samples for filtering + break + except (KeyError, IndexError, TypeError) as e: + print(f"Warning: Skipping invalid conversation data: {e}") + continue + + if not texts: + raise ValueError("No valid texts found in the dataset") + + print(f"Loading tokenizer for {DENSE_MODEL_NAME}...") + tokenizer = AutoTokenizer.from_pretrained(DENSE_MODEL_NAME, use_fast=True) + + rng = np.random.default_rng(42) + + print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...") + engine = sgl.Engine( + model_path=DENSE_MODEL_NAME, + attention_backend="flashinfer", + enable_deterministic_inference=True, + random_seed=42, + skip_tokenizer_init=True, + mem_fraction_static=0.8, + max_running_requests=1, + ) + + records = [] + prompt_lengths = [] + + try: + for i, text in enumerate(texts): + if len(records) >= num_samples: + break + + try: + ids = tokenizer.encode(text, add_special_tokens=False) + if len(ids) < 5: + continue + + start_pos = int(rng.integers(0, max(1, len(ids) - 3))) + + outputs = engine.generate( + input_ids=[ids], + sampling_params={ + "temperature": 1.0, + "top_p": 1.0, + "top_k": TOP_K, + "max_new_tokens": 1, + }, + return_logprob=True, + logprob_start_len=start_pos, + top_logprobs_num=TOP_K, + ) + meta = outputs[0]["meta_info"] + + records.append( + dict(id=i, text=text, ids=ids, start_pos=start_pos, meta=meta) + ) + prompt_lengths.append(len(ids)) + + if (i + 1) % 50 == 0: + print(f"Processed {len(records)}/{num_samples} samples") + + except Exception as e: + print(f"Warning: Failed to process sample {i}: {e}") + continue + + if not records: + raise RuntimeError( + "Failed to generate any baseline records. Please check the warnings above for errors." + ) + + # Save baseline files + with open(baseline_file, "wb") as f: + pickle.dump(records, f) + with open(meta_file, "w", encoding="utf-8") as f: + json.dump(records[:2], f, ensure_ascii=False, indent=2) + + print(f"โœ… Saved {len(records)} samples to {baseline_file}") + print(f"โœ… Meta preview saved to {meta_file}") + + if prompt_lengths: + avg_prompt_length = sum(prompt_lengths) / len(prompt_lengths) + print(f"๐Ÿ“Š Average prompt length: {avg_prompt_length:.2f} tokens") + + finally: + engine.shutdown() + torch.cuda.empty_cache() class TestLogprobsDense(unittest.TestCase): @@ -48,6 +216,8 @@ class TestLogprobsDense(unittest.TestCase): cls.engine = sgl.Engine( model_path=DENSE_MODEL_NAME, random_seed=42, + attention_backend="flashinfer", + enable_deterministic_inference=True, skip_tokenizer_init=True, mem_fraction_static=0.80, ) @@ -58,31 +228,24 @@ class TestLogprobsDense(unittest.TestCase): cls.engine.shutdown() torch.cuda.empty_cache() - def load_test_data(self): - """Load test data from Hugging Face dataset with retry mechanism.""" - print(f"Loading data from {DENSE_INPUT_PKL_URL}...") + def load_test_data(self, baseline_file=None): + """Load test data from local baseline file. In test mode, only local baseline is supported.""" + if not baseline_file: + raise ValueError("baseline_file is required in test mode") - for attempt in range(MAX_RETRIES): - try: - response = requests.get(DENSE_INPUT_PKL_URL, timeout=30) - response.raise_for_status() + if not os.path.exists(baseline_file): + raise FileNotFoundError( + f"Baseline file not found: {baseline_file}. Please run 'gen' mode first to generate the baseline." + ) - with io.BytesIO(response.content) as f: - records = pickle.load(f) - - if not records: - raise ValueError("Empty dataset") - - print(f"Successfully loaded {len(records)} records") - return records - - except Exception as e: - print(f"Attempt {attempt + 1}/{MAX_RETRIES} failed: {e}") - if attempt == MAX_RETRIES - 1: - raise Exception( - f"Failed to load data after {MAX_RETRIES} attempts: {e}" - ) - time.sleep(RETRY_DELAY) + print(f"Loading local baseline from {baseline_file}...") + try: + with open(baseline_file, "rb") as f: + records = pickle.load(f) + print(f"Successfully loaded {len(records)} records from local baseline") + return records + except (IOError, pickle.PickleError) as e: + raise Exception(f"Failed to load local baseline: {e}") from e def compare_meta(self, baseline_meta, sglang_meta): """Compare metadata between two outputs and return max and mean differences.""" @@ -102,19 +265,21 @@ class TestLogprobsDense(unittest.TestCase): common_tokens = baseline_token_map.keys() & sglang_token_map.keys() self.assertGreaterEqual( len(common_tokens), - TOP_K / 2, + TOP_K, f"there are only {len(common_tokens)} common topk tokens that matches", ) for token_id in common_tokens: diffs.append( abs(baseline_token_map[token_id] - sglang_token_map[token_id]) ) + if not diffs: + return 0.0, 0.0 return max(diffs), float(np.mean(diffs)) - def test_logprobs_comparison(self): + def test_logprobs_comparison(self, baseline_file=None): """Test the logprobs comparison functionality with different parameter combinations.""" # Load test data with retry mechanism - records = self.load_test_data() + records = self.load_test_data(baseline_file) with self.subTest( config={ @@ -224,15 +389,6 @@ class TestLogprobsDense(unittest.TestCase): f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}", ) - # Write results to GitHub summary - summary_content = f""" -- **Configuration**: {{"num_samples": {NUM_SAMPLES}, "logprob_sample_ratio": {LOGPROB_SAMPLE_RATIO}, "temperature": {TEMPERATURE}}} -- **Max of max ฮ”**: {max_of_max:.6g} -- **Mean of mean ฮ”**: {mean_of_mean:.6g} -- **Status**: {'โœ… Passed' if max_of_max <= DENSE_TOLERANCE_MAX_DIFF and mean_of_mean <= DENSE_TOLERANCE_MEAN_DIFF else 'โŒ Failed'} -""" - write_github_step_summary(summary_content) - # Basic validation self.assertIsInstance(all_max, list) self.assertIsInstance(all_mean, list) @@ -261,5 +417,52 @@ class TestLogprobsDense(unittest.TestCase): ) +def main(): + """Main function to handle command line arguments and run either generation or testing.""" + parser = argparse.ArgumentParser( + description="SGLang Logprobs Test and Baseline Generation" + ) + parser.add_argument( + "mode", + choices=["gen", "test"], + help="Mode to run: 'gen' to generate baseline, 'test' to run tests", + ) + + args = parser.parse_args() + + if args.mode == "gen": + print("๐Ÿš€ Generating baseline...") + generate_baseline() + print(f"\nโœ… Baseline generation complete!") + print(f"๐Ÿ“ Baseline saved to: {DEFAULT_BASELINE_PKL}") + print(f"๐Ÿ“ Metadata preview saved to: {DEFAULT_META_JSON}") + print(f"\n๐Ÿ’ก Next steps:") + print(f" 1. Make your code changes") + print(f" 2. Run: python {__file__} test") + + elif args.mode == "test": + print("๐Ÿงช Running logprobs test...") + if not os.path.exists(DEFAULT_BASELINE_PKL): + print(f"โŒ Baseline file not found: {DEFAULT_BASELINE_PKL}") + print(f"๐Ÿ’ก Generate baseline first by running:") + print(f" python {__file__} gen") + print(f" This will download ShareGPT data and generate a local baseline.") + return 1 + + # Set environment variable for testing + os.environ["RETURN_ORIGINAL_LOGPROB"] = "True" + + # Create test instance and run + test_instance = TestLogprobsDense() + test_instance.setUpClass() + try: + test_instance.test_logprobs_comparison(baseline_file=DEFAULT_BASELINE_PKL) + print("\nโœ… Test completed successfully!") + finally: + test_instance.tearDownClass() + + return 0 + + if __name__ == "__main__": - unittest.main() + exit(main())