From 1e495e08470b6dc56645081f644831e0c620dfa5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 3 Sep 2024 06:31:45 -0700 Subject: [PATCH] [Fix] Fix select by ensuring each request has at least one token (#1318) --- python/sglang/srt/managers/schedule_batch.py | 9 ++- python/sglang/test/test_programs.py | 68 ++++++++++++++++++++ python/sglang/utils.py | 39 +++++++++++ test/lang/test_srt_backend.py | 7 ++ 4 files changed, 120 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f5b9c9eb2..c80cf2e27 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -178,19 +178,22 @@ class Req: def adjust_max_prefix_ids(self): self.fill_ids = self.origin_input_ids + self.output_ids input_len = len(self.fill_ids) - max_prefix_len = input_len + + # FIXME: To work around some bugs in logprob computation, we need to ensure each + # request has at least one token. Later, we can relax this requirement and use `input_len`. + max_prefix_len = input_len - 1 if self.sampling_params.max_new_tokens > 0: # Need at least one token to compute logits max_prefix_len = min(max_prefix_len, input_len - 1) if self.return_logprob: - max_prefix_len = min(max_prefix_len, self.logprob_start_len) - if self.normalized_prompt_logprob is None: # Need at least two tokens to compute normalized logprob max_prefix_len = min(max_prefix_len, input_len - 2) + max_prefix_len = min(max_prefix_len, self.logprob_start_len) + max_prefix_len = max(max_prefix_len, 0) return self.fill_ids[:max_prefix_len] # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index ce4025585..bdecdff2f 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -2,8 +2,12 @@ import json import re +import time + +import numpy as np import sglang as sgl +from sglang.utils import fetch_and_cache_jsonl def test_few_shot_qa(): @@ -447,3 +451,67 @@ def test_chat_completion_speculative(): ) gen_character_spec().sync() + + +def test_hellaswag_select(): + """Benchmark the accuracy of sgl.select on the HellaSwag dataset.""" + + url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" + lines = fetch_and_cache_jsonl(url) + + # Construct prompts + def get_one_example(lines, i, include_answer): + ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " " + if include_answer: + ret += lines[i]["endings"][lines[i]["label"]] + return ret + + def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + num_questions = 200 + num_shots = 20 + few_shot_examples = get_few_shot_examples(lines, num_shots) + + questions = [] + choices = [] + labels = [] + for i in range(len(lines[:num_questions])): + questions.append(get_one_example(lines, i, False)) + choices.append(lines[i]["endings"]) + labels.append(lines[i]["label"]) + arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)] + + ##################################### + ######### SGL Program Begin ######### + ##################################### + + import sglang as sgl + + @sgl.function + def few_shot_hellaswag(s, question, choices): + s += few_shot_examples + question + s += sgl.select("answer", choices=choices) + + ##################################### + ########## SGL Program End ########## + ##################################### + + # Run requests + tic = time.time() + rets = few_shot_hellaswag.run_batch( + arguments, + temperature=0, + num_threads=64, + progress_bar=True, + ) + preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))] + latency = time.time() - tic + + # Compute accuracy + accuracy = np.mean(np.array(preds) == np.array(labels)) + + return accuracy, latency diff --git a/python/sglang/utils.py b/python/sglang/utils.py index c880d259d..b212f6caa 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -4,6 +4,7 @@ import base64 import importlib import json import logging +import os import signal import sys import traceback @@ -15,6 +16,7 @@ from typing import Union import numpy as np import requests +from tqdm import tqdm logger = logging.getLogger(__name__) @@ -260,3 +262,40 @@ class LazyImport: def __call__(self, *args, **kwargs): module = self._load() return module(*args, **kwargs) + + +def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"): + """Read and cache a jsonl file from a url.""" + + # Check if the cache file already exists + if os.path.exists(cache_file): + print("Loading data from cache...") + with open(cache_file, "r") as f: + data = [json.loads(line) for line in f] + else: + print("Downloading data from URL...") + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(cache_file, "wb") as f, tqdm( + desc=cache_file, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + + # Convert the data to a list of dictionaries + with open(cache_file, "r") as f: + data = [json.loads(line) for line in f] + + return data diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index fcd86ae3d..62c595928 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -7,6 +7,7 @@ from sglang.test.test_programs import ( test_dtype_gen, test_expert_answer, test_few_shot_qa, + test_hellaswag_select, test_mt_bench, test_parallel_decoding, test_regex, @@ -62,6 +63,12 @@ class TestSRTBackend(unittest.TestCase): def test_dtype_gen(self): test_dtype_gen() + def test_hellaswag_select(self): + # Run twice to capture more bugs + for _ in range(2): + accuracy, latency = test_hellaswag_select() + assert accuracy > 0.71 + if __name__ == "__main__": unittest.main()