diff --git a/.gitignore b/.gitignore index 68bc17f9f..49e810d4e 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,23 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ + +# MacOS +.DS_Store +*.json + +# Vim +*.swp + +# SGL +benchmark/mmlu/data +benchmark/mmlu/data.tar +benchmark/llava_bench/images +benchmark/llava_bench/mme_pack +*.jsonl +tmp*.txt + +# Plots +*.png +*.pdf diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..1e0e96768 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "3rdparty/flashinfer"] + path = 3rdparty/flashinfer + url = git@github.com:flashinfer-ai/flashinfer.git diff --git a/README.md b/README.md index e4af36992..fc666595d 100644 --- a/README.md +++ b/README.md @@ -1 +1,167 @@ -# sglang \ No newline at end of file +# SGLang + +SGLang is a structured generation language designed for large language models (LLMs). +It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system. + +The core features of SGLang include: +- **A Flexible Front-End Language**: This allows for easy programming of LLM applications with multiple chained generation calls, advanced prompting techniques, control flow, multiple modalities, parallelism, and external interaction. +- **A High-Performance Runtime with RadixAttention**: This feature significantly accelerates the execution of complex LLM programs by automatic KV cache reuse across multiple calls. It also supports other common techniques like continuous batching and tensor parallelism. + +## Contents +- [Install](#install) +- [Quick Start](#quick-start) +- [Frontend: Structured Generation Langauge (SGLang)](#frontend-structured-generation-langauge-sglang) +- [Backend: SGLang Runtime (SRT)](#backend-sglang-runtime-srt) +- [Benchmark And Performance](#benchmark-and-performance) +- [Roadmap](#roadmap) +- [Citation And Acknowledgment](#citation-and-acknowledgment) + +## Install + +### Method 1: With Pip + +### Method 2: From Source +``` +git clone git@github.com:sgl-project/sglang.git +cd sglang + +pip install --upgrade pip +pip install -e "python[all]" +``` + +## Quick Start +The example below shows how to use sglang to answer a mulit-turn question. + +### Using OpenAI Models +```python +from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI + +@function +def multi_turn_question(s, question_1, question_2): + s += system("You are a helpful assistant.") + s += user(question_1) + s += assistant(gen("answer_1", max_tokens=256)) + s += user(question_2) + s += assistant(gen("answer_2", max_tokens=256)) + +set_default_backend(OpenAI("gpt-3.5-turbo")) + +state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", +) + +for m in state.messages(): + print(m["role"], ":", m["content"]) +``` + +### Using Local Models +First, launch a server with +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +Then, connect to the server and answer a multi-turn question. + +```python +from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint + +@function +def multi_turn_question(s, question_1, question_2): + s += system("You are a helpful assistant.") + s += user(question_1) + s += assistant(gen("answer_1", max_tokens=256)) + s += user(question_2) + s += assistant(gen("answer_2", max_tokens=256)) + +set_default_backend(RuntimeEndpoint("http://localhost:30000")) + +state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", +) + +for m in state.messages(): + print(m["role"], ":", m["content"]) +``` + +### More Examples + +You can find more examples at [examples/quick_start](examples/quick_start). + +## Frontend: Structured Generation Langauge (SGLang) + +### Control Flow + +### Parallelism + +### Multi Modality +```python +@sgl.function +def multi_turn_question(s, image_file, question): + s += sgl.user(sgl.image(image_file) + question) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) +``` + +### Batching + +### Streaming + +### Other Backends + +## Backend: SGLang Runtime (SRT) +The SGLang Runtime (SRT) is designed to work best with the SGLang frontend. +However, it can also be used as a standalone API server. +In this case, the RadixAttention can still greatly accelerate many use cases. + +### Usage +Launch a server +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +Send a request +``` +curl http://localhost:30000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "Say this is a test", + "max_tokens": 16, + "temperature": 0 + }' +``` + +### Additional Arguments +- Add `--tp 2` to enable tensor parallelism. +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2 +``` + +### Supported Models +- Llama +- Mistral +- Mixtral +- LLaVA + +## Benchmark And Performance + +## Roadmap +- [ ] Function call +- [ ] Constrained decoding +- [ ] Quantization +- [ ] S-LoRA +- [ ] More models + +## Citation And Acknowledgment +``` +@misc{zheng2023efficiently, + title={Efficiently Programming Large Language Models using SGLang}, + author={Lianmin Zheng and Liangsheng Yin and Zhiqiang Xie and Jeff Huang and Chuyue Sun and Cody Hao Yu and Shiyi Cao and Christos Kozyrakis and Ion Stoica and Joseph E. Gonzalez and Clark Barrett and Ying Sheng}, + year={2023}, + eprint={2312.07104}, + archivePrefix={arXiv}, + primaryClass={cs.AI} +} +``` + +We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [LMQL](https://github.com/eth-sri/lmql). diff --git a/benchmark/dspy/README.md b/benchmark/dspy/README.md new file mode 100644 index 000000000..aeb829697 --- /dev/null +++ b/benchmark/dspy/README.md @@ -0,0 +1,45 @@ +## Install + +``` +pip3 install dspy-ai +``` + +Turn off cache at https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/dsp/modules/cache_utils.py#L10. +``` +cache_turn_on = False +``` + +## Benchmark SGLang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_dspy_intro.py --backend sglang +``` + + +## Benchmark TGI +``` +docker run --name tgi --rm -ti --gpus all --network host \ + -v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \ + ghcr.io/huggingface/text-generation-inference:1.1.0 \ + --model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \ + --max-input-length 2048 --max-total-tokens 4096 \ + --port 24000 +``` + +``` +python3 bench_dspy_intro.py --backend tgi +``` + + + +## Benchmark vLLM +``` +python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_dspy_intro.py --backend vllm +``` diff --git a/benchmark/dspy/bench_dspy_intro.py b/benchmark/dspy/bench_dspy_intro.py new file mode 100644 index 000000000..76606330a --- /dev/null +++ b/benchmark/dspy/bench_dspy_intro.py @@ -0,0 +1,165 @@ +""" +Adapted from +https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9 +""" +import argparse + +import dspy +from dspy.datasets import HotPotQA + + +class BasicQA(dspy.Signature): + """Answer questions with short factoid answers.""" + + question = dspy.InputField() + answer = dspy.OutputField(desc="often between 1 and 5 words") + + +class GenerateAnswer(dspy.Signature): + """Answer questions with short factoid answers.""" + + context = dspy.InputField(desc="may contain relevant facts") + question = dspy.InputField() + answer = dspy.OutputField(desc="often between 1 and 5 words") + + +class RAG(dspy.Module): + def __init__(self, num_passages=3): + super().__init__() + + self.retrieve = dspy.Retrieve(k=num_passages) + self.generate_answer = dspy.ChainOfThought(GenerateAnswer) + + def forward(self, question): + context = self.retrieve(question).passages + prediction = self.generate_answer(context=context, question=question) + return dspy.Prediction(context=context, answer=prediction.answer) + + +def main(args): + #lm = dspy.OpenAI(model='gpt-3.5-turbo') + if args.backend == "tgi": + lm = dspy.HFClientTGI(model="meta-llama/Llama-2-7b-chat-hf", port=args.port, + url="http://localhost") + elif args.backend == "sglang": + lm = dspy.HFClientSGLang(model="meta-llama/Llama-2-7b-chat-hf", port=args.port, + url="http://localhost") + elif args.backend == "vllm": + lm = dspy.HFClientVLLM(model="meta-llama/Llama-2-7b-chat-hf", port=args.port, + url="http://localhost") + else: + raise ValueError(f"Invalid backend: {args.backend}") + + colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') + dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts) + + # Load the dataset. + dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size, + test_size=0) + + # Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata. + trainset = [x.with_inputs('question') for x in dataset.train] + devset = [x.with_inputs('question') for x in dataset.dev] + + print(len(trainset), len(devset)) + + train_example = trainset[0] + print(f"Question: {train_example.question}") + print(f"Answer: {train_example.answer}") + + dev_example = devset[18] + print(f"Question: {dev_example.question}") + print(f"Answer: {dev_example.answer}") + print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}") + + print(f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}") + print(f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}") + + # Define the predictor. + generate_answer = dspy.Predict(BasicQA) + + # Call the predictor on a particular input. + pred = generate_answer(question=dev_example.question) + + # Print the input and the prediction. + print(f"Question: {dev_example.question}") + print(f"Predicted Answer: {pred.answer}") + + lm.inspect_history(n=1) + + # Define the predictor. Notice we're just changing the class. The signature BasicQA is unchanged. + generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA) + + # Call the predictor on the same input. + pred = generate_answer_with_chain_of_thought(question=dev_example.question) + + # Print the input, the chain of thought, and the prediction. + print(f"Question: {dev_example.question}") + print(f"Thought: {pred.rationale.split('.', 1)[1].strip()}") + print(f"Predicted Answer: {pred.answer}") + + retrieve = dspy.Retrieve(k=3) + topK_passages = retrieve(dev_example.question).passages + + print(f"Top {retrieve.k} passages for question: {dev_example.question} \n", '-' * 30, '\n') + + for idx, passage in enumerate(topK_passages): + print(f'{idx+1}]', passage, '\n') + + retrieve("When was the first FIFA World Cup held?").passages[0] + + from dspy.teleprompt import BootstrapFewShot + + # Validation logic: check that the predicted answer is correct. + # Also check that the retrieved context does actually contain that answer. + def validate_context_and_answer(example, pred, trace=None): + answer_EM = dspy.evaluate.answer_exact_match(example, pred) + answer_PM = dspy.evaluate.answer_passage_match(example, pred) + return answer_EM and answer_PM + + # Set up a basic teleprompter, which will compile our RAG program. + teleprompter = BootstrapFewShot(metric=validate_context_and_answer) + + # Compile! + compiled_rag = teleprompter.compile(RAG(), trainset=trainset) + + # Ask any question you like to this simple RAG program. + my_question = "What castle did David Gregory inherit?" + + # Get the prediction. This contains `pred.context` and `pred.answer`. + pred = compiled_rag(my_question) + + # Print the contexts and the answer. + print(f"Question: {my_question}") + print(f"Predicted Answer: {pred.answer}") + print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}") + + from dspy.evaluate.evaluate import Evaluate + + # Set up the `evaluate_on_hotpotqa` function. We'll use this many times below. + evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=args.num_threads, display_progress=True, display_table=5) + + # Evaluate the `compiled_rag` program with the `answer_exact_match` metric. + metric = dspy.evaluate.answer_exact_match + evaluate_on_hotpotqa(compiled_rag, metric=metric) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int) + parser.add_argument("--num-threads", type=int, default=32) + parser.add_argument("--dev-size", type=int, default=150) + parser.add_argument("--backend", type=str, choices=["sglang", "tgi", "vllm"], + default="sglang") + args = parser.parse_args() + + if args.port is None: + default_port = { + "vllm": 21000, + "lightllm": 22000, + "tgi": 24000, + "sglang": 30000, + } + args.port = default_port.get(args.backend, None) + + main(args) diff --git a/benchmark/generative_agents/README.md b/benchmark/generative_agents/README.md new file mode 100644 index 000000000..190dd5093 --- /dev/null +++ b/benchmark/generative_agents/README.md @@ -0,0 +1,26 @@ +## Run benchmark + +Ensure that this benchmark is run in a serial manner (using --parallel 1) to preserve any potential dependencies between requests. + +### Benchmark sglang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-events 1000 --parallel 1 +``` + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_other.py --num-events 1000 --backend vllm --parallel 1 +``` + +### Benchmark guidance +``` +python3 bench_other.py --num-events 1000 --backend guidance --parallel 1 +``` diff --git a/benchmark/generative_agents/agent_functions.py b/benchmark/generative_agents/agent_functions.py new file mode 100644 index 000000000..90d05cfd0 --- /dev/null +++ b/benchmark/generative_agents/agent_functions.py @@ -0,0 +1,231 @@ +import sglang as sgl + +# here are the top five agent functions contributing ~70% LLM calls +# reference: https://github.com/joonspk-research/generative_agents/ + + +@sgl.function +def poignancy_event(s, persona_name, persona_iss, event): + s += "Here is a brief description of " + persona_name + ".\n" + s += persona_iss + "\n" + s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for" + s += persona_name + ".\n\n" + s += "Event: " + event + s += "Rate (return a number between 1 to 10):" + s += sgl.gen(name="Rate", max_tokens=2) + + +def poignancy_event_prompt(persona_name, persona_iss, event): + # return prompt and max_tokens + s = "" + s += "Here is a brief description of " + persona_name + ".\n" + s += persona_iss + "\n" + s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for" + s += persona_name + ".\n\n" + s += "Event: " + event + s += "Rate (return a number between 1 to 10):" + return {"prompt": s, "max_tokens": 2, "stop": None} + + +@sgl.function +def generate_event_triple(s, persona_name, action): + s += """Task: Turn the input into (subject, predicate, object). +Input: Sam Johnson is eating breakfast. +Output: (Dolores Murphy, eat, breakfast) +--- +Input: Joon Park is brewing coffee. +Output: (Joon Park, brew, coffee) +--- +Input: Jane Cook is sleeping. +Output: (Jane Cook, is, sleep) +--- +Input: Michael Bernstein is writing email on a computer. +Output: (Michael Bernstein, write, email) +--- +Input: Percy Liang is teaching students in a classroom. +Output: (Percy Liang, teach, students) +--- +Input: Merrie Morris is running on a treadmill. +Output: (Merrie Morris, run, treadmill) +---""" + s += persona_name + "is" + action + ".\n" + s += "(" + persona_name + "," + s += sgl.gen(name="Triple", max_tokens=20, stop=")") + + +def generate_event_triple_prompt(persona_name, action): + s = "" + s += """Task: Turn the input into (subject, predicate, object). +Input: Sam Johnson is eating breakfast. +Output: (Dolores Murphy, eat, breakfast) +--- +Input: Joon Park is brewing coffee. +Output: (Joon Park, brew, coffee) +--- +Input: Jane Cook is sleeping. +Output: (Jane Cook, is, sleep) +--- +Input: Michael Bernstein is writing email on a computer. +Output: (Michael Bernstein, write, email) +--- +Input: Percy Liang is teaching students in a classroom. +Output: (Percy Liang, teach, students) +--- +Input: Merrie Morris is running on a treadmill. +Output: (Merrie Morris, run, treadmill) +---""" + s += persona_name + "is" + action + ".\n" + s += "(" + persona_name + "," + return {"prompt": s, "max_tokens": 20, "stop": ")"} + + +@sgl.function +def generate_pronunciatio(s, action): + s += "Convert an action description to an emoji (important: use two or less emojis).\n" + s += "Action description: " + action + ".\n" + s += "Emoji:" + sgl.gen(name="Emoji", max_tokens=6) + + +def generate_pronunciatio_prompt(action): + s = "" + s += "Convert an action description to an emoji (important: use two or less emojis).\n" + s += "Action description: " + action + ".\n" + s += "Emoji:" + return {"prompt": s, "max_tokens": 6, "stop": None} + + +@sgl.function +def action_location_sector( + s, + persona_name, + living_sector, + living_sector_areas, + current_sector, + current_sector_areas, + daily_plan, + sector_options, + current_action, + next_action, +): + s += """Task -- choose an appropriate area from the area options for a task at hand. +Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen. +Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen. +Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}. +* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place. +* Must be one of the "Area options," verbatim. +For taking a walk, Sam Kim should go to the following area: {Johnson Park} +--- +Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room. +Jane Anderson is currently in {Oak Hill College} that has a classroom, library +Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}. +* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place. +* Must be one of the "Area options," verbatim. +For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe} +---""" + s += (persona_name + " lives in " + living_sector + " that has " + + living_sector_areas + ".\n") + s += (persona_name + " is currently in " + current_sector + " that has " + + current_sector_areas + ".\n") + s += daily_plan + ".\n" + s += "Area options: " + sector_options + ".\n" + s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place. +* Must be one of the "Area options," verbatim.\n""" + s += (persona_name + " is " + current_action + ". For " + next_action + + ", " + persona_name + " should go to the following area: {") + s += sgl.gen(name="Location", max_tokens=10, stop="}") + + +def action_location_sector_prompt( + persona_name, + living_sector, + living_sector_areas, + current_sector, + current_sector_areas, + daily_plan, + sector_options, + current_action, + next_action, +): + s = "" + s += """Task -- choose an appropriate area from the area options for a task at hand. +Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen. +Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen. +Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}. +* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place. +* Must be one of the "Area options," verbatim. +For taking a walk, Sam Kim should go to the following area: {Johnson Park} +--- +Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room. +Jane Anderson is currently in {Oak Hill College} that has a classroom, library +Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}. +* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place. +* Must be one of the "Area options," verbatim. +For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe} +---""" + s += (persona_name + " lives in " + living_sector + " that has " + + living_sector_areas + ".\n") + s += (persona_name + " is currently in " + current_sector + " that has " + + current_sector_areas + ".\n") + s += daily_plan + ".\n" + s += "Area options: " + sector_options + ".\n" + s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place. +* Must be one of the "Area options," verbatim.\n""" + s += (persona_name + " is " + current_action + ". For " + next_action + + ", " + persona_name + " should go to the following area: {") + return {"prompt": s, "max_tokens": 10, "stop": "}"} + + +@sgl.function +def action_location_object(s, persona_name, target_sector, target_sector_areas, + current_action, next_action): + s += """ +Jane Anderson is in kitchen in Jane Anderson's house. +Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom} +Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary. +For cooking, Jane Anderson should go to the following area in Jane Anderson's house: +Answer: {kitchen} +--- +Tom Watson is in common room in Tom Watson's apartment. +Tom Watson is going to Hobbs Cafe that has the following areas: {cafe} +Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary. +For getting coffee, Tom Watson should go to the following area in Hobbs Cafe: +Answer: {cafe} +---""" + s += (persona_name + " is going to " + target_sector + + " that has the following areas: {" + target_sector_areas + "}\n") + s += """* Stay in the current area if the activity can be done there. +* NEVER go into other people's rooms unless necessary.""" + s += (persona_name + " is " + current_action + ". For " + next_action + + ", " + persona_name + "should go to the following area in " + + target_sector) + s += " (MUST pick one of {" + target_sector_areas + "}):\n" + s += "Answer: {" + sgl.gen(name="Area", max_tokens=5, stop="}") + + +def action_location_object_prompt(persona_name, target_sector, + target_sector_areas, current_action, + next_action): + s = "" + s += """ +Jane Anderson is in kitchen in Jane Anderson's house. +Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom} +Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary. +For cooking, Jane Anderson should go to the following area in Jane Anderson's house: +Answer: {kitchen} +--- +Tom Watson is in common room in Tom Watson's apartment. +Tom Watson is going to Hobbs Cafe that has the following areas: {cafe} +Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary. +For getting coffee, Tom Watson should go to the following area in Hobbs Cafe: +Answer: {cafe} +---""" + s += (persona_name + " is going to " + target_sector + + " that has the following areas: {" + target_sector_areas + "}\n") + s += """* Stay in the current area if the activity can be done there. +* NEVER go into other people's rooms unless necessary.""" + s += (persona_name + " is " + current_action + ". For " + next_action + + ", " + persona_name + "should go to the following area in " + + target_sector) + s += " (MUST pick one of {" + target_sector_areas + "}):\n" + s += "Answer: {" + return {"prompt": s, "max_tokens": 5, "stop": "}"} diff --git a/benchmark/generative_agents/bench_other.py b/benchmark/generative_agents/bench_other.py new file mode 100644 index 000000000..7cf8d40b8 --- /dev/null +++ b/benchmark/generative_agents/bench_other.py @@ -0,0 +1,104 @@ +import argparse +from functools import partial +import json +import time +from pathlib import Path + +from tqdm import tqdm +from sglang.test.test_utils import ( + add_common_other_args_and_parse, + call_generate_lightllm, + call_generate_vllm, + call_generate_srt_raw, +) +from sglang.utils import read_jsonl, dump_state_text + +from agent_functions import ( + poignancy_event_prompt, + generate_event_triple_prompt, + generate_pronunciatio_prompt, + action_location_sector_prompt, + action_location_object_prompt, +) + + +def main(args): + lines = read_jsonl(args.data_path)[:args.num_events] + mapping = { + "poignancy_event": poignancy_event_prompt, + "generate_event_triple": generate_event_triple_prompt, + "generate_pronunciatio": generate_pronunciatio_prompt, + "action_location_sector": action_location_sector_prompt, + "action_location_object": action_location_object_prompt, + } + + arguments = [mapping[k](**v) for l in lines for k, v in l.items()] + states = [] + + # Select backend + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_lightllm, url=url) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_vllm, url=url) + elif args.backend == "srt-raw": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_srt_raw, url=url) + elif args.backend == "guidance": + from guidance import models, gen + + model = models.LlamaCpp( + str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf", + n_gpu_layers=-1, + n_ctx=4096, + ) + + def call_generate(prompt, temperature, max_tokens, stop): + out = model + prompt + gen( + name="result", + max_tokens=max_tokens, + temperature=temperature, + stop=stop, + ) + return out["result"] + + else: + raise ValueError(f"Invalid backend: {args.backend}") + + def get_one_answer(arg): + answer = call_generate(**arg, temperature=0) + states.append(answer) + + tic = time.time() + # we always sequentially execute agent calls to maintain its dependency + for arg in tqdm(arguments): + get_one_answer(arg) + latency = time.time() - tic + + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "Generative Agents", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + # to pack weighted functions as a single agent + "num_requests": len(arguments) / len(mapping), + "other": { + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="agent_calls.jsonl") + parser.add_argument("--num-events", type=int, default=10) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/generative_agents/bench_sglang.py b/benchmark/generative_agents/bench_sglang.py new file mode 100644 index 000000000..bf03c7686 --- /dev/null +++ b/benchmark/generative_agents/bench_sglang.py @@ -0,0 +1,74 @@ +import argparse +import json +import time + +import sglang as sgl +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) +from sglang.utils import read_jsonl, dump_state_text + +from agent_functions import ( + poignancy_event, + generate_event_triple, + generate_pronunciatio, + action_location_sector, + action_location_object, +) + + +def main(args): + lines = read_jsonl(args.data_path)[:args.num_events] + mapping = { + "poignancy_event": poignancy_event, + "generate_event_triple": generate_event_triple, + "generate_pronunciatio": generate_pronunciatio, + "action_location_sector": action_location_sector, + "action_location_object": action_location_object, + } + arguments = [{mapping[k]: v for k, v in l.items()} for l in lines] + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + states = [] + # Run requests + tic = time.time() + for a in arguments: + # only a single key in the dict + for func, arg in a.items(): + result = func.run(**arg) + result.sync() + states.append(result) + latency = time.time() - tic + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "Generative Agents", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + # to pack weighted functions as a single agent + "num_requests": len(arguments) / len(mapping), + "other": { + "num_events": args.num_events, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="agent_calls.jsonl") + parser.add_argument("--num-events", type=int, default=10) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/gsm8k/README.md b/benchmark/gsm8k/README.md new file mode 100644 index 000000000..ffd2bcf9d --- /dev/null +++ b/benchmark/gsm8k/README.md @@ -0,0 +1,52 @@ +## Download data +``` +wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl +``` + +## Run benchmark + +### Benchmark sglang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-questions 200 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_other.py --num-questions 200 --backend vllm +``` + + +### Benchmark lightllm +``` +# A10G +python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000 +``` + +``` +python3 bench_other.py --num-questions 200 --backend lightllm +``` + + +### Benchmark guidance +``` +python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 +``` + + +### Benchmark lmql +``` +CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000 +``` + +``` +python3 bench_other.py --num-questions 100 --backend lmql --parallel 2 +``` diff --git a/benchmark/gsm8k/bench_other.py b/benchmark/gsm8k/bench_other.py new file mode 100644 index 000000000..297b534ff --- /dev/null +++ b/benchmark/gsm8k/bench_other.py @@ -0,0 +1,168 @@ +import argparse +import ast +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import json +import re +import time + +import numpy as np +from tqdm import tqdm +from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw +from sglang.utils import read_jsonl, dump_state_text + + +INVALID = -9999999 + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r'\d+', answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def main(args): + lines = read_jsonl(args.data_path) + + # Construct prompts + k = args.num_shot + few_shot_examples = get_few_shot_examples(lines, k) + + questions = [] + labels = [] + for i in range(len(lines[:args.num_questions])): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + + states = [None] * len(labels) + + # Select backend + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_lightllm, url=url) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_vllm, url=url) + elif args.backend == "srt-raw": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_srt_raw, url=url) + elif args.backend == "guidance": + from guidance import models, gen + + model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) + + def call_generate(prompt, temperature, max_tokens, stop): + out = model + prompt + gen(name="answer", + max_tokens=max_tokens, temperature=temperature, stop=stop) + return out["answer"] + + elif args.backend == "lmql": + import lmql + model = lmql.model(args.model_path, + endpoint=f"{args.host}:{args.port}") + + @lmql.query(model=model) + async def program(question): + '''lmql + """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 257 and STOPS_AT(ANSWER, "Question") + return ANSWER + ''' + + async def call_generate(prompt, temperature, max_tokens, stop): + return await program(question=prompt, temperature=0) + + else: + raise ValueError(f"Invalid backend: {args.backend}") + + # Run requests + if args.backend != "lmql": + # Use thread pool + def get_one_answer(i): + answer = call_generate( + prompt=few_shot_examples + questions[i], + temperature=0, + max_tokens=256, + stop="Question") + states[i] = answer + + tic = time.time() + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(get_one_answer, list(range(len(questions)))) + else: + # Use asyncio + async def batched_call(batch_size): + for i in range(0, len(questions), batch_size): + tasks = [] + for q in questions[i:i+batch_size]: + tasks.append(call_generate(few_shot_examples + q, + temperature=0, max_tokens=256, stop="Question")) + rets = await asyncio.gather(*tasks) + for j in range(len(rets)): + states[i+j] = rets[j] + + tic = time.time() + asyncio.run(batched_call(batch_size=args.parallel)) + latency = time.time() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + print(f"Latency: {latency:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Accuracy: {acc:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "gsm8k", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-shot", type=int, default=5) + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=200) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py new file mode 100644 index 000000000..c5d76af31 --- /dev/null +++ b/benchmark/gsm8k/bench_sglang.py @@ -0,0 +1,115 @@ +import argparse +import ast +import json +import re +import time + +import numpy as np +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.utils import read_jsonl, dump_state_text + + +INVALID = -9999999 + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r'\d+', answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def main(args): + lines = read_jsonl(args.data_path) + + # Construct prompts + k = args.num_shot + few_shot_examples = get_few_shot_examples(lines, k) + + questions = [] + labels = [] + for i in range(len(lines[:args.num_questions])): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + arguments = [{"question": q} for q in questions] + + ##################################### + ######### SGL Program Begin ######### + ##################################### + + import sglang as sgl + + @sgl.function + def few_shot_gsm8k(s, question): + s += few_shot_examples + question + s += sgl.gen("answer", max_tokens=256, stop="Question") + + ##################################### + ########## SGL Program End ########## + ##################################### + + # Select backend + backend = select_sglang_backend(args) + + # Run requests + tic = time.time() + states = few_shot_gsm8k.run_batch( + arguments, temperature=0, backend=backend, num_threads=args.parallel) + latency = time.time() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i]["answer"])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + print(f"Latency: {latency:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Accuracy: {acc:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "gsm8k", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-shot", type=int, default=5) + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=200) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/hellaswag/README.md b/benchmark/hellaswag/README.md new file mode 100644 index 000000000..c2b7b2aa2 --- /dev/null +++ b/benchmark/hellaswag/README.md @@ -0,0 +1,52 @@ +## Download data +``` +wget https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl +``` + +## Run benchmark + +### Benchmark sglang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-questions 200 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_other.py --num-questions 200 --backend vllm +``` + + +### Benchmark lightllm +``` +# A10G +python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000 +``` + +``` +python3 bench_other.py --num-questions 200 --backend lightllm +``` + + +### Benchmark guidance +``` +CUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 +``` + + +### Benchmark lmql +``` +lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000 +``` + +``` +python3 bench_other.py --num-questions 200 --backend lmql --port 23000 --parallel 1 +``` diff --git a/benchmark/hellaswag/bench_other.py b/benchmark/hellaswag/bench_other.py new file mode 100644 index 000000000..bdab5cecc --- /dev/null +++ b/benchmark/hellaswag/bench_other.py @@ -0,0 +1,140 @@ +import argparse +import asyncio +from concurrent.futures import ThreadPoolExecutor +import json +from functools import partial +import time + +import numpy as np +from sglang.test.test_utils import add_common_other_args_and_parse, call_select_lightllm, call_select_vllm +from sglang.utils import read_jsonl + + +def get_one_example(lines, i, include_answer): + ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " " + if include_answer: + ret += lines[i]["endings"][lines[i]["label"]] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def main(args): + lines = read_jsonl(args.data_path) + + # Construct prompts + k = args.num_shot + few_shot_examples = get_few_shot_examples(lines, k) + + questions = [] + choices = [] + labels = [] + for i in range(len(lines[:args.num_questions])): + questions.append(get_one_example(lines, i, False)) + choices.append(lines[i]["endings"]) + labels.append(lines[i]["label"]) + + preds = [None] * len(labels) + + # Select backend + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + call_select = partial(call_select_lightllm, url=url) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + call_select = partial(call_select_vllm, url=url) + elif args.backend == "guidance": + from guidance import models, select + + model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) + + def call_select(context, choices): + out = model + context + select(choices, name="answer") + return choices.index(out["answer"]) + + elif args.backend == "lmql": + import lmql + model = lmql.model("meta-llama/Llama-2-7b-chat-hf", + endpoint=f"{args.host}:{args.port}") + + @lmql.query(model=model) + async def program(ctx, choices): + '''lmql + """{ctx}[ANSWER]""" where ANSWER in set(choices) + return ANSWER + ''' + + async def call_select(context, choices): + answer = await program(ctx=context, choices=choices, temperature=0) + return choices.index(answer) + + else: + raise ValueError(f"Invalid backend: {args.backend}") + + # Run requests + if args.backend != "lmql": + # Use thread pool + def get_one_answer(i): + preds[i] = call_select( + context=few_shot_examples + questions[i], + choices=choices[i]) + + tic = time.time() + if args.parallel == 1: + for i in range(len(questions)): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(get_one_answer, list(range(len(questions)))) + else: + # Use asyncio + async def batched_call(batch_size): + for i in range(0, len(questions), batch_size): + tasks = [] + for q, c in zip(questions[i:i+batch_size], choices[i:i+batch_size]): + tasks.append(call_select( + context=few_shot_examples + q, + choices=c)) + rets = await asyncio.gather(*tasks) + for j in range(len(rets)): + preds[i+j] = rets[j] + + tic = time.time() + asyncio.run(batched_call(batch_size=args.parallel)) + + latency = time.time() - tic + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + print(f"Latency: {latency:.3f}") + print(f"Accuracy: {acc:.3f}") + + # Write results + with open(args.result_file, "a") as fout: + value = { + "task": "hellaswag", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-shot", type=int, default=20) + parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl") + parser.add_argument("--num-questions", type=int, default=100) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/hellaswag/bench_sglang.py b/benchmark/hellaswag/bench_sglang.py new file mode 100644 index 000000000..c94386395 --- /dev/null +++ b/benchmark/hellaswag/bench_sglang.py @@ -0,0 +1,96 @@ +import argparse +import json +import time + +import numpy as np +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.utils import read_jsonl + + +def get_one_example(lines, i, include_answer): + ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " " + if include_answer: + ret += lines[i]["endings"][lines[i]["label"]] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def main(args): + lines = read_jsonl(args.data_path) + + # Construct prompts + k = args.num_shot + few_shot_examples = get_few_shot_examples(lines, k) + + questions = [] + choices = [] + labels = [] + for i in range(len(lines[:args.num_questions])): + questions.append(get_one_example(lines, i, False)) + choices.append(lines[i]["endings"]) + labels.append(lines[i]["label"]) + arguments = [ + {"question": q, "choices": c} + for q, c in zip(questions, choices) + ] + + ##################################### + ######### SGL Program Begin ######### + ##################################### + + import sglang as sgl + + @sgl.function + def few_shot_hellaswag(s, question, choices): + s += few_shot_examples + question + s += sgl.select("answer", choices=choices) + + ##################################### + ########## SGL Program End ########## + ##################################### + + # Select backend + backend = select_sglang_backend(args) + + # Run requests + tic = time.time() + rets = few_shot_hellaswag.run_batch( + arguments, temperature=0, backend=backend, num_threads=args.parallel) + preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))] + latency = time.time() - tic + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + print(f"Latency: {latency:.3f}") + print(f"Accuracy: {acc:.3f}") + + # Write results + with open(args.result_file, "a") as fout: + value = { + "task": "hellaswag", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-shot", type=int, default=20) + parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl") + parser.add_argument("--num-questions", type=int, default=100) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/latency_throughput/README.md b/benchmark/latency_throughput/README.md new file mode 100644 index 000000000..785b3289b --- /dev/null +++ b/benchmark/latency_throughput/README.md @@ -0,0 +1,46 @@ +### Download data +``` +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +``` + +### Performance + +- Model: Llama-2-7b-chat-hf +- `--num-prompts 2000 --request-rate 200` +- On 4 A10 (24G) GPUs + +| Backend | Throughput | Latency | +| ----------- | --------------- | -------- | +| srt | 5.82 requests/s | 343.54 s | +| vllm==0.2.6 | 3.93 requests/s | 509.08 s | +| vllm==0.2.7 | 5.02 requests/s | 398.25 s | + + +### SGLang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 --port 30000 +``` + + +### vLLM +``` +python3 -m vllm.entrypoints.api_server --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --swap-space 16 +``` + +``` +python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 +``` + + +### LightLLM +``` +python -m lightllm.server.api_server --model_dir ~/model_weights/Llama-2-7b-chat-hf --max_total_token_num 15600 --tokenizer_mode auto --port 22000 +``` + +``` +python3 bench_throughput.py --backend lightllm --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 --port 22000 +``` diff --git a/benchmark/latency_throughput/bench_throughput.py b/benchmark/latency_throughput/bench_throughput.py new file mode 100644 index 000000000..63c88dc48 --- /dev/null +++ b/benchmark/latency_throughput/bench_throughput.py @@ -0,0 +1,254 @@ +"""Benchmark online serving throughput. + +On the server side, run one of the following commands: + (vLLM backend) + python -m vllm.entrypoints.api_server \ + --model --swap-space 16 \ + --disable-log-requests + + (TGI backend) + ./launch_hf_server.sh + +On the client side, run: + python benchmarks/benchmark_serving.py \ + --backend \ + --tokenizer --dataset \ + --request-rate +""" +import argparse +import asyncio +import json +import random +import time +from typing import AsyncGenerator, List, Tuple +from tqdm.asyncio import tqdm_asyncio + +import aiohttp +import numpy as np +from transformers import PreTrainedTokenizerBase +from vllm.transformers_utils.tokenizer import get_tokenizer + +# (prompt len, output len, latency) +REQUEST_LATENCY: List[Tuple[int, int, float]] = [] + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, +) -> List[Tuple[str, int, int]]: + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [ + data for data in dataset + if len(data["conversations"]) >= 2 + ] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Tokenize the prompts and completions. + prompts = [prompt for prompt, _ in dataset] + prompt_token_ids = tokenizer(prompts).input_ids + completions = [completion for _, completion in dataset] + completion_token_ids = tokenizer(completions).input_ids + tokenized_dataset = [] + for i in range(len(dataset)): + output_len = len(completion_token_ids[i]) + tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + + # Filter out too long sequences. + filtered_dataset: List[Tuple[str, int, int]] = [] + for prompt, prompt_token_ids, output_len in tokenized_dataset: + prompt_len = len(prompt_token_ids) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + # This is because TGI causes errors when the input or output length + # is too short. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + # Sample the requests. + sampled_requests = random.sample(filtered_dataset, num_requests) + return sampled_requests + + +async def get_request( + input_requests: List[Tuple[str, int, int]], + request_rate: float, +) -> AsyncGenerator[Tuple[str, int, int], None]: + input_requests = iter(input_requests) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +async def send_request( + backend: str, + api_url: str, + prompt: str, + prompt_len: int, + output_len: int, + best_of: int, + use_beam_search: bool, +) -> None: + request_start_time = time.perf_counter() + + headers = {"User-Agent": "Benchmark Client"} + if backend == "vllm": + pload = { + "prompt": prompt, + "n": 1, + "best_of": best_of, + "use_beam_search": use_beam_search, + "temperature": 0.0 if use_beam_search else 1.0, + "top_p": 1.0, + "max_tokens": output_len, + "ignore_eos": True, + "stream": False, + } + elif backend == "tgi": + assert not use_beam_search + params = { + "best_of": best_of, + "max_new_tokens": output_len, + "do_sample": True, + } + pload = { + "inputs": prompt, + "parameters": params, + } + elif backend == "srt": + assert not use_beam_search + params = { + "ignore_eos": True, + "max_new_tokens": output_len, + } + pload = { + "text": prompt, + "sampling_params": params, + } + elif backend == "lightllm": + assert not use_beam_search + params = { + "ignore_eos": True, + "max_new_tokens": output_len, + } + pload = { + "inputs": prompt, + "parameters": params, + } + else: + raise ValueError(f"Unknown backend: {backend}") + + timeout = aiohttp.ClientTimeout(total=3 * 3600) + async with aiohttp.ClientSession(timeout=timeout) as session: + while True: + async with session.post(api_url, headers=headers, json=pload) as response: + chunks = [] + async for chunk, _ in response.content.iter_chunks(): + chunks.append(chunk) + output = b"".join(chunks).decode("utf-8") + output = json.loads(output) + + # Re-send the request if it failed. + if "error" not in output: + break + + request_end_time = time.perf_counter() + request_latency = request_end_time - request_start_time + REQUEST_LATENCY.append((prompt_len, output_len, request_latency)) + + +async def benchmark( + backend: str, + api_url: str, + input_requests: List[Tuple[str, int, int]], + best_of: int, + use_beam_search: bool, + request_rate: float, +) -> None: + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate): + prompt, prompt_len, output_len = request + task = asyncio.create_task(send_request(backend, api_url, prompt, + prompt_len, output_len, + best_of, use_beam_search)) + tasks.append(task) + await tqdm_asyncio.gather(*tasks) + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + api_url = f"http://{args.host}:{args.port}/generate" + tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) + input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) + + benchmark_start_time = time.perf_counter() + asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of, + args.use_beam_search, args.request_rate)) + benchmark_end_time = time.perf_counter() + benchmark_time = benchmark_end_time - benchmark_start_time + print(f"Total time: {benchmark_time:.2f} s") + print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s") + + # Compute the latency statistics. + avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY]) + print(f"Average latency: {avg_latency:.2f} s") + avg_per_token_latency = np.mean([ + latency / (prompt_len + output_len) + for prompt_len, output_len, latency in REQUEST_LATENCY + ]) + print(f"Average latency per token: {avg_per_token_latency:.2f} s") + avg_per_output_token_latency = np.mean([ + latency / output_len + for _, output_len, latency in REQUEST_LATENCY + ]) + print("Average latency per output token: " + f"{avg_per_output_token_latency:.2f} s") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark the online serving throughput.") + parser.add_argument("--backend", type=str, default="vllm", + choices=["vllm", "tgi", "srt", "lightllm"]) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--dataset", type=str, required=True, + help="Path to the dataset.") + parser.add_argument("--tokenizer", type=str, required=True, + help="Name or path of the tokenizer.") + parser.add_argument("--best-of", type=int, default=1, + help="Generates `best_of` sequences per prompt and " + "returns the best one.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--num-prompts", type=int, default=1000, + help="Number of prompts to process.") + parser.add_argument("--request-rate", type=float, default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize " + "the request arrival times.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument('--trust-remote-code', action='store_true', + help='trust remote code from huggingface') + args = parser.parse_args() + main(args) diff --git a/benchmark/latency_throughput/test_latency.py b/benchmark/latency_throughput/test_latency.py new file mode 100644 index 000000000..a58c98851 --- /dev/null +++ b/benchmark/latency_throughput/test_latency.py @@ -0,0 +1,66 @@ +import argparse +import random +import time + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=None) + parser.add_argument("--backend", type=str, default="srt") + args = parser.parse_args() + + if args.port is None: + if args.backend == "srt": + args.port = 30000 + elif args.backend == "vllm": + args.port = 21000 + elif args.backend == "lightllm": + args.port = 22000 + else: + raise ValueError(f"Invalid backend: {args.backend}") + + url = f"{args.host}:{args.port}" + a = random.randint(0, 1 << 20) + max_new_tokens = 256 + + tic = time.time() + if args.backend == "srt": + response = requests.post( + url + "/generate", + json={ + "text": f"{a}, ", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + }, + ) + elif args.backend == "lightllm": + response = requests.post( + url + "/generate", + json={ + "inputs": f"{a}, ", + "parameters": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + }, + ) + elif args.backend == "vllm": + response = requests.post( + url + "/generate", + json={ + "prompt": f"{a}, ", + "temperature": 0, + "max_tokens": max_new_tokens, + }, + ) + latency = time.time() - tic + + ret = response.json() + print(ret) + + speed = max_new_tokens / latency + print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s") diff --git a/benchmark/line_retrieval/README.md b/benchmark/line_retrieval/README.md new file mode 100644 index 000000000..3c81a63cb --- /dev/null +++ b/benchmark/line_retrieval/README.md @@ -0,0 +1,37 @@ +## Download data + +``` +wget https://raw.githubusercontent.com/merrymercy/merrymercy.github.io/master/files/random_words.json +python3 gen_data.py --number 1000 +``` + +## Run benchmark + +### Benchmark sglang +``` +python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-hf --port 30000 +``` + +``` +python3 bench_sglang.py --src-index 600 --num-q 50 --parallel 1 +``` + + +### + +``` +# original +Accuracy: 0.940, latency: 332.83 s + +# parallel encoding (no_adjust, offset = 1000) +Accuracy: 0.760, latency: 238.46 s + +# parallel encoding (no_adjust, offset = 3000) +Accuracy: 0.760, latency: 238.46 s + +# parallel encoding (no_adjust, offset = 0) +Accuracy: 0.520, latency: 238.46 s + +# parallel encoding (adjust_cache) +Accuracy: 0.460, latency: 257.66 s +``` diff --git a/benchmark/line_retrieval/bench_sglang.py b/benchmark/line_retrieval/bench_sglang.py new file mode 100644 index 000000000..5ac56a491 --- /dev/null +++ b/benchmark/line_retrieval/bench_sglang.py @@ -0,0 +1,133 @@ +import argparse +import json +import time +import re + +import numpy as np +import sglang as sgl +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.utils import dump_state_text + + +@sgl.function +def line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3): + s += prefix + "\n" + + contexts = [body_0, body_1, body_2, body_3] + position_ids_offset = [i * 1000 for i in range(len(contexts))] + forks = s.fork(len(contexts), position_ids_offset) + forks += lambda i: contexts[i] + "\n" + forks.join(mode="concate_and_append") + + s += "\n" + suffix + s += sgl.gen("answer", max_tokens=16) + + +def eval_model(args, line_obj, num_hoops, src_indices, dst_percents): + arguments = [] + labels = [] + sum_src_indices = [] + sum_dst_indices = [] + + for i in range(len(src_indices)): + for j in range(len(dst_percents)): + src_index = src_indices[i] + dst_percent = dst_percents[j] + + query_indices = line_obj["group_by_num_hoops"][str(num_hoops)] + query_indices = [q for q in query_indices if + all(l <= src_index for l in line_obj["links"][q]) and q < src_index] + dst_index = query_indices[min(int(len(query_indices) * dst_percent), len(query_indices)-1)] + label = line_obj["values"][dst_index] + + body = line_obj["lines"][:src_index+1] + suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index]) + body_part_len = len(body) // 4 + + arguments.append({ + "prefix": line_obj["prefix"], + "body_0": "\n".join(body[:body_part_len]), + "body_1": "\n".join(body[body_part_len: 2 * body_part_len]), + "body_2": "\n".join(body[2 * body_part_len: 3 * body_part_len]), + "body_3": "\n".join(body[3 * body_part_len:]), + "suffix": suffix, + }) + labels.append(label) + sum_src_indices.append(src_index) + sum_dst_indices.append(dst_index) + + # Select backend + backend = select_sglang_backend(args) + + tic = time.time() + states = line_retrieval.run_batch( + arguments, temperature=0, backend=backend, num_threads=args.parallel) + latency = time.time() - tic + + corrects = [] + for i in range(len(arguments)): + output = states[i]["answer"] + prompt_len = states[i].get_meta_info("answer").get("prompt_length", -1) + label = labels[i] + + # Try all numbers + findall = re.findall("\d+", output) + if not findall: + response_number = output + else: + for response_number in findall: + if response_number == label: + break + + correct = (response_number == label) + corrects.append(correct) + + # Log results + summary = ( + f"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, " + f"Prompt len: {prompt_len}, " + f"Correct: {correct}, " + f"Label: {label}, Predicted: {response_number}, " + ) + print(summary) + + accuracy = np.mean(corrects) + print(f"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "line_retrieval", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": len(arguments), + "other": { + "num_questions": len(arguments), + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +def main(args): + line_obj = json.load(open(args.data_path, "r")) + + num_hoops = args.num_hoops + for src_index in args.src_index: + src_indices = [src_index] + num_queries = args.num_queries_per_src + dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)] + eval_model(args, line_obj, num_hoops, src_indices, dst_percents) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="lines_1000_0.0.json") + parser.add_argument("--src-index", type=int, nargs="+", default=[100]) + parser.add_argument("--num-queries-per-src", type=int, default=10) + parser.add_argument("--num-hoops", type=int, default=1) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/line_retrieval/gen_data.py b/benchmark/line_retrieval/gen_data.py new file mode 100644 index 000000000..d5a189e31 --- /dev/null +++ b/benchmark/line_retrieval/gen_data.py @@ -0,0 +1,135 @@ +""" +Generate line data for line retrieval task. + +Usage: +python3 gen_data.py --number 1000 +""" +import argparse +from collections import defaultdict +import json + +from tqdm import tqdm +import numpy as np + + +def generate_lines(random_words, num_lines, redirect_ratio): + prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask." + suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resovling the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is" + + # Raw lines + visited_indices = set([None]) + visited_values = set([None]) + + lines = [] + redirects = [] + indices = [] + values = [] + for i in tqdm(range(num_lines)): + line_index = None + while line_index in visited_indices: + line_index = "-".join(np.random.choice(random_words, size=(2,))) + visited_indices.add(line_index) + + line_value = np.random.randint(low=0, high=999999) + line_value = f"{line_value:06}" + + line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}." + lines.append(line) + redirects.append(None) + indices.append(line_index) + values.append(line_value) + + # Add redirect + if redirect_ratio > 0: + num_redirect_lines = int(len(lines) * redirect_ratio) + redirect_indices = np.random.choice(np.arange(len(lines)), + size=(num_redirect_lines,), replace=False) + for i in redirect_indices: + target_idx = np.random.choice(min(i * 2 + 100, num_lines)) + lines[i] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." + redirects[i] = target_idx + + # Build links and find sources + links = [[] for _ in range(num_lines)] + contains_ring = set() + for i in range(num_lines): + if redirects[i] is None: + continue + + tmp_link = [] + cur = i + visited = set() + while redirects[cur] is not None: + visited.add(cur) + tmp_link.append(redirects[cur]) + cur = redirects[cur] + + if cur in visited: + contains_ring.add(i) + tmp_link = None + break + values[i] = values[cur] + links[i] = tmp_link + + # Group by num_links + group_by_num_hoops = defaultdict(list) + for i in range(num_lines): + if i in contains_ring: + continue + group_by_num_hoops[len(links[i]) + 1].append(i) + + keys = sorted(list(group_by_num_hoops.keys())) + for num_links in keys: + print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}") + + # Append few-shot examples + hoop1_candidates = list(group_by_num_hoops[1]) + hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates} + hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c]) + hoop2_candidates = list(group_by_num_hoops[2]) + hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates} + hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c]) + + i = hoop1_candidates[5] + suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i]) + if len(hoop2_candidates): + i = hoop2_candidates[0] + suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i]) + i = hoop2_candidates[1] + suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i]) + else: + i = hoop1_candidates[1] + suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i]) + i = hoop1_candidates[10] + suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i]) + + obj = { + "prefix": prefix, + "suffix": suffix, + "lines": lines, + "indices": indices, + "values": values, + "links": links, + "group_by_num_hoops": group_by_num_hoops, + "contains_ring": sorted(list(contains_ring)), + } + return obj + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--number", type=int) + parser.add_argument("--redirect-ratio", type=float, default=0.0) + args = parser.parse_args() + + num_lines = args.number + + random_words_filename = "random_words.json" + random_words = json.load(open(random_words_filename, "r")) + + np.random.seed(42) + obj = generate_lines(random_words, num_lines, args.redirect_ratio) + + fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json" + with open(fout, "w") as fout: + json.dump(obj, fout, indent=2) diff --git a/benchmark/llava_bench/README.md b/benchmark/llava_bench/README.md new file mode 100644 index 000000000..bae4df648 --- /dev/null +++ b/benchmark/llava_bench/README.md @@ -0,0 +1,60 @@ +## Download benchmark images + +``` +python3 download_images.py +``` + +image benchmark source: https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild + +### Other Dependency +``` +pip3 install "torch>=2.1.2" "transformers>=4.36" pillow +``` + +## Run benchmark + +### Benchmark sglang +Launch a server +``` +python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 +``` + +Run benchmark +``` +# Run with local models +python3 bench_sglang.py --num-questions 60 + +# Run with OpenAI models +python3 bench_sglang.py --num-questions 60 --backend gpt-4-vision-preview +``` + +### Bench LLaVA original code +``` +git clone git@github.com:haotian-liu/LLaVA.git +cd LLaVA +git reset --hard 9a26bd1435b4ac42c282757f2c16d34226575e96 +pip3 install -e . + +cd ~/sglang/benchmark/llava_bench +CUDA_VISIBLE_DEVICES=0 bash bench_hf_llava_bench.sh +``` + + +### Benchmark llama.cpp + +``` +# Install +CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python +pip install sse_starlette starlette_context pydantic_settings + +# Download weights +mkdir -p ~/model_weights/llava-v1.5-7b/ +wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/ggml-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf +wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/mmproj-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf +``` + +``` +python3 -m llama_cpp.server --model ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf --clip_model_path ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf --chat_format llava-1-5 --port 23000 + +OPENAI_BASE_URL=http://localhost:23000/v1 python3 bench_sglang.py --backend gpt-4-vision-preview --num-q 1 +``` diff --git a/benchmark/llava_bench/bench_hf_llava_bench.sh b/benchmark/llava_bench/bench_hf_llava_bench.sh new file mode 100644 index 000000000..a51a715b6 --- /dev/null +++ b/benchmark/llava_bench/bench_hf_llava_bench.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +python -m llava.eval.model_vqa \ + --model-path liuhaotian/llava-v1.5-7b \ + --question-file ./questions.jsonl \ + --image-folder ./images \ + --answers-file ./answers_hf.jsonl \ + --temperature 0 \ + --conv-mode vicuna_v1 diff --git a/benchmark/llava_bench/bench_hf_mme.sh b/benchmark/llava_bench/bench_hf_mme.sh new file mode 100644 index 000000000..6ed332fe0 --- /dev/null +++ b/benchmark/llava_bench/bench_hf_mme.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +python -m llava.eval.model_vqa_loader \ + --model-path liuhaotian/llava-v1.5-7b \ + --question-file ./mme_pack/llava_mme_bench_replace.jsonl \ + --image-folder ./mme_pack/MME_Benchmark_release_version \ + --answers-file ./answers_hf_mme.jsonl \ + --temperature 0 \ + --conv-mode vicuna_v1 diff --git a/benchmark/llava_bench/bench_sglang.py b/benchmark/llava_bench/bench_sglang.py new file mode 100644 index 000000000..d2c5d2aac --- /dev/null +++ b/benchmark/llava_bench/bench_sglang.py @@ -0,0 +1,96 @@ +import argparse +import json +import time +import os + +import sglang as sgl +import tqdm +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.utils import read_jsonl, dump_state_text +from PIL import Image + + +@sgl.function +def image_qa(s, image_file, question): + s += sgl.user(sgl.image(image_file) + question) + s += sgl.assistant(sgl.gen("answer", max_tokens=args.max_tokens)) + + +def main(args): + lines = read_jsonl(args.question_file)[:args.num_questions] + arguments = [ + {"image_file": + os.path.abspath(args.image_folder + "/" + l["image"]), + "question": l["text"]} for l in lines + ] + #arguments = [ + # {"image_file": + # Image.open(os.path.abspath(args.image_folder + "/" + l["image"])), + # "question": l["text"]} for l in lines + #] + + states = [None] * len(lines) + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + # Run requests + tic = time.time() + if args.parallel == 1: + for i in tqdm.tqdm(range(len(lines))): + image_file = arguments[i]["image_file"] + question = arguments[i]["question"] + ret = image_qa.run( + image_file=image_file, + question=question, + temperature=0) + states[i] = ret + else: + states = image_qa.run_batch( + arguments, + temperature=0, + num_threads=args.parallel, + progress_bar=True) + latency = time.time() - tic + + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + print(f"Write output to {args.answer_file}") + with open(args.answer_file, "w") as fout: + for i in range(len(lines)): + value = { + "question_id": lines[i]["question_id"], + "prompt": lines[i]["text"], + "text": states[i]["answer"].strip(), + "model_id": backend.model_info["model_path"], + "answer_id": i, + "metadata": {}, + } + fout.write(json.dumps(value) + "\n") + + with open(args.result_file, "a") as fout: + value = { + "task": "llava_bench", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": len(lines), + "parallel": args.parallel, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--question-file", type=str, default="questions.jsonl") + parser.add_argument("--answer-file", type=str, default="answers.jsonl") + parser.add_argument("--image-folder", type=str, default="./images") + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--num-questions", type=int, default=None) + parser.add_argument("--max-tokens", type=int, default=768) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/llava_bench/bench_sglang_mme.sh b/benchmark/llava_bench/bench_sglang_mme.sh new file mode 100644 index 000000000..8c8c09cc0 --- /dev/null +++ b/benchmark/llava_bench/bench_sglang_mme.sh @@ -0,0 +1,2 @@ +MME_FOLDER=./mme_pack +python3 bench_sglang.py --num-questions 5000 --question-file $MME_FOLDER/llava_mme_bench_replace.jsonl --answer-file answer_mme.jsonl --image-folder $MME_FOLDER/MME_Benchmark_release_version --max-tokens 4 diff --git a/benchmark/llava_bench/download_images.py b/benchmark/llava_bench/download_images.py new file mode 100644 index 000000000..ac865fe2d --- /dev/null +++ b/benchmark/llava_bench/download_images.py @@ -0,0 +1,20 @@ +import os + +# Create the 'images' directory if it doesn't exist +if not os.path.exists('images'): + os.makedirs('images') + +# Base URL +base_url = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/" + +# Loop through image numbers +for i in range(1, 25): + # Format the image number with leading zeros + image_number = str(i).zfill(3) + image_url = base_url + image_number + ".jpg" + image_path = "images/" + image_number + ".jpg" + + # Download the image using wget + os.system(f"wget -O {image_path} {image_url}") + +print("Download complete.") diff --git a/benchmark/llm_judge/README.md b/benchmark/llm_judge/README.md new file mode 100644 index 000000000..e4516bf10 --- /dev/null +++ b/benchmark/llm_judge/README.md @@ -0,0 +1,27 @@ +## Run benchmark + +### Benchmark sglang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-questions 25 --parallel 8 +python3 bench_sglang.py --num-questions 16 --parallel 1 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_other.py --backend vllm --num-questions 25 +``` + + +### Benchmark guidance +``` +python3 bench_other.py --backend guidance --num-questions 25 --parallel 1 +``` diff --git a/benchmark/llm_judge/bench_other.py b/benchmark/llm_judge/bench_other.py new file mode 100644 index 000000000..98d917bc7 --- /dev/null +++ b/benchmark/llm_judge/bench_other.py @@ -0,0 +1,120 @@ +import argparse +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import json +import time + +import numpy as np +from tqdm import tqdm +from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw +from sglang.utils import read_jsonl, dump_state_text + + +system_prompt = ( +"Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency." +) + +dimension_prompts = [ +"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.", +"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.", +"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.", +"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.", +"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.", +"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.", +] + + +def multi_dimension_judge(article, generate): + s = system_prompt + s += "\n```\n" + article + "\n```\n\n" + + judges = [] + for i in range(len(dimension_prompts)): + comp = generate(s + + "USER: Please judge the quality based on the following metric. " + + dimension_prompts[i] + " Please provide a single-paragraph judgement. " + + "Focus on the provided metric and do not say other things. " + 'End your judgement paragraph with the word "END"\nJUDGE:', + max_tokens=256, stop="END") + judges.append(comp) + + s += "I will judge the quality based on the following metrics.\n" + for i in range(len(dimension_prompts)): + s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n" + + s += "In summary, on a scale of 1 to 10, I would give the article a score of" + s += generate(s, max_tokens=2, stop=None) + + return s + + +def main(args): + lines = read_jsonl(args.data_path)[:args.num_questions] + states = [None] * len(lines) + + # Select backend + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_lightllm, url=url, temperature=0) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_vllm, url=url, temperature=0) + elif args.backend == "srt-raw": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_srt_raw, url=url, temperature=0) + elif args.backend == "guidance": + from guidance import models, gen + + model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) + + def generate(prompt, max_tokens, stop): + out = model + prompt + gen(name="answer", + max_tokens=max_tokens, temperature=0, stop=stop) + return out["answer"] + + # warmup + generate("Hello!", max_tokens=8, stop=None) + else: + raise ValueError(f"Invalid backend: {args.backend}") + + # Run requests + def get_one_answer(i): + states[i] = multi_dimension_judge(lines[i], generate) + + tic = time.time() + if args.parallel == 1: + for i in tqdm(range(len(lines))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(get_one_answer, list(range(len(lines)))) + latency = time.time() - tic + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "llm_judge", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="articles.jsonl") + parser.add_argument("--num-questions", type=int, default=20) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/llm_judge/bench_sglang.py b/benchmark/llm_judge/bench_sglang.py new file mode 100644 index 000000000..0cef33700 --- /dev/null +++ b/benchmark/llm_judge/bench_sglang.py @@ -0,0 +1,85 @@ +import argparse +import json +import time + +import numpy as np +import sglang as sgl +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.utils import read_jsonl, dump_state_text + + +system_prompt = ( +"Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency." +) + +dimension_prompts = [ +"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.", +"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.", +"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.", +"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.", +"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.", +"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.", +] + + +@sgl.function +def multi_dimension_judge(s, article): + s += system_prompt + s += "\n```\n" + article + "\n```\n\n" + + forks = s.fork(len(dimension_prompts)) + for i in range(len(dimension_prompts)): + forks[i] += ("USER: Please judge the quality based on the following metric. " + + dimension_prompts[i] + " Please provide a single-paragraph judgement. " + + "Focus on the provided metric and do not say other things. " + 'End your judgement paragraph with the word "END"\nJUDGE:') + forks[i] += sgl.gen("judgement", max_tokens=256, stop="END") + forks.join() + + s += "I will judge the quality based on the following metrics.\n" + for i in range(len(dimension_prompts)): + s += dimension_prompts[i].split(":")[0] + ": " + forks[i]["judgement"].strip() + "\n" + + s += "In summary, on a scale of 1 to 10, I would give the article a score of" + s += sgl.gen("score", max_tokens=2) + + +def main(args): + lines = read_jsonl(args.data_path)[:args.num_questions] + arguments = [{"article": l} for l in lines] + + # Select backend + backend = select_sglang_backend(args) + + # Run requests + tic = time.time() + states = multi_dimension_judge.run_batch( + arguments, temperature=0, backend=backend, num_threads=args.parallel) + latency = time.time() - tic + + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "llm_judge", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="articles.jsonl") + parser.add_argument("--num-questions", type=int, default=20) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/long_json_decode/README.md b/benchmark/long_json_decode/README.md new file mode 100644 index 000000000..6d52030a5 --- /dev/null +++ b/benchmark/long_json_decode/README.md @@ -0,0 +1,33 @@ +## Run benchmark + +### Benchmark sglang +``` +python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-questions 5 --parallel 1 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97 +``` + +``` +python3 bench_other.py --backend vllm --num-questions 5 +``` + + +### Benchmark guidance +``` +python3 bench_other.py --backend guidance --num-questions 5 --parallel 1 +``` + + +### Build dataset +``` +pip install wikipedia +python3 build_dataset.py +``` diff --git a/benchmark/long_json_decode/bench_other.py b/benchmark/long_json_decode/bench_other.py new file mode 100644 index 000000000..7fa76213b --- /dev/null +++ b/benchmark/long_json_decode/bench_other.py @@ -0,0 +1,104 @@ +import argparse +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import json +import time + +from tqdm import tqdm +import numpy as np +from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw +from sglang.utils import read_jsonl, dump_state_text + + +def json_decode(document, generate): + s = "Please extract the information of a city from the following wikipedia page.\n" + s += "Page begin.\n" + document + "Page end.\n" + s += "Here is the name, country, and symbol of the city in JSON format.\n" + s += '{\n' + s += ' "name": "' + s += generate(s, max_tokens=8, stop='"') + '",\n' + s += ' "country": "' + s += generate(s, max_tokens=8, stop='"') + '",\n' + s += ' "air port code": "' + s += generate(s, max_tokens=8, stop='"') + '",\n' + s += ' "top 3 landmarks": "' + s += generate(s, max_tokens=24, stop='"') + '",\n' + s += '}\n' + return s + + +def main(args): + lines = read_jsonl(args.data_path) + arguments = [] + for i in range(len(lines[:args.num_questions])): + arguments.append({ + "document": lines[i]["document"], + }) + states = [None] * len(arguments) + + # Select backend + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_lightllm, url=url, temperature=0) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_vllm, url=url, temperature=0) + elif args.backend == "srt-raw": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_srt_raw, url=url, temperature=0) + elif args.backend == "guidance": + from guidance import models, gen + + model = models.LlamaCpp("/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf", n_gpu_layers=-1, n_ctx=11000) + + def generate(prompt, max_tokens, stop): + out = model + prompt + gen(name="answer", + max_tokens=max_tokens, temperature=0, stop=stop) + return out["answer"] + + # warmup + generate("Hello!", max_tokens=8, stop=None) + else: + raise ValueError(f"Invalid backend: {args.backend}") + + # Run requests + def get_one_answer(i): + states[i] = json_decode(generate=generate, **arguments[i]) + + tic = time.time() + if args.parallel == 1: + for i in tqdm(range(len(arguments))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(get_one_answer, list(range(len(arguments)))) + latency = time.time() - tic + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "long_json_decode", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="questions.jsonl") + parser.add_argument("--num-questions", type=int, default=100) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/long_json_decode/bench_sglang.py b/benchmark/long_json_decode/bench_sglang.py new file mode 100644 index 000000000..4cb1e6f87 --- /dev/null +++ b/benchmark/long_json_decode/bench_sglang.py @@ -0,0 +1,68 @@ +import argparse +import json +import time + +import numpy as np +import sglang as sgl +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.utils import read_jsonl, dump_state_text + + +@sgl.function +def json_decode(s, document): + s += "Please extract the information of a city from the following wikipedia page.\n" + s += "Page begin.\n" + document + "Page end.\n" + s += "Here is the name, country, and symbol of the city in JSON format.\n" + s += '{\n' + s += ' "name": "' + sgl.gen("name", max_tokens=8, stop='"') + '",\n' + s += ' "country": "' + sgl.gen("country", max_tokens=8, stop='"') + '",\n' + s += ' "air port code": "' + sgl.gen("air port code", max_tokens=8, stop='"') + '",\n' + s += ' "top 3 landmarks": "' + sgl.gen("landmarks", max_tokens=24, stop='"') + '",\n' + s += '}\n' + + +def main(args): + lines = read_jsonl(args.data_path) + arguments = [] + for i in range(len(lines[:args.num_questions])): + arguments.append({ + "document": lines[i]["document"], + }) + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + # Run requests + tic = time.time() + states = json_decode.run_batch( + arguments, temperature=0, num_threads=args.parallel) + latency = time.time() - tic + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "long_json_decode", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="questions.jsonl") + parser.add_argument("--num-questions", type=int, default=10) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/long_json_decode/build_dataset.py b/benchmark/long_json_decode/build_dataset.py new file mode 100644 index 000000000..78d479183 --- /dev/null +++ b/benchmark/long_json_decode/build_dataset.py @@ -0,0 +1,26 @@ +import json + +import transformers +import wikipedia + + +name = "meta-llama/Llama-2-7b-chat-hf" +t = transformers.AutoTokenizer.from_pretrained(name) +city_names = ["los angles", "london", "tokyo", "beijing", "singapore"] + + +for city_name in city_names: + content = str(wikipedia.page(city_name).content) + content = content.replace("\n\n", "\n") + + tokens = t.encode(content) + + truncate_len = int((10000 / len(tokens)) * len(content)) + truncate_content = content[:truncate_len] + truncate_tokens = t.encode(truncate_content) + + # Count token + print(f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}") + + with open("questions.jsonl", "a") as fout: + fout.write(json.dumps({"document": truncate_content}) + "\n") diff --git a/benchmark/mmlu/README.md b/benchmark/mmlu/README.md new file mode 100644 index 000000000..25553ab4d --- /dev/null +++ b/benchmark/mmlu/README.md @@ -0,0 +1,56 @@ +## Download data +``` +wget https://people.eecs.berkeley.edu/~hendrycks/data.tar +tar xf data.tar +``` + +## Run benchmark + +### Benchmark sglang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_sglang.py --nsub 10 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_other.py --nsub 10 --backend vllm +``` + + +### Benchmark lightllm +``` +# A10G +python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000 + +# V100 +python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 4500 --port 22000 +``` + +``` +python3 bench_other.py --nsub 10 --backend lightllm +``` + + +### Benchmark guidance +``` +python3 bench_other.py --nsub 10 --backend guidance --parallel 1 +``` + + +### Benchmark lmql +``` +CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000 +``` + +``` +python3 bench_other.py --nsub 10 --backend lmql --parallel 2 +``` diff --git a/benchmark/mmlu/bench_other.py b/benchmark/mmlu/bench_other.py new file mode 100644 index 000000000..5371bb958 --- /dev/null +++ b/benchmark/mmlu/bench_other.py @@ -0,0 +1,202 @@ +import argparse +import asyncio +from concurrent.futures import ThreadPoolExecutor +import json +from functools import partial +import os +import time + +import numpy as np +import pandas as pd +import tiktoken +from tqdm import tqdm +from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw + + +choices = ["A", "B", "C", "D"] + +tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") + + +def format_subject(subject): + l = subject.split("_") + s = "" + for entry in l: + s += " " + entry + return s + +def format_example(df, idx, include_answer=True): + prompt = df.iloc[idx, 0] + k = df.shape[1] - 2 + for j in range(k): + prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1]) + prompt += "\nAnswer:" + if include_answer: + prompt += " {}\n\n".format(df.iloc[idx, k + 1]) + return prompt + +def gen_prompt(train_df, subject, k=-1): + prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(format_subject(subject)) + if k == -1: + k = train_df.shape[0] + for i in range(k): + prompt += format_example(train_df, i) + return prompt + + +model_initialized = None + + +def evaluate(args, subject, dev_df, test_df): + prompts = [] + labels = [] + + # Construct prompts + k = args.ntrain + train_prompt = gen_prompt(dev_df, subject, k) + while len(tokenizer.encode(train_prompt)) > 1536: + k -= 1 + train_prompt = gen_prompt(dev_df, subject, k) + + for i in range(test_df.shape[0]): + prompt_end = format_example(test_df, i, include_answer=False) + prompt = train_prompt + prompt_end + prompts.append(prompt) + + label = test_df.iloc[i, test_df.shape[1]-1] + labels.append(label) + + preds = [None] * len(prompts) + max_tokens = 1 + + # Select backend + global model_initialized + + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_lightllm, url=url, stop=None) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_vllm, url=url, stop=None) + elif args.backend == "srt-raw": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_srt_raw, url=url, stop=None) + elif args.backend == "guidance": + from guidance import models, gen + + if model_initialized is None: + model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) + model_initialized = model + else: + model = model_initialized + + def call_generate(prompt, temperature, max_tokens): + out = model + prompt + gen(name="answer", + max_tokens=max_tokens, temperature=0) + return out["answer"] + + elif args.backend == "lmql": + import lmql + model = lmql.model("meta-llama/Llama-2-7b-chat-hf", + endpoint=f"{args.host}:{args.port}") + + @lmql.query(model=model) + async def program(question): + '''lmql + """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 2 + return ANSWER + ''' + + async def call_generate(prompt, temperature, max_tokens): + return await program(question=prompt, temperature=temperature) + else: + raise ValueError(f"Invalid backend: {args.backend}") + + # Run requests + if args.backend != "lmql": + # Use thread pool + def get_one_answer(i): + pred = call_generate(prompts[i], temperature=0, + max_tokens=max_tokens) + preds[i] = pred.strip()[0] + + tic = time.time() + if args.parallel == 1: + for i in range(len(prompts)): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(get_one_answer, list(range(len(prompts)))) + else: + # Use asyncio + async def batched_call(batch_size): + for i in range(0, len(prompts), batch_size): + tasks = [] + for p in prompts[i:i+batch_size]: + tasks.append(call_generate(p, + temperature=0, max_tokens=max_tokens)) + rets = await asyncio.gather(*tasks) + for j in range(len(rets)): + preds[i+j] = rets[j].strip()[0] + + tic = time.time() + asyncio.run(batched_call(batch_size=args.parallel)) + latency = time.time() - tic + + # Compute accuracy + cors = [pred == label for pred, label in zip(preds, labels)] + acc = np.mean(cors) + cors = np.array(cors) + + print("Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format( + acc, latency, len(prompts), subject)) + + return cors, acc, latency + + +def main(args): + subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f]) + + all_cors = [] + all_latencies = [] + num_requests = 0 + + for subject in tqdm(subjects[:args.nsub]): + dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain] + test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None) + + cors, acc, latency = evaluate(args, subject, dev_df, test_df) + all_cors.append(cors) + all_latencies.append(latency) + num_requests += len(test_df) + + total_latency = np.sum(all_latencies) + print("Total latency: {:.3f}".format(total_latency)) + + weighted_acc = np.mean(np.concatenate(all_cors)) + print("Average accuracy: {:.3f}".format(weighted_acc)) + + # Write results + with open(args.result_file, "a") as fout: + value = { + "task": "mmlu", + "backend": args.backend, + "num_gpus": 1, + "latency": round(total_latency, 3), + "accuracy": round(weighted_acc, 3), + "num_requests": num_requests, + "other": { + "nsub": args.nsub, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ntrain", type=int, default=5) + parser.add_argument("--data_dir", type=str, default="data") + parser.add_argument("--nsub", type=int, default=60) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/mmlu/bench_sglang.py b/benchmark/mmlu/bench_sglang.py new file mode 100644 index 000000000..543b4ad61 --- /dev/null +++ b/benchmark/mmlu/bench_sglang.py @@ -0,0 +1,143 @@ +import argparse +import json +import os +import time + +import numpy as np +import pandas as pd +import tiktoken +from tqdm import tqdm +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend + + +choices = ["A", "B", "C", "D"] + +tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") + + +def format_subject(subject): + l = subject.split("_") + s = "" + for entry in l: + s += " " + entry + return s + +def format_example(df, idx, include_answer=True): + prompt = df.iloc[idx, 0] + k = df.shape[1] - 2 + for j in range(k): + prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1]) + prompt += "\nAnswer:" + if include_answer: + prompt += " {}\n\n".format(df.iloc[idx, k + 1]) + return prompt + +def gen_prompt(train_df, subject, k=-1): + prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(format_subject(subject)) + if k == -1: + k = train_df.shape[0] + for i in range(k): + prompt += format_example(train_df, i) + return prompt + +def evaluate(args, subject, dev_df, test_df): + prompts = [] + labels = [] + + k = args.ntrain + few_shot_examples = gen_prompt(dev_df, subject, k) + while len(tokenizer.encode(few_shot_examples)) > 1536: + k -= 1 + few_shot_examples = gen_prompt(dev_df, subject, k) + + for i in range(test_df.shape[0]): + prompt_end = format_example(test_df, i, include_answer=False) + prompts.append(prompt_end) + + label = test_df.iloc[i, test_df.shape[1]-1] + labels.append(label) + + arguments = [{"question": p} for p in prompts] + + ##################################### + ######### SGL Program Begin ######### + ##################################### + + import sglang as sgl + + @sgl.function + def few_shot_mmlu(s, examples, question): + s += examples + question + sgl.gen("answer") + + ##################################### + ########## SGL Program End ########## + ##################################### + + # Select backend + backend = select_sglang_backend(args) + + tic = time.time() + states = few_shot_mmlu.bind(examples=few_shot_examples).run_batch( + arguments, temperature=0, max_new_tokens=1, + backend=backend, num_threads=args.parallel) + preds = [s["answer"].strip()[0] if len(s["answer"].strip()) > 0 else "" + for s in states] + latency = time.time() - tic + + cors = [pred == label for pred, label in zip(preds, labels)] + acc = np.mean(cors) + cors = np.array(cors) + + print("Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format( + acc, latency, len(prompts), subject)) + + return cors, acc, latency + + +def main(args): + subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f]) + + all_cors = [] + all_latencies = [] + num_requests = 0 + + for subject in tqdm(subjects[:args.nsub]): + dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain] + test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None) + + cors, acc, latency = evaluate(args, subject, dev_df, test_df) + all_cors.append(cors) + all_latencies.append(latency) + num_requests += len(test_df) + + total_latency = np.sum(all_latencies) + print("Total latency: {:.3f}".format(total_latency)) + + weighted_acc = np.mean(np.concatenate(all_cors)) + print("Average accuracy: {:.3f}".format(weighted_acc)) + + # Write results + with open(args.result_file, "a") as fout: + value = { + "task": "mmlu", + "backend": args.backend, + "num_gpus": 1, + "latency": round(total_latency, 3), + "accuracy": round(weighted_acc, 3), + "num_requests": num_requests, + "other": { + "nsub": args.nsub, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ntrain", "-k", type=int, default=5) + parser.add_argument("--data_dir", "-d", type=str, default="data") + parser.add_argument("--save_dir", "-s", type=str, default="results") + parser.add_argument("--nsub", type=int, default=60) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/mtbench/README.md b/benchmark/mtbench/README.md new file mode 100644 index 000000000..e32a6eaab --- /dev/null +++ b/benchmark/mtbench/README.md @@ -0,0 +1,31 @@ +## Run benchmark + +### Benchmark sglang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-questions 80 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_other.py --num-questions 80 --backend vllm +``` + + +### Benchmark lightllm +``` +# A10G +python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000 +``` + +``` +python3 bench_other.py --num-questions 80 --backend lightllm +``` diff --git a/benchmark/mtbench/bench_other.py b/benchmark/mtbench/bench_other.py new file mode 100644 index 000000000..d4f0c0513 --- /dev/null +++ b/benchmark/mtbench/bench_other.py @@ -0,0 +1,116 @@ +import argparse +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import json +import os +import time +import uuid + +from fastchat.model import get_conversation_template +import requests +from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt + + +def load_questions(filename): + questions = [] + with open(filename, "r") as fin: + for line in fin: + obj = json.loads(line) + questions.append(obj) + return questions + + +def write_answers(filename, model_id, questions, answers): + with open(os.path.expanduser(filename), "w") as fout: + for i in range(len(answers)): + ans_json = { + "question_id": questions[i]["question_id"], + "answer_id": uuid.uuid4().hex, + "model_id": model_id, + "choices": { + "index": 0, + "turns": [answers[i][0], answers[i][1]], + }, + "tstamp": time.time(), + } + fout.write(json.dumps(ans_json) + "\n") + + +def main(args): + questions = load_questions(args.question_file) + questions = (questions * 10)[:args.num_questions] + max_tokens = 256 + model_id = "llama-2-chat" + + conv_main = get_conversation_template(model_id) + + # Select backend + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_lightllm, url=url, stop=None) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_vllm, url=url, stop=None) + elif args.backend == "srt": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_srt, url=url, stop=None) + else: + raise ValueError(f"Invalid backend: {args.backend}") + + answers = [None] * len(questions) + + def get_answer(i): + conv = conv_main.copy() + cur_answers = [] + for j in range(2): + q = questions[i]["turns"][j] + conv.append_message(conv.roles[0], q) + conv.append_message(conv.roles[1], None) + + prompt = conv.get_prompt() + output = call_generate(prompt, + temperature=0, max_tokens=max_tokens).strip() + + cur_answers.append(output) + conv.update_last_message(output) + + answers[i] = cur_answers + + # Run requests + tic = time.time() + if args.parallel == 1: + for i in range(len(questions)): + get_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(get_answer, list(range(len(questions)))) + latency = time.time() - tic + + print(f"#questions: {len(questions)}, Latency: {latency:.2f}") + + # Write results + answer_file = args.answer_file or f"tmp_output_{args.backend}.txt" + write_answers(answer_file, model_id, questions, answers) + + with open(args.result_file, "a") as fout: + value = { + "task": "mtbench", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--question-file", type=str, default="question.jsonl") + parser.add_argument("--answer-file", type=str, default=None) + parser.add_argument("--num-questions", type=int, default=80) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/mtbench/bench_sglang.py b/benchmark/mtbench/bench_sglang.py new file mode 100644 index 000000000..7727e03fc --- /dev/null +++ b/benchmark/mtbench/bench_sglang.py @@ -0,0 +1,95 @@ +import argparse +import json +import os +import time +import uuid + +import sglang as sgl +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend + + +def load_questions(filename): + questions = [] + with open(filename, "r") as fin: + for line in fin: + obj = json.loads(line) + questions.append(obj) + return questions + + +def write_answers(filename, model_id, questions, answers): + with open(os.path.expanduser(filename), "w") as fout: + for i in range(len(answers)): + ans_json = { + "question_id": questions[i]["question_id"], + "answer_id": uuid.uuid4().hex, + "model_id": model_id, + "choices": { + "index": 0, + "turns": [answers[i][0], answers[i][1]], + }, + "tstamp": time.time(), + } + fout.write(json.dumps(ans_json) + "\n") + + +@sgl.function +def answer_mt_bench(s, question_1, question_2): + s += sgl.system() + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1")) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2")) + + +def main(args): + # Construct prompts + questions = load_questions(args.question_file)[:args.num_questions] + arguments = [ + {"question_1": q["turns"][0], "question_2": q["turns"][1]} + for q in questions + ] + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + # Run requests + tic = time.time() + rets = answer_mt_bench.run_batch( + arguments, + temperature=0, + max_new_tokens=256, + num_threads=args.parallel) + answers = [[s["answer_1"], s["answer_2"]] for s in rets] + latency = time.time() - tic + + print(f"#questions: {len(questions)}, Latency: {latency:.2f}") + + # Write results + model_id = backend.model_info["model_path"] + answer_file = args.answer_file or f"tmp_output_{args.backend}.txt" + write_answers(answer_file, model_id, questions, answers) + + with open(args.result_file, "a") as fout: + value = { + "task": "mtbench", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--question-file", type=str, default="question.jsonl") + parser.add_argument("--answer-file", type=str, default=None) + parser.add_argument("--num-questions", type=int, default=80) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/multi_chain_reasoning/README.md b/benchmark/multi_chain_reasoning/README.md new file mode 100644 index 000000000..5859145eb --- /dev/null +++ b/benchmark/multi_chain_reasoning/README.md @@ -0,0 +1,43 @@ +## Download data +``` +wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl +``` + +## Run benchmark + +### Benchmark sglang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-questions 64 +python3 bench_sglang.py --num-questions 32 --parallel 1 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_other.py --num-questions 64 --backend vllm +``` + + +### Benchmark lightllm +``` +# A10G +python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000 +``` + +``` +python3 bench_other.py --num-questions 64 --backend lightllm +``` + + +### Benchmark guidance +``` +python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 +``` diff --git a/benchmark/multi_chain_reasoning/bench_other.py b/benchmark/multi_chain_reasoning/bench_other.py new file mode 100644 index 000000000..6d2388762 --- /dev/null +++ b/benchmark/multi_chain_reasoning/bench_other.py @@ -0,0 +1,195 @@ +import argparse +import ast +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import json +import re +import time + +import numpy as np +from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw +from sglang.utils import read_jsonl, dump_state_text + + +INVALID = -9999999 + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r'\d+', answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +prompt_lib = [ + "Let us think step by step.", + "Approach this methodically. Let's dissect the problem into smaller, more manageable parts.", + "It's important to proceed step by step, ensuring accuracy at each stage.", + "Take a deep breath and break this down.", + "A little bit of arithmetic and a logical approach will help us quickly arrive at the solution to this problem.", + "I am extremely good at math.", +] + + +def multi_chain_gsm8k(question, num_chains, call_generate): + s = "Question: " + question + "\n" + # s += call_generate(s + "Answer: " + prompt_lib[0], max_tokens=256, + # stop="Question", temperature=0) + # return s + + comps = [] + for i in range(num_chains): + comps.append(call_generate(s + "Answer: " + prompt_lib[i % num_chains], + max_tokens=256, temperature=0.3, stop="Question")) + + s += "Answer: To answer this question, here are some possible solutions. " + s += "After considering all of them, I will do a majority vote.\n\n" + for i in range(num_chains): + s += f"Solution {i+1}: " + comps[i].strip() + "\n\n" + s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is " + s += call_generate(s, max_tokens=16, temperature=0, stop=None) + return s + + +def main(args): + lines = read_jsonl(args.data_path) + + # Construct prompts + k = args.num_shot + + questions = [] + labels = [] + for i in range(len(lines[:args.num_questions])): + questions.append(lines[i]["question"]) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + + states = [None] * len(labels) + + # Select backend + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_lightllm, url=url) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_vllm, url=url) + elif args.backend == "srt-raw": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_srt_raw, url=url) + elif args.backend == "guidance": + from guidance import models, gen + + model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) + + def call_generate(prompt, temperature, max_tokens, stop): + out = model + prompt + gen(name="answer", + max_tokens=max_tokens, temperature=temperature, stop=stop) + return out["answer"] + + #def multi_chain_gsm8k(question, num_chains, call_generate): + # s = model + "Question: " + question + "\n" + + # comps = [] + # for i in range(num_chains): + # comps.append(call_generate(s + "Answer: " + prompt_lib[i % num_chains], + # max_tokens=256, temperature=0.3, stop="Question")) + + # s += "Answer: To answer this question, here are some possible solutions. " + # s += "After considering all of them, I will do a majority vote.\n\n" + # for i in range(num_chains): + # s += f"Solution {i+1}: " + comps[i].strip() + "\n\n" + # s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is " + # return call_generate(s, max_tokens=16, temperature=0, stop=None) + + elif args.backend == "lmql": + import lmql + model = lmql.model("meta-llama/Llama-2-7b-chat-hf", + endpoint=f"{args.host}:{args.port}") + + @lmql.query(model=model) + async def program(question): + '''lmql + """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 257 and STOPS_AT(ANSWER, "Question") + return ANSWER + ''' + + async def call_generate(prompt, temperature, max_tokens, stop): + return await program(question=prompt, temperature=0) + + else: + raise ValueError(f"Invalid backend: {args.backend}") + + # Run requests + if args.backend != "lmql": + # Use thread pool + def get_one_answer(i): + answer = multi_chain_gsm8k(questions[i], args.num_chains, + call_generate) + states[i] = answer + + tic = time.time() + if args.parallel == 1: + for i in range(len(questions)): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(get_one_answer, list(range(len(questions)))) + else: + # Use asyncio + async def batched_call(batch_size): + for i in range(0, len(questions), batch_size): + tasks = [] + for q in questions[i:i+batch_size]: + tasks.append(call_generate(few_shot_examples + q, + temperature=0, max_tokens=256, stop="Question")) + rets = await asyncio.gather(*tasks) + for j in range(len(rets)): + states[i+j] = get_answer_value(rets[j]) + + tic = time.time() + asyncio.run(batched_call(batch_size=args.parallel)) + latency = time.time() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + print(f"Latency: {latency:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Accuracy: {acc:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "multi_chain_gsm8k", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-shot", type=int, default=0) + parser.add_argument("--num-chains", type=int, default=5) + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=50) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/multi_chain_reasoning/bench_sglang.py b/benchmark/multi_chain_reasoning/bench_sglang.py new file mode 100644 index 000000000..f10c7b2c3 --- /dev/null +++ b/benchmark/multi_chain_reasoning/bench_sglang.py @@ -0,0 +1,129 @@ +import argparse +import ast +import json +import re +import time + +import numpy as np +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.utils import read_jsonl, dump_state_text + + +INVALID = -9999999 + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r'\d+', answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +prompt_lib = [ + "Let us think step by step.", + "Approach this methodically. Let's dissect the problem into smaller, more manageable parts.", + "It's important to proceed step by step, ensuring accuracy at each stage.", + "Take a deep breath and break this down.", + "A little bit of arithmetic and a logical approach will help us quickly arrive at the solution to this problem.", + "I am extremely good at math.", +] + + +def main(args): + lines = read_jsonl(args.data_path) + + # Construct prompts + #k = args.num_shot + #few_shot_examples = get_few_shot_examples(lines, k) + + questions = [] + labels = [] + for i in range(len(lines[:args.num_questions])): + questions.append(lines[i]["question"]) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + arguments = [{"question": q} for q in questions] + + num_chains = args.num_chains + + ##################################### + ######### SGL Program Begin ######### + ##################################### + + import sglang as sgl + + @sgl.function + def multi_chain_gsm8k(s, question): + s += "Question: " + question + "\n" + #s += "Answer: " + prompt_lib[0] + sgl.gen("answer", max_tokens=256, stop="Question", + # temperature=0) + #return + + forks = s.fork(num_chains) + for i in range(num_chains): + forks[i] += ("Answer: " + prompt_lib[i % num_chains] + + sgl.gen(f"chain", max_tokens=256, temperature=0.3, stop="Question")) + forks.join() + + s += "Answer: To answer this question, here are some possible solutions. " + s += "After considering all of them, I will do a majority vote.\n\n" + for i in range(num_chains): + s += f"Solution {i+1}: " + forks[i]["chain"].strip() + "\n\n" + s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is " + s += sgl.gen("answer", max_tokens=16) + + ##################################### + ########## SGL Program End ########## + ##################################### + + # Select backend + backend = select_sglang_backend(args) + + # Run requests + tic = time.time() + states = multi_chain_gsm8k.run_batch( + arguments, temperature=0, backend=backend, num_threads=args.parallel) + latency = time.time() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i]["answer"])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + print(f"Latency: {latency:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Accuracy: {acc:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "multi_chain_gsm8k", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-shot", type=int, default=0) + parser.add_argument("--num-chains", type=int, default=5) + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=50) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/multi_document_qa/README.md b/benchmark/multi_document_qa/README.md new file mode 100644 index 000000000..96b0a3ad6 --- /dev/null +++ b/benchmark/multi_document_qa/README.md @@ -0,0 +1,47 @@ +## Run benchmark + +### Benchmark sglang +``` +python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-questions 10 --parallel 1 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97 +``` + +``` +python3 bench_other.py --backend vllm --num-questions 64 +``` + + +### Benchmark guidance +``` +python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 +``` + + + +### Build dataset + +``` +pip install PyPDF2 +python3 build_dataset.py +``` + +```python +import PyPDF2 + +with open('llama2.pdf', 'rb') as file: + reader = PyPDF2.PdfReader(file) + text = '' + for page_num in range(len(reader.pages)): + text += reader.pages[page_num].extract_text() + with open('output.txt', 'w') as text_file: + text_file.write(text) +``` diff --git a/benchmark/multi_document_qa/bench_other.py b/benchmark/multi_document_qa/bench_other.py new file mode 100644 index 000000000..8d0f795d0 --- /dev/null +++ b/benchmark/multi_document_qa/bench_other.py @@ -0,0 +1,126 @@ +import argparse +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import json +import time + +from tqdm import tqdm +import numpy as np +from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw +from sglang.utils import read_jsonl, dump_state_text + + +USER_PREFIX = "[INST] " +USER_SUFFIX = " [/INST]" +ASSISTANT_PREFIX = "" +ASSISTANT_SUFFIX = " " + + +def multi_document_qa(docs, question, generate): + s = USER_PREFIX + s += "Pleaes answer a question according to given documents.\n" + s += "Question:" + question + "Documents begin.\n" + + s += "".join(docs) + + s += "\nDocuments end." + s += ("\n\nBased on the above documents, please answer this question:\n" + question + "\nAnswer in three words or fewer.") + s += USER_SUFFIX + s += ASSISTANT_PREFIX + answer = generate(s, max_tokens=16, stop=None) + return answer + + +def main(args): + lines = read_jsonl(args.data_path) + l = lines[0] + arguments = [] + labels = [] + + num_docs = 10 + if args.backend == "guidance": + num_docs = 7 # due to OOM + + for i in range(len(l["questions"][:args.num_questions])): + arguments.append({ + "docs": l["documents"][:num_docs], + "question": l["questions"][i], + }) + labels.append(l["answers"][i]) + states = [None] * len(arguments) + + # Select backend + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_lightllm, url=url, temperature=0) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_vllm, url=url, temperature=0) + elif args.backend == "srt-raw": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_srt_raw, url=url, temperature=0) + elif args.backend == "guidance": + from guidance import models, gen + + model = models.LlamaCpp("/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf", n_gpu_layers=-1, n_ctx=11000) + + def generate(prompt, max_tokens, stop): + out = model + prompt + gen(name="answer", + max_tokens=max_tokens, temperature=0, stop=stop) + return out["answer"] + + # warmup + generate("Hello!", max_tokens=8, stop=None) + else: + raise ValueError(f"Invalid backend: {args.backend}") + + # Run requests + def get_one_answer(i): + states[i] = multi_document_qa(generate=generate, **arguments[i]) + + tic = time.time() + if args.parallel == 1: + for i in tqdm(range(len(labels))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(get_one_answer, list(range(len(labels)))) + latency = time.time() - tic + + # Compute accuracy + print(states) + correct = 0 + for s, label in zip(states, labels): + answer = s.lower() + if all(x in answer for x in label.lower().split(" ")): + correct += 1 + accuracy = correct / len(labels) + print(f"Accuracy: {accuracy:.3f}") + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "multi_document_qa", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_questions, + "accuracy": accuracy, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="questions.jsonl") + parser.add_argument("--num-questions", type=int, default=100) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/multi_document_qa/bench_sglang.py b/benchmark/multi_document_qa/bench_sglang.py new file mode 100644 index 000000000..22bbfc5d8 --- /dev/null +++ b/benchmark/multi_document_qa/bench_sglang.py @@ -0,0 +1,84 @@ +import argparse +import json +import time + +import numpy as np +import sglang as sgl +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.utils import read_jsonl, dump_state_text + + +@sgl.function +def multi_document_qa(s, docs, question): + s += sgl.user_begin() + s += "Pleaes answer a question according to given documents.\n" + s += "Question:" + question + "Documents begin.\n" + + forks = s.fork(len(docs)) + forks += lambda i: docs[i] + forks.join("concate_and_append") + + s += "\nDocuments end." + s += ("\n\nBased on the above documents, please answer this question:\n" + question + "\nAnswer in three words or fewer.") + s += sgl.user_end() + s += sgl.assistant(sgl.gen("answer", max_tokens=16)) + + +def main(args): + lines = read_jsonl(args.data_path) + l = lines[0] + arguments = [] + labels = [] + for i in range(len(l["questions"][:args.num_questions])): + arguments.append({ + "docs": l["documents"][:10], + "question": l["questions"][i], + }) + labels.append(l["answers"][i]) + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + # Run requests + tic = time.time() + states = multi_document_qa.run_batch( + arguments, temperature=0, num_threads=args.parallel) + latency = time.time() - tic + + # Compute accuracy + print([s["answer"] for s in states]) + correct = 0 + for s, label in zip(states, labels): + answer = s["answer"].lower() + if all(x in answer for x in label.lower().split(" ")): + correct += 1 + accuracy = correct / len(labels) + print(f"Accuracy: {accuracy:.3f}") + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "multi_document_qa", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_questions, + "accuracy": accuracy, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="questions.jsonl") + parser.add_argument("--num-questions", type=int, default=100) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/multi_document_qa/build_dataset.py b/benchmark/multi_document_qa/build_dataset.py new file mode 100644 index 000000000..670cd3166 --- /dev/null +++ b/benchmark/multi_document_qa/build_dataset.py @@ -0,0 +1,64 @@ +import json + +import transformers + +content = "\n".join( + open("llama2.txt", 'r', encoding='utf-8', errors='ignore').readlines()) +content = content.replace("\n\n", "\n") + +# Count token +name = "meta-llama/Llama-2-7b-chat-hf" +t = transformers.AutoTokenizer.from_pretrained(name) +print(f"num tokens: {len(t.encode(content))}") + +# Segment +SEP = "\n\n" +parts = content.split(SEP) +print(f"num segments: {len(parts)}") + +segment_len = 1100 + +segments = [] +tmp = [] +tmp_len = 0 +for i in range(len(parts)): + tmp.append(parts[i]) + tmp_len += len(t.encode(parts[i])) + + if tmp_len > segment_len: + segments.append(SEP.join(tmp)) + tmp = [] + tmp_len = 0 + +for i, s in enumerate(segments): + print(i, len(t.encode(segments[i]))) + +# Dump +with open("questions.jsonl", "w") as fout: + fout.write(json.dumps({ + "documents": segments[:30], + "questions": [ + "What is the name of the fine-tuned LLMs?", + "Which figure shows the helpfulness human evaluation results for Llama 2-Chat?", + "What is the number of parameters in the largest Llama 2 model?", + "What is the batch size of fine-tuning?", + "Where can we find the details of potential data contamination?", + "What is the full name of MPT?", + "What is the power consumption of RSC in Watt?", + "How many tokens of data do they train on?", + "Which model's release is delayed due to a lack of time to sufficiently red team?", + "Which activation function is used in Llama?" + ], + "answers": [ + "Llama 2 Chat", + "1", + "70 B", + "64", + "A 6", + "MosaicML", + "400", + "2 trillion", + "34 B", + "SwiGLU", + ], + }) + "\n") diff --git a/benchmark/react/README.md b/benchmark/react/README.md new file mode 100644 index 000000000..47d230afb --- /dev/null +++ b/benchmark/react/README.md @@ -0,0 +1,26 @@ +## Run benchmark + +### Benchmark sglang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-questions 100 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_other.py --num-questions 100 --backend vllm +``` + + +### Benchmark guidance +``` +python3 bench_other.py --num-questions 100 --backend guidance --parallel 1 +``` diff --git a/benchmark/react/bench_other.py b/benchmark/react/bench_other.py new file mode 100644 index 000000000..bdf4c8f0d --- /dev/null +++ b/benchmark/react/bench_other.py @@ -0,0 +1,182 @@ +import argparse +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import json +import time +from pathlib import Path + +from tqdm import tqdm +from sglang.test.test_utils import ( + add_common_other_args_and_parse, + call_generate_lightllm, + call_generate_vllm, + call_generate_srt_raw, +) +from sglang.utils import read_jsonl, dump_state_text + + +def get_prompt(question): + prompt = ( + """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types: +(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. +(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage. +(3) Finish[answer], which returns the answer and finishes the task. +Here are some examples. +Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into? +Thought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area. +Action 1: Search[Colorado orogeny] +Observation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas. +Thought 2: It does not mention the eastern sector. So I need to look up eastern sector. +Action 2: Lookup[eastern sector] +Observation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny. +Thought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range. +Action 3: Search[High Plains] +Observation 3: High Plains refers to one of two distinct land regions: +Thought 4: I need to instead search High Plains (United States). +Action 4: Search[High Plains (United States)] +Observation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3] +Thought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft. +Action 5: Finish[1,800 to 7,000 ft] +Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who? +Thought 1: The question simplifies to "The Simpsons" character Milhouse is named after who. I only need to search Milhouse and find who it is named after. +Action 1: Search[Milhouse] +Observation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening. +Thought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up "named after". +Action 2: Lookup[named after] +Observation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous. +Thought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon. +Action 3: Finish[Richard Nixon] +Question: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture? +Thought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups. +Action 1: Search[Adam Clayton Powell] +Observation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito']. +Thought 2: To find the documentary, I can search Adam Clayton Powell (film). +Action 2: Search[Adam Clayton Powell (film)] +Observation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg. +The film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience. +Thought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture. +Action 3: Finish[The Saimaa Gesture] +Question: What profession does Nicholas Ray and Elia Kazan have in common? +Thought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common. +Action 1: Search[Nicholas Ray] +Observation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 – June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause. +Thought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions. +Action 2: Search[Elia Kazan] +Observation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor. +Thought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor. +Action 3: Finish[director, screenwriter, actor] +Question: Which magazine was started first Arthur's Magazine or First for Women? +Thought 1: I need to search Arthur's Magazine and First for Women, and find which was started first. +Action 1: Search[Arthur's Magazine] +Observation 1: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century. +Thought 2: Arthur's Magazine was started in 1844. I need to search First for Women next. +Action 2: Search[First for Women] +Observation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989. +Thought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first. +Action 3: Finish[Arthur's Magazine] +Question: Were Pavel Urysohn and Leonid Levin known for the same type of work? +Thought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same. +Action 1: Search[Pavel Urysohn] +Observation 1: Pavel Samuilovich Urysohn (February 3, 1898 â August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory. +Thought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work. +Action 2: Search[Leonid Levin] +Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist. +Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work. +Action 3: Finish[yes] +""" + question) + return prompt + + +def main(args): + lines = read_jsonl(args.data_path)[:args.num_questions] + arguments = [{ + "question": k, + "triplets": v + } for l in lines for k, v in l.items()] + + states = [] + + # Select backend + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_lightllm, url=url) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_vllm, url=url) + elif args.backend == "srt-raw": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_srt_raw, url=url) + elif args.backend == "guidance": + from guidance import models, gen + + model = models.LlamaCpp( + str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf", + n_gpu_layers=-1, + n_ctx=4096, + ) + + def call_generate(prompt, temperature, max_tokens, stop): + out = (model + prompt + gen( + name="result", + max_tokens=max_tokens, + temperature=temperature, + stop=stop, + )) + return out["result"] + + else: + raise ValueError(f"Invalid backend: {args.backend}") + + def run_single_agent(argument): + question = argument["question"] + triplets = argument["triplets"] + prompt = get_prompt(question) + for i in range(1, len(triplets) + 2): + prompt += "Thought " + str(i) + ":" + states.append(prompt) + answer = call_generate(prompt, + max_tokens=200, + temperature=0, + stop="Observation") + if i > len(triplets): + break + prompt += (triplets[i - 1]["thought"] + "\nAction " + str(i) + + ":" + triplets[i - 1]["action"] + "\nObservation " + + str(i) + ":" + triplets[i - 1]["observation"] + "\n") + + states.append(answer) + + tic = time.time() + if args.parallel == 1: + for arg in tqdm(arguments): + run_single_agent(arg) + else: + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(run_single_agent, arguments) + latency = time.time() - tic + + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "ReAct Agents", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": len(arguments), + "other": { + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="hotpotqa_100.jsonl") + parser.add_argument("--num-questions", type=int, default=10) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/react/bench_sglang.py b/benchmark/react/bench_sglang.py new file mode 100644 index 000000000..933de8e76 --- /dev/null +++ b/benchmark/react/bench_sglang.py @@ -0,0 +1,141 @@ +import argparse +import json +import time + +import sglang as sgl +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) +from sglang.utils import read_jsonl, dump_state_text + + +@sgl.function +def webthink(s, question, triplets): + s += ( + """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types: +(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. +(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage. +(3) Finish[answer], which returns the answer and finishes the task. +Here are some examples. +Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into? +Thought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area. +Action 1: Search[Colorado orogeny] +Observation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas. +Thought 2: It does not mention the eastern sector. So I need to look up eastern sector. +Action 2: Lookup[eastern sector] +Observation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny. +Thought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range. +Action 3: Search[High Plains] +Observation 3: High Plains refers to one of two distinct land regions: +Thought 4: I need to instead search High Plains (United States). +Action 4: Search[High Plains (United States)] +Observation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3] +Thought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft. +Action 5: Finish[1,800 to 7,000 ft] +Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who? +Thought 1: The question simplifies to "The Simpsons" character Milhouse is named after who. I only need to search Milhouse and find who it is named after. +Action 1: Search[Milhouse] +Observation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening. +Thought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up "named after". +Action 2: Lookup[named after] +Observation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous. +Thought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon. +Action 3: Finish[Richard Nixon] +Question: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture? +Thought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups. +Action 1: Search[Adam Clayton Powell] +Observation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito']. +Thought 2: To find the documentary, I can search Adam Clayton Powell (film). +Action 2: Search[Adam Clayton Powell (film)] +Observation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg. +The film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience. +Thought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture. +Action 3: Finish[The Saimaa Gesture] +Question: What profession does Nicholas Ray and Elia Kazan have in common? +Thought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common. +Action 1: Search[Nicholas Ray] +Observation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 – June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause. +Thought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions. +Action 2: Search[Elia Kazan] +Observation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor. +Thought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor. +Action 3: Finish[director, screenwriter, actor] +Question: Which magazine was started first Arthur's Magazine or First for Women? +Thought 1: I need to search Arthur's Magazine and First for Women, and find which was started first. +Action 1: Search[Arthur's Magazine] +Observation 1: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century. +Thought 2: Arthur's Magazine was started in 1844. I need to search First for Women next. +Action 2: Search[First for Women] +Observation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989. +Thought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first. +Action 3: Finish[Arthur's Magazine] +Question: Were Pavel Urysohn and Leonid Levin known for the same type of work? +Thought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same. +Action 1: Search[Pavel Urysohn] +Observation 1: Pavel Samuilovich Urysohn (February 3, 1898 â August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory. +Thought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work. +Action 2: Search[Leonid Levin] +Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist. +Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work. +Action 3: Finish[yes] +""" + question) + for i in range(1, len(triplets) + 2): + s += "Thought " + str(i) + ":" + ss = s.fork(1) + ss[0] += sgl.gen(name="thought_action", max_tokens=200, stop="Observation") + # ss.join() + # to verify the correctness of output, this should be collected + # print(ss[0]["thought_action"]) + if i > len(triplets): + break + s += (triplets[i - 1]["thought"] + "\nAction " + str(i) + ":" + + triplets[i - 1]["action"] + "\nObservation " + str(i) + ":" + + triplets[i - 1]["observation"] + "\n") + + +def main(args): + lines = read_jsonl(args.data_path)[:args.num_questions] + arguments = [{ + "question": k, + "triplets": v + } for l in lines for k, v in l.items()] + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + states = [] + tic = time.time() + states = webthink.run_batch(arguments, + temperature=0, + num_threads=args.parallel) + latency = time.time() - tic + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "ReAct Agents", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": len(arguments), + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="hotpotqa_100.jsonl") + parser.add_argument("--num-questions", type=int, default=10) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/tip_suggestion/README.md b/benchmark/tip_suggestion/README.md new file mode 100644 index 000000000..81ca47d04 --- /dev/null +++ b/benchmark/tip_suggestion/README.md @@ -0,0 +1,27 @@ +## Run benchmark + +### Benchmark sglang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-questions 64 +python3 bench_sglang.py --num-questions 32 --parallel 1 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_other.py --backend vllm --num-questions 64 +``` + + +### Benchmark guidance +``` +python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 +``` diff --git a/benchmark/tip_suggestion/bench_other.py b/benchmark/tip_suggestion/bench_other.py new file mode 100644 index 000000000..0e974e49f --- /dev/null +++ b/benchmark/tip_suggestion/bench_other.py @@ -0,0 +1,124 @@ +import argparse +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import json +import time + +from tqdm import tqdm +import numpy as np +from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw +from sglang.utils import read_jsonl, dump_state_text + + +number = 5 + + +def expand_tip(topic, tip, generate): + s = ( +"""Please expand a tip for a topic into a detailed paragraph. + +Topic: staying healthy +Tip: Regular Exercise +Paragraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement. + +Topic: building a campfire +Tip: Choose the Right Location +Paragraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches. + +Topic: writing a blog post +Tip: structure your content effectively +Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement. + +Topic: """ + topic + "\nTip: " + tip + "\nParagraph:") + return generate(s, max_tokens=128, stop=["\n\n"]) + + +def suggest_tips(topic, generate): + s = "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n" + s += "USER: Give some tips for " + topic + ".\n" + s += ("ASSISTANT: Okay. Here are " + str(number) + " concise tips, each under 8 words:\n") + + tips = [] + for i in range(1, 1 + number): + s += f"{i}." + tip = generate(s, max_tokens=24, stop=[".", "\n"]) + s += tip + ".\n" + tips.append(tip) + + paragraphs = [expand_tip(topic, tip, generate=generate) for tip in tips] + + for i in range(1, 1 + number): + s += f"Tip {i}:" + paragraphs[i-1] + "\n" + return s + + +def main(args): + lines = read_jsonl(args.data_path)[:args.num_questions] + states = [None] * len(lines) + + # Select backend + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_lightllm, url=url, temperature=0) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_vllm, url=url, temperature=0) + elif args.backend == "srt-raw": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_srt_raw, url=url, temperature=0) + elif args.backend == "guidance": + from guidance import models, gen + + model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) + + def generate(prompt, max_tokens, stop): + out = model + prompt + gen(name="answer", + max_tokens=max_tokens, temperature=0, stop=stop) + return out["answer"] + + # warmup + generate("Hello!", max_tokens=8, stop=None) + else: + raise ValueError(f"Invalid backend: {args.backend}") + + # Run requests + def get_one_answer(i): + states[i] = suggest_tips(lines[i]["topic"], generate) + + tic = time.time() + if args.parallel == 1: + for i in tqdm(range(len(lines))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(get_one_answer, list(range(len(lines)))) + latency = time.time() - tic + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "tip_suggestion", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="topic.jsonl") + parser.add_argument("--num-questions", type=int, default=100) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/tip_suggestion/bench_sglang.py b/benchmark/tip_suggestion/bench_sglang.py new file mode 100644 index 000000000..5cd0f23cf --- /dev/null +++ b/benchmark/tip_suggestion/bench_sglang.py @@ -0,0 +1,91 @@ +import argparse +import json +import time + +import numpy as np +import sglang as sgl +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.utils import read_jsonl, dump_state_text + + +number = 5 + + +@sgl.function +def expand_tip(s, topic, tip): + s += ( +"""Please expand a tip for a topic into a detailed paragraph. + +Topic: staying healthy +Tip: Regular Exercise +Paragraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement. + +Topic: building a campfire +Tip: Choose the Right Location +Paragraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches. + +Topic: writing a blog post +Tip: structure your content effectively +Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement. + +Topic: """ + topic + "\nTip: " + tip + "\nParagraph:") + s += sgl.gen("paragraph", max_tokens=128, stop=["\n\n"], temperature=0) + + +@sgl.function +def suggest_tips(s, topic): + s += "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n" + s += "USER: Give some tips for " + topic + ".\n" + s += ("ASSISTANT: Okay. Here are " + str(number) + " concise tips, each under 8 words:\n") + + paragraphs = [] + for i in range(1, 1 + number): + s += f"{i}." + sgl.gen(f"tip_{i}", max_tokens=24, stop=[".", "\n"]) + ".\n" + paragraphs.append(expand_tip(topic=topic, tip=s[f"tip_{i}"])) + + for i in range(1, 1 + number): + s += f"Tip {i}:" + paragraphs[i-1]["paragraph"] + "\n" + + +def main(args): + lines = read_jsonl(args.data_path)[:args.num_questions] + arguments = [ + {"topic": l["topic"]} for l in lines + ] + + # Select backend + sgl.set_default_backend(select_sglang_backend(args)) + + # Run requests + tic = time.time() + states = suggest_tips.run_batch( + arguments, temperature=0, num_threads=args.parallel) + latency = time.time() - tic + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "tip_suggestion", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="topic.jsonl") + parser.add_argument("--num-questions", type=int, default=100) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/tree_of_thought/README.md b/benchmark/tree_of_thought/README.md new file mode 100644 index 000000000..760b24c46 --- /dev/null +++ b/benchmark/tree_of_thought/README.md @@ -0,0 +1,43 @@ +## Download data +``` +wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl +``` + +## Run benchmark + +### Benchmark sglang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-questions 32 --parallel 16 +python3 bench_sglang.py --num-questions 10 --parallel 1 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_other.py --num-questions 32 --backend vllm +``` + + +### Benchmark lightllm +``` +# A10G +python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000 +``` + +``` +python3 bench_other.py --num-questions 32 --backend lightllm +``` + + +### Benchmark guidance +``` +python3 bench_other.py --num-questions 32 --backend guidance --parallel 1 +``` diff --git a/benchmark/tree_of_thought/bench_other.py b/benchmark/tree_of_thought/bench_other.py new file mode 100644 index 000000000..ef61edc79 --- /dev/null +++ b/benchmark/tree_of_thought/bench_other.py @@ -0,0 +1,183 @@ +import argparse +import ast +import asyncio +from collections import Counter +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import json +import re +import time + +import numpy as np +from tqdm import tqdm +from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw +from sglang.utils import read_jsonl, dump_state_text + + +INVALID = -9999999 + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r'\d+', answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def most_frequent_number(numbers): + if not numbers: + return None + + frequency = Counter(numbers) + most_frequent = max(frequency, key=frequency.get) + return most_frequent + + +USER_PREFIX = "[INST] " +USER_SUFFIX = " [/INST]" +ASSISTANT_PREFIX = "" +ASSISTANT_SUFFIX = " " + +# Use a low temp to make the results more deterministic and the comparison more fair. +temp = 0.3 + + +def propose_plan(s, question, num_branches, call_generate): + s += (USER_PREFIX + +"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question + USER_SUFFIX) + + s += ASSISTANT_PREFIX + comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +def execute_plan(s, num_branches, call_generate): + s += (USER_PREFIX + +"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + USER_SUFFIX) + s += ASSISTANT_PREFIX + comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +def reflect_solution(s, num_branches, call_generate): + s += (USER_PREFIX + +"""Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + USER_SUFFIX) + s += ASSISTANT_PREFIX + comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +def tree_search(question, num_branches, call_generate): + s = "" + solutions = [] + + plan_forks = propose_plan(s, question, num_branches, call_generate) + for plan in plan_forks: + sol_forks = execute_plan(plan, num_branches, call_generate) + for sol in sol_forks: + score_forks = reflect_solution(sol, num_branches, call_generate) + solutions.append(sol_forks) + + return solutions + + +def main(args): + lines = read_jsonl(args.data_path) + + # Construct prompts + num_branches = 3 + questions = [] + labels = [] + for i in range(len(lines[:args.num_questions])): + questions.append(lines[i]["question"]) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + arguments = [{"question": q, "num_branches": num_branches} for q in questions] + + # Select backend + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_lightllm, url=url) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_vllm, url=url) + elif args.backend == "srt-raw": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_srt_raw, url=url) + elif args.backend == "guidance": + from guidance import models, gen + + model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) + + def call_generate(prompt, temperature, max_tokens, stop, n): + if n == 1: + out = model + prompt + gen(name="answer", + max_tokens=max_tokens, temperature=temperature, stop=stop) + return out["answer"] + else: + rets = [] + for i in range(n): + out = model + prompt + gen(name="answer", + max_tokens=max_tokens, temperature=temperature, stop=stop) + rets.append(out["answer"]) + return rets + + # Run requests + states = [None] * len(questions) + def get_one_answer(i): + states[i] = tree_search(**arguments[i], call_generate=call_generate) + + tic = time.time() + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(get_one_answer, list(range(len(questions)))) + latency = time.time() - tic + + answers_text = [] + for s in states: + answers_text.append([x for xs in s for x in xs]) + + preds = [] + for i in range(len(states)): + answers = [get_answer_value(v) for v in answers_text[i]] + preds.append(most_frequent_number(answers)) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + print(f"Latency: {latency:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Accuracy: {acc:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", answers_text) + + with open(args.result_file, "a") as fout: + value = { + "task": "tree_of_thought_gsm8k", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=200) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/tree_of_thought/bench_sglang.py b/benchmark/tree_of_thought/bench_sglang.py new file mode 100644 index 000000000..7b717a60c --- /dev/null +++ b/benchmark/tree_of_thought/bench_sglang.py @@ -0,0 +1,147 @@ +import argparse +import ast +from collections import Counter +import json +import re +import time + +import numpy as np +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.utils import read_jsonl, dump_state_text +import sglang as sgl + + +INVALID = -9999999 + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r'\d+', answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def most_frequent_number(numbers): + if not numbers: + return None + + frequency = Counter(numbers) + most_frequent = max(frequency, key=frequency.get) + return most_frequent + + +# Use a low temp to make the results more deterministic and the comparison more fair. +temp = 0.3 + + +def propose_plan(s, question, num_branches): + s += sgl.user( +"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question) + forks = s.fork(num_branches) + forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp)) + return forks + + +def execute_plan(s, num_branches): + s += sgl.user( +"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""") + forks = s.fork(num_branches) + forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp)) + return forks + + +def reflect_solution(s, num_branches): + s += sgl.user( +"""Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""") + forks = s.fork(num_branches) + forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp)) + return forks + + +@sgl.function +def tree_search(s, question, num_branches): + forks_to_join = [] + + plan_forks = propose_plan(s, question, num_branches) + forks_to_join.append(plan_forks) + + sol_states = [] + for plan in plan_forks: + forks = execute_plan(plan, num_branches) + forks_to_join.append(forks) + sol_states.extend(forks) + + for sol in sol_states: + forks = reflect_solution(sol, num_branches) + forks_to_join.append(forks) + + for f in reversed(forks_to_join): + f.join() + + +def main(args): + lines = read_jsonl(args.data_path) + + # Construct prompts + num_branches = 3 + questions = [] + labels = [] + for i in range(len(lines[:args.num_questions])): + questions.append(lines[i]["question"]) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + arguments = [{"question": q, "num_branches": num_branches} for q in questions] + + # Select backend + backend = select_sglang_backend(args) + + # Run requests + tic = time.time() + states = tree_search.run_batch( + arguments, temperature=0, backend=backend, num_threads=args.parallel) + latency = time.time() - tic + answers_text = [] + for s in states: + answers_text.append([x for xs in s["answer"] for x in xs]) + + preds = [] + for i in range(len(states)): + answers = [get_answer_value(v) for v in answers_text[i]] + preds.append(most_frequent_number(answers)) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + print(f"Latency: {latency:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Accuracy: {acc:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", answers_text) + + with open(args.result_file, "a") as fout: + value = { + "task": "tree_of_thought_gsm8k", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=200) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/tree_of_thought_deep/README.md b/benchmark/tree_of_thought_deep/README.md new file mode 100644 index 000000000..52141914a --- /dev/null +++ b/benchmark/tree_of_thought_deep/README.md @@ -0,0 +1,43 @@ +## Download data +``` +wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl +``` + +## Run benchmark + +### Benchmark sglang +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +python3 bench_sglang.py --num-questions 32 --parallel 8 +python3 bench_sglang.py --num-questions 16 --parallel 1 +``` + + +### Benchmark vllm +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +``` +python3 bench_other.py --num-questions 32 --backend vllm +``` + + +### Benchmark lightllm +``` +# A10G +python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000 +``` + +``` +python3 bench_other.py --num-questions 32 --backend lightllm +``` + + +### Benchmark guidance +``` +python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 +``` diff --git a/benchmark/tree_of_thought_deep/bench_other.py b/benchmark/tree_of_thought_deep/bench_other.py new file mode 100644 index 000000000..3ffe66664 --- /dev/null +++ b/benchmark/tree_of_thought_deep/bench_other.py @@ -0,0 +1,198 @@ +import argparse +import ast +import asyncio +from collections import Counter +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import json +import re +import time + +import numpy as np +from tqdm import tqdm +from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw +from sglang.utils import read_jsonl, dump_state_text + + +INVALID = -9999999 + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r'\d+', answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def most_frequent_number(numbers): + if not numbers: + return None + + frequency = Counter(numbers) + most_frequent = max(frequency, key=frequency.get) + return most_frequent + + +USER_PREFIX = "[INST] " +USER_SUFFIX = " [/INST]" +ASSISTANT_PREFIX = "" +ASSISTANT_SUFFIX = " " + +# Use a low temp to make the results more deterministic and the comparison more fair. +temp = 0.001 + + +def propose_plan(s, question, num_branches, call_generate): + s += (USER_PREFIX + +"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question + USER_SUFFIX) + + s += ASSISTANT_PREFIX + comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +def execute_plan(s, num_branches, call_generate): + s += (USER_PREFIX + +"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + USER_SUFFIX) + s += ASSISTANT_PREFIX + comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +def reflect_solution(s, num_branches, call_generate): + s += (USER_PREFIX + +"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + USER_SUFFIX) + s += ASSISTANT_PREFIX + comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +def get_final_answer(s, num_branches, call_generate): + s += (USER_PREFIX + +"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""" + USER_SUFFIX) + s += ASSISTANT_PREFIX + comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +def tree_search(question, num_branches, call_generate): + plan_forks = propose_plan("", question, num_branches, call_generate) + + sol_states = [] + for plan in plan_forks: + forks = execute_plan(plan, num_branches, call_generate) + sol_states.extend(forks) + + ref_states = [] + for sol in sol_states: + forks = reflect_solution(sol, num_branches, call_generate) + ref_states.extend(forks) + + solutions = [] + for sol in ref_states: + ans = get_final_answer(sol, num_branches, call_generate) + solutions.append(ans) + + return solutions + + +def main(args): + lines = read_jsonl(args.data_path) + + # Construct prompts + num_branches = 2 + questions = [] + labels = [] + for i in range(len(lines[:args.num_questions])): + questions.append(lines[i]["question"]) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + arguments = [{"question": q, "num_branches": num_branches} for q in questions] + + # Select backend + if args.backend == "lightllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_lightllm, url=url) + elif args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_vllm, url=url) + elif args.backend == "srt-raw": + url = f"{args.host}:{args.port}/generate" + call_generate = partial(call_generate_srt_raw, url=url) + elif args.backend == "guidance": + from guidance import models, gen + + model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096) + + def call_generate(prompt, temperature, max_tokens, stop, n): + if n == 1: + out = model + prompt + gen(name="answer", + max_tokens=max_tokens, temperature=temperature, stop=stop) + return out["answer"] + else: + rets = [] + for i in range(n): + out = model + prompt + gen(name="answer", + max_tokens=max_tokens, temperature=temperature, stop=stop) + rets.append(out["answer"]) + return rets + + # Run requests + states = [None] * len(questions) + def get_one_answer(i): + states[i] = tree_search(**arguments[i], call_generate=call_generate) + + tic = time.time() + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(get_one_answer, list(range(len(questions)))) + latency = time.time() - tic + + answers_text = [] + for s in states: + answers_text.append([x for xs in s for x in xs]) + + preds = [] + for i in range(len(states)): + answers = [get_answer_value(v) for v in answers_text[i]] + preds.append(most_frequent_number(answers)) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + print(f"Latency: {latency:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Accuracy: {acc:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", answers_text) + + with open(args.result_file, "a") as fout: + value = { + "task": "tree_of_thought_gsm8k", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=200) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/tree_of_thought_deep/bench_sglang.py b/benchmark/tree_of_thought_deep/bench_sglang.py new file mode 100644 index 000000000..e8b617597 --- /dev/null +++ b/benchmark/tree_of_thought_deep/bench_sglang.py @@ -0,0 +1,157 @@ +import argparse +import ast +from collections import Counter +import json +import re +import time + +import numpy as np +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.utils import read_jsonl, dump_state_text +import sglang as sgl + + +INVALID = -9999999 + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r'\d+', answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def most_frequent_number(numbers): + if not numbers: + return None + + frequency = Counter(numbers) + most_frequent = max(frequency, key=frequency.get) + return most_frequent + + +# Use a low temp to make the results more deterministic and the comparison more fair. +temp = 0.001 + + +def propose_plan(s, question, num_branches): + s += sgl.user( +"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question) + forks = s.fork(num_branches) + forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp)) + return forks + + +def execute_plan(s, num_branches): + s += sgl.user( +"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""") + forks = s.fork(num_branches) + forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp)) + return forks + + +def reflect_solution(s, num_branches): + s += sgl.user( +"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""") + forks = s.fork(num_branches) + forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp)) + return forks + + +def get_final_answer(s, num_branches): + s += sgl.user( +"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""") + forks = s.fork(num_branches) + forks += sgl.assistant(sgl.gen("final_answer", max_tokens=256, temperature=temp)) + return forks + + + +@sgl.function +def tree_search(s, question, num_branches): + plan_forks = propose_plan(s, question, num_branches) + + sol_states = [] + for plan in plan_forks: + forks = execute_plan(plan, num_branches) + sol_states.extend(forks) + + ref_states = [] + for sol in sol_states: + forks = reflect_solution(sol, num_branches) + ref_states.extend(forks) + + solutions = [] + for sol in ref_states: + forks = get_final_answer(sol, num_branches) + solutions.append(forks) + solutions = [[s.text() for s in forks] for forks in solutions] + + return solutions + +def main(args): + lines = read_jsonl(args.data_path) + + # Construct prompts + num_branches = 2 + questions = [] + labels = [] + for i in range(len(lines[:args.num_questions])): + questions.append(lines[i]["question"]) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + arguments = [{"question": q, "num_branches": num_branches} for q in questions] + + # Select backend + backend = select_sglang_backend(args) + + # Run requests + tic = time.time() + states = tree_search.run_batch( + arguments, temperature=0, backend=backend, num_threads=args.parallel) + latency = time.time() - tic + answers_text = [] + for s in states: + answers_text.append([x for xs in s.ret_value for x in xs]) + + preds = [] + for i in range(len(states)): + answers = [get_answer_value(v) for v in answers_text[i]] + preds.append(most_frequent_number(answers)) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + print(f"Latency: {latency:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Accuracy: {acc:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", answers_text) + + with open(args.result_file, "a") as fout: + value = { + "task": "tree_of_thought_gsm8k", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + } + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=200) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/docs/flashinfer.md b/docs/flashinfer.md new file mode 100644 index 000000000..38567f74e --- /dev/null +++ b/docs/flashinfer.md @@ -0,0 +1,20 @@ +## Flashinfer Mode + +[`flashinfer`](https://github.com/flashinfer-ai/flashinfer) is a kernel library for LLM serving; we use it here to support our attention computation. + +### Install flashinfer + +```bash +git submodule update --init --recursive +pip install 3rdparty/flashinfer/python +``` + +### Run Sever With Flashinfer Mode + +Add through `--model_mode` argument from the command line. + +Example: + +```bash +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --model-mode flashinfer +``` \ No newline at end of file diff --git a/docs/test_process.md b/docs/test_process.md new file mode 100644 index 000000000..fff46fd14 --- /dev/null +++ b/docs/test_process.md @@ -0,0 +1,63 @@ +## SRT Unit Tests + +### Low-level API +``` +cd sglang/test/srt/model + +python3 test_llama_low_api.py +python3 test_llama_extend.py +python3 test_llava_low_api.py +python3 bench_llama_low_api.py +``` + +### High-level API + +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +cd test/lang +python3 test_srt_backend.py +``` + +### Performance + +#### MMLU +``` +cd benchmark/mmlu +``` +Follow README.md to download the data. + +``` +python3 bench_sglang.py --nsub 3 + +# Expected performance on A10G +# Total latency: 8.200 +# Average accuracy: 0.413 +``` + +### More Models + +#### LLaVA + +``` +python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 +``` + +``` +cd benchmark/llava_bench +python3 bench_sglang.py +``` + +## SGLang Unit Tests +``` +export ANTHROPIC_API_KEY= +export OPENAI_API_KEY= +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +``` +cd test/lang +python3 run_all.py +``` diff --git a/examples/quick_start/anthropic_example_chat.py b/examples/quick_start/anthropic_example_chat.py new file mode 100644 index 000000000..75e90ca5e --- /dev/null +++ b/examples/quick_start/anthropic_example_chat.py @@ -0,0 +1,19 @@ +from sglang import function, system, user, assistant, gen, set_default_backend, Anthropic + + +@function +def multi_turn_question(s, question_1, question_2): + s += user(question_1) + s += assistant(gen("answer_1", max_tokens=256)) + s += user(question_2) + s += assistant(gen("answer_2", max_tokens=256)) + +set_default_backend(Anthropic("claude-2")) + +state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", +) + +for m in state.messages(): + print(m["role"], ":", m["content"]) diff --git a/examples/quick_start/anthropic_example_complete.py b/examples/quick_start/anthropic_example_complete.py new file mode 100644 index 000000000..8648f2ff1 --- /dev/null +++ b/examples/quick_start/anthropic_example_complete.py @@ -0,0 +1,26 @@ +from sglang import function, gen, set_default_backend, Anthropic + + +@function +def few_shot_qa(s, question): + s += ( +""" +\n\nHuman: What is the capital of France? +\n\nAssistant: Paris +\n\nHuman: What is the capital of Germany? +\n\nAssistant: Berlin +\n\nHuman: What is the capital of Italy? +\n\nAssistant: Rome +""") + s += "\n\nHuman: " + question + "\n" + s += "\n\nAssistant:" + gen("answer", stop="\n", temperature=0) + + +set_default_backend(Anthropic("claude-2")) + +state = few_shot_qa.run(question="What is the capital of the United States?") +answer = state["answer"].strip().lower() + +assert "washington" in answer, f"answer: {state['answer']}" + +print(state.text()) diff --git a/examples/quick_start/anthropic_example_stream.py b/examples/quick_start/anthropic_example_stream.py new file mode 100644 index 000000000..e265e16c7 --- /dev/null +++ b/examples/quick_start/anthropic_example_stream.py @@ -0,0 +1,20 @@ +from sglang import function, system, user, assistant, gen, set_default_backend, Anthropic + + +@function +def multi_turn_question(s, question_1, question_2): + s += user(question_1) + s += assistant(gen("answer_1", max_tokens=256)) + s += user(question_2) + s += assistant(gen("answer_2", max_tokens=256)) + +set_default_backend(Anthropic("claude-2")) + +state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + stream=True +) + +for out in state.text_iter(): + print(out, end="", flush=True) diff --git a/examples/quick_start/more_stream_methods.py b/examples/quick_start/more_stream_methods.py new file mode 100644 index 000000000..15a41483a --- /dev/null +++ b/examples/quick_start/more_stream_methods.py @@ -0,0 +1,44 @@ +import asyncio +import sglang as sgl + + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.system("You are a helpful assistant.") + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) + + +sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo")) +#sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) + + +def stream_a_variable(): + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + stream=True + ) + + for out in state.text_iter(var_name="answer_2"): + print(out, end="", flush=True) + print() + + +async def async_stream(): + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + stream=True + ) + + async for out in state.text_async_iter(var_name="answer_2"): + print(out, end="", flush=True) + print() + + +if __name__ == "__main__": + #stream_a_variable() + asyncio.run(async_stream()) diff --git a/examples/quick_start/openai_example_chat.py b/examples/quick_start/openai_example_chat.py new file mode 100644 index 000000000..bdd5b171c --- /dev/null +++ b/examples/quick_start/openai_example_chat.py @@ -0,0 +1,20 @@ +from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI + + +@function +def multi_turn_question(s, question_1, question_2): + s += system("You are a helpful assistant.") + s += user(question_1) + s += assistant(gen("answer_1", max_tokens=256)) + s += user(question_2) + s += assistant(gen("answer_2", max_tokens=256)) + +set_default_backend(OpenAI("gpt-3.5-turbo")) + +state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", +) + +for m in state.messages(): + print(m["role"], ":", m["content"]) diff --git a/examples/quick_start/openai_example_complete.py b/examples/quick_start/openai_example_complete.py new file mode 100644 index 000000000..fd74fba69 --- /dev/null +++ b/examples/quick_start/openai_example_complete.py @@ -0,0 +1,26 @@ +from sglang import function, gen, set_default_backend, OpenAI + + +@function +def few_shot_qa(s, question): + s += ( +"""The following are questions with answers. +Q: What is the capital of France? +A: Paris +Q: What is the capital of Germany? +A: Berlin +Q: What is the capital of Italy? +A: Rome +""") + s += "Q: " + question + "\n" + s += "A:" + gen("answer", stop="\n", temperature=0) + + +set_default_backend(OpenAI("gpt-3.5-turbo-instruct")) + +state = few_shot_qa.run(question="What is the capital of the United States?") +answer = state["answer"].strip().lower() + +assert "washington" in answer, f"answer: {state['answer']}" + +print(state.text()) diff --git a/examples/quick_start/openai_example_stream.py b/examples/quick_start/openai_example_stream.py new file mode 100644 index 000000000..0ed010701 --- /dev/null +++ b/examples/quick_start/openai_example_stream.py @@ -0,0 +1,21 @@ +from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI + + +@function +def multi_turn_question(s, question_1, question_2): + s += system("You are a helpful assistant.") + s += user(question_1) + s += assistant(gen("answer_1", max_tokens=256)) + s += user(question_2) + s += assistant(gen("answer_2", max_tokens=256)) + +set_default_backend(OpenAI("gpt-3.5-turbo")) + +state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + stream=True +) + +for out in state.text_iter(): + print(out, end="", flush=True) diff --git a/examples/quick_start/srt_example_chat.py b/examples/quick_start/srt_example_chat.py new file mode 100644 index 000000000..a5130dca3 --- /dev/null +++ b/examples/quick_start/srt_example_chat.py @@ -0,0 +1,26 @@ +from sglang import function, system, user, assistant, gen, set_default_backend, Runtime + + +@function +def multi_turn_question(s, question_1, question_2): + s += system("You are a helpful assistant.") + s += user(question_1) + s += assistant(gen("answer_1", max_tokens=256)) + s += user(question_2) + s += assistant(gen("answer_2", max_tokens=256)) + + +runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf") +#runtime = Runtime(model_path="mistralai/Mixtral-8x7B-Instruct-v0.1") +set_default_backend(runtime) + +state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", +) + +for m in state.messages(): + print(m["role"], ":", m["content"]) + + +runtime.shutdown() diff --git a/examples/quick_start/srt_example_complete.py b/examples/quick_start/srt_example_complete.py new file mode 100644 index 000000000..61e2facbf --- /dev/null +++ b/examples/quick_start/srt_example_complete.py @@ -0,0 +1,28 @@ +from sglang import function, gen, set_default_backend, Runtime + + +@function +def few_shot_qa(s, question): + s += ( +"""The following are questions with answers. +Q: What is the capital of France? +A: Paris +Q: What is the capital of Germany? +A: Berlin +Q: What is the capital of Italy? +A: Rome +""") + s += "Q: " + question + "\n" + s += "A:" + gen("answer", stop="\n", temperature=0) + + +runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf") +set_default_backend(runtime) + +state = few_shot_qa.run(question="What is the capital of the United States?") + +answer = state["answer"].strip().lower() +assert "washington" in answer, f"answer: {state['answer']}" +print(state.text()) + +runtime.shutdown() diff --git a/examples/quick_start/srt_example_regex.py b/examples/quick_start/srt_example_regex.py new file mode 100644 index 000000000..8f85aec5e --- /dev/null +++ b/examples/quick_start/srt_example_regex.py @@ -0,0 +1,21 @@ +from sglang import function, gen, set_default_backend, Runtime + + +@function +def regex_gen(s): + s += "Q: What is the IP address of the Google DNS servers?\n" + s += "A: " + gen( + "answer", + temperature=0, + regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)", + ) + + +runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf") +set_default_backend(runtime) + +state = regex_gen.run() + +print(state.text()) + +runtime.shutdown() diff --git a/examples/quick_start/srt_example_stream.py b/examples/quick_start/srt_example_stream.py new file mode 100644 index 000000000..8f03bd146 --- /dev/null +++ b/examples/quick_start/srt_example_stream.py @@ -0,0 +1,26 @@ +from sglang import function, system, user, assistant, gen, set_default_backend, Runtime + + +@function +def multi_turn_question(s, question_1, question_2): + s += system("You are a helpful assistant.") + s += user(question_1) + s += assistant(gen("answer_1", max_tokens=256)) + s += user(question_2) + s += assistant(gen("answer_2", max_tokens=256)) + +runtime = Runtime("meta-llama/Llama-2-7b-chat-hf") +set_default_backend(runtime) + +state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + temperature=0, + stream=True, +) + +for out in state.text_iter(): + print(out, end="", flush=True) +print() + +runtime.shutdown() diff --git a/format.sh b/format.sh new file mode 100644 index 000000000..104db69bf --- /dev/null +++ b/format.sh @@ -0,0 +1,5 @@ +isort python +black python + +isort test +black test diff --git a/playground/launch_tgi.sh b/playground/launch_tgi.sh new file mode 100644 index 000000000..a32405cdd --- /dev/null +++ b/playground/launch_tgi.sh @@ -0,0 +1,7 @@ +# Assuming the model is downdloaded at /home/ubuntu/model_weights/Llama-2-7b-chat-hf +docker run --name tgi --rm -ti --gpus all --network host \ + -v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \ + ghcr.io/huggingface/text-generation-inference:1.1.0 \ + --model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \ + --max-input-length 2048 --max-total-tokens 4096 \ + --port 24000 diff --git a/playground/load_tokenizer.py b/playground/load_tokenizer.py new file mode 100644 index 000000000..33a6a700d --- /dev/null +++ b/playground/load_tokenizer.py @@ -0,0 +1,7 @@ +import transformers +import code + +name = "meta-llama/Llama-2-7b-chat-hf" + +t = transformers.AutoTokenizer.from_pretrained(name) +code.interact(local=locals()) diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 000000000..06ec94cab --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,31 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "sglang" +version = "0.1.0" +description = "A structured generation langauge for LLMs." +readme = "README.md" +requires-python = ">=3.8" +license = {file = "LICENSE"} +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", +] +dependencies = [ + "requests", +] + +[project.optional-dependencies] +srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5", + "interegular", "lark"] +openai = ["openai>=1.0"] +anthropic = ["anthropic"] +all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] + +[tool.setuptools.packages.find] +exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] + +[tool.wheel] +exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py new file mode 100644 index 000000000..4fe98b772 --- /dev/null +++ b/python/sglang/__init__.py @@ -0,0 +1,2 @@ +from sglang.api import * +from sglang.global_config import global_config diff --git a/python/sglang/api.py b/python/sglang/api.py new file mode 100644 index 000000000..2e35d8888 --- /dev/null +++ b/python/sglang/api.py @@ -0,0 +1,161 @@ +"""Public API""" +import re +from typing import Callable, List, Optional, Union + +from sglang.backend.anthropic import Anthropic +from sglang.backend.base_backend import BaseBackend +from sglang.backend.openai import OpenAI +from sglang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.global_config import global_config +from sglang.lang.ir import ( + SglExpr, + SglExprList, + SglFunction, + SglGen, + SglImage, + SglRoleBegin, + SglRoleEnd, + SglSelect, +) +from sglang.srt.server import Runtime + + +def function(func: Callable): + return SglFunction(func) + + +def set_default_backend(backend: BaseBackend): + global_config.default_backend = backend + + +def gen( + name: Optional[str] = None, + max_tokens: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + dtype: Optional[type] = None, + choices: Optional[List[str]] = None, + regex: Optional[str] = None, +): + if choices: + return SglSelect(name, choices, temperature) + + # check regex is valid + if regex is not None: + try: + re.compile(regex) + except re.error as e: + raise e + + return SglGen( + name, + max_tokens, + stop, + temperature, + top_p, + top_k, + frequency_penalty, + presence_penalty, + dtype, + regex, + ) + + +def gen_int( + name: Optional[str] = None, + max_tokens: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, +): + return SglGen( + name, + max_tokens, + stop, + temperature, + top_p, + top_k, + frequency_penalty, + presence_penalty, + int, + None, + ) + + +def gen_string( + name: Optional[str] = None, + max_tokens: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, +): + return SglGen( + name, + max_tokens, + stop, + temperature, + top_p, + top_k, + frequency_penalty, + presence_penalty, + str, + None, + ) + + +def image(expr: SglExpr): + return SglImage(expr) + + +def select( + name: Optional[str] = None, + choices: List[str] = None, + temperature: float = 0.0, +): + assert choices is not None + return SglSelect(name, choices, temperature) + + +def _role_common(name: str, expr: Optional[SglExpr] = None): + if expr is None: + return SglExprList([SglRoleBegin(name), SglRoleEnd(name)]) + else: + return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)]) + + +def system(expr: Optional[SglExpr] = None): + return _role_common("system", expr) + + +def user(expr: Optional[SglExpr] = None): + return _role_common("user", expr) + + +def assistant(expr: Optional[SglExpr] = None): + return _role_common("assistant", expr) + + +def user_begin(): + return SglRoleBegin("user") + + +def user_end(): + return SglRoleEnd("user") + + +def assistant_begin(): + return SglRoleBegin("assistant") + + +def assistant_end(): + return SglRoleEnd("assistant") diff --git a/python/sglang/backend/__init__.py b/python/sglang/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/sglang/backend/anthropic.py b/python/sglang/backend/anthropic.py new file mode 100644 index 000000000..77d7f5127 --- /dev/null +++ b/python/sglang/backend/anthropic.py @@ -0,0 +1,57 @@ +from typing import List, Optional, Union + +import numpy as np +from sglang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SamplingParams + +try: + import anthropic +except ImportError as e: + anthropic = e + + +class Anthropic(BaseBackend): + def __init__(self, model_name): + super().__init__() + + if isinstance(anthropic, Exception): + raise anthropic + + self.model_name = model_name + self.chat_template = get_chat_template("claude") + + def get_chat_template(self): + return self.chat_template + + def generate( + self, + s: StreamExecutor, + sampling_params: SamplingParams, + ): + prompt = s.text_ + ret = anthropic.Anthropic().completions.create( + model=self.model_name, + prompt=prompt, + **sampling_params.to_anthropic_kwargs(), + ) + comp = ret.completion + + return comp, {} + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SamplingParams, + ): + prompt = s.text_ + generator = anthropic.Anthropic().completions.create( + model=self.model_name, + prompt=prompt, + stream=True, + **sampling_params.to_anthropic_kwargs(), + ) + + for ret in generator: + yield ret.completion, {} diff --git a/python/sglang/backend/base_backend.py b/python/sglang/backend/base_backend.py new file mode 100644 index 000000000..7f59f5b15 --- /dev/null +++ b/python/sglang/backend/base_backend.py @@ -0,0 +1,74 @@ +from typing import Callable, List, Optional, Union + +from sglang.lang.chat_template import get_chat_template +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SamplingParams + + +class BaseBackend: + def __init__(self) -> None: + self.support_concate_and_append = False + self.chat_template = get_chat_template("default") + + def get_model_name(self): + raise NotImplementedError() + + def get_chat_template(self): + return self.chat_template + + def cache_prefix(self, prefix_str: str): + pass + + def uncache_prefix(self, rid: str): + pass + + def end_request(self, rid: Union[str, List[str]]): + pass + + def begin_program(self, s: StreamExecutor): + pass + + def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]): + pass + + def commit_lazy_operations(self, s: StreamExecutor): + pass + + def fork_program( + self, + src: StreamExecutor, + dst: List[StreamExecutor], + position_ids_offset: Optional[List[int]] = None, + ): + pass + + def fill_image(self, s: StreamExecutor): + pass + + def generate( + self, + s: StreamExecutor, + sampling_params: SamplingParams, + ): + raise NotImplementedError() + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SamplingParams, + ): + raise NotImplementedError() + + def select( + self, + s: StreamExecutor, + choices: List[str], + temperature: float, + ): + raise NotImplementedError() + + def concatenate_and_append(self, src_rids: List[str], dst_rid: str): + raise NotImplementedError() + + def shutdown(self): + pass diff --git a/python/sglang/backend/huggingface.py b/python/sglang/backend/huggingface.py new file mode 100644 index 000000000..acd6e251f --- /dev/null +++ b/python/sglang/backend/huggingface.py @@ -0,0 +1,349 @@ +import functools +from enum import Enum, auto +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +import transformers +from sglang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template_by_model_path +from sglang.lang.interpreter import ProgramState +from sglang.utils import get_available_gpu_memory +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + StoppingCriteria, + StoppingCriteriaList, +) +from transformersgl.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + + +class StopReason(Enum): + EOS_TOKEN = auto() + STOP_STR = auto() + LENGTH = auto() + + +def load_model( + model_name: str, + device, + num_gpus, + max_gpu_memory, + model_kwargs=None, + tokenizer_kwargs=None, +): + model_kwargs = model_kwargs or {} + tokenizer_kwargs = tokenizer_kwargs or {} + + if device == "cuda": + model_kwargs["torch_dtype"] = torch.float16 + if num_gpus != 1: + model_kwargs["device_map"] = "auto" + if max_gpu_memory is None: + model_kwargs[ + "device_map" + ] = "sequential" # This is important for not the same VRAM sizes + available_gpu_memory = [ + get_available_gpu_memory(i, False) for i in range(num_gpus) + ] + model_kwargs["max_memory"] = { + i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" + for i in range(num_gpus) + } + else: + model_kwargs["max_memory"] = { + i: max_gpu_memory for i in range(num_gpus) + } + elif device == "cpu": + model_kwargs["torch_dtype"] = torch.float32 + else: + raise ValueError(f"Invalid device: {device}") + + model = AutoModelForCausalLM.from_pretrained( + model_name, low_cpu_mem_usage=True, **model_kwargs + ) + tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) + + if num_gpus == 1: + model.to(device).eval() + + return model, tokenizer + + +def prepare_logits_processor( + temperature: float, repetition_penalty: float, top_p: float, top_k: int +) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +@functools.lru_cache +def get_token_healing_mask(tokenizer, prompt_last_token): + last_str = tokenizer.convert_ids_to_tokens(prompt_last_token) + disallowed = torch.zeros(len(tokenizer), dtype=bool) + for s, t_id in tokenizer.get_vocab().items(): + if not s.startswith(last_str): + disallowed[t_id] = 1 + return disallowed + + +@functools.lru_cache +def get_int_token_mask(tokenizer): + disallowed = torch.zeros(len(tokenizer), dtype=bool) + for s, t_id in tokenizer.get_vocab().items(): + s = s.replace("▁", "").strip() + if not (s.isdigit() or len(s) == 0 or s == ","): + disallowed[t_id] = 1 + disallowed[tokenizer.eos_token_id] = 0 + return disallowed + + +@torch.inference_mode() +def generate_stream( + model, + tokenizer, + prompt, + max_new_tokens, + stop: List[str], + temperature, + top_p, + token_healing, + logit_mask=None, +): + logits_processor = prepare_logits_processor( + temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0 + ) + device = model.device + input_ids = tokenizer.encode(prompt) + output_ids = list(input_ids) + prompt_len = len(prompt) + + # Resolve stop + stop_token_ids = [tokenizer.eos_token_id] + + # Token healing + token_healing = token_healing and len(input_ids) > 0 + if token_healing: + token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1]) + del output_ids[-1] + + # Generate + past_key_values = None + stop_reason = None + for i in range(max_new_tokens): + # Forward + if i == 0: # prefill + out = model(torch.as_tensor([output_ids], device=device), use_cache=True) + else: # decoding + out = model( + input_ids=torch.as_tensor([[token]], device=device), + use_cache=True, + past_key_values=past_key_values, + ) + logits = out.logits + past_key_values = out.past_key_values + + # Logit mask + if token_healing and i == 0: + logits[0, -1, token_healing_mask] = -1e4 + if logit_mask is not None: + logits[0, -1, logit_mask] = -1e4 + + # Sample next token + last_token_logits = logits_processor(None, logits[:, -1, :])[0] + if temperature < 1e-5 or top_p < 1e-8: # greedy + token = int(torch.argmax(last_token_logits)) + else: + probs = torch.softmax(last_token_logits, dim=-1) + token = int(torch.multinomial(probs, num_samples=1)) + output_ids.append(token) + + # Stop condition + if token in stop_token_ids: + stop_reason = StopReason.EOS_TOKEN + break + + output_str = tokenizer.decode(output_ids, skip_special_tokens=True) + for stop_str in stop: + pos = output_str[prompt_len:].find(stop_str) + if pos != -1: + stop_reason = StopReason.STOP_STR + output_str = output_str[: prompt_len + pos] + break + + if stop_reason: + break + + return output_str[prompt_len:] + + +class HuggingFaceTransformers(BaseBackend): + def __init__( + self, + model_name, + device="cuda", + num_gpus=1, + max_gpu_memory=None, + model_kwargs=None, + tokenizer_kwargs=None, + ): + self.model_name = model_name + self.device = device + + self.model, self.tokenizer = load_model( + model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs + ) + + self.chat_template = get_chat_template_by_model_path(model_name) + + def get_chat_template(self): + return self.chat_template + + def cache_prefix(self, prefix_str: str): + pass + + def uncache_prefix(self, rid: str): + pass + + def end_request(self, rid: str): + pass + + def begin_program(self, s: ProgramState): + pass + + def end_program(self, s: ProgramState): + pass + + def fill(self, s: ProgramState, text: str): + return False + + def generate_internal( + self, + prompt: str, + max_tokens: int, + stop: Union[str, List[str]], + temperature: float, + top_p: float, + dtype: Optional[str] = None, + ): + if dtype is None: + comp = generate_stream( + self.model, + self.tokenizer, + prompt, + max_new_tokens=max_tokens, + stop=stop, + temperature=temperature, + top_p=top_p, + token_healing=True, + ) + elif dtype in [str, "str", "string"]: + comp = generate_stream( + self.model, + self.tokenizer, + prompt + '"', + max_new_tokens=max_tokens, + stop=['"'], + temperature=temperature, + top_p=top_p, + token_healing=False, + ) + comp = '"' + comp + '"' + elif dtype in [int, "int"]: + logit_mask = get_int_token_mask(self.tokenizer) + comp = generate_stream( + self.model, + self.tokenizer, + prompt, + max_new_tokens=max_tokens, + stop=stop + [" ", ","], + temperature=temperature, + top_p=top_p, + token_healing=False, + logit_mask=logit_mask, + ) + return comp + + def generate( + self, + s: ProgramState, + max_tokens: int, + stop: Union[str, List[str]], + temperature: float, + top_p: float, + dtype: Optional[str] = None, + ): + prompt = s.text + comp = self.generate_internal( + prompt, max_tokens, stop, temperature, top_p, dtype + ) + return comp + + def parallel_generate( + self, + s: ProgramState, + prefixes: List[str], + join_func: Callable, + max_tokens: int, + stop: Union[str, List[str]], + temperature: float, + top_p: float, + dtype: Optional[str] = None, + ): + prompt = s.text + parallel_prompts = [prompt + prefix for prefix in prefixes] + + comps = [] + for i in range(len(parallel_prompts)): + comps.append( + self.generate_internal( + parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype + ) + ) + + joined = join_func([p + c for p, c in zip(prefixes, comps)]) + return joined, comps + + @torch.inference_mode() + def select( + self, s: ProgramState, choices: List[str], temperature: float, top_p: float + ): + loss_fct = torch.nn.CrossEntropyLoss() + prompt = s.text + + prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1] + prompt_choices = [prompt + choice for choice in choices] + + scores = [] + for i in range(len(choices)): + choice_ids = self.tokenizer.encode( + prompt_choices[i], return_tensors="pt" + ).to(self.model.device) + logits = self.model(choice_ids).logits + + # score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item() + + logprobs = torch.log(torch.softmax(logits, dim=-1)) + idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device) + idx2 = choice_ids[0, 1:] + selected_logprobs = logprobs[0, idx1, idx2] + score = selected_logprobs.mean().item() + scores.append(score) + + decision = choices[np.argmax(scores)] + return decision, scores diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py new file mode 100644 index 000000000..b22e149c9 --- /dev/null +++ b/python/sglang/backend/openai.py @@ -0,0 +1,241 @@ +from typing import Callable, List, Optional, Union + +import numpy as np +from sglang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SamplingParams + +try: + import openai + import tiktoken +except ImportError as e: + openai = tiktoken = e + + +def create_logit_bias_int(tokenizer): + """Get logit bias for integer numbers.""" + int_token_ids = [] + + tokens = tokenizer._mergeable_ranks + for token, token_id in tokens.items(): + s = tokenizer.decode([token_id]) + if all([c.isdigit() for c in s]) or s in [" "]: + int_token_ids.append(token_id) + if len(int_token_ids) >= 300: # OpenAI API limit + break + special_tokens = tokenizer._special_tokens + mask = {t: 100 for t in int_token_ids[:299]} + mask[special_tokens["<|endoftext|>"]] = 100 + return mask + + +CHAT_MODEL_NAMES = [ + # GPT-4 + "gpt-4", + "gpt-4-32k", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4-0613", + "gpt-4-0314", + # GPT-3.5 + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-16k-0613", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-0301", +] + + +class OpenAI(BaseBackend): + def __init__(self, model_name, *args, **kwargs): + super().__init__() + self.client = openai.OpenAI(*args, **kwargs) + + if isinstance(openai, Exception): + raise e + + self.model_name = model_name + self.tokenizer = tiktoken.encoding_for_model(model_name) + self.logit_bias_int = create_logit_bias_int(self.tokenizer) + + if model_name in CHAT_MODEL_NAMES: + self.is_chat_model = True + else: + self.is_chat_model = False + + self.chat_template = get_chat_template("default") + + def get_chat_template(self): + return self.chat_template + + def generate( + self, + s: StreamExecutor, + sampling_params: SamplingParams, + ): + if sampling_params.dtype is None: + if self.is_chat_model: + assert s.text_.endswith("ASSISTANT:") + prompt = s.messages_ + else: + prompt = s.text_ + + kwargs = sampling_params.to_openai_kwargs() + comp = openai_completion( + client=self.client, + is_chat=self.is_chat_model, + model=self.model_name, + prompt=prompt, + **kwargs, + ) + elif sampling_params.dtype in [str, "str", "string"]: + kwargs = sampling_params.to_openai_kwargs() + kwargs.pop("stop") + comp = openai_completion( + client=self.client, + is_chat=self.is_chat_model, + model=self.model_name, + prompt=s.text_ + '"', + stop='"', + **kwargs, + ) + comp = '"' + comp + '"' + elif sampling_params.dtype in [int, "int"]: + kwargs = sampling_params.to_openai_kwargs() + kwargs.pop("stop") + comp = openai_completion( + client=self.client, + is_chat=self.is_chat_model, + model=self.model_name, + prompt=s.text_, + logit_bias=self.logit_bias_int, + stop=[" "], + **kwargs, + ) + else: + raise ValueError(f"Unknown dtype: {dtype}") + + return comp, {} + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SamplingParams, + ): + if sampling_params.dtype is None: + if self.is_chat_model: + assert s.text_.endswith("ASSISTANT:") + prompt = s.messages_ + else: + prompt = s.text_ + + kwargs = sampling_params.to_openai_kwargs() + generator = openai_completion_stream( + client=self.client, + is_chat=self.is_chat_model, + model=self.model_name, + prompt=prompt, + **kwargs, + ) + return generator + else: + raise ValueError(f"Unknown dtype: {dtype}") + + def select( + self, + s: StreamExecutor, + choices: List[str], + temperature: float, + ): + n_choices = len(choices) + token_ids = [self.tokenizer.encode(x) for x in choices] + scores = [0] * n_choices + valid = [len(x) > 0 for x in token_ids] + prompt_tokens = self.tokenizer.encode(s.text_) + + max_len = max([len(x) for x in token_ids]) + for step in range(max_len): + # Build logit bias + logit_bias = {} + for i in range(n_choices): + if valid[i]: + logit_bias[token_ids[i][step]] = 100 + + # Call API + ret = self.client.completions.create( + model=self.model_name, + prompt=prompt_tokens, + logit_bias=logit_bias, + max_tokens=1, + temperature=temperature, + ) + ret_str = ret.choices[0].text + ret_token = self.tokenizer.encode(ret_str)[0] + + # TODO: + # 1. return logits as the scores + # 2. compute logits of the full choice + # 3. consider chunk-based decoding + + # Update valid + hit = False + for i in range(n_choices): + if valid[i]: + if step == len(token_ids[i]) - 1: + valid[i] = False + + if ret_token == token_ids[i][step]: + scores[i] += 1 + hit = True + else: + valid[i] = False + assert hit + + if np.sum(valid) <= 1: + break + + prompt_tokens.append(ret_token) + + decision = choices[np.argmax(scores)] + return decision, scores + + +def openai_completion(client, is_chat=None, prompt=None, **kwargs): + try: + if is_chat: + if kwargs["stop"] is None: + kwargs.pop("stop") + ret = client.chat.completions.create(messages=prompt, **kwargs) + comp = ret.choices[0].message.content + else: + ret = client.completions.create(prompt=prompt, **kwargs) + if isinstance(prompt, (list, tuple)): + comp = [c.text for c in ret.choices] + else: + comp = ret.choices[0].text + except openai.OpenAIError as e: + print(f"OpenAI Error: {e}") + raise e + + return comp + + +def openai_completion_stream(client, is_chat=None, prompt=None, **kwargs): + try: + if is_chat: + generator = client.chat.completions.create( + messages=prompt, stream=True, **kwargs + ) + for ret in generator: + content = ret.choices[0].delta.content + yield content or "", {} + else: + generator = client.completions.create(prompt=prompt, stream=True, **kwargs) + for ret in generator: + content = ret.choices[0].text + yield content or "", {} + except openai.OpenAIError as e: + print(f"OpenAI Error: {e}") + raise e diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py new file mode 100644 index 000000000..9d15be72a --- /dev/null +++ b/python/sglang/backend/runtime_endpoint.py @@ -0,0 +1,171 @@ +import json +from typing import Callable, List, Optional, Union + +import numpy as np +import requests +from sglang.backend.base_backend import BaseBackend +from sglang.global_config import global_config +from sglang.lang.chat_template import get_chat_template_by_model_path +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SamplingParams, SglArgument +from sglang.utils import encode_image_base64, find_printable_text, http_request + + +class RuntimeEndpoint(BaseBackend): + def __init__(self, base_url): + super().__init__() + self.support_concate_and_append = True + + self.base_url = base_url + + res = http_request(self.base_url + "/get_model_info") + assert res.status_code == 200 + self.model_info = res.json() + + self.chat_template = get_chat_template_by_model_path( + self.model_info["model_path"] + ) + + def get_model_name(self): + return self.model_info["model_path"] + + def get_chat_template(self): + return self.chat_template + + def cache_prefix(self, prefix_str: str): + res = http_request( + self.base_url + "/generate", + json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, + ) + assert res.status_code == 200 + + def commit_lazy_operations(self, s: StreamExecutor): + res = http_request( + self.base_url + "/generate", + json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}}, + ) + assert res.status_code == 200 + + def fill_image(self, s: StreamExecutor): + data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} + self._add_images(s, data) + res = http_request(self.base_url + "/generate", json=data) + assert res.status_code == 200 + + def generate( + self, + s: StreamExecutor, + sampling_params: SamplingParams, + ): + if sampling_params.dtype is None: + data = { + "text": s.text_, + "sampling_params": { + "skip_special_tokens": global_config.skip_special_tokens_in_output, + **sampling_params.to_srt_kwargs(), + }, + } + elif sampling_params.dtype in [int, "int"]: + data = { + "text": s.text_, + "sampling_params": { + "skip_special_tokens": global_config.skip_special_tokens_in_output, + "dtype": "int", + **sampling_params.to_srt_kwargs(), + }, + } + else: + raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") + + self._add_images(s, data) + + res = http_request(self.base_url + "/generate", json=data) + obj = res.json() + comp = obj["text"] + return comp, obj["meta_info"] + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SamplingParams, + ): + if sampling_params.dtype is None: + data = { + "text": s.text_, + "sampling_params": { + "skip_special_tokens": global_config.skip_special_tokens_in_output, + **sampling_params.to_srt_kwargs(), + }, + } + elif sampling_params.dtype in [int, "int"]: + data = { + "text": s.text_, + "sampling_params": { + "skip_special_tokens": global_config.skip_special_tokens_in_output, + "dtype": "int", + **sampling_params.to_srt_kwargs(), + }, + } + else: + raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") + + data["stream"] = True + self._add_images(s, data) + + response = http_request(self.base_url + "/generate", json=data, stream=True) + pos = 0 + + incomplete_text = "" + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + text = find_printable_text(data["text"][pos:]) + meta_info = data["meta_info"] + pos += len(text) + incomplete_text = data["text"][pos:] + yield text, meta_info + + if len(incomplete_text) > 0: + yield incomplete_text, meta_info + + def select( + self, + s: StreamExecutor, + choices: List[str], + temperature: float, + ): + assert temperature <= 1e-5 + + # Cache common prefix + data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} + self._add_images(s, data) + res = http_request(self.base_url + "/generate", json=data) + assert res.status_code == 200 + prompt_len = res.json()["meta_info"]["prompt_tokens"] + + # Compute logprob + data = { + "text": [s.text_ + c for c in choices], + "sampling_params": {"max_new_tokens": 0}, + "return_normalized_logprob": True, + "normalized_logprob_start_len": prompt_len, + } + self._add_images(s, data) + res = http_request(self.base_url + "/generate", json=data) + assert res.status_code == 200 + logps = [r["meta_info"]["normalized_logprob"] for r in res.json()] + + decision = choices[np.argmax(logps)] + return decision, logps + + def concatenate_and_append(self, src_rids: List[str], dst_rid: str): + res = http_request( + self.base_url + "/concate_and_append_request", + json={"src_rids": src_rids, "dst_rid": dst_rid}, + ) + assert res.status_code == 200 + + def _add_images(self, s: StreamExecutor, data): + if s.images_: + assert len(s.images_) == 1, "Only support one image." + data["image_data"] = s.images_[0][1] diff --git a/python/sglang/backend/tgi.py b/python/sglang/backend/tgi.py new file mode 100644 index 000000000..e5462218d --- /dev/null +++ b/python/sglang/backend/tgi.py @@ -0,0 +1,190 @@ +import re +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from itertools import repeat +from typing import List, Optional, Union + +from sglang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template_by_model_path +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SamplingParams +from sglang.utils import http_request + + +class TGI(BaseBackend): + def __init__(self, base_url): + super().__init__() + + self.base_url = base_url + + res = http_request(self.base_url + "/info") + assert res.status_code == 200 + self.model_info = res.json() + self.chat_template = get_chat_template_by_model_path( + self.model_info["model_id"] + ) + + def get_model_name(self): + return self.model_info["model_id"] + + def get_chat_template(self): + return self.chat_template + + @staticmethod + def adapt_params(max_tokens, stop, sampling_params, **override_params): + temperature = sampling_params.temperature + do_sample = True + if temperature == 0: + do_sample = False + temperature = None + + if stop is None: + stop = [] + elif isinstance(stop, str): + stop = [stop] + + top_p = sampling_params.top_p + if top_p == 0: + top_p = 0.001 + if top_p == 1: + top_p = 0.999 + + top_k = sampling_params.top_k + if top_k == -1: + top_k = None + + params = { + "decoder_input_details": False, + "details": False, + "do_sample": do_sample, + "max_new_tokens": max_tokens, + "stop": stop, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "return_full_text": False, + } + params.update(override_params) + return params + + @staticmethod + def _extract_int(text): + words = re.split("\ |'|\/|\(|\)|\n|\.|,", text) + for word in words: + try: + int(word) + return word + except ValueError: + continue + raise ValueError + + @staticmethod + def _extract_choice(choices, text): + # FIXME: Current only support the case where the choices are single words. + words = re.split("\ |'|\/|\(|\)|\n|\.|,", text) + for word in words: + if word in choices: + return word + raise ValueError + + @staticmethod + def _truncate_to_stop(text, stop): + # The stop sequence may not be a single token. In this case TGI will generate + # too many tokens so we need to truncate the output. + if stop: + stop = [stop] if isinstance(stop, str) else stop + for stop_seq in stop: + pos = text.find(stop_seq) + if pos != -1: + return text[:pos] + return text + + def _make_request(self, params): + res = http_request(self.base_url + "/generate", json=params) + if res.status_code != 200: + raise ValueError(f"Error from TGI backend: {res.text}") + return res.json() + + def retry_for_expected(self, prompt, params, extract_fn, retry=5): + # TGI does not support logis_bias (yet), so we have to use an inefficient hack. + failed = [] + while retry > 0: + res_json = self._make_request( + { + "inputs": prompt, + "parameters": params, + } + ) + text = res_json["generated_text"] + try: + return extract_fn(text) + except ValueError: + retry -= 1 + failed.append(text) + + msg = "=" * 20 + "\n" + msg += f"Prompt:\n{prompt}\n" + msg += "=" * 20 + "\n" + for i, text in enumerate(failed): + msg += f"====== Try {i+1}:\n{text}\n" + + raise ValueError( + f"Model {self.model_info['model_id']} served by TGI backend does not generate" + "expected output. Please improve the prompt, increase the temperature, or " + f"use different models.\n{msg}" + ) + + def select( + self, + s: StreamExecutor, + choices: List[str], + sampling_params: SamplingParams, + ): + decision = self.retry_for_expected( + s.text_, + self.adapt_params(16, [], sampling_params), + partial(self._extract_choice, choices), + ) + return decision, [1 if choice == decision else 0 for choice in choices] + + def generate( + self, + s: StreamExecutor, + max_tokens: int, + stop: Union[str, List[str]], + sampling_params: SamplingParams, + dtype: Optional[str] = None, + ): + if dtype is None: + res_json = self._make_request( + { + "inputs": s.text_, + "parameters": self.adapt_params(max_tokens, stop, sampling_params), + } + ) + return self._truncate_to_stop(res_json["generated_text"], stop), {} + + if dtype in [str, "str", "string"]: + stop = ['"'] + res_json = self._make_request( + { + "inputs": f'{s.text_}"', + "parameters": self.adapt_params(max_tokens, stop, sampling_params), + } + ) + return ( + '"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"', + {}, + ) + + if dtype in [int, "int"]: + return ( + self.retry_for_expected( + s.text_, + self.adapt_params(max_tokens, stop, sampling_params), + self._extract_int, + ), + {}, + ) + + raise ValueError(f"Unknown dtype: {dtype}") diff --git a/python/sglang/flush_cache.py b/python/sglang/flush_cache.py new file mode 100644 index 000000000..6050ee22c --- /dev/null +++ b/python/sglang/flush_cache.py @@ -0,0 +1,60 @@ +"""Flush cache in the backend by sending random requests.""" +import argparse +import random +import string +import time + +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) + +import sglang as sgl + + +@sgl.function +def flush_radix_cache(s, prompt): + s += prompt + sgl.gen("flush", max_tokens=1, stop="END") + + +def main(args, max_total_tokens, context_length, print_flag): + backend = select_sglang_backend(args) + flush_length = int(context_length * 0.8) + batch_size = int(max_total_tokens / flush_length) + prompt_length = flush_length * 2 + prompts = [ + " ".join(random.choices(string.ascii_letters, k=int(prompt_length))) + for _ in range(batch_size) + ] + arguments = [{"prompt": prompts[i]} for i in range(batch_size)] + + start_time = time.time() + flush_radix_cache.run_batch( + arguments, temperature=0, backend=backend, num_threads=1 + ) + end_time = time.time() + + if print_flag: + print( + f"Flush length: {flush_length}\n", + f"Prompt length: {prompt_length}\n", + f"Total Prompt letters: {batch_size * prompt_length}\n", + f"Flush radix cache latency: {end_time - start_time:.3f}", + sep="", + ) + + # to prevent the backend still running + time.sleep(1) + + +def run_flush(args, max_total_tokens=20000, context_length=1024, print_flag=False): + main(args, max_total_tokens, context_length, print_flag=print_flag) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-total-tokens", type=int, default=20000) + parser.add_argument("--context-length", type=int, default=1024) + args = add_common_sglang_args_and_parse(parser) + random.seed(0) + main(args, args.max_total_tokens, args.context_length, print_flag=True) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py new file mode 100644 index 000000000..36458b7b1 --- /dev/null +++ b/python/sglang/global_config.py @@ -0,0 +1,28 @@ +"""Global configurations""" + + +class GlobalConfig: + def __init__(self): + # Verbosity level + # 0: do not output anything + # 2: output final text after every run + self.verbosity = 0 + + self.default_backend = None + + # Output configs + self.skip_special_tokens_in_output = True + + # Optimization configs + self.eager_fill_image = False + self.enable_prefix_sharing = True + self.enable_parallel_encoding = True + self.enable_parallel_decoding = True + + # Choices: ["no_adjust", "adjust_cache"] + # no_adjust: Do not adjust the position embedding of KV cache. + # adjust_cache: Adjust the position embedding of KV cache. + self.concate_and_append_mode = "no_adjust" + + +global_config = GlobalConfig() diff --git a/python/sglang/lang/__init__.py b/python/sglang/lang/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py new file mode 100644 index 000000000..579cc845b --- /dev/null +++ b/python/sglang/lang/chat_template.py @@ -0,0 +1,186 @@ +from dataclasses import dataclass +from enum import Enum, auto +from typing import Callable, Dict, List, Tuple + + +class ChatTemplateStyle(Enum): + PLAIN = auto() + LLAMA2 = auto() + + +@dataclass +class ChatTemplate: + name: str + default_system_prompt: str + role_prefix_and_suffix: Dict[str, Tuple[str]] + image_token: str = "" + style: ChatTemplateStyle = ChatTemplateStyle.PLAIN + + def get_prefix_and_suffix(self, role, hist_messages): + if self.style == ChatTemplateStyle.PLAIN: + return self.role_prefix_and_suffix[role] + elif self.style == ChatTemplateStyle.LLAMA2: + if len(hist_messages) == 0 and role == "system": + return ( + self.role_prefix_and_suffix["user"][0] + + self.role_prefix_and_suffix["system"][0], + self.role_prefix_and_suffix["system"][1], + ) + elif ( + len(hist_messages) == 1 + and role == "user" + and hist_messages[0]["content"] is not None + ): + return ("", self.role_prefix_and_suffix["user"][1]) + return self.role_prefix_and_suffix[role] + else: + raise ValueError(f"Invalid style: {self.style}") + + def get_prompt(self, messages): + prompt = "" + for i in range(len(messages)): + role, content = messages[i]["role"], messages[i]["content"] + if role == "system" and content is None: + content = self.default_system_prompt + if content is None: + continue + + prefix, suffix = self.get_prefix_and_suffix(role, messages[:i]) + prompt += prefix + content + suffix + return prompt + + +chat_template_registry: Dict[str, ChatTemplate] = {} +matching_function_registry: List[Callable] = [] + + +def register_chat_template(template): + chat_template_registry[template.name] = template + + +def register_chat_template_matching_function(func): + matching_function_registry.append(func) + + +def get_chat_template(name): + return chat_template_registry[name] + + +def get_chat_template_by_model_path(model_path): + for matching_func in matching_function_registry: + template = matching_func(model_path) + if template is not None: + return template + return get_chat_template("default") + + +register_chat_template( + ChatTemplate( + name="default", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("SYSTEM:", "\n"), + "user": ("USER:", "\n"), + "assistant": ("ASSISTANT:", "\n"), + }, + ) +) + + +register_chat_template( + ChatTemplate( + name="claude", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("", ""), + "user": ("\n\nHuman: ", ""), + "assistant": ("\n\nAssistant:", ""), + }, + ) +) + + +register_chat_template( + ChatTemplate( + name="chatml", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("<|im_start|>system\n", "\n<|im_end|>\n"), + "user": ("<|im_start|>user\n", "\n<|im_end|>\n"), + "assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"), + }, + style=ChatTemplateStyle.PLAIN, + ) +) + + +register_chat_template( + ChatTemplate( + name="vicuna_v1.1", + default_system_prompt=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + role_prefix_and_suffix={ + "system": ("", " "), + "user": ("USER:", " "), + "assistant": ("ASSISTANT:", ""), + }, + image_token=" \n", + ) +) + + +register_chat_template( + ChatTemplate( + name="llama-2-chat", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("<>\n", "\n<>\n\n"), + "user": ("[INST] ", " [/INST]"), + "assistant": ("", " "), + }, + style=ChatTemplateStyle.LLAMA2, + ) +) + + +@register_chat_template_matching_function +def match_vicuna(model_path: str): + if "vicuna" in model_path.lower(): + return get_chat_template("vicuna_v1.1") + if "llava" in model_path.lower(): + return get_chat_template("vicuna_v1.1") + + +@register_chat_template_matching_function +def match_llama2_chat(model_path: str): + model_path = model_path.lower() + if "llama-2" in model_path and "chat" in model_path: + return get_chat_template("llama-2-chat") + if ( + "mistral" in model_path or "mixtral" in model_path + ) and "instruct" in model_path: + return get_chat_template("llama-2-chat") + if "codellama" in model_path and "instruct" in model_path: + return get_chat_template("llama-2-chat") + + +@register_chat_template_matching_function +def match_chat_ml(model_path: str): + if "tinyllama" in model_path.lower(): + return get_chat_template("chatml") + + +if __name__ == "__main__": + messages = [ + {"role": "system", "content": None}, # None means default + # {"role": "system", "content": "You are a helpful, respectful and honest assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi!"}, + {"role": "user", "content": "What can you do?"}, + {"role": "assistant", "content": "I can chat with you."}, + ] + + template = get_chat_template("llama-2-chat") + print(template.get_prompt(messages)) diff --git a/python/sglang/lang/compiler.py b/python/sglang/lang/compiler.py new file mode 100644 index 000000000..0d1ba68c6 --- /dev/null +++ b/python/sglang/lang/compiler.py @@ -0,0 +1,237 @@ +import multiprocessing +from concurrent.futures import ThreadPoolExecutor +from queue import Queue +from typing import List, Union + +from sglang.global_config import global_config +from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program +from sglang.lang.ir import ( + SamplingParams, + SglArgument, + SglConstantText, + SglExpr, + SglVariable, +) + + +def compile_func(function, backend): + tracer = function.trace(backend=backend) + compiler = CompiledFunction(tracer, function) + return compiler + + +class CompiledFunction: + def __init__(self, tracer, function): + self.function = function + + self.last_node = CompGraphNode(tracer.last_node) + self.expr_to_node = {} + self.build_graph(tracer) + self.topological_sort() + + def build_graph(self, tracer): + self.nodes = [self.last_node] + self.expr_to_node[tracer.last_node] = self.nodes[-1] + + rename_pid = {} + + visited = set([tracer.last_node]) + head = 0 + while head < len(self.nodes): + cur_node = self.nodes[head] + + # add prev node + prev_node = cur_node.expr.prev_node + if prev_node is not None: + if prev_node not in visited: + visited.add(prev_node) + self.nodes.append(CompGraphNode(prev_node)) + self.expr_to_node[prev_node] = self.nodes[-1] + cur_node.prev_node = self.expr_to_node[prev_node] + self.expr_to_node[prev_node].add_next_node(cur_node) + + # add source node + if isinstance(cur_node.expr, SglVariable): + if cur_node.expr.name in tracer.variables: + source = tracer.variables[cur_node.expr.name].source + else: + source = cur_node.expr.source + if source not in visited: + visited.add(source) + self.nodes.append(CompGraphNode(source)) + self.expr_to_node[source] = self.nodes[-1] + cur_node.source_node = self.expr_to_node[source] + self.expr_to_node[source].add_next_node(cur_node) + head += 1 + + # rename pid + if cur_node.expr.pid not in rename_pid: + rename_pid[cur_node.expr.pid] = len(rename_pid) + cur_node.expr.pid = rename_pid[cur_node.expr.pid] + + def topological_sort(self): + prevd = {} + cand = Queue() + for x in self.nodes: + prevd[x] = (x.prev_node is not None) + (x.source_node is not None) + if prevd[x] == 0: + cand.put(x) + new_list = [] + while cand.qsize() > 0: + head = cand.get() + new_list.append(head) + for x in head.next_nodes: + prevd[x] -= 1 + if prevd[x] == 0: + cand.put(x) + self.nodes = new_list + + def print_graph( + self, + ): + for node in self.nodes: + print(node) + + def run_internal( + self, + backend, + kwargs, + default_sampling_para, + ): + stream_executor_ids = set([x.expr.pid for x in self.nodes]) + stream_executors = {} + for x in stream_executor_ids: + arguments = kwargs if x == self.last_node.expr.pid else {} + stream_executors[x] = StreamExecutor( + backend, arguments, default_sampling_para, None, False + ) + for node in self.nodes: + se_id = node.expr.pid + expr = node.expr + if isinstance(expr, SglVariable): + # Make a copy for SglVariable + expr = SglVariable(expr.name, expr.source) + expr.source_stream_executor = stream_executors[ + node.source_node.expr.pid + ] + elif isinstance(expr, SglArgument): + # Substitute SglArgument + expr = kwargs[expr.name] + stream_executors[se_id].submit(expr) + for stream_executor in stream_executors.values(): + stream_executor.end() + return ProgramState(stream_executors[self.last_node.expr.pid]) + + def run( + self, + *, + max_new_tokens: int = 16, + stop: Union[str, List[str]] = (), + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + backend=None, + **kwargs, + ): + backend = backend or global_config.default_backend + + kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()} + kwargs.update(self.function.bind_arguments) + + default_sampling_para = SamplingParams( + max_new_tokens=max_new_tokens, + stop=stop, + temperature=temperature, + top_p=top_p, + top_k=top_k, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + ) + + return self.run_internal(backend, kwargs, default_sampling_para) + + def run_batch( + self, + batch_kwargs, + *, + max_new_tokens: int = 16, + stop: Union[str, List[str]] = (), + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + backend=None, + num_threads: Union[str, int] = "auto", + ): + assert isinstance(batch_kwargs, (list, tuple)) + if len(batch_kwargs) == 0: + return [] + assert isinstance(batch_kwargs[0], dict) + + backend = backend or global_config.default_backend + + default_sampling_para = SamplingParams( + max_new_tokens=max_new_tokens, + stop=stop, + temperature=temperature, + top_p=top_p, + top_k=top_k, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + ) + batch_kwargs = [ + {k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs + ] + + # Extract prefix by tracing and cache it + if len(batch_kwargs) > 1: + pin_program(self.function, backend) + + # Run all programs + if num_threads == "auto": + num_threads = multiprocessing.cpu_count() + num_threads = min(num_threads, len(batch_kwargs)) + + if num_threads == 1: + rets = [] + for arguments in batch_kwargs: + rets.append( + self.run_internal(backend, arguments, default_sampling_para) + ) + else: + with ThreadPoolExecutor(num_threads) as executor: + futures = [] + for arguments in batch_kwargs: + futures.append( + executor.submit( + self.run_internal, backend, arguments, default_sampling_para + ) + ) + rets = [f.result() for f in futures] + rets[-1].sync() + + return rets + + +class CompGraphNode: + def __init__( + self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None + ): + self.expr = expr + self.next_nodes = next_nodes or [] + self.prev_node = prev_node + self.source_node = source_node + + def add_next_node(self, other): + self.next_nodes.append(other) + + def __repr__(self): + re = f"stream {self.expr.pid:2d}: " + re += f"%{self.expr.node_id} = " + if self.prev_node is not None: + re += f"%{self.prev_node.expr.node_id} + " + re += repr(self.expr) + return re diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py new file mode 100644 index 000000000..175d2afa9 --- /dev/null +++ b/python/sglang/lang/interpreter.py @@ -0,0 +1,697 @@ +"""The interpreter that executes SGL programs""" + +import asyncio +import multiprocessing +import queue +import threading +import uuid +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from typing import Any, Callable, Dict, List, Optional, Union + +import tqdm +from sglang.global_config import global_config +from sglang.lang.ir import ( + SglArgument, + SglCommitLazy, + SglConcateAndAppend, + SglConstantText, + SglExpr, + SglExprList, + SglFunction, + SglGen, + SglImage, + SglRoleBegin, + SglRoleEnd, + SglSelect, + SglVariable, + SglVarScopeBegin, + SglVarScopeEnd, +) +from sglang.utils import encode_image_base64 + + +def run_internal(state, program, func_args, func_kwargs, sync): + try: + state.ret_value = program.func(state, *func_args, **func_kwargs) + except Exception as e: + raise e + finally: + state.stream_executor.end() + + if sync: + state.stream_executor.sync() + + if global_config.verbosity >= 2: + print(state.text()) + + +def run_program( + program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False +): + assert backend is not None, "Please specify a backend" + func_kwargs.update(program.bind_arguments) + stream_executor = StreamExecutor( + backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream + ) + state = ProgramState(stream_executor) + + if stream: + t = threading.Thread( + target=run_internal, args=(state, program, func_args, func_kwargs, sync) + ) + t.start() + return state + else: + run_internal(state, program, func_args, func_kwargs, sync) + return state + + +def run_program_batch( + program, + backend, + batch_arguments, + default_sampling_para, + num_threads, + progress_bar, +): + # Extract prefix by tracing and cache it + if len(batch_arguments) > 1: + pin_program(program, backend) + + # Run all programs + if num_threads == "auto": + num_threads = multiprocessing.cpu_count() + num_threads = min(num_threads, len(batch_arguments)) + + if num_threads == 1: + rets = [] + for arguments in batch_arguments: + rets.append( + run_program( + program, backend, (), arguments, default_sampling_para, False, False + ) + ) + else: + if progress_bar: + pbar = tqdm.tqdm(total=len(batch_arguments)) + + with ThreadPoolExecutor(num_threads) as executor: + futures = [] + for arguments in batch_arguments: + futures.append( + executor.submit( + run_program, + program, + backend, + (), + arguments, + default_sampling_para, + False, + False, + ) + ) + if progress_bar: + futures[-1].add_done_callback(lambda _: pbar.update()) + + rets = [f.result() for f in futures] + rets[-1].sync() + + if progress_bar: + pbar.close() + + return rets + + +def pin_program(program, backend): + if global_config.enable_prefix_sharing and program.pin_prefix_rid is None: + # TODO: handle multiple backends + from sglang.lang.tracer import extract_prefix_by_tracing + + prefix = extract_prefix_by_tracing(program, backend) + if prefix and len(prefix) > 64: + prefix_rid = backend.cache_prefix(prefix) + program.pin_prefix_rid = prefix_rid + return prefix_rid + return None + + +def unpin_program(program, backend): + pass + + +class StreamExecutor: + """A stream executor that executes SGL expressions in a background thread.""" + + def __init__( + self, + backend, + arguments, + default_sampling_para, + chat_template, + stream, + use_thread=True, + ): + self.sid = uuid.uuid4().hex + self.backend = backend + self.arguments: Dict[str, Any] = arguments + self.default_sampling_para = default_sampling_para + self.stream = stream + + if hasattr(backend, "endpoint"): + self.backend = backend.endpoint + + self.variables = {} # Dict[name: str -> value: str] + self.variable_event = {} # Dict[name: str -> event: threading.Event] + self.meta_info = {} # Dict[name: str -> info: str] + self.is_finished = False + + # For completion + self.text_ = "" # The full text + + # For chat + self.messages_ = [] # The messages in the OpenAI API format + self.chat_template = chat_template or self.backend.get_chat_template() + self.cur_role = None + self.cur_role_begin_pos = None + + # For vision + self.images_ = [] + self.cur_images = [] + + # For fork/join + self.fork_start_text_pos = None + + # Worker thread + self.use_thread = use_thread + if self.use_thread: + self.queue = queue.Queue() + self.worker = threading.Thread(target=self._thread_worker_func) + self.worker.start() + + # For streaming + if stream: + self.stream_text_event = threading.Event() + self.stream_var_event = {} + else: + self.stream_text_event = None + self.stream_var_event = None + + def submit(self, expr: SglExpr): + if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)): + self.variable_event[expr.name] = threading.Event() + if self.stream: + self.stream_var_event[expr.name] = threading.Event() + elif isinstance(expr, SglExprList): + for e in expr.expr_list: + if isinstance(e, (SglGen, SglSelect, SglVarScopeBegin)): + self.variable_event[e.name] = threading.Event() + if self.stream: + self.stream_var_event[e.name] = threading.Event() + + if self.use_thread: + self.queue.put(expr) + else: + self._execute(expr) + + def sync(self): + if self.use_thread: + self.queue.join() + + def get_var(self, name): + if name in self.variable_event: + self.variable_event[name].wait() + return self.variables[name] + + def get_meta_info(self, name): + if name in self.variable_event: + self.variable_event[name].wait() + ret = self.meta_info.get(name, None) + return ret + + def fork(self, number: int, position_ids_offset: Optional[List[int]] = None): + if number > 1: + self.submit(SglCommitLazy()) + self.sync() + + number = int(number) + + exes = [ + StreamExecutor( + self.backend, + self.arguments, + self.default_sampling_para, + self.chat_template, + self.stream, + ) + for _ in range(number) + ] + for i in range(number): + exes[i].variables = dict(self.variables) + exes[i].text_ = str(self.text_) + exes[i].messages_ = list(self.messages_) + exes[i].cur_role = self.cur_role + exes[i].fork_start_text_pos = len(self.text_) + + return exes + + def text(self): + self.sync() + return self.text_ + + def messages(self): + self.sync() + return self.messages_ + + def end(self): + if self.use_thread: + if self.worker.is_alive(): + self.queue.put(None) + self.backend.end_program(self) + + def _thread_worker_func(self): + while True: + expr = self.queue.get() + if expr is None: + self.queue.task_done() + break + + self._execute(expr) + self.queue.task_done() + if self.stream_text_event: + self.stream_text_event.set() + + if self.stream_text_event: + self.stream_text_event.set() + + self.is_finished = True + + def _execute(self, other): + if isinstance(other, str): + other = SglConstantText(other) + + assert isinstance(other, SglExpr), f"{other}" + + if isinstance(other, (SglConstantText, SglArgument)): + self._execute_fill(other.value) + elif isinstance(other, SglGen): + self._execute_gen(other) + elif isinstance(other, SglSelect): + self._execute_select(other) + elif isinstance(other, SglExprList): + for x in other.expr_list: + self._execute(x) + elif isinstance(other, SglRoleBegin): + self._execute_role_begin(other) + elif isinstance(other, SglRoleEnd): + self._execute_role_end(other) + elif isinstance(other, SglImage): + self._execute_image(other) + elif isinstance(other, SglVariable): + self._execute_variable(other) + elif isinstance(other, SglVarScopeBegin): + self._execute_var_scope_begin(other) + elif isinstance(other, SglVarScopeEnd): + self._execute_var_scope_end(other) + elif isinstance(other, SglCommitLazy): + self._execute_commit_lazy_operations(other) + elif isinstance(other, SglConcateAndAppend): + if ( + global_config.enable_parallel_encoding + and self.backend.support_concate_and_append + ): + self._execute_concatenate_and_append_kv_cache(other) + else: + self._execute_concatenate_and_append_text(other) + else: + raise ValueError(f"Unknown type: {type(other)}") + + def _execute_fill(self, value: str): + value = str(value) + self.text_ += value + + def _execute_image(self, expr: SglImage): + path = expr.path + if isinstance(path, SglArgument): + path = path.value + + base64_data = encode_image_base64(path) + + self.images_.append((path, base64_data)) + self.cur_images.append((path, base64_data)) + self.text_ += self.chat_template.image_token + + # if global_config.eager_fill_image: + # self.backend.fill_image(self) + + def _execute_gen(self, expr: SglGen): + sampling_params = self._resolve_sampling_params(expr.sampling_params) + name = expr.name + + if not self.stream: + comp, meta_info = self.backend.generate( + self, sampling_params=sampling_params + ) + self.text_ += comp + + self.variables[name] = comp + self.meta_info[name] = meta_info + self.variable_event[name].set() + else: + generator = self.backend.generate_stream( + self, sampling_params=sampling_params + ) + + self.stream_var_event[name].set() + + self.variables[name] = "" + for comp, meta_info in generator: + self.text_ += comp + self.variables[name] += comp + self.stream_var_event[name].set() + self.stream_text_event.set() + + self.meta_info[name] = meta_info + + self.variable_event[name].set() + self.stream_var_event[name].set() + + def _execute_select(self, expr: SglSelect): + decision, scores = self.backend.select(self, expr.choices, expr.temperature) + if expr.name is not None: + name = expr.name + self.variables[name] = decision + self.variable_event[name].set() + self.text_ += decision + + def _execute_variable(self, expr: SglVariable): + src_executor = expr.source_stream_executor + value = src_executor.get_var(expr.name) + self._execute_fill(value) + + def _execute_role_begin(self, expr: SglRoleBegin): + assert self.cur_role is None, "Nested roles are not allowed." + + if len(self.messages_) == 0 and expr.role != "system": + # Insert the default system message + default_system = self.chat_template.default_system_prompt + if default_system: + self._execute_role_begin(SglRoleBegin("system")) + self._execute_fill(default_system) + self._execute_role_end(SglRoleEnd("system")) + + self.cur_role = expr.role + + prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_) + + self._execute_fill(prefix) + self.cur_role_begin_pos = len(self.text_) + + def _execute_role_end(self, expr: SglRoleEnd): + new_text = self.text_[self.cur_role_begin_pos :].lstrip() + + _, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_) + self._execute_fill(suffix) + + if self.cur_images: + # OpenAI vision API format + last_msg = { + "role": expr.role, + "content": [{"type": "text", "text": new_text}], + } + for (image_path, image_base64_data) in self.cur_images: + last_msg["content"].append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64_data}" + }, + } + ) + self.messages_.append(last_msg) + self.cur_images = [] + else: + self.messages_.append({"role": expr.role, "content": new_text}) + + self.cur_role = None + + def _execute_var_scope_begin(self, expr: SglVarScopeBegin): + self.variables[expr.name] = int(len(self.text_)) + + def _execute_var_scope_end(self, expr: SglVarScopeEnd): + self.variables[expr.name] = self.text_[self.variables[expr.name] :] + self.variable_event[expr.name].set() + + def _execute_commit_lazy_operations(self, expr: SglCommitLazy): + self.backend.commit_lazy_operations(self) + + def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend): + new_text = "" + for s in expr.states: + exe = s.stream_executor + exe.sync() + new_text += exe.text_[exe.fork_start_text_pos :] + + self._execute_fill(new_text) + + def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend): + self_len = len(self.text_) + + for i, s in enumerate(expr.states): + exe = s.stream_executor + exe.submit(SglCommitLazy()) + + for i, s in enumerate(expr.states): + exe = s.stream_executor + exe.sync() + assert exe.fork_start_text_pos == self_len + self.text_ += exe.text_[exe.fork_start_text_pos :] + + src_rids = [state.stream_executor.sid for state in expr.states] + self.backend.concatenate_and_append(src_rids, self.sid) + + def _resolve_sampling_params(self, sampling_params): + clone = None + for item in [ + "max_new_tokens", + "stop", + "temperature", + "top_p", + "top_k", + "frequency_penalty", + "presence_penalty", + "dtype", + "regex", + ]: + value = getattr(sampling_params, item, None) + if value is not None: + if clone is None: + clone = self.default_sampling_para.clone() + setattr(clone, item, value) + return clone or self.default_sampling_para + + def __del__(self): + self.end() + + +class ProgramState: + """The state of an SGL program.""" + + def __init__(self, stream_executor: StreamExecutor): + self.stream_executor = stream_executor + + def _role_common(self, name: str, expr: Optional[SglExpr] = None): + if expr is not None: + self.stream_executor.submit( + SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)]) + ) + else: + + @contextmanager + def role_scope(): + self.stream_executor.submit(SglRoleBegin(name)) + yield + self.stream_executor.submit(SglRoleEnd(name)) + + return role_scope() + + def system(self, expr: Optional[SglExpr] = None): + return self._role_common("system", expr) + + def user(self, expr: Optional[SglExpr] = None): + return self._role_common("user", expr) + + def assistant(self, expr: Optional[SglExpr] = None): + return self._role_common("assistant", expr) + + @contextmanager + def var_scope(self, name: str): + self.stream_executor.submit(SglVarScopeBegin(name)) + yield + self.stream_executor.submit(SglVarScopeEnd(name)) + + def fork(self, number: int = 1, position_ids_offset: Optional[List[int]] = None): + stream_executors = self.stream_executor.fork(number, position_ids_offset) + states = [ProgramState(x) for x in stream_executors] + state_group = ProgramStateGroup(states, self) + return state_group + + @contextmanager + def copy(self, position_ids_offset: Optional[List[int]] = None): + state_group = self.fork(1, position_ids_offset) + try: + yield state_group[0] + finally: + state_group.join() + + def text(self): + return self.stream_executor.text() + + def messages(self): + return self.stream_executor.messages() + + def sync(self): + return self.stream_executor.sync() + + def text_iter(self, var_name=None): + if self.stream_executor.stream: + prev = 0 + if var_name is None: + event = self.stream_executor.stream_text_event + while True: + event.wait() + event.clear() + out = str(self.stream_executor.text_[prev:]) + prev += len(out) + if out: + yield out + if self.stream_executor.is_finished: + break + else: + event = self.stream_executor.stream_var_event[var_name] + while True: + event.wait() + event.clear() + out = str(self.stream_executor.variables[var_name][prev:]) + prev += len(out) + if out: + yield out + if self.stream_executor.variable_event[var_name].is_set(): + break + else: + if var_name is None: + yield self.text() + else: + yield self.get_var(name) + + async def text_async_iter(self, var_name=None): + loop = asyncio.get_running_loop() + + if self.stream_executor.stream: + prev = 0 + if var_name is None: + event = self.stream_executor.stream_text_event + while True: + await loop.run_in_executor(None, event.wait) + event.clear() + out = str(self.stream_executor.text_[prev:]) + prev += len(out) + if out: + yield out + if self.stream_executor.is_finished: + break + else: + event = self.stream_executor.stream_var_event[var_name] + while True: + await loop.run_in_executor(None, event.wait) + event.clear() + out = str(self.stream_executor.variables[var_name][prev:]) + prev += len(out) + if out: + yield out + if self.stream_executor.variable_event[var_name].is_set(): + break + else: + if var_name is None: + yield self.text() + else: + yield self.get_var(name) + + def get_var(self, name): + return self.stream_executor.get_var(name) + + def get_meta_info(self, name): + return self.stream_executor.get_meta_info(name) + + def __iadd__(self, other): + self.stream_executor.submit(other) + return self + + def __getitem__(self, name): + return self.get_var(name) + + def __del__(self): + self.stream_executor.end() + + def __repr__(self) -> str: + msgs = self.messages() + ret = "" + for msg in msgs: + ret += msg["role"] + ":\n" + msg["content"] + "\n" + return ret + + +class ProgramStateGroup: + def __init__( + self, states: List[ProgramState], src_state: Optional[ProgramState] = None + ): + self.states = states + self.src_state = src_state + + def join(self, mode: str = "gather_variable"): + if mode == "gather_variable": + # Copy variables back + src_vars = self.src_state.stream_executor.variables + src_var_set = set(src_vars.keys()) + for child_state in self.states: + child_state.stream_executor.sync() + child_vars = child_state.stream_executor.variables + new_vars = set(child_vars.keys()) - src_var_set + + for k in new_vars: + if k in src_vars: + src_vars[k].append(child_vars[k]) + else: + src_vars[k] = [child_vars[k]] + elif mode == "concate_and_append": + # Concatenate and append KV cache + self.src_state += SglConcateAndAppend(self.states) + # Need a sync here. Otherwise, `states` can be deleted. + self.src_state.stream_executor.sync() + else: + raise ValueError(f"Invalid join mode: {mode}") + + for s in self.states: + s.stream_executor.end() + + def __getitem__(self, i: int): + return self.states[i] + + def __setitem__(self, i: int, value): + assert self.states[i] == value + + def __iadd__(self, other): + if isinstance(other, Callable): + # lambda function + for i in range(len(self.states)): + self.states[i] += other(i) + elif isinstance(other, SglExpr): + for i in range(len(self.states)): + self.states[i] += other + elif isinstance(other, (list, tuple)): + for i in range(len(self.states)): + self.states[i] += other[i] + else: + raise ValueError(f"Invalid value: {other}") + + return self diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py new file mode 100644 index 000000000..ddce5cf54 --- /dev/null +++ b/python/sglang/lang/ir.py @@ -0,0 +1,442 @@ +"""The intermediate representation.""" + +import dataclasses +import inspect +from typing import List, Optional, Union + +from sglang.global_config import global_config + + +@dataclasses.dataclass +class SamplingParams: + max_new_tokens: int = 16 + stop: Union[str, List[str]] = () + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 # -1 means disable + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 + + # for constrained generation, not included in to_xxx_kwargs + dtype: Optional[str] = None + regex: Optional[str] = None + + def clone(self): + return SamplingParams( + self.max_new_tokens, + self.stop, + self.temperature, + self.top_p, + self.top_k, + self.frequency_penalty, + self.presence_penalty, + ) + + def to_openai_kwargs(self): + # OpenAI does not support top_k, so we drop it here + return { + "max_tokens": self.max_new_tokens, + "stop": self.stop or None, + "temperature": self.temperature, + "top_p": self.top_p, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + } + + def to_anthropic_kwargs(self): + # Anthropic does not support frequency_penalty or presence_penalty, so we drop it here + return { + "max_tokens_to_sample": self.max_new_tokens, + "stop_sequences": self.stop, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + } + + def to_srt_kwargs(self): + return { + "max_new_tokens": self.max_new_tokens, + "stop": self.stop, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "regex": self.regex, + } + + +class SglFunction: + def __init__(self, func, bind_arguments=None): + self.func = func + self.bind_arguments = bind_arguments or {} + self.pin_prefix_rid = None + + # Parse arguments + argspec = inspect.getfullargspec(func) + assert argspec.args[0] == "s", 'The first argument must be "s"' + self.arg_names = argspec.args[1:] + + def bind(self, **kwargs): + assert all(key in self.arg_names for key in kwargs) + + new_bind_dict = {**self.bind_arguments, **kwargs} + return SglFunction(self.func, bind_arguments=new_bind_dict) + + def run( + self, + *args, + max_new_tokens: int = 16, + stop: Union[str, List[str]] = (), + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + stream: bool = False, + backend=None, + **kwargs, + ): + from sglang.lang.interpreter import run_program + + default_sampling_para = SamplingParams( + max_new_tokens=max_new_tokens, + stop=stop, + temperature=temperature, + top_p=top_p, + top_k=top_k, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + ) + backend = backend or global_config.default_backend + kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()} + return run_program(self, backend, args, kwargs, default_sampling_para, stream) + + def run_batch( + self, + batch_kwargs, + *, + max_new_tokens: int = 16, + stop: Union[str, List[str]] = (), + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + backend=None, + num_threads: Union[str, int] = "auto", + progress_bar: bool = False, + ): + from sglang.lang.interpreter import run_program_batch + + assert isinstance(batch_kwargs, (list, tuple)) + if len(batch_kwargs) == 0: + return [] + assert isinstance(batch_kwargs[0], dict) + + default_sampling_para = SamplingParams( + max_new_tokens=max_new_tokens, + stop=stop, + temperature=temperature, + top_p=top_p, + top_k=top_k, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + ) + backend = backend or global_config.default_backend + batch_kwargs = [ + {k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs + ] + return run_program_batch( + self, + backend, + batch_kwargs, + default_sampling_para, + num_threads, + progress_bar, + ) + + def trace(self, *, backend=None, **kwargs): + from sglang.lang.tracer import trace_program + + backend = backend or global_config.default_backend + return trace_program(self, kwargs, backend) + + def pin(self, backend=None): + from sglang.lang.interpreter import pin_program + + backend = backend or global_config.default_backend + return pin_program(self, backend) + + def unpin(self, backend=None): + from sglang.lang.interpreter import unpin_program + + backend = backend or global_config.default_backend + return unpin_program(self, backend) + + def compile(self, *, backend=None): + from sglang.lang.compiler import compile_func + + return compile_func(self, backend) + + def __call__(self, *args, **kwargs): + from sglang.lang.tracer import TracingScope + + tracing_scope = TracingScope.get_current_scope() + if tracing_scope is None: + return self.run(*args, **kwargs) + else: + kwargs["backend"] = tracing_scope.tracer_state.backend + return self.trace(*args, **kwargs) + + +class SglExpr: + node_ct = 0 + + def __init__(self): + self.node_id = SglExpr.node_ct + self.prev_node = None + self.pid = None + SglExpr.node_ct += 1 + + def __add__(self, other): + if isinstance(other, str): + other = SglConstantText(other) + assert isinstance(other, SglExpr) + + return self.concatenate_ir(self, other) + + def __radd__(self, other): + if isinstance(other, str): + other = SglConstantText(other) + assert isinstance(other, SglExpr), f"{other}" + + return self.concatenate_ir(other, self) + + def concatenate_ir(self, a, b): + if isinstance(a, SglExprList): + if isinstance(b, SglExprList): + return SglExprList(a.expr_list + b.expr_list) + else: + return SglExprList(a.expr_list + [b]) + elif isinstance(b, SglExprList): + return SglExprList([a] + b.expr_list) + + return SglExprList([a, b]) + + def print_graph_dfs(self): + ret = [""] + visited = set() + + def dfs_print(x): + if x is None or x in visited: + return + visited.add(x) + + # Print dependency + if x.prev_node is not None: + dfs_print(x.prev_node) + + if isinstance(x, SglExprList): + for y in x.expr_list: + dfs_print(y) + # elif isinstance(x, SglRole): + # dfs_print(x.expr) + elif isinstance(x, SglVariable): + dfs_print(x.source) + + # Print the node itself + if isinstance(x, (SglFork, SglGetForkItem)): + ret[0] += f"%{x.node_id} = {x}\n" + else: + if x.prev_node is not None: + ret[0] += ( + f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n" + ) + else: + ret[0] += f"%{x.node_id} = " + str(x) + "\n" + + dfs_print(self) + return ret[0] + + +class SglExprList(SglExpr): + def __init__(self, expr_list: List[SglExpr]): + super().__init__() + self.expr_list = expr_list + + def __repr__(self): + return f"ExprList({self.expr_list})" + + +class SglArgument(SglExpr): + def __init__(self, name: str, value: str): + super().__init__() + self.name = name + self.value = value + + def __repr__(self): + return f"Argument(name={self.name}, value={repr(self.value)})" + + def __len__(self): + return len(self.value) + + def __getitem__(self, i): + return self.value[i] + + def __int__(self): + return self.value + + def __bool__(self): + return self.value + + def __format__(self, *args): + raise TypeError( + "Cannot put argument inside a f-string. " + "This is not compatible with the tracer. " + ) + + +class SglImage(SglExpr): + def __init__(self, path): + self.path = path + + def __repr__(self) -> str: + return f"SglImage({self.path})" + + +class SglGen(SglExpr): + def __init__( + self, + name, + max_new_tokens, + stop, + temperature, + top_p, + top_k, + frequency_penalty, + presence_penalty, + dtype, + regex, + ): + super().__init__() + self.name = name + self.sampling_params = SamplingParams( + max_new_tokens=max_new_tokens, + stop=stop, + temperature=temperature, + top_p=top_p, + top_k=top_k, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + dtype=dtype, + regex=regex, + ) + + def __repr__(self): + return f"Gen('{self.name}')" + + +class SglConstantText(SglExpr): + def __init__(self, value): + super().__init__() + self.value = value + + def __repr__(self): + return f"Constant({repr(self.value)})" + + +class SglRoleBegin(SglExpr): + def __init__(self, role): + super().__init__() + self.role = role + + def __repr__(self): + return f"RoleBegin({self.role})" + + +class SglRoleEnd(SglExpr): + def __init__(self, role): + super().__init__() + self.role = role + + def __repr__(self): + return f"RoleEnd({self.role})" + + +class SglSelect(SglExpr): + def __init__(self, name, choices, temperature): + super().__init__() + self.name = name + self.choices = choices + self.temperature = temperature + + def __repr__(self): + return f"Select({self.name}, choices={self.choices})" + + +class SglFork(SglExpr): + def __init__(self, number, position_ids_offset=None): + super().__init__() + self.number = number + self.position_ids_offset = position_ids_offset + + def __repr__(self): + return ( + f"Fork(%{self.prev_node.node_id}, number={self.number}, " + f"position_ids_offset={self.position_ids_offset})" + ) + + +class SglGetForkItem(SglExpr): + def __init__(self, index): + super().__init__() + self.index = index + + def __repr__(self): + return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})" + + +class SglVariable(SglExpr): + def __init__(self, name, source): + super().__init__() + self.name = name + self.source = source + + def __repr__(self): + return f"Variable('{self.name}', source=%{self.source.node_id})" + + +class SglVarScopeBegin(SglExpr): + def __init__(self, name): + super().__init__() + self.name = name + + def __repr__(self): + return f"VarScopeBegin('{self.name}')" + + +class SglVarScopeEnd(SglExpr): + def __init__(self, name): + super().__init__() + self.name = name + + def __repr__(self): + return f"VarScopeEnd('{self.name}')" + + +class SglConcateAndAppend(SglExpr): + def __init__(self, states): + super().__init__() + self.states = states + + def __repr__(self): + return f"ConcatenateAndAppend('{self.states}')" + + +class SglCommitLazy(SglExpr): + def __init__(self): + super().__init__() + + def __repr__(self): + return f"CommitLazy()" diff --git a/python/sglang/lang/tracer.py b/python/sglang/lang/tracer.py new file mode 100644 index 000000000..7e89320b5 --- /dev/null +++ b/python/sglang/lang/tracer.py @@ -0,0 +1,279 @@ +"""Tracing a program.""" +import uuid +from typing import Any, Callable, Dict, List, Optional, Union + +from sglang.backend.base_backend import BaseBackend +from sglang.global_config import global_config +from sglang.lang.interpreter import ProgramState, ProgramStateGroup +from sglang.lang.ir import ( + SglArgument, + SglCommitLazy, + SglConcateAndAppend, + SglConstantText, + SglExpr, + SglExprList, + SglFork, + SglFunction, + SglGen, + SglGetForkItem, + SglRoleBegin, + SglRoleEnd, + SglSelect, + SglVariable, + SglVarScopeBegin, + SglVarScopeEnd, +) + + +class StopTracing(Exception): + pass + + +def extract_prefix_by_tracing(program, backend): + # Create dummy arguments + dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names} + arguments = dummy_arguments + arguments.update(program.bind_arguments) + + # Trace + tracer = TracerProgramState(backend, arguments, only_trace_prefix=True) + try: + with TracingScope(tracer): + tracer.ret_value = program.func(tracer, **arguments) + except StopTracing: + pass + + # Run and cache prefix + prefix = "" + for expr in tracer.flatten_nodes(): + if isinstance(expr, SglConstantText): + prefix += expr.value + else: + break + return prefix + + +def trace_program(program, arguments, backend): + # Create dummy backend + if backend is None: + backend = BaseBackend() + + # Create dummy arguments + dummy_arguments = { + name: SglArgument(name, None) + for name in program.arg_names + if name not in arguments + } + arguments.update(dummy_arguments) + arguments.update(program.bind_arguments) + + # Trace + tracer = TracerProgramState(backend, arguments, only_trace_prefix=False) + with TracingScope(tracer): + tracer.ret_value = program.func(tracer, **arguments) + return tracer + + +class TracerProgramState(ProgramState): + def __init__(self, backend, arguments, only_trace_prefix): + self.pid = uuid.uuid4().hex + self.backend = backend + self.arguments: Dict[str, Any] = arguments + self.only_trace_prefix = only_trace_prefix + + if hasattr(backend, "endpoint"): + self.backend = backend.endpoint + + self.nodes = [] + self.last_node = None + self.variables = {} + self.ret_value = None + + # For completion + + # For chat + self.messages_ = [] + self.cur_role = None + self.chat_template = self.backend.get_chat_template() + + # For multi states + self.child_states = [] + + cur_scope = TracingScope.get_current_scope() + if cur_scope is not None: + cur_scope.add_child_state(self) + + ################################## + ########### Public API ########### + ################################## + + def fork(self, number: int, position_ids_offset: Optional[List[int]] = None): + if self.only_trace_prefix: + raise StopTracing() + + fork_node = SglFork(number) + fork_node.prev_node = self.last_node + + states = [ + TracerProgramState(self.backend, self.arguments, self.only_trace_prefix) + for _ in range(number) + ] + + for i in range(number): + node = SglGetForkItem(i) + node.prev_node = fork_node + states[i].last_node = node + states[i].variables = dict(self.variables) + states[i].messages_ = list(self.messages_) + states[i].cur_role = self.cur_role + states[i].chat_template = self.chat_template + + state_group = ProgramStateGroup(states, self) + + return state_group + + ################################## + ########## Internal API ########## + ################################## + + def _append_node(self, other: SglExpr): + self.nodes.append(other) + other.prev_node = self.last_node + self.last_node = other + + def _execute(self, other: SglExpr): + if isinstance(other, str): + other = SglConstantText(other) + + other.pid = self.pid + + if isinstance(other, SglConstantText): + self._execute_fill(other) + elif isinstance(other, SglGen): + self._execute_gen(other) + elif isinstance(other, SglSelect): + self._execute_select(other) + elif isinstance(other, SglExprList): + for x in other.expr_list: + self._execute(x) + elif isinstance(other, SglRoleBegin): + self._execute_role_begin(other) + elif isinstance(other, SglRoleEnd): + self._execute_role_end(other) + elif isinstance(other, SglVarScopeBegin): + self._execute_var_scope_begin(other) + elif isinstance(other, SglVarScopeEnd): + self._execute_var_scope_end(other) + else: + if self.only_trace_prefix: + raise StopTracing() + else: + self._append_node(other) + + return self + + def __iadd__(self, other): + self._execute(other) + return self + + def _execute_fill(self, expr: SglConstantText): + if isinstance(expr, str): + expr = SglConstantText(expr) + self._append_node(expr) + + def _execute_gen(self, expr: SglGen): + name = expr.name if expr.name is not None else "gen_" + str(len(self.variables)) + new_node = SglVariable(name, source=expr) + self.variables[name] = new_node + self._append_node(expr) + + def _execute_select(self, expr: SglSelect): + name = ( + expr.name if expr.name is not None else "select_" + str(len(self.variables)) + ) + new_node = SglVariable(name, source=expr) + self.variables[name] = new_node + self._append_node(expr) + + def _execute_role_begin(self, expr: SglRoleBegin): + assert self.cur_role is None, "Nested roles are not allowed." + + if len(self.messages_) == 0 and expr.role != "system": + # Insert default system message + default_system = self.chat_template.default_system_prompt + if default_system: + self._execute_role_begin(SglRoleBegin("system")) + self._execute_fill(default_system) + self._execute_role_end(SglRoleEnd("system")) + + self.cur_role = expr.role + + prefix, suffix = self.chat_template.get_prefix_and_suffix( + expr.role, self.messages_ + ) + + self._execute_fill(prefix) + + def _execute_role_end(self, expr: SglRoleEnd): + prefix, suffix = self.chat_template.get_prefix_and_suffix( + expr.role, self.messages_ + ) + + self._execute_fill(suffix) + + self.messages_.append({"role": expr.role, "content": ""}) + + self.cur_role = None + + def _execute_var_scope_end(self, expr: SglVarScopeEnd): + new_node = SglVariable(name, source=self.last_node) + self.variables[name] = new_node + + def get_var(self, name): + ret = self.arguments.get(name, None) + if ret is not None: + return ret + + v = self.variables[name] + return SglVariable(v.name, v.source) + + def flatten_nodes(self): + def traverse(cur): + if isinstance(cur, SglExprList): + for child in cur.expr_list: + traverse(child) + else: + ret.append(cur) + + ret = [] + for x in self.nodes: + traverse(x) + return ret + + def __del__(self): + pass + + +class TracingScope: + cur_scope = None + + def __init__(self, tracer_state: TracerProgramState): + self.tracer_state = tracer_state + self.last_scope = TracingScope.cur_scope + + def __enter__(self): + TracingScope.cur_scope = self + return self + + def __exit__(self, exc_type, exc_value, traceback): + TracingScope.cur_scope = self.last_scope + + @staticmethod + def get_current_scope(): + return TracingScope.cur_scope + + def add_child_state(self, state: TracerProgramState): + cur_scope = self + while cur_scope != None: + cur_scope.tracer_state.child_states.append(state) + cur_scope = cur_scope.last_scope diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py new file mode 100644 index 000000000..9d63a2aed --- /dev/null +++ b/python/sglang/launch_server.py @@ -0,0 +1,11 @@ +import argparse + +from sglang.srt.server import ServerArgs, launch_server + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + + launch_server(server_args, None) diff --git a/python/sglang/srt/constrained/fsm.py b/python/sglang/srt/constrained/fsm.py new file mode 100644 index 000000000..ceec5d3e5 --- /dev/null +++ b/python/sglang/srt/constrained/fsm.py @@ -0,0 +1,385 @@ +# Adapted from: +# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py +from typing import List, NewType, Protocol + +import interegular +from lark import Lark + +# from outlines.fsm.parsing import PartialLark +from sglang.srt.constrained.regex import ( + create_fsm_index_tokenizer, + make_deterministic_fsm, +) +from sglang.srt.constrained.tokenizer import Tokenizer + +FSMState = NewType("FSMState", int) + + +class FSM(Protocol): + def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: + ... + + def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: + ... + + def is_final_state(self, state: FSMState, idx: int = 0) -> bool: + ... + + def reset(self) -> None: + ... + + +class StopAtTokenFSM(FSM): + """FSM to generate text until a specified token id is generated or + a specified number of tokens has been generated. + + Text is usually produced until the EOS token is generated by the + model. + + """ + + def __init__( + self, + tokenizer: "Tokenizer", + stop_token_id: int, + ): + self.stop_token_id = stop_token_id + self.num_tokens_generated = 0 + self.vocabulary = tokenizer.vocabulary.values() + self.final_states = {1} + + def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: + """Generate a list of allowed tokens for the next step. + + When in the initial state we allow every token to be generated. + In the final state the only allowed token is `stop_token_id`. + + Parameters + ---------- + state + The current state of the FSM. + idx + The index of the current input in the batch. + + Returns + ------- + A list that contains the tokens to mask. + + """ + if state == 0: + return list(self.vocabulary) + else: + return [self.stop_token_id] + + def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: + """Update the state of the FSM. + + The FSM stays in the initial state `0` unless the specified stop token + has been generated or the maximum number of tokens has been reached. In + which case the FSM moves to the final state `1`. + + Parameters + ---------- + state + The current state of the FSM. + token_id + The id of the token that was just generated. + idx + The index of the current input in the batch. + + Returns + ------- + The new state of the FSM. + + """ + if idx == 0: + self.num_tokens_generated += 1 + + if token_id == self.stop_token_id: + return FSMState(1) + + return FSMState(0) + + def is_final_state(self, state: FSMState, idx: int = 0) -> bool: + """Determine whether the current state of the FSM is a final state.""" + return state in self.final_states + + def reset(self) -> None: + """Reset the FSM to its initial state. Here this only resets the token counter.""" + self.num_tokens_generated = 0 + + +class RegexFSM(FSM): + """FSM to generate text that is in the language of a regular expression.""" + + def __init__( + self, + regex_string: str, + tokenizer: "Tokenizer", + ): + regex_pattern = interegular.parse_pattern(regex_string) + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + ( + self.states_to_token_maps, + self.empty_token_ids, + ) = create_fsm_index_tokenizer(regex_fsm, tokenizer) + + # We make sure that it is possible to generate strings in the language + # of the regular expression with the tokens present in the model's + # vocabulary. + if not any( + regex_fsm.finals.intersection(v.values()) + for v in self.states_to_token_maps.values() + ): + raise ValueError( + "The vocabulary does not allow us to build a sequence that matches the input regex" + ) + + self.final_states = regex_fsm.finals | { + -1 + } # Include the EOS token in final states + self.num_tokens_generated = 0 + self.vocabulary = tokenizer.vocabulary.values() + self.end_token_id = tokenizer.eos_token_id + + def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: + """Generate a list of allowed tokens for the next step. + + The initialization of the FSM builds an index which maps FSM states to a + map from authorized tokens to the state in which the FSM needs to move + if said token is generated. Therefore the authorized tokens at the + current state are the keys of the map returned by the value of the index + for current state. + + If the current state is not contained in the end this means that we are + in a final state of the FSM. We only authorize EOS tokens in the final + state. + + Parameters + ---------- + state + The current state of the FSM. + idx + The index of the current input in the batch. + + Returns + ------- + A list that contains the tokens to mask. + + """ + next_tokens_to_end_states = self.states_to_token_maps.get(state) + + if next_tokens_to_end_states is None: + return [self.end_token_id] + else: + return list(next_tokens_to_end_states.keys()) + + def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: + """Update the state of the FSM. + + We use the index to determine to which state the FSM should transition + given the token that was just generated. + + Parameters + ---------- + state + The current state of the FSM. + token_id + The id of the token that was just generated. + idx + The index of the current input in the batch. + + Returns + ------- + The new state of the FSM. + + """ + if idx == 0: + self.num_tokens_generated += 1 + + if token_id == self.end_token_id: + return FSMState(-1) + + last_token_to_end_state = self.states_to_token_maps[state] + next_state = last_token_to_end_state.get(token_id) + if next_state is None: + next_state = -1 + + return FSMState(next_state) + + def is_final_state(self, state: FSMState, idx: int = 0) -> bool: + """Determine whether the current state of the FSM is a final state.""" + return state in self.final_states + + def reset(self) -> None: + """Reset the FSM to its initial state. Here this only resets the token counter.""" + self.num_tokens_generated = 0 + + +class CFGFSM(FSM): + """FSM to generate text that is in the language of a context-free grammar.""" + + def __init__( + self, + cfg_string: str, + tokenizer: "Tokenizer", + ): + # self.parser = PartialLark(cfg_string, parser="lalr") + self.parser = Lark( + cfg_string, + parser="lalr", + lexer="contextual", + propagate_positions=False, + maybe_placeholders=False, + regex=True, + ) + self.terminal_regexps = dict() + for terminal in self.parser.terminals: + if terminal.pattern is not None: + self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp() + self.terminal_regexps["$END"] = tokenizer.eos_token + + self.tokenizer = tokenizer + self.num_tokens_generated = 0 + self.generations: List[str] = [] + self.regex_fsms: List[RegexFSM] = [] + self.reset_state: List[bool] = [] + self.allow_eos: List[bool] = [] + self.done: List[bool] = [] + + def _set_next_regex_fsm(self, idx: int = 0) -> None: + """Use the CFG incremental parser to set the next regex FSM. + + Check what the CFG incremental parser proposes next. + If the only proposal is the EOS token, + we set the state to done and return. + If there are other proposals, + we set a new regex FSM and return. + + """ + interactive = self.parser.parse_interactive(self.generations[idx]) + interactive.exhaust_lexer() + options = {self.terminal_regexps[x] for x in interactive.accepts()} + + if self.terminal_regexps["$END"] in options: + options.remove(self.terminal_regexps["$END"]) + if len(options) == 0: + self.done[idx] = True + return + self.allow_eos[idx] = True + options.add("") + assert len(options) > 1 + + regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")" + args = ( + regex_string, + self.tokenizer, + ) + if len(self.regex_fsms) <= idx: + self.regex_fsms.append(RegexFSM(*args)) + else: + self.regex_fsms[idx] = RegexFSM(*args) + self.reset_state[idx] = True + + def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: + """Generate a list of allowed tokens for the next step. + + Upon initialization, the CFG incremental parser is used to determine the first regex. + + This regex is used for proposals until either: + - the regex is exhausted, and its only remaining option is the EOS token, + in which case we always transition to the next regex + - the regex can be exhausted, but the EOS token is not the only remaining option, + in which case we transition to the next regex with probability P (TODO) + or remove the possibility of generating the EOS token and continue with the current regex + + The CFG incremental parser is allowed to propose the EOS token from any final state, + and once it is generated, the FSM will continue to always generate the EOS token. + + Parameters + ---------- + state + The current state of the FSM. + idx + The index of the current input in the batch. + + Returns + ------- + A list that contains the tokens to mask. + + """ + if len(self.generations) <= idx: + self.generations.append("") + self.reset_state.append(False) + self.allow_eos.append(False) + self.done.append(False) + + if len(self.regex_fsms) > idx: + proposal = self.regex_fsms[idx].allowed_token_ids(state) + if self.tokenizer.eos_token_id not in proposal: + return proposal + if set(proposal) != {self.tokenizer.eos_token_id}: + if False: # TODO: THIS NEEDS TO BE SAMPLED + proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] + return proposal + + self._set_next_regex_fsm(idx) + + if self.done[idx]: + return [self.tokenizer.eos_token_id] + + if self.reset_state[idx]: + state = FSMState(0) + + proposal = self.regex_fsms[idx].allowed_token_ids(state) + if self.allow_eos[idx]: + self.allow_eos[idx] = False + else: + proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] + assert len(proposal) > 0 + return proposal + + def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: + """Update the state of the FSM. + + Transitions the underlying regex FSM to its next state. + If at max tokens or EOS token, transition permanently to the final state. + Update stored partial generations for subsequent incremental parsing. + + Parameters + ---------- + state + The current state of the FSM. + token_id + The id of the token that was just generated. + idx + The index of the current input in the batch. + + Returns + ------- + The new state of the FSM. + """ + if idx == 0: + self.num_tokens_generated += 1 + if token_id == self.tokenizer.eos_token_id: + self.done[idx] = True + return FSMState(-1) + if self.reset_state[idx]: + self.reset_state[idx] = False + state = FSMState(0) + + self.generations[idx] += self.tokenizer.decode([token_id])[0] + + return self.regex_fsms[idx].next_state(state, token_id, idx) + + def is_final_state(self, state: FSMState, idx: int = 0) -> bool: + """Return whether the current state of the FSM is a final state.""" + return self.done[idx] + + def reset(self) -> None: + """Reset the FSM to its initial state, so it can be called on a fresh batch on inputs.""" + self.num_tokens_generated = 0 + self.generations = [] + self.regex_fsms = [] + self.reset_state = [] + self.done = [] diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py new file mode 100644 index 000000000..bd6c6a073 --- /dev/null +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -0,0 +1,41 @@ +import threading + +from sglang.srt.constrained.fsm import RegexFSM +from sglang.srt.constrained.tokenizer import TransformerTokenizer + + +def get_fsm(regex, tokenizer, fsm_cache_entry): + outlines_tokenizer = TransformerTokenizer(tokenizer) + fsm = RegexFSM(regex, outlines_tokenizer) + fsm_cache_entry.fsm = fsm + fsm_cache_entry.event.set() + + +class FSMCacheEntry: + def __init__(self): + self.fsm = None + self.event = threading.Event() + + +class FSMCache: + def __init__(self, tokenizer): + self.cache = {} + self.tokenizer = tokenizer + + def init_fsm_in_background(self, regex): + if regex not in self.cache: + self.cache[regex] = FSMCacheEntry() + threading.Thread( + target=get_fsm, + args=( + regex, + self.tokenizer, + self.cache[regex], + ), + ).start() + + def get_fsm(self, regex): + self.init_fsm_in_background(regex) + entry = self.cache[regex] + entry.event.wait() + return entry.fsm diff --git a/python/sglang/srt/constrained/regex.py b/python/sglang/srt/constrained/regex.py new file mode 100644 index 000000000..0f1f89ff5 --- /dev/null +++ b/python/sglang/srt/constrained/regex.py @@ -0,0 +1,586 @@ +# Adapted from: +# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/regex.py +from collections import namedtuple +from functools import lru_cache +from typing import Dict, Generator, List, Sequence, Set, Tuple + +import numba +import numpy as np +from interegular.fsm import FSM, Alphabet, OblivionError, anything_else +from numba.typed.typedobjectutils import _nonoptional +from sglang.srt.constrained.tokenizer import Tokenizer + + +class BetterAlphabet(Alphabet): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert anything_else in self._symbol_mapping + self.anything_value = self._symbol_mapping[anything_else] + + def __getitem__(self, item): + return self._symbol_mapping.get(item, self.anything_value) + + def copy(self): + return BetterAlphabet(self._symbol_mapping.copy()) + + +class BetterFSM(FSM): + flat_transition_map: Dict[Tuple[int, int], int] + trans_key_to_states: Dict[int, List[int]] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if not isinstance(self.alphabet, BetterAlphabet): + self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping) + + flat_transition_map = {} + trans_key_to_states = {} + for from_state, trans_map in self.map.items(): + for trans_key, to_state in trans_map.items(): + flat_transition_map[(from_state, trans_key)] = to_state + trans_key_to_states.setdefault(trans_key, set()).add(from_state) + + self.__dict__["trans_key_to_states"] = trans_key_to_states + self.__dict__["flat_transition_map"] = flat_transition_map + self.__dict__["_fsm_info"] = None + + def copy(self): + return BetterFSM( + alphabet=self.alphabet.copy(), + states=self.states.copy(), + initial=self.initial, + finals=self.finals.copy(), + map=self.map.copy(), + __no_validation__=True, + ) + + @property + def fsm_info(self): + if self._fsm_info is None: + flat_transition_map_items = np.fromiter( + ((a[0], a[1], b) for a, b in self.flat_transition_map.items()), + dtype=np.dtype("i8, i8, i8"), + ) + trans_key_to_states_items = np.fromiter( + ((k, z) for k, v in self.trans_key_to_states.items() for z in v), + dtype=np.dtype("i8, i8"), + ) + alphabet_symbol_mapping_items = np.fromiter( + ( + it + for it in self.alphabet._symbol_mapping.items() + if it[0] != anything_else + ), + dtype=np.dtype("U1, i8"), + ) + nb_finals = np.fromiter(self.finals, dtype=np.dtype("i8")) + self.__dict__["_fsm_info"] = create_fsm_info( + self.initial, + nb_finals, + flat_transition_map_items, + trans_key_to_states_items, + self.alphabet.anything_value, + alphabet_symbol_mapping_items, + ) + + return self._fsm_info + + +nb_int_list_type = numba.types.ListType(numba.int64) +nb_int_pair_type = numba.types.UniTuple(numba.int64, 2) +nb_unichar_1_type = numba.types.UnicodeCharSeq(1) + + +@numba.njit(cache=True) +def create_fsm_info( + py_initial, + py_finals, + flat_transition_map_items, + trans_key_to_states_items, + py_anything_value, + alphabet_symbol_mapping_items, +): + trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type) + for trans_key_and_state in trans_key_to_states_items: + trans_key_to_states.setdefault( + trans_key_and_state[0], numba.typed.List.empty_list(numba.int64) + ).append(trans_key_and_state[1]) + + flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64) + for trans_key_and_state in flat_transition_map_items: + flat_transition_map[ + (trans_key_and_state[0], trans_key_and_state[1]) + ] = trans_key_and_state[2] + + alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_1_type, numba.int64) + for symbol_and_trans_key in alphabet_symbol_mapping_items: + alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1] + + initial = numba.int64(py_initial) + + finals = set() + for final in py_finals: + finals.add(final) + + anything_value = numba.int64(py_anything_value) + + return FSMInfo( + initial, + finals, + flat_transition_map, + trans_key_to_states, + anything_value, + alphabet_symbol_map, + ) + + +FSMInfo = namedtuple( + "FSMInfo", + [ + "initial", + "finals", + "transitions", + "trans_key_to_states", + "alphabet_anything_value", + "alphabet_symbol_mapping", + ], +) + + +def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: + """Construct an equivalent FSM with deterministic state labels.""" + old_to_new_trans_keys = { + trans_key: i + for i, (trans_key, _) in enumerate( + sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1])) + ) + } + + new_symbol_mapping = { + symbol: old_to_new_trans_keys[trans_key] + for symbol, trans_key in fsm.alphabet._symbol_mapping.items() + } + + new_alphabet = BetterAlphabet(new_symbol_mapping) + + new_map = { + from_state: { + old_to_new_trans_keys[trans_key]: to_state + for trans_key, to_state in trans_map.items() + } + for from_state, trans_map in fsm.map.items() + } + + old_to_new_states = {} + old_to_new_states[fsm.initial] = 0 + + i = 0 + seen = {fsm.initial} + old_state_queue = [fsm.initial] + while old_state_queue: + old_state = old_state_queue.pop(-1) + transitions = new_map[old_state] + sorted_transitions = sorted(transitions.items(), key=lambda v: v[0]) + for _, old_state in sorted_transitions: + if old_state not in seen: + old_state_queue.append(old_state) + seen.add(old_state) + if old_state not in old_to_new_states: + i += 1 + old_to_new_states[old_state] = i + + new_map = dict( + sorted( + ( + ( + old_to_new_states[from_state], + dict( + sorted( + ( + (trans_key, old_to_new_states[to_state]) + for trans_key, to_state in trans_map.items() + ), + key=lambda v: v[0], + ) + ), + ) + for from_state, trans_map in new_map.items() + ), + key=lambda v: v[0], + ) + ) + + new_initial = 0 + new_finals = frozenset( + sorted(old_to_new_states[old_state] for old_state in fsm.finals) + ) + new_states = frozenset(sorted(new_map.keys())) + + new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map) + + return new_fsm, old_to_new_states + + +@numba.njit(nogil=True, cache=True) +def _walk_fsm( + fsm_transitions: Dict[Tuple[int, int], int], + alphabet_symbol_mapping: Dict[str, int], + alphabet_anything_value: int, + fsm_initial: int, + fsm_finals: Set[int], + input_string: str, + start_state: int, + full_match: bool = True, +) -> List[int]: + state = start_state + accepted_states: List[int] = numba.typed.List.empty_list(numba.int64) + last_final_idx: int = numba.uint64(0) + + for i, symbol in enumerate(input_string): + trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value) + + new_state = fsm_transitions.get((state, trans_key)) + + if new_state is None: + if not full_match and last_final_idx > 0: + return accepted_states[:last_final_idx] + + return numba.typed.List.empty_list(numba.int64) + + state = new_state + + if state in fsm_finals: + last_final_idx = numba.uint64(i + 1) + + accepted_states.append(_nonoptional(state)) + + if full_match and last_final_idx - 1 != i: + return numba.typed.List.empty_list(numba.int64) + + return accepted_states + + +def walk_fsm( + fsm: BetterFSM, + input_string: str, + start_state: int, + full_match: bool = True, +) -> List[int]: + fsm_finals = fsm.finals + + state = start_state + accepted_states: List[int] = [] + last_final_idx: int = 0 + + alphabet_symbol_mapping = fsm.alphabet._symbol_mapping + alphabet_anything_value = fsm.alphabet.anything_value + fsm_transitions = fsm.flat_transition_map + + for i, symbol in enumerate(input_string): + trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value) + + new_state = fsm_transitions.get((state, trans_key)) + + if new_state is None: + if not full_match and last_final_idx > 0: + return accepted_states[:last_final_idx] + + return [] + + state = new_state + + if state in fsm_finals: + last_final_idx = i + 1 + + accepted_states.append(state) + + if full_match and last_final_idx - 1 != i: + return [] + + return accepted_states + + +def fsm_union( + fsms: Sequence[FSM], +) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]: + """Construct an FSM representing the union of the FSMs in `fsms`. + + This is an updated version of `interegular.fsm.FSM.union` made to return an + extra map of component FSMs to the sets of state transitions that + correspond to them in the new FSM. + + """ + + alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms]) + + indexed_fsms = tuple(enumerate(fsms)) + + initial = {i: fsm.initial for (i, fsm) in indexed_fsms} + + # Dedicated function accepting a "superset" and returning the next + # "superset" obtained by following this transition in the new FSM + def follow(current_state, new_transition: int): + next = {} + for i, f in indexed_fsms: + old_transition = new_to_old[i][new_transition] + if ( + i in current_state + and current_state[i] in f.map + and old_transition in f.map[current_state[i]] + ): + next[i] = f.map[current_state[i]][old_transition] + if not next: + raise OblivionError + return next + + states = [initial] + finals: Set[int] = set() + map: Dict[int, Dict[int, int]] = {} + + # Map component FSMs to their new state-to-state transitions, finals, and a + # map translating component FSM states to aggregate FSM states + fsms_to_trans_finals: Dict[ + int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] + ] = {} + + i = 0 + while i < len(states): + state = states[i] + + # Add to the finals of the aggregate FSM whenever we hit a final in a + # component FSM + if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms): + finals.add(i) + + # Compute the map for this state + map[i] = {} + for transition in alphabet.by_transition: + try: + next = follow(state, transition) + except OblivionError: + # Reached an oblivion state; don't list it + continue + else: + try: + # TODO: Seems like this could--and should--be avoided + j = states.index(next) + except ValueError: + j = len(states) + states.append(next) + + map[i][transition] = j + + for fsm_id, fsm_state in next.items(): + ( + fsm_transitions, + fsm_finals, + fsm_old_to_new, + ) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {})) + old_from = state[fsm_id] + old_to = fsm_state + fsm_old_to_new.setdefault(old_from, set()).add(i) + fsm_old_to_new.setdefault(old_to, set()).add(j) + fsm_transitions.add((i, j)) + if fsm_state in fsms[fsm_id].finals: + fsm_finals.add(j) + + i += 1 + + fsm = FSM( + alphabet=alphabet, + states=range(len(states)), + initial=0, + finals=finals, + map=map, + __no_validation__=True, + ) + + fsm, old_to_new_states = make_deterministic_fsm(fsm) + _fsms_to_trans_finals = { + fsm_id: ( + {(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions}, + {old_to_new_states[s] for s in finals}, + { + old_state: {old_to_new_states[new_state] for new_state in new_states} + for old_state, new_states in old_to_new.items() + }, + ) + for fsm_id, (transitions, finals, old_to_new) in sorted( + fsms_to_trans_finals.items(), key=lambda x: x[0] + ) + } + + return ( + fsm, + _fsms_to_trans_finals, + ) + + +def get_sub_fsms_from_seq( + state_seq: Sequence[int], + fsms_to_trans_finals: Dict[ + int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] + ], +) -> Generator[Tuple[int, bool, bool], None, None]: + """Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`. + + Parameters + ---------- + state_seq + A state sequence. + fsms_to_trans_finals + A map from FSM indices to tuples containing sets of their state transitions + and sets of the final/accept states. + + Returns + ------- + A generator returning tuples containing each sub-FSM index (in the order + they were union-ed to construct `fsm`) and booleans indicating whether or + not there is another valid transition from the last state in the sequence + for the associated sub-FSM (i.e. if the FSM can continue + accepting/matching) and whether or not the sequence ends in a final state + of the sub-FSM. + """ + state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:])) + last_fsm_state = state_seq[-1] + yield from ( + ( + # The sub-FMS index + fsm_idx, + # Is there another possible transition in this sub-FSM? + any(last_fsm_state == from_s for (from_s, to_s) in transitions), + # Is this sub-FSM in a final state? + state_seq[-1] in finals, + ) + for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items() + if state_seq_transitions.issubset(transitions) + ) + + +@numba.njit(cache=True, nogil=True) +def state_scan_tokens( + fsm_transitions: Dict[Tuple[int, int], int], + alphabet_symbol_mapping: Dict[str, int], + alphabet_anything_value: int, + fsm_initial: int, + fsm_finals: Set[int], + vocabulary: Dict[str, List[int]], + start_state: int, +) -> Set[Tuple[int, int]]: + res = set() + + for token, token_ids in vocabulary.items(): + state_seq = _walk_fsm( + fsm_transitions, + alphabet_symbol_mapping, + alphabet_anything_value, + fsm_initial, + fsm_finals, + token, + start_state, + False, + ) + + if state_seq is not None and len(state_seq) < len(token): + continue + + for token_id in token_ids: + res.add((token_id, state_seq[-1])) + + return res + + +def create_fsm_index_end_to_end( + fsm_info: FSMInfo, + vocabulary: Dict[str, List[int]], +) -> Dict[int, Set[Tuple[int, int]]]: + """Create an FSM state-to-vocabulary map/index through end-to-end token parsing.""" + + # TODO: Consider using a `List` of `Set`s instead; that way we can JIT this + # code, too. + states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {} + seen: Set[int] = set() + next_states = {fsm_info.initial} + + while next_states: + start_state = next_states.pop() + + token_ids_end_states = state_scan_tokens( + fsm_info.transitions, + fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + fsm_info.initial, + fsm_info.finals, + vocabulary, + start_state, + ) + + for token_id_and_end_state in token_ids_end_states: + states_to_token_subsets.setdefault(start_state, set()).add( + token_id_and_end_state + ) + end_state = token_id_and_end_state[1] + if end_state not in seen: + next_states.add(end_state) + + seen.add(start_state) + + return states_to_token_subsets + + +# TODO: Cannot cache typed collections to disk, yet. See +# https://github.com/numba/numba/issues/4698 +@lru_cache +def reduced_vocabulary(tokenizer: "Tokenizer"): + """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" + vocabulary = numba.typed.Dict.empty( + numba.types.string, numba.types.ListType(numba.int64) + ) + empty_token_ids = set() + for token, token_idx in tokenizer.vocabulary.items(): + if token in tokenizer.special_tokens: + continue + + token_str = tokenizer.convert_token_to_string(token) + + if token_str: + vocabulary.setdefault( + token_str, + numba.typed.List.empty_list(numba.int64), + ).append(numba.int64(token_idx)) + else: + empty_token_ids.add(numba.int64(token_idx)) + + return vocabulary, empty_token_ids + + +def create_fsm_index_tokenizer( + fsm: BetterFSM, + tokenizer: "Tokenizer", +) -> Tuple[Dict[int, Dict[int, int]], Set[int]]: + """Construct an FMS index from a tokenizer. + + This uses the end-to-end approach of `create_fsm_index_end_to_end`. + + .. warning:: + + `fsm` needs to be deterministically ordered so that future caching makes sense. + + """ + vocabulary, empty_token_ids = reduced_vocabulary(tokenizer) + + states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary) + + # Allow transitions to EOS from all terminals FSM states that are + # reachable + # TODO: Do we really need this anymore? + for state in fsm.fsm_info.finals: + subset = states_to_token_subsets.get(state) + if subset is not None: + subset.add((tokenizer.eos_token_id, state)) + + # Convert to token-to-end-state maps + states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()} + + return states_to_token_subsets, empty_token_ids diff --git a/python/sglang/srt/constrained/tokenizer.py b/python/sglang/srt/constrained/tokenizer.py new file mode 100644 index 000000000..ac1c8ebed --- /dev/null +++ b/python/sglang/srt/constrained/tokenizer.py @@ -0,0 +1,266 @@ +# Adapted from: +# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py +# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py +from abc import abstractmethod +from typing import ( + TYPE_CHECKING, + Dict, + Hashable, + List, + Optional, + Protocol, + Set, + Tuple, + Union, +) + +import numpy as np +import torch +from numpy.typing import NDArray + + +class Tokenizer(Protocol, Hashable): + eos_token: str + eos_token_id: int + pad_token_id: int + vocabulary: Dict[str, int] + special_tokens: Set[int] + + @abstractmethod + def encode( + self, prompt: Union[str, List[str]] + ) -> Tuple[NDArray[np.int64], NDArray[np.int64]]: + """Translate the input prompts into NumPy arrays of token ids and attention mask.""" + ... + + @abstractmethod + def decode(self, token_ids: NDArray[np.int64]) -> List[str]: + """Translate an array of token ids to a string or list of strings.""" + ... + + @abstractmethod + def convert_token_to_string(self, token: str) -> str: + """Convert a token to its equivalent string. + + This is for instance useful for BPE tokenizers where whitespaces are + represented by the special characted `Ġ`. This prevents matching a raw + token that includes `Ġ` with a string. + + """ + ... + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + +__all__ = ["transformers"] + + +KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...] + + +def get_llama_tokenizer_types(): + """Get all the Llama tokenizer types/classes that need work-arounds. + + When they can't be imported, a dummy class is created. + + """ + try: + from transformers.models.llama import LlamaTokenizer + except ImportError: + + class LlamaTokenizer: # type: ignore + pass + + try: + from transformers.models.llama import LlamaTokenizerFast + except ImportError: + + class LlamaTokenizerFast: # type: ignore + pass + + try: + from transformers.models.code_llama import CodeLlamaTokenizer + except ImportError: + + class CodeLlamaTokenizer: # type: ignore + pass + + try: + from transformers.models.code_llama import CodeLlamaTokenizerFast + except ImportError: + + class CodeLlamaTokenizerFast: # type: ignore + pass + + return ( + LlamaTokenizer, + LlamaTokenizerFast, + CodeLlamaTokenizer, + CodeLlamaTokenizerFast, + ) + + +class Transformer: + """Represents a `transformers` model.""" + + def __init__( + self, + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + ): + self.device = model.device + self.model = model + self.tokenizer = tokenizer + + @torch.inference_mode + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + past_key_values: Optional[Tuple] = None, + ) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]: + """Compute a forward pass through the transformer model. + + Parameters + ---------- + input_ids + The input token ids. Must be one or two dimensional. + attention_mask + The attention mask. Must be one or two dimensional. + past_key_values + A tuple of tuples containing the cached key and value tensors for each + attention head. + + Returns + ------- + The computed logits and the new cached key and value tensors. + + """ + assert 0 < input_ids.ndim < 3 + + if past_key_values: + input_ids = input_ids[..., -1].unsqueeze(-1) + + output = self.model( + input_ids, + attention_mask=attention_mask, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + past_key_values=past_key_values, + ) + + return output.logits, output.past_key_values + + def __call__( + self, + input_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + past_key_values: Optional[Tuple] = None, + ) -> torch.FloatTensor: + logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values) + next_token_logits = logits[..., -1, :] + + return next_token_logits, kv_cache + + +class TransformerTokenizer(Tokenizer): + """Represents a tokenizer for models in the `transformers` library.""" + + def __init__(self, tokenizer): + # TODO: Do something to make this hashable? + self.tokenizer = tokenizer + self.eos_token_id = self.tokenizer.eos_token_id + self.eos_token = self.tokenizer.eos_token + + if not self.tokenizer.pad_token_id: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.pad_token_id = self.eos_token_id + else: + self.pad_token_id = self.tokenizer.pad_token_id + self.pad_token = self.tokenizer.pad_token + + self.special_tokens = set(self.tokenizer.all_special_tokens) + + self.vocabulary = self.tokenizer.get_vocab() + self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types()) + + def encode( + self, prompt: Union[str, List[str]], **kwargs + ) -> Tuple[torch.LongTensor, torch.LongTensor]: + kwargs["padding"] = True + kwargs["return_tensors"] = "pt" + output = self.tokenizer(prompt, **kwargs) + return output["input_ids"], output["attention_mask"] + + def decode(self, token_ids: torch.LongTensor) -> List[str]: + text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) + return text + + def convert_token_to_string(self, token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = self.tokenizer.convert_tokens_to_string([token]) + + if self.is_llama: + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def __eq__(self, other): + if isinstance(other, type(self)): + return False + # TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ? + # return other.model_name == self.model_name and other.kwargs == self.kwargs + return NotImplemented + + def __hash__(self): + from datasets.fingerprint import Hasher + + return hash(Hasher.hash(self.tokenizer)) + + +def transformers( + model_name: str, + device: Optional[str] = None, + model_kwargs: dict = {}, + tokenizer_kwargs: dict = {}, +): + """Instantiate a model from the `transformers` library and its tokenizer. + + Parameters + ---------- + model_name + The name of the model as listed on Hugging Face's model page. + device + The device(s) on which the model should be loaded. This overrides + the `device_map` entry in `model_kwargs` when provided. + model_kwargs + A dictionary that contains the keyword arguments to pass to the + `from_pretrained` method when loading the model. + tokenizer_kwargs + A dictionary that contains the keyword arguments to pass to the + `from_pretrained` method when loading the tokenizer. + + Returns + ------- + A `TransformersModel` model instance. + + """ + try: + from transformers import AutoModelForCausalLM + except ImportError: + raise ImportError( + "The `transformers` library needs to be installed in order to use `transformers` models." + ) + + if device is not None: + model_kwargs["device_map"] = device + + model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs) + + return Transformer(model, tokenizer) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py new file mode 100644 index 000000000..fde8457a3 --- /dev/null +++ b/python/sglang/srt/hf_transformers_utils.py @@ -0,0 +1,164 @@ +"""Utilities for Huggingface Transformers.""" + +import json +import os +import warnings +from typing import List, Optional, Tuple, Union + +from huggingface_hub import snapshot_download +from sglang.srt.utils import is_multimodal_model +from transformers import ( + AutoConfig, + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) + + +def download_from_hf(model_path: str): + if os.path.exists(model_path): + return model_path + + return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"]) + + +def get_config_json(model_path: str): + with open(os.path.join(model_path, "config.json")) as f: + config = json.load(f) + return config + + +def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None): + config = AutoConfig.from_pretrained( + model, trust_remote_code=trust_remote_code, revision=revision + ) + return config + + +# Models don't use the same configuration key for determining the maximum +# context length. Store them here so we can sanely check them. +# NOTE: The ordering here is important. Some models have two of these and we +# have a preference for which value gets used. +CONTEXT_LENGTH_KEYS = [ + "max_sequence_length", + "seq_length", + "max_position_embeddings", + "max_seq_len", + "model_max_length", +] + + +def get_context_length(config): + """Get the context length of a model from a huggingface model config.""" + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling: + rope_scaling_factor = config.rope_scaling["factor"] + else: + rope_scaling_factor = 1 + + for key in CONTEXT_LENGTH_KEYS: + val = getattr(config, key, None) + if val is not None: + return int(rope_scaling_factor * val) + return 2048 + + +# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" + + +def get_tokenizer( + tokenizer_name: str, + *args, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + tokenizer_revision: Optional[str] = None, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Gets a tokenizer for the given model name via Huggingface.""" + if is_multimodal_model(tokenizer_name): + processor = get_processor( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + tokenizer_revision=tokenizer_revision, + **kwargs, + ) + tokenizer = processor.tokenizer + return tokenizer + + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + + if ( + "llama" in tokenizer_name.lower() + and kwargs.get("use_fast", True) + and tokenizer_name != _FAST_LLAMA_TOKENIZER + ): + pass + # warnings.warn( + # "For some LLaMA V1 models, initializing the fast tokenizer may " + # "take a long time. To reduce the initialization time, consider " + # f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + # "tokenizer." + # ) + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + tokenizer_revision=tokenizer_revision, + **kwargs, + ) + except TypeError as e: + # The LLaMA tokenizer causes a protobuf error in some environments. + err_msg = ( + "Failed to load the tokenizer. If you are using a LLaMA V1 model " + f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the " + "original tokenizer." + ) + raise RuntimeError(err_msg) from e + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + if not trust_remote_code and ( + "does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e) + ): + err_msg = ( + "Failed to load the tokenizer. If the tokenizer is a custom " + "tokenizer not yet available in the HuggingFace transformers " + "library, consider setting `trust_remote_code=True` in LLM " + "or using the `--trust-remote-code` flag in the CLI." + ) + raise RuntimeError(err_msg) from e + else: + raise e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + warnings.warn( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead." + ) + return tokenizer + + +def get_processor( + tokenizer_name: str, + *args, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + tokenizer_revision: Optional[str] = None, + **kwargs, +): + processor = AutoProcessor.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + tokenizer_revision=tokenizer_revision, + **kwargs, + ) + return processor diff --git a/python/sglang/srt/layers/context_flashattention_nopad.py b/python/sglang/srt/layers/context_flashattention_nopad.py new file mode 100644 index 000000000..6159e9a51 --- /dev/null +++ b/python/sglang/srt/layers/context_flashattention_nopad.py @@ -0,0 +1,181 @@ +# Adapted from +# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 +import torch +import triton +import triton.language as tl +from sglang.srt.utils import wrap_kernel_launcher + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + Out, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + kv_group_num: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] + + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + + +cached_kernel = None + + +def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): + BLOCK = 128 + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + num_warps = 4 if Lk <= 64 else 8 + + global cached_kernel + if cached_kernel: + cached_kernel( + grid, + num_warps, + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + o, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + ) + return + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + o, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + kv_group_num=kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + cached_kernel = wrap_kernel_launcher(_fwd_kernel) diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py new file mode 100644 index 000000000..18f403ae6 --- /dev/null +++ b/python/sglang/srt/layers/extend_attention.py @@ -0,0 +1,371 @@ +import torch +import triton +import triton.language as tl +from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd + + +@triton.jit +def _fwd_kernel( + Q_Extend, + K_Extend, + V_Extend, + O_Extend, + K_Buffer, + V_Buffer, + Req_to_tokens, + B_req_idx, + B_Seq_Len, + B_Start_Loc_Extend, + B_Seq_Len_Extend, + sm_scale, + kv_group_num, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_req_to_tokens_b, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_seq = tl.program_id(0) + cur_head = tl.program_id(1) + cur_block_m = tl.program_id(2) + cur_kv_head = cur_head // kv_group_num + + cur_seq_len = tl.load(B_Seq_Len + cur_seq) + cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq) + cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend + + cur_seq_prefix_start_in_loc = 0 + cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq) + cur_batch_req_idx = tl.load(B_req_idx + cur_seq) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = tl.arange(0, BLOCK_M) + mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend + offs_q = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0) + + # stage1: compute scores with prefix + offs_n = tl.arange(0, BLOCK_N) + + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + deno = tl.zeros([BLOCK_M], dtype=tl.float32) + e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + + for start_n in range(0, cur_seq_len_prefix, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_seq_len_prefix + offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + ( + cur_seq_prefix_start_in_loc + start_n + offs_n + ) + offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0) + + # load k in transposed way + offs_buf_k = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + offs_buf_v = ( + offs_kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_d[None, :] + ) + v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + # stage2: compute the trianlge part + + cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) + for start_n in range(0, cur_block_m_end, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_block_m_end + + # load k in transposed way + offs_k = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] + ) + k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( + start_n + offs_n[None, :] + ) + mask_causual &= mask_m[:, None] & mask_n[None, :] + qk = tl.where(mask_causual, qk, float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + offs_v = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs + + cur_kv_head * stride_vh + + offs_d[None, :] + ) + v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + offs_o = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_obs + + cur_head * stride_oh + + offs_d[None, :] + ) + tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) + + +def extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + b_seq_len_prefix, + b_start_loc_extend, + b_seq_len_extend, + max_len_in_batch, + max_len_extend, +): + """ + q_extend, k_extend, v_extend, o_extend: contiguous tensors + + k_buffer, v_buffer: (prefix + extend) tensors in mem_manager + """ + BLOCK_M, BLOCK_N = 128, 128 + Lq, Lk, Lv, Lo = ( + q_extend.shape[-1], + k_extend.shape[-1], + v_extend.shape[-1], + o_extend.shape[-1], + ) + assert Lq == Lk and Lk == Lv and Lv == Lo + assert Lq in {16, 32, 64, 128} + + sm_scale = 1.0 / (Lq**0.5) + batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] + kv_group_num = q_extend.shape[1] // k_extend.shape[1] + + grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) + num_warps = 4 if Lk <= 64 else 8 + num_stages = 1 + + _fwd_kernel[grid]( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_seq_len, + b_start_loc_extend, + b_seq_len_extend, + sm_scale, + kv_group_num, + q_extend.stride(0), + q_extend.stride(1), + k_extend.stride(0), + k_extend.stride(1), + v_extend.stride(0), + v_extend.stride(1), + o_extend.stride(0), + o_extend.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), + req_to_tokens.stride(0), + BLOCK_DMODEL=Lq, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def redundant_attention( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + b_seq_len_prefix, + max_len_in_batch, +): + total_token_num = k_buffer.shape[0] + B, H_Q, D = b_req_idx.shape[0], q_extend.shape[-2], q_extend.shape[-1] + q_buffer = torch.empty( + (total_token_num, H_Q, D), dtype=q_extend.dtype, device=q_extend.device + ) + + pt = 0 + for i in range(B): + cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i] + pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i] + q_buffer[pl:pr] = q_extend[pt : pt + cur_seq_len_extend] + pt += cur_seq_len_extend + + o_buffer = torch.empty_like(q_buffer) + context_attention_fwd( + q_buffer, k_buffer, v_buffer, o_buffer, b_start_loc, b_seq_len, max_len_in_batch + ) + + pt = 0 + for i in range(B): + cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i] + pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i] + o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr] + pt += cur_seq_len_extend + + +def test(): + torch.manual_seed(0) + + B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128 + dtype = torch.float16 + + b_seq_len_prefix = torch.randint( + 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" + ) + b_seq_len_extend = torch.randint( + 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" + ) + b_seq_len = b_seq_len_prefix + b_seq_len_extend + max_len_in_batch = torch.max(b_seq_len, 0)[0].item() + + b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") + req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda") + b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) + b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + for i in range(B): + req_to_tokens[i, : b_seq_len[i]] = torch.arange( + b_start_loc[i], b_start_loc[i] + b_seq_len[i] + ) + + total_token_num = torch.sum(b_seq_len).item() + extend_token_num = torch.sum(b_seq_len_extend).item() + k_buffer = torch.empty( + (total_token_num, H_KV, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + v_buffer = torch.empty( + (total_token_num, H_KV, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + + k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") + v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") + q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + for i in range(B): + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] + extend_start = b_start_loc_extend[i] + extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] + k_extend[extend_start:extend_end] = k_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + v_extend[extend_start:extend_end] = v_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + q_extend[extend_start:extend_end] = torch.empty( + (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + + o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + + b_seq_len_extend = b_seq_len - b_seq_len_prefix + b_start_loc_extend = torch.zeros_like(b_seq_len) + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + b_seq_len_prefix, + b_start_loc_extend, + b_seq_len_extend, + max_len_in_batch, + max_len_extend, + ) + + redundant_attention( + q_extend, + k_extend, + v_extend, + o_redundant, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + b_seq_len_prefix, + max_len_in_batch, + ) + + print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant))) + print("Max: ", torch.max(torch.abs(o_extend - o_redundant))) + + assert torch.allclose(o_extend, o_redundant, rtol=1e-2) + + +if __name__ == "__main__": + test() diff --git a/python/sglang/srt/layers/get_selected_logprob.py b/python/sglang/srt/layers/get_selected_logprob.py new file mode 100644 index 000000000..60e5b3ba2 --- /dev/null +++ b/python/sglang/srt/layers/get_selected_logprob.py @@ -0,0 +1,79 @@ +import torch +import triton +import triton.language as tl +from sglang.srt.utils import wrap_kernel_launcher + + +@triton.jit +def _fwd_segmented_gather( + all_logits, + len_add_1, + cum_len, + input_ids, + logprobs, + max_seq_len, + voc_size: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + cur_req = tl.program_id(0) + cur_l = tl.load(len_add_1 + cur_req) + cum_l = tl.load(cum_len + cur_req) + + for i in range(0, (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE): + off = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = off < cur_l - 1 + + idx = tl.load(input_ids + cum_l - cur_l + off + 1, mask=mask) + data = tl.load(all_logits + (cum_l - cur_l + off) * voc_size + idx, mask=mask) + tl.store(logprobs + cum_l - cur_l - cur_req + off, data, mask=mask) + + +cached_kernel = None + + +def get_selected_logprob(all_logits, len_add_1, input_ids, logprobs): + cum_len = torch.cumsum(len_add_1, dtype=torch.int32, dim=0) + voc_size = all_logits.shape[1] + grid = (len_add_1.shape[0], 1, 1) + max_seq_len = len_add_1.max().item() + + global cached_kernel + if cached_kernel: + cached_kernel( + grid, + 4, + all_logits, + len_add_1, + cum_len, + input_ids, + logprobs, + max_seq_len, + ) + return + + _fwd_segmented_gather[grid]( + all_logits, + len_add_1, + cum_len, + input_ids, + logprobs, + max_seq_len, + voc_size, + BLOCK_SIZE=128, + ) + cached_kernel = wrap_kernel_launcher(_fwd_segmented_gather) + + +if __name__ == "__main__": + all_logits = torch.tensor( + # s s s + [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]], + dtype=torch.float32, + device="cuda", + ) + len_add_1 = torch.tensor([2, 3], dtype=torch.int32, device="cuda") + input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda") + logprobs = torch.empty((3), dtype=torch.float32, device="cuda") + get_selected_logprobs(all_logits, len_add_1, input_ids, logprobs) + print(logprobs) + # assert logprobs == [2, 2, 4] diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py new file mode 100644 index 000000000..315d71869 --- /dev/null +++ b/python/sglang/srt/layers/logits_processor.py @@ -0,0 +1,77 @@ +import torch +from sglang.srt.layers.get_selected_logprob import get_selected_logprob +from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata +from torch import nn +from vllm.model_executor.parallel_utils.communication_op import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) + + +class LogitsProcessor(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + + def forward(self, input_ids, hidden_states, weight, input_metadata): + if not input_metadata.return_normalized_logprob: + if input_metadata.forward_mode == ForwardMode.DECODE: + last_hidden = hidden_states + else: + last_index = ( + torch.cumsum( + input_metadata.seq_lens - input_metadata.prefix_lens, + dim=0, + dtype=torch.long, + ) + - 1 + ) + last_hidden = hidden_states[last_index] + hidden_states = None + + last_logits = torch.matmul(last_hidden, weight.T) + if self.tp_size > 1: + last_logits = tensor_model_parallel_all_gather(last_logits) + last_logits = last_logits[:, : self.config.vocab_size] + return last_logits, None + else: + assert input_metadata.forward_mode != ForwardMode.DECODE + last_index = ( + torch.cumsum( + input_metadata.seq_lens - input_metadata.prefix_lens, + dim=0, + dtype=torch.long, + ) + - 1 + ) + + logits = torch.matmul(hidden_states, weight.T) + if self.tp_size > 1: + logits = tensor_model_parallel_all_gather(logits) + logits = logits[:, : self.config.vocab_size] + all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6) + + normalized_logprobs = compute_normalized_logprobs( + all_logprobs, + input_metadata.seq_lens - input_metadata.prefix_lens, + input_ids, + ) + + last_logits = logits[last_index] + return last_logits, normalized_logprobs + + +def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids): + # assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0] + logprobs = torch.zeros( + (all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda" + ) + get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs) + cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) + end = torch.cumsum(len_add_1.sub_(1), dim=0) + start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0) + end.sub_(1) + sum_logp = cumsum[end] - cumsum[start] + logprobs[start] + res = sum_logp / len_add_1 + return res diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py new file mode 100644 index 000000000..41bc23b43 --- /dev/null +++ b/python/sglang/srt/layers/radix_attention.py @@ -0,0 +1,158 @@ +from typing import List + +import torch +from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd +from sglang.srt.layers.extend_attention import extend_attention_fwd +from sglang.srt.layers.token_attention import token_attention_fwd +from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata +from torch import nn +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + + +class RadixAttention(nn.Module): + def __init__( + self, + num_heads, + head_dim, + scaling, + num_kv_heads, + layer_id, + ): + super().__init__() + + self.tp_q_head_num = num_heads + self.tp_k_head_num = num_kv_heads + self.tp_v_head_num = num_kv_heads + self.head_dim = head_dim + self.layer_id = layer_id + + from sglang.srt.managers.router.model_runner import global_model_mode + + self.use_flashinfer = "flashinfer" in global_model_mode + + if self.use_flashinfer: + self.prefill_forward = self.prefill_forward_flashinfer + self.extend_forward = self.prefill_forward_flashinfer + self.decode_forward = self.decode_forward_flashinfer + else: + self.prefill_forward = self.prefill_forward_triton + self.extend_forward = self.extend_forward_triton + self.decode_forward = self.decode_forward_triton + + def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata): + o = torch.empty_like(q) + + context_attention_fwd( + q.view(-1, self.tp_q_head_num, self.head_dim), + k, + v, + o.view(-1, self.tp_q_head_num, self.head_dim), + input_metadata.start_loc, + input_metadata.seq_lens, + input_metadata.max_seq_len, + ) + self.store_kv_cache(k, v, input_metadata) + + return o + + def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): + o = torch.empty_like(q) + self.store_kv_cache(k, v, input_metadata) + + extend_attention_fwd( + q.view(-1, self.tp_q_head_num, self.head_dim), + k.contiguous(), + v.contiguous(), + o.view(-1, self.tp_q_head_num, self.head_dim), + input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), + input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), + input_metadata.req_to_token_pool.req_to_token, + input_metadata.req_pool_indices, + input_metadata.start_loc, + input_metadata.seq_lens, + input_metadata.prefix_lens, + input_metadata.extend_start_loc, + input_metadata.extend_seq_lens, + input_metadata.max_seq_len, + input_metadata.max_extend_len, + ) + + return o + + def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): + o = torch.empty_like(q) + self.store_kv_cache(k, v, input_metadata) + + token_attention_fwd( + q.view(-1, self.tp_q_head_num, self.head_dim), + input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), + input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), + o.view(-1, self.tp_q_head_num, self.head_dim), + input_metadata.req_to_token_pool.req_to_token, + input_metadata.req_pool_indices, + input_metadata.start_loc, + input_metadata.seq_lens, + input_metadata.max_seq_len, + input_metadata.other_kv_index, + input_metadata.total_num_tokens, + ) + + return o + + def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): + self.store_kv_cache(k, v, input_metadata) + + o = input_metadata.prefill_wrapper.forward( + q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), + input_metadata.qo_indptr, + input_metadata.token_to_kv_pool.kv_data[self.layer_id], + input_metadata.kv_indptr, + input_metadata.kv_indices, + input_metadata.kv_last_page_len, + allow_fp16_qk_reduction=True, + ) + + return o.view(-1, self.tp_q_head_num * self.head_dim) + + def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): + self.store_kv_cache(k, v, input_metadata) + + o = input_metadata.decode_wrapper.forward( + q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), + input_metadata.token_to_kv_pool.kv_data[self.layer_id], + input_metadata.kv_indptr, + input_metadata.kv_indices, + input_metadata.kv_last_page_len, + ) + + return o.view(-1, self.tp_q_head_num * self.head_dim) + + def forward(self, q, k, v, input_metadata: InputMetadata): + k = k.view(-1, self.tp_k_head_num, self.head_dim) + v = v.view(-1, self.tp_v_head_num, self.head_dim) + + if input_metadata.forward_mode == ForwardMode.PREFILL: + return self.prefill_forward(q, k, v, input_metadata) + elif input_metadata.forward_mode == ForwardMode.EXTEND: + return self.extend_forward(q, k, v, input_metadata) + elif input_metadata.forward_mode == ForwardMode.DECODE: + return self.decode_forward(q, k, v, input_metadata) + + def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): + key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) + value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) + if input_metadata.out_cache_loc is not None: + key_buffer[input_metadata.out_cache_loc] = cache_k + value_buffer[input_metadata.out_cache_loc] = cache_v + elif input_metadata.out_cache_cont_start is not None: + key_buffer[ + input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end + ] = cache_k + value_buffer[ + input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end + ] = cache_v + else: + raise RuntimeError() diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py new file mode 100644 index 000000000..8ac4ed959 --- /dev/null +++ b/python/sglang/srt/layers/token_attention.py @@ -0,0 +1,324 @@ +# Adapted from +# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py +# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py +import torch +import triton +import triton.language as tl +from sglang.srt.utils import wrap_kernel_launcher + + +@triton.jit +def _fwd_kernel_stage1( + Q, + K_Buffer, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + att_stride_h, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_n = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + cur_batch_start_index = 0 + cur_batch_end_index = cur_batch_seq_len + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + offs_n_new = cur_batch_start_index + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + offs_buf_k = ( + k_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[None, :] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=offs_n_new[:, None] < cur_batch_end_index, + other=0.0, + ) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) + tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) + + +@triton.jit +def _fwd_kernel_stage2( + Logics, + V_Buffer, + Out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + stride_logic_h, + stride_buf_vbs, + stride_buf_vh, + stride_obs, + stride_oh, + stride_req_to_token_b, + other_kv_index, # To fix a NAN issue + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] + v_ptrs = V_Buffer + offs_buf_v + + e_max = float("-inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + v_index = tl.load( + Req_to_tokens + + cur_batch_req_idx * stride_req_to_token_b + + (start_n + offs_n), + mask=(start_n + offs_n) < cur_batch_seq_len, + other=other_kv_index, + ) + + qk = tl.load( + Logics + + cur_head * stride_logic_h + + (cur_batch_start_loc + start_n + offs_n), + mask=start_n + offs_n < cur_batch_seq_len, + other=float("-inf"), + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + old_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + e_sum = e_sum * old_scale + tl.sum(p, 0) + v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs) + acc = acc * old_scale + tl.sum(p[:, None] * v, 0) + e_max = n_e_max + + acc = acc / e_sum + off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + +cached_kernel_stage1 = None +cached_kernel_stage2 = None + + +def _token_att_m_fwd( + q, + k_buffer, + att_out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + max_len_in_batch, +): + BLOCK = 32 + # shape constraints + Lq, Lk = q.shape[-1], k_buffer.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk**0.5) + + batch, head_num = B_req_idx.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK)) + kv_group_num = q.shape[1] // k_buffer.shape[1] + + if kv_group_num == 1: + num_warps = 4 + else: + num_warps = 2 + + global cached_kernel_stage1 + if cached_kernel_stage1: + cached_kernel_stage1( + grid, + num_warps, + q, + k_buffer, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + att_out.stride(0), + ) + return + + _fwd_kernel_stage1[grid]( + q, + k_buffer, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + att_out.stride(0), + kv_group_num=kv_group_num, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1) + + +def _token_softmax_reducev_fwd( + logics, + v_buffer, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + other_kv_index, +): + BLOCK = 64 + batch, head = b_seq_len.shape[0], logics.shape[0] + grid = (batch, head, 1) + kv_group_num = logics.shape[0] // v_buffer.shape[1] + + num_warps = 1 + + global cached_kernel_stage2 + if cached_kernel_stage2: + cached_kernel_stage2( + grid, + num_warps, + logics, + v_buffer, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + logics.stride(0), + v_buffer.stride(0), + v_buffer.stride(1), + o.stride(0), + o.stride(1), + req_to_tokens.stride(0), + other_kv_index, + ) + return + + _fwd_kernel_stage2[grid]( + logics, + v_buffer, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + logics.stride(0), + v_buffer.stride(0), + v_buffer.stride(1), + o.stride(0), + o.stride(1), + req_to_tokens.stride(0), + other_kv_index, + kv_group_num=kv_group_num, + BLOCK_DMODEL=v_buffer.shape[-1], + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=3, + ) + cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2) + + +def token_attention_fwd( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + max_len_in_batch, + other_kv_index, + total_num_tokens, + att_m=None, +): + if att_m is None: + att_m = torch.empty( + (q.shape[-2], total_num_tokens), dtype=q.dtype, device="cuda" + ) + + _token_att_m_fwd( + q, + k_buffer, + att_m, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + max_len_in_batch, + ) + _token_softmax_reducev_fwd( + att_m, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + other_kv_index, + ) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py new file mode 100644 index 000000000..27fdeb749 --- /dev/null +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -0,0 +1,85 @@ +import asyncio + +import uvloop +import zmq +import zmq.asyncio +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import get_exception_traceback + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +class DetokenizerManager: + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + ): + context = zmq.asyncio.Context(2) + self.recv_from_router = context.socket(zmq.PULL) + self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}") + + self.send_to_tokenizer = context.socket(zmq.PUSH) + self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}") + + self.tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + + async def handle_loop(self): + while True: + recv_obj = await self.recv_from_router.recv_pyobj() + + if isinstance(recv_obj, BatchTokenIDOut): + output_tokens = recv_obj.output_tokens + + # TODO(lmzheng): handle skip_special_tokens per request + output_strs = self.tokenizer.batch_decode( + output_tokens, + skip_special_tokens=recv_obj.skip_special_tokens[0], + ) + + # Trim stop str + # TODO(lmzheng): handle the case where multiple stop strs are hit + for i in range(len(output_strs)): + if recv_obj.hit_stop_str[i] is not None: + pos = output_strs[i].find(recv_obj.hit_stop_str[i]) + if pos != -1: + output_strs[i] = output_strs[i][:pos] + + if len(output_tokens[i]) > 0: + first_token = self.tokenizer.convert_ids_to_tokens( + int(output_tokens[i][0]) + ) + if first_token.startswith("▁"): + output_strs[i] = " " + output_strs[i] + + self.send_to_tokenizer.send_pyobj( + BatchStrOut( + recv_obj.rids, + output_strs, + recv_obj.meta_info, + recv_obj.finished, + ) + ) + else: + raise ValueError(f"Invalid object: {recv_obj}") + + +def start_detokenizer_process( + server_args: ServerArgs, + port_args: PortArgs, + pipe_writer, +): + try: + manager = DetokenizerManager(server_args, port_args) + except Exception as e: + pipe_writer.send(get_exception_traceback()) + raise + pipe_writer.send("init ok") + loop = asyncio.get_event_loop() + loop.run_until_complete(manager.handle_loop()) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py new file mode 100644 index 000000000..a6dc1a380 --- /dev/null +++ b/python/sglang/srt/managers/io_struct.py @@ -0,0 +1,88 @@ +import uuid +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +from sglang.srt.sampling_params import SamplingParams + + +@dataclass +class GenerateReqInput: + text: Union[List[str], str] + image_data: Optional[Union[List[str], str]] = None + sampling_params: Union[List[Dict], Dict] = None + rid: Optional[Union[List[str], str]] = None + return_normalized_logprob: Optional[Union[List[bool], bool]] = None + normalized_logprob_start_len: Optional[Union[List[int], int]] = None + stream: bool = False + + def post_init(self): + is_single = isinstance(self.text, str) + + if is_single: + if self.sampling_params is None: + self.sampling_params = {} + if self.rid is None: + self.rid = uuid.uuid4().hex + if self.return_normalized_logprob is None: + self.return_normalized_logprob = False + if self.normalized_logprob_start_len is None: + self.normalized_logprob_start_len = 0 + else: + num = len(self.text) + + if self.image_data is None: + self.image_data = [None] * num + elif not isinstance(self.image_data, list): + self.image_data = [self.image_data] * num + + if self.sampling_params is None: + self.sampling_params = [{}] * num + elif not isinstance(self.sampling_params, list): + self.sampling_params = [self.sampling_params] * num + + if self.rid is None: + self.rid = [uuid.uuid4().hex for _ in range(num)] + else: + assert isinstance(self.rid, list) + + if self.return_normalized_logprob is None: + self.return_normalized_logprob = [False] * num + elif not isinstance(self.return_normalized_logprob, list): + self.return_normalized_logprob = [self.return_normalized_logprob] * num + + if self.normalized_logprob_start_len is None: + self.normalized_logprob_start_len = [0] * num + elif not isinstance(self.normalized_logprob_start_len, list): + self.normalized_logprob_start_len = [ + self.normalized_logprob_start_len + ] * num + + +@dataclass +class TokenizedGenerateReqInput: + rid: str + input_ids: List[int] + pixel_values: List[float] + image_hash: int + sampling_params: SamplingParams + return_normalized_logprob: bool + normalized_logprob_start_len: int + stream: bool + + +@dataclass +class BatchTokenIDOut: + rids: List[str] + output_tokens: List[List[int]] + hit_stop_str: List[Optional[str]] + skip_special_tokens: List[bool] + meta_info: List[Dict] + finished: List[bool] + + +@dataclass +class BatchStrOut: + rids: List[str] + output_str: List[str] + meta_info: List[Dict] + finished: List[bool] diff --git a/python/sglang/srt/managers/openai_protocol.py b/python/sglang/srt/managers/openai_protocol.py new file mode 100644 index 000000000..daa4ac9dc --- /dev/null +++ b/python/sglang/srt/managers/openai_protocol.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import Any, List, Optional, Union + + +@dataclass +class CompletionRequest: + prompt: Union[str, List[Any]] + model: str = "default" + temperature: Optional[float] = 0.7 + max_tokens: Optional[int] = 16 + n: Optional[int] = 1 + stop: Optional[Union[str, List[str]]] = None diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py new file mode 100644 index 000000000..c98d6d519 --- /dev/null +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -0,0 +1,326 @@ +from enum import Enum, auto +from typing import List + +import numpy as np +import torch +from sglang.srt.managers.router.radix_cache import RadixCache +from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool + + +class ForwardMode(Enum): + PREFILL = auto() + EXTEND = auto() + DECODE = auto() + + +class FinishReason(Enum): + LENGTH = auto() + EOS_TOKEN = auto() + STOP_STR = auto() + + +class Req: + def __init__(self, rid): + self.rid = rid + self.input_ids = [] + self.output_ids = [] + self.pixel_values = None + self.image_offset = 0 + self.sampling_params = None + self.return_normalized_logprob = False + self.normalized_logprob_start_len = 0 + self.stream = False + + self.tokenizer = None + self.finished = False + self.finish_reason = None + self.hit_stop_str = None + + self.adjust_input_len = 0 + self.prefix_indices = [] + + self.normalized_logprob = None + + # for constrained decoding + self.regex_fsm = None + self.regex_fsm_state = None + + def max_new_tokens(self): + return self.sampling_params.max_new_tokens + + def check_finished(self): + if self.finished: + return + + if len(self.output_ids) >= self.sampling_params.max_new_tokens: + self.finished = True + self.finish_reason = FinishReason.LENGTH + return + + if ( + self.output_ids[-1] == self.tokenizer.eos_token_id + and self.sampling_params.ignore_eos == False + ): + self.finished = True + self.finish_reason = FinishReason.EOS_TOKEN + return + + if len(self.sampling_params.stop_strs) > 0: + tail_str = self.tokenizer.decode( + self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] + ) + + for stop_str in self.sampling_params.stop_strs: + if stop_str in tail_str: + self.finished = True + self.finish_reason = FinishReason.STOP_STR + self.hit_stop_str = stop_str + return + + def __repr__(self): + return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, " + + +class Batch: + def __init__( + self, + reqs: List[Req], + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool: TokenToKVPool, + tree_cache: RadixCache, + ): + self.reqs = reqs + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool = token_to_kv_pool + self.tree_cache = tree_cache + + self.return_normalized_logprob = any( + req.return_normalized_logprob for req in reqs + ) + + def is_empty(self): + return len(self.reqs) == 0 + + def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor): + device = "cuda" + bs = len(self.reqs) + reqs = self.reqs + input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs] + prefix_indices = [r.prefix_indices for r in reqs] + + # Handle prefix + flatten_input_ids = [] + extend_lens = [] + prefix_lens = [] + seq_lens = [] + + req_pool_indices = self.req_to_token_pool.alloc(bs) + req_pool_indices_cpu = req_pool_indices.cpu().numpy() + for i in range(bs): + flatten_input_ids.extend(input_ids[i]) + extend_lens.append(len(input_ids[i])) + + if len(prefix_indices[i]) == 0: + prefix_lens.append(0) + else: + prefix_lens.append(len(prefix_indices[i])) + self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][ + : len(prefix_indices[i]) + ] = prefix_indices[i] + + seq_lens.append(prefix_lens[-1] + extend_lens[-1]) + + position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device) + + # Alloc mem + seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens) + extend_num_tokens = seq_lens.sum() - prefix_lens.sum() + out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) + if out_cache_loc is None: + self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free) + out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) + + if out_cache_loc is None: + print("Prefill out of memory.") + self.tree_cache.pretty_print() + exit() + + pt = 0 + for i in range(bs): + self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][ + prefix_lens[i] : prefix_lens[i] + extend_lens[i] + ] = out_cache_loc[pt : pt + extend_lens[i]] + pt += extend_lens[i] + + # Handle logit bias + logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device) + for i in range(bs): + if reqs[i].sampling_params.dtype == "int": + logit_bias[i] = int_token_logit_bias + + # Set fields + self.input_ids = torch.tensor( + flatten_input_ids, dtype=torch.int32, device=device + ) + self.pixel_values = [r.pixel_values for r in reqs] + self.image_offsets = [ + r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens) + ] + self.req_pool_indices = req_pool_indices + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device) + self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device) + self.position_ids_offsets = position_ids_offsets + self.extend_num_tokens = extend_num_tokens + self.out_cache_loc = out_cache_loc + + self.temperatures = torch.tensor( + [r.sampling_params.temperature for r in reqs], + dtype=torch.float, + device=device, + ).view(-1, 1) + self.top_ps = torch.tensor( + [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device + ).view(-1, 1) + self.top_ks = torch.tensor( + [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device + ).view(-1, 1) + self.frequency_penalties = torch.tensor( + [r.sampling_params.frequency_penalty for r in reqs], + dtype=torch.float, + device=device, + ) + self.presence_penalties = torch.tensor( + [r.sampling_params.presence_penalty for r in reqs], + dtype=torch.float, + device=device, + ) + self.logit_bias = logit_bias + + def update_for_decode(self, input_ids=None): + if input_ids is None: + input_ids = [ + r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs + ] + self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") + self.seq_lens.add_(1) + self.prefix_lens = None + + # Alloc mem + bs = len(self.reqs) + alloc_res = self.token_to_kv_pool.alloc_contiguous(bs) + if alloc_res is None: + self.out_cache_loc = self.token_to_kv_pool.alloc(bs) + + if self.out_cache_loc is None: + self.tree_cache.evict(bs, self.token_to_kv_pool.free) + self.out_cache_loc = self.token_to_kv_pool.alloc(bs) + + if self.out_cache_loc is None: + print("Decode out of memory.") + self.tree_cache.pretty_print() + exit() + + self.out_cache_cont_start = None + self.out_cache_cont_end = None + else: + self.out_cache_loc = alloc_res[0] + self.out_cache_cont_start = alloc_res[1] + self.out_cache_cont_end = alloc_res[2] + + self.req_to_token_pool.req_to_token[ + self.req_pool_indices, self.seq_lens - 1 + ] = self.out_cache_loc + + def filter_batch(self, unfinished_indices: List[int]): + self.reqs = [self.reqs[i] for i in unfinished_indices] + new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda") + self.seq_lens = self.seq_lens[new_indices] + self.input_ids = None + self.req_pool_indices = self.req_pool_indices[new_indices] + self.prefix_lens = None + self.position_ids_offsets = self.position_ids_offsets[new_indices] + self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None + + for item in [ + "temperatures", + "top_ps", + "top_ks", + "frequency_penalties", + "presence_penalties", + "logit_bias", + ]: + setattr(self, item, getattr(self, item)[new_indices]) + + def merge(self, other): + self.reqs.extend(other.reqs) + + self.req_pool_indices = torch.concat( + [self.req_pool_indices, other.req_pool_indices] + ) + self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) + self.prefix_lens = None + self.position_ids_offsets = torch.concat( + [self.position_ids_offsets, other.position_ids_offsets] + ) + self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None + + for item in [ + "temperatures", + "top_ps", + "top_ks", + "frequency_penalties", + "presence_penalties", + "logit_bias", + ]: + setattr( + self, item, torch.concat([getattr(self, item), getattr(other, item)]) + ) + + def sample(self, logits: torch.Tensor): + # Post process logits + logits = logits.contiguous() + logits.div_(self.temperatures) + logits.add_(self.logit_bias) + + has_regex = any(req.regex_fsm is not None for req in self.reqs) + if has_regex: + allowed_mask = torch.empty_like(logits[0], dtype=torch.bool) + for i, req in enumerate(self.reqs): + if req.regex_fsm is not None: + allowed_mask.zero_() + allowed_mask[ + req.regex_fsm.allowed_token_ids(req.regex_fsm_state) + ] = 1 + logits[i].masked_fill_(~allowed_mask, float("-inf")) + + # TODO(lmzheng): apply penalty + probs = torch.softmax(logits, dim=-1) + probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks) + sampled_index = torch.multinomial(probs_sort, num_samples=1) + batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view( + -1 + ) + batch_next_token_probs = torch.gather( + probs_sort, dim=1, index=sampled_index + ).view(-1) + + if has_regex: + batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() + for i, req in enumerate(self.reqs): + if req.regex_fsm is not None: + req.regex_fsm_state = req.regex_fsm.next_state( + req.regex_fsm_state, batch_next_token_ids_cpu[i] + ) + + return batch_next_token_ids, batch_next_token_probs + + +def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0 + probs_sort[ + torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks + ] = 0.0 + probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) + return probs_sort, probs_idx diff --git a/python/sglang/srt/managers/router/manager.py b/python/sglang/srt/managers/router/manager.py new file mode 100644 index 000000000..8ac027dfa --- /dev/null +++ b/python/sglang/srt/managers/router/manager.py @@ -0,0 +1,71 @@ +import asyncio +import logging +from typing import List, Tuple + +import uvloop +import zmq +import zmq.asyncio +from sglang.srt.managers.router.model_rpc import ModelRpcClient +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import get_exception_traceback + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +class RouterManager: + def __init__(self, model_client: ModelRpcClient, port_args: PortArgs): + # Init communication + context = zmq.asyncio.Context(2) + self.recv_from_tokenizer = context.socket(zmq.PULL) + self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}") + + self.send_to_detokenizer = context.socket(zmq.PUSH) + self.send_to_detokenizer.connect( + f"tcp://127.0.0.1:{port_args.detokenizer_port}" + ) + + # Init status + self.model_client = model_client + self.recv_reqs = [] + + async def loop_for_forward(self): + while True: + next_step_input = list(self.recv_reqs) + self.recv_reqs = [] + out_pyobjs = await self.model_client.step(next_step_input) + + for obj in out_pyobjs: + self.send_to_detokenizer.send_pyobj(obj) + + # await for a while to accept input requests + await asyncio.sleep(0.001) + + async def loop_for_recv_requests(self): + while True: + recv_req = await self.recv_from_tokenizer.recv_pyobj() + self.recv_reqs.append(recv_req) + + +def start_router_process( + server_args: ServerArgs, + port_args: PortArgs, + pipe_writer, +): + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format="%(message)s", + ) + + try: + model_client = ModelRpcClient(server_args, port_args) + router = RouterManager(model_client, port_args) + except Exception: + pipe_writer.send(get_exception_traceback()) + raise + + pipe_writer.send("init ok") + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.create_task(router.loop_for_recv_requests()) + loop.run_until_complete(router.loop_for_forward()) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py new file mode 100644 index 000000000..78080fcb3 --- /dev/null +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -0,0 +1,497 @@ +import asyncio +import logging +import multiprocessing +import time +from concurrent.futures import ThreadPoolExecutor +from enum import Enum, auto +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import rpyc +import torch +from rpyc.utils.classic import obtain +from rpyc.utils.server import ThreadedServer +from sglang.srt.constrained.fsm_cache import FSMCache +from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput +from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req +from sglang.srt.managers.router.model_runner import ModelRunner +from sglang.srt.managers.router.radix_cache import RadixCache +from sglang.srt.managers.router.scheduler import Scheduler +from sglang.srt.model_config import ModelConfig +from sglang.srt.sampling_params import SamplingParams +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import ( + get_exception_traceback, + get_int_token_logit_bias, + is_multimodal_model, + set_random_seed, +) + +logger = logging.getLogger("model_rpc") + + +class ModelRpcServer(rpyc.Service): + def exposed_init_model( + self, + tp_rank: int, + server_args: ServerArgs, + port_args: PortArgs, + ): + server_args, port_args = [obtain(x) for x in [server_args, port_args]] + + # Copy arguments + self.model_mode = server_args.model_mode + self.tp_rank = tp_rank + self.tp_size = server_args.tp_size + self.schedule_heuristic = server_args.schedule_heuristic + + # Init model and tokenizer + self.model_config = ModelConfig( + server_args.model_path, server_args.trust_remote_code + ) + self.model_runner = ModelRunner( + self.model_config, + server_args.mem_fraction_static, + tp_rank, + server_args.tp_size, + port_args.nccl_port, + server_args.load_format, + server_args.trust_remote_code, + server_args.model_mode, + ) + if is_multimodal_model(server_args.model_path): + self.processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + self.tokenizer = self.processor.tokenizer + else: + self.tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + self.eos_token_id = self.tokenizer.eos_token_id + self.max_total_num_token = self.model_runner.max_total_num_token + self.max_num_running_seq = self.max_total_num_token // 2 + self.max_prefill_num_token = max( + self.model_config.context_len, self.max_total_num_token // 6 + ) + self.int_token_logit_bias = torch.tensor( + get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) + ) + set_random_seed(server_args.random_seed) + logger.info( + f"Rank {self.tp_rank}: " + f"max_total_num_token={self.max_total_num_token}, " + f"max_prefill_num_token={self.max_prefill_num_token}, " + f"context_len={self.model_config.context_len}, " + f"model_mode={self.model_mode}" + ) + + # Init cache + self.tree_cache = RadixCache(disable="no-cache" in self.model_mode) + self.scheduler = Scheduler( + self.schedule_heuristic, + self.max_num_running_seq, + self.max_prefill_num_token, + self.max_total_num_token, + self.tree_cache, + ) + self.req_to_token_pool = self.model_runner.req_to_token_pool + self.token_to_kv_pool = self.model_runner.token_to_kv_pool + + # Init running status + self.forward_queue: List[Req] = [] + self.running_batch: Batch = None + self.out_pyobjs = [] + self.decode_forward_ct = 0 + self.stream_interval = 2 + + # Init the FSM cache for constrained generation + self.regex_fsm_cache = FSMCache(self.tokenizer) + + def exposed_step(self, recv_reqs): + if self.tp_size != 1: + recv_reqs = obtain(recv_reqs) + + try: + # Recv requests + for recv_req in recv_reqs: + if isinstance(recv_req, TokenizedGenerateReqInput): + self.handle_generate_request(recv_req) + else: + raise ValueError(f"Invalid request: {recv_req}") + + # Forward + self.forward_step() + except Exception: + logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback()) + + # Return results + ret = self.out_pyobjs + self.out_pyobjs = [] + return ret + + @torch.inference_mode() + def forward_step(self): + new_batch = self.get_new_fill_batch() + + if new_batch is not None: + # Run new fill batch + self.forward_fill_batch(new_batch) + + if not new_batch.is_empty(): + if self.running_batch is None: + self.running_batch = new_batch + else: + self.running_batch.merge(new_batch) + else: + # Run decode batch + if self.running_batch is not None: + # Run a few decode batches continuously for reducing overhead + for _ in range(10): + self.forward_decode_batch(self.running_batch) + + if self.running_batch.is_empty(): + self.running_batch = None + break + + if self.running_batch is not None and self.tp_rank == 0: + if self.decode_forward_ct >= 20: + self.decode_forward_ct = 0 + num_used = self.max_total_num_token - ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + ) + logger.info( + f"#running-req: {len(self.running_batch.reqs)}, " + f"#token: {num_used}, " + f"token usage: {num_used / self.max_total_num_token:.2f}, " + f"#queue-req: {len(self.forward_queue)}" + ) + + def handle_generate_request( + self, + recv_req: TokenizedGenerateReqInput, + ): + req = Req(recv_req.rid) + req.input_ids = recv_req.input_ids + req.pixel_values = recv_req.pixel_values + if req.pixel_values is not None: + pad_value = [ + (recv_req.image_hash) % self.model_config.vocab_size, + (recv_req.image_hash >> 16) % self.model_config.vocab_size, + (recv_req.image_hash >> 32) % self.model_config.vocab_size, + (recv_req.image_hash >> 64) % self.model_config.vocab_size, + ] + req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids( + req.input_ids, pad_value + ) + req.sampling_params = recv_req.sampling_params + req.return_normalized_logprob = recv_req.return_normalized_logprob + req.normalized_logprob_start_len = recv_req.normalized_logprob_start_len + req.stream = recv_req.stream + req.tokenizer = self.tokenizer + + # init the regex fsm + if req.sampling_params.regex is not None: + req.regex_fsm_state = 0 + req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex) + + # Truncate long prompts + req.input_ids = req.input_ids[: self.model_config.context_len - 1] + req.sampling_params.max_new_tokens = min( + req.sampling_params.max_new_tokens, + self.model_config.context_len - 1 - len(req.input_ids), + ) + self.forward_queue.append(req) + + def get_new_fill_batch(self): + if ( + self.running_batch is not None + and len(self.running_batch.reqs) > self.max_num_running_seq + ): + return None + + for req in self.forward_queue: + prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) + if req.return_normalized_logprob: + prefix_indices = prefix_indices[: req.normalized_logprob_start_len] + req.adjust_input_len = len(req.input_ids) - len(prefix_indices) + req.prefix_indices = prefix_indices + req.last_node = last_node + + # Get priority queue + self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue) + + # Add requests if there is available space + can_run_list = [] + new_batch_total_tokens = 0 + new_batch_input_tokens = 0 + new_batch_prefix_tokens = 0 + + available_size = ( + self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() + ) + new_ratio = self.scheduler.new_token_estimation_ratio() + if self.running_batch: + available_size -= sum( + [ + (r.max_new_tokens() - len(r.output_ids)) * new_ratio + for r in self.running_batch.reqs + ] + ) + + for req in self.forward_queue: + if req.return_normalized_logprob: + # Need at least two tokens to compute normalized logprob + if req.adjust_input_len < 2: + delta = 2 - req.adjust_input_len + req.adjust_input_len += delta + req.prefix_indices = req.prefix_indices[:-delta] + if req.image_offset is not None: + req.image_offset += delta + if req.adjust_input_len == 0 and req.max_new_tokens() > 0: + # Need at least one token to compute logits + req.adjust_input_len = 1 + req.prefix_indices = req.prefix_indices[:-1] + if req.image_offset is not None: + req.image_offset += 1 + + if ( + req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens + < available_size + and req.adjust_input_len + new_batch_input_tokens + < self.max_prefill_num_token + ): + delta = self.tree_cache.inc_ref_counter(req.last_node) + available_size += delta + + if not ( + req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens + < available_size + ): + delta = self.tree_cache.dec_ref_counter(req.last_node) + available_size += delta + else: + self.token_to_kv_pool.add_refs(req.prefix_indices) + can_run_list.append(req) + new_batch_total_tokens += ( + req.adjust_input_len + req.max_new_tokens() + ) + new_batch_input_tokens += req.adjust_input_len + + if len(can_run_list) == 0: + return None + + if self.tp_rank == 0: + logger.info( + f"new fill batch. #seq: {len(can_run_list)}. " + f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. " + f"#new_token: {new_batch_input_tokens}. " + f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. " + f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" + ) + + new_batch = Batch( + can_run_list, + self.req_to_token_pool, + self.token_to_kv_pool, + self.tree_cache, + ) + self.forward_queue = [x for x in self.forward_queue if x not in can_run_list] + return new_batch + + def forward_fill_batch(self, batch: Batch): + # Build batch tensors + batch.init_extend_batch(self.model_config.vocab_size, self.int_token_logit_bias) + if batch.extend_num_tokens != 0: + # Forward + logits, normalized_logprobs = self.model_runner.forward( + batch, ForwardMode.EXTEND, batch.return_normalized_logprob + ) + # print("extend logits", logits) + if normalized_logprobs is not None: + normalized_logprobs = normalized_logprobs.cpu().tolist() + + next_token_ids, next_token_probs = batch.sample(logits) + next_token_ids = next_token_ids.cpu().tolist() + else: + next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) + normalized_logprobs = None + + # Check finish condition + reqs = batch.reqs + for i in range(len(reqs)): + reqs[i].output_ids = [next_token_ids[i]] + reqs[i].check_finished() + + if normalized_logprobs is not None: + reqs[i].normalized_logprob = normalized_logprobs[i] + + self.handle_finished_requests(batch) + + def forward_decode_batch(self, batch: Batch): + # Update batch tensors + self.decode_forward_ct += 1 + batch.update_for_decode() + + # Forward + logits = self.model_runner.forward(batch, ForwardMode.DECODE) + next_token_ids, next_token_probs = batch.sample(logits) + next_token_ids = next_token_ids.cpu().tolist() + + # Check finish condition + reqs = batch.reqs + for i in range(len(reqs)): + reqs[i].output_ids.append(next_token_ids[i]) + reqs[i].check_finished() + + self.handle_finished_requests(batch) + + def handle_finished_requests(self, batch: Batch): + output_rids = [] + output_tokens = [] + output_hit_stop_str = [] + output_skip_special_tokens = [] + output_meta_info = [] + output_finished = [] + finished_indices = [] + unfinished_indices = [] + for i, req in enumerate(batch.reqs): + if req.finished: + finished_indices.append(i) + else: + unfinished_indices.append(i) + + if req.finished or ( + req.stream and self.decode_forward_ct % self.stream_interval == 0 + ): + output_rids.append(req.rid) + output_tokens.append(req.output_ids) + output_hit_stop_str.append(req.hit_stop_str) + output_skip_special_tokens.append( + req.sampling_params.skip_special_tokens + ) + meta_info = { + "prompt_tokens": len(req.input_ids), + "completion_tokens": len(req.output_ids), + } + if req.return_normalized_logprob: + meta_info["normalized_logprob"] = req.normalized_logprob + output_meta_info.append(meta_info) + output_finished.append(req.finished) + + # Send to detokenizer + if output_rids: + self.out_pyobjs.append( + BatchTokenIDOut( + output_rids, + output_tokens, + output_hit_stop_str, + output_skip_special_tokens, + output_meta_info, + output_finished, + ) + ) + + # Remove finished reqs + if finished_indices: + # Update radix cache + req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist() + for i in finished_indices: + req = batch.reqs[i] + req_pool_idx = req_pool_indices_cpu[i] + token_ids = tuple(req.input_ids + req.output_ids) + seq_len = len(token_ids) - 1 + indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len] + prefix_len = self.tree_cache.insert(token_ids, indices.clone()) + + self.token_to_kv_pool.free(indices[:prefix_len]) + self.req_to_token_pool.free(req_pool_idx) + self.tree_cache.dec_ref_counter(req.last_node) + + # Update batch tensors + if unfinished_indices: + batch.filter_batch(unfinished_indices) + else: + batch.reqs = [] + + +class ModelRpcClient: + def __init__(self, server_args: ServerArgs, port_args: PortArgs): + tp_size = server_args.tp_size + + if tp_size == 1: + # Init model + self.model_server = ModelRpcServer() + self.model_server.exposed_init_model(0, server_args, port_args) + + # Wrap functions + def async_wrap(f): + async def _func(*args, **kwargs): + return f(*args, **kwargs) + + return _func + + self.step = async_wrap(self.model_server.exposed_step) + else: + with ThreadPoolExecutor(tp_size) as executor: + # Launch model processes + rets = executor.map(start_model_process, port_args.model_rpc_ports) + self.model_servers = [x[0] for x in rets] + self.procs = [x[1] for x in rets] + + # Init model + def init_model(i): + return self.model_servers[i].init_model(i, server_args, port_args) + + rets = [obtain(x) for x in executor.map(init_model, range(tp_size))] + + # Wrap functions + def async_wrap(func_name): + fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers] + + async def _func(*args, **kwargs): + tasks = [f(*args, **kwargs) for f in fs] + await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks]) + return obtain(tasks[0].value) + + return _func + + self.step = async_wrap("step") + + +def start_model_process(port): + def _init_service(port): + t = ThreadedServer( + ModelRpcServer(), + port=port, + protocol_config={"allow_pickle": True, "sync_request_timeout": 600}, + ) + t.start() + + proc = multiprocessing.Process(target=_init_service, args=(port,)) + proc.start() + time.sleep(1) + + repeat_count = 0 + while repeat_count < 20: + try: + con = rpyc.connect( + "localhost", + port, + config={"allow_pickle": True, "sync_request_timeout": 600}, + ) + break + except ConnectionRefusedError: + time.sleep(1) + repeat_count += 1 + if repeat_count == 20: + raise RuntimeError("init rpc env error!") + + assert proc.is_alive() + return con.root, proc diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py new file mode 100644 index 000000000..2a42eb362 --- /dev/null +++ b/python/sglang/srt/managers/router/model_runner.py @@ -0,0 +1,458 @@ +from dataclasses import dataclass +from enum import Enum, auto +from typing import List + +import numpy as np +import torch +from sglang.srt.managers.router.infer_batch import Batch, ForwardMode +from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool +from sglang.srt.utils import is_multimodal_model +from sglang.utils import get_available_gpu_memory +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.model_loader import _set_default_torch_dtype +from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel + +# for model_mode +global_model_mode: List[str] = [] + + +@dataclass +class InputMetadata: + model_runner: "ModelRunner" + forward_mode: ForwardMode + batch_size: int + total_num_tokens: int + max_seq_len: int + req_pool_indices: torch.Tensor + start_loc: torch.Tensor + seq_lens: torch.Tensor + prefix_lens: torch.Tensor + positions: torch.Tensor + req_to_token_pool: ReqToTokenPool + token_to_kv_pool: TokenToKVPool + + # for extend + extend_seq_lens: torch.Tensor = None + extend_start_loc: torch.Tensor = None + max_extend_len: int = 0 + + out_cache_loc: torch.Tensor = None + out_cache_cont_start: torch.Tensor = None + out_cache_cont_end: torch.Tensor = None + + other_kv_index: torch.Tensor = None + return_normalized_logprob: bool = False + + # for flashinfer + use_flashinfer: bool = False + qo_indptr: torch.Tensor = None + kv_indptr: torch.Tensor = None + kv_indices: torch.Tensor = None + kv_last_page_len: torch.Tensor = None + prefill_wrapper = None + decode_wrapper = None + + def init_flashinfer_args(self, tp_size): + self.kv_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0) + self.kv_indices = torch.cat( + [ + self.req_to_token_pool.req_to_token[ + self.req_pool_indices[i].item(), : self.seq_lens[i].item() + ] + for i in range(self.batch_size) + ], + dim=0, + ).contiguous() + self.kv_last_page_len = torch.ones( + (self.batch_size,), dtype=torch.int32, device="cuda" + ) + + from flashinfer.ops import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + ) + + if ( + self.forward_mode == ForwardMode.PREFILL + or self.forward_mode == ForwardMode.EXTEND + ): + self.qo_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) + self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper() + self.prefill_wrapper.begin_forward( + self.qo_indptr, + self.batch_size, + self.model_runner.model_config.num_attention_heads // tp_size, + self.model_runner.model_config.num_key_value_heads // tp_size, + ) + else: + self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper() + self.decode_wrapper.begin_forward( + self.kv_indptr, + self.kv_last_page_len, + self.batch_size, + self.model_runner.model_config.num_attention_heads // tp_size, + self.model_runner.model_config.num_key_value_heads // tp_size, + self.model_runner.model_config.head_dim, + 1, + "NONE", + "float16", + ) + + def init_extend_args(self): + self.extend_seq_lens = self.seq_lens - self.prefix_lens + self.extend_start_loc = torch.zeros_like(self.seq_lens) + self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0) + self.max_extend_len = int(torch.max(self.extend_seq_lens)) + + @classmethod + def create( + cls, + model_runner, + tp_size, + forward_mode, + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + out_cache_cont_start=None, + out_cache_cont_end=None, + return_normalized_logprob=False, + ): + batch_size = len(req_pool_indices) + start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) + total_num_tokens = int(torch.sum(seq_lens)) + max_seq_len = int(torch.max(seq_lens)) + + if forward_mode == ForwardMode.DECODE: + positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) + other_kv_index = model_runner.req_to_token_pool.req_to_token[ + req_pool_indices[0], seq_lens[0] - 1 + ].item() + else: + seq_lens_np = seq_lens.cpu().numpy() + prefix_lens_np = prefix_lens.cpu().numpy() + position_ids_offsets_np = position_ids_offsets.cpu().numpy() + positions = torch.tensor( + np.concatenate( + [ + np.arange( + prefix_lens_np[i] + position_ids_offsets_np[i], + seq_lens_np[i] + position_ids_offsets_np[i], + ) + for i in range(batch_size) + ], + axis=0, + ), + device="cuda", + ) + other_kv_index = None + + ret = cls( + model_runner=model_runner, + forward_mode=forward_mode, + batch_size=batch_size, + total_num_tokens=total_num_tokens, + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + start_loc=start_loc, + seq_lens=seq_lens, + prefix_lens=prefix_lens, + positions=positions, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool=model_runner.token_to_kv_pool, + out_cache_loc=out_cache_loc, + out_cache_cont_start=out_cache_cont_start, + out_cache_cont_end=out_cache_cont_end, + return_normalized_logprob=return_normalized_logprob, + other_kv_index=other_kv_index, + ) + + if forward_mode == ForwardMode.EXTEND: + ret.init_extend_args() + + ret.use_flashinfer = "flashinfer" in model_runner.model_mode + if ret.use_flashinfer: + ret.init_flashinfer_args(tp_size) + + return ret + + +class ModelRunner: + def __init__( + self, + model_config, + mem_fraction_static, + tp_rank, + tp_size, + nccl_port, + load_format="auto", + trust_remote_code=True, + model_mode: List[str] = (), + ): + self.model_config = model_config + self.mem_fraction_static = mem_fraction_static + self.tp_rank = tp_rank + self.tp_size = tp_size + self.nccl_port = nccl_port + self.load_format = load_format + self.trust_remote_code = trust_remote_code + self.model_mode = model_mode + + global global_model_mode + global_model_mode = model_mode + + # Init torch distributed + torch.cuda.set_device(self.tp_rank) + torch.distributed.init_process_group( + backend="nccl", + world_size=self.tp_size, + rank=self.tp_rank, + init_method=f"tcp://127.0.0.1:{self.nccl_port}", + ) + + # A small all_reduce for warmup. + if self.tp_size > 1: + torch.distributed.all_reduce(torch.zeros(1).cuda()) + initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + + total_gpu_memory = get_available_gpu_memory( + self.tp_rank, distributed=self.tp_size > 1 + ) * (1 << 30) + self.load_model() + self.init_memory_pool(total_gpu_memory) + + self.is_multimodal_model = is_multimodal_model(self.model_config) + + def load_model(self): + """See also vllm/model_executor/model_loader.py::get_model""" + from sglang.srt.models.llama2 import LlamaForCausalLM + from sglang.srt.models.llava import LlavaLlamaForCausalLM + from sglang.srt.models.mixtral import MixtralForCausalLM + + # Select model class + architectures = getattr(self.model_config.hf_config, "architectures", []) + + model_class = None + for arch in architectures: + if arch == "LlamaForCausalLM": + model_class = LlamaForCausalLM + break + if arch == "MistralForCausalLM": + model_class = LlamaForCausalLM + break + if arch == "LlavaLlamaForCausalLM": + model_class = LlavaLlamaForCausalLM + break + if arch == "MixtralForCausalLM": + model_class = MixtralForCausalLM + break + if model_class is None: + raise ValueError(f"Unsupported architectures: {architectures}") + + # Load weights + linear_method = None + with _set_default_torch_dtype(torch.float16): + with torch.device("cuda"): + hf_quant_config = getattr( + self.model_config.hf_config, "quantization_config", None + ) + if hf_quant_config is not None: + # TODO: config quantization awq etc + quant_config = AWQConfig.from_config(hf_quant_config) + print(f"quant_config: {quant_config}") + linear_method = quant_config.get_linear_method() + model = model_class( + config=self.model_config.hf_config, linear_method=linear_method + ) + model.load_weights( + self.model_config.path, + cache_dir=None, + load_format=self.load_format, + revision=None, + ) + self.model = model + + def profile_max_num_token(self, total_gpu_memory): + available_gpu_memory = get_available_gpu_memory( + self.tp_rank, distributed=self.tp_size > 1 + ) * (1 << 30) + head_dim = ( + self.model_config.hidden_size // self.model_config.num_attention_heads + ) + head_num = self.model_config.num_key_value_heads // self.tp_size + cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2 + rest_memory = available_gpu_memory - total_gpu_memory * ( + 1 - self.mem_fraction_static + ) + max_num_token = int(rest_memory // cell_size) + return max_num_token + + def init_memory_pool(self, total_gpu_memory): + self.max_total_num_token = self.profile_max_num_token(total_gpu_memory) + self.req_to_token_pool = ReqToTokenPool( + int(self.max_total_num_token / self.model_config.context_len * 256), + self.model_config.context_len + 8, + ) + self.token_to_kv_pool = TokenToKVPool( + self.max_total_num_token, + dtype=torch.float16, + head_num=self.model_config.num_key_value_heads // self.tp_size, + head_dim=self.model_config.hidden_size + // self.model_config.num_attention_heads, + layer_num=self.model_config.num_hidden_layers, + ) + + @torch.inference_mode() + def forward_prefill( + self, + input_ids, + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + return_normalized_logprob, + ): + input_metadata = InputMetadata.create( + self, + forward_mode=ForwardMode.PREFILL, + tp_size=self.tp_size, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + prefix_lens=prefix_lens, + position_ids_offsets=position_ids_offsets, + out_cache_loc=out_cache_loc, + return_normalized_logprob=return_normalized_logprob, + ) + return self.model.forward(input_ids, input_metadata.positions, input_metadata) + + @torch.inference_mode() + def forward_extend( + self, + input_ids, + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + return_normalized_logprob, + ): + input_metadata = InputMetadata.create( + self, + forward_mode=ForwardMode.EXTEND, + tp_size=self.tp_size, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + prefix_lens=prefix_lens, + position_ids_offsets=position_ids_offsets, + out_cache_loc=out_cache_loc, + return_normalized_logprob=return_normalized_logprob, + ) + return self.model.forward(input_ids, input_metadata.positions, input_metadata) + + @torch.inference_mode() + def forward_decode( + self, + input_ids, + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + out_cache_cont_start, + out_cache_cont_end, + ): + input_metadata = InputMetadata.create( + self, + forward_mode=ForwardMode.DECODE, + tp_size=self.tp_size, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + prefix_lens=prefix_lens, + position_ids_offsets=position_ids_offsets, + out_cache_loc=out_cache_loc, + out_cache_cont_start=out_cache_cont_start, + out_cache_cont_end=out_cache_cont_end, + ) + return self.model.forward(input_ids, input_metadata.positions, input_metadata)[ + 0 + ] + + @torch.inference_mode() + def forward_extend_multi_modal( + self, + input_ids, + pixel_values, + image_offsets, + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + return_normalized_logprob, + ): + input_metadata = InputMetadata.create( + self, + forward_mode=ForwardMode.EXTEND, + tp_size=self.tp_size, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + prefix_lens=prefix_lens, + position_ids_offsets=position_ids_offsets, + out_cache_loc=out_cache_loc, + return_normalized_logprob=return_normalized_logprob, + ) + return self.model.forward( + input_ids, + input_metadata.positions, + input_metadata, + pixel_values, + image_offsets, + ) + + def forward( + self, batch: Batch, forward_mode: ForwardMode, return_normalized_logprob=False + ): + if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: + kwargs = { + "input_ids": batch.input_ids, + "pixel_values": batch.pixel_values, + "image_offsets": batch.image_offsets, + "req_pool_indices": batch.req_pool_indices, + "seq_lens": batch.seq_lens, + "prefix_lens": batch.prefix_lens, + "position_ids_offsets": batch.position_ids_offsets, + "out_cache_loc": batch.out_cache_loc, + } + kwargs["return_normalized_logprob"] = return_normalized_logprob + return self.forward_extend_multi_modal(**kwargs) + else: + kwargs = { + "input_ids": batch.input_ids, + "req_pool_indices": batch.req_pool_indices, + "seq_lens": batch.seq_lens, + "prefix_lens": batch.prefix_lens, + "position_ids_offsets": batch.position_ids_offsets, + "out_cache_loc": batch.out_cache_loc, + } + + if forward_mode == ForwardMode.DECODE: + kwargs["out_cache_cont_start"] = batch.out_cache_cont_start + kwargs["out_cache_cont_end"] = batch.out_cache_cont_end + return self.forward_decode(**kwargs) + elif forward_mode == ForwardMode.EXTEND: + kwargs["return_normalized_logprob"] = return_normalized_logprob + return self.forward_extend(**kwargs) + elif forward_mode == ForwardMode.PREFILL: + kwargs["return_normalized_logprob"] = return_normalized_logprob + return self.forward_prefill(**kwargs) + else: + raise ValueError(f"Invaid forward mode: {forward_mode}") diff --git a/python/sglang/srt/managers/router/radix_cache.py b/python/sglang/srt/managers/router/radix_cache.py new file mode 100644 index 000000000..a70f00da6 --- /dev/null +++ b/python/sglang/srt/managers/router/radix_cache.py @@ -0,0 +1,220 @@ +import heapq +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import Tuple + +import torch + + +class TreeNode: + def __init__(self): + self.children = defaultdict(TreeNode) + self.parent = None + self.value = None + self.ref_counter = 0 + self.last_access_time = time.time() + + def __lt__(self, other): + return self.last_access_time < other.last_access_time + + +def match(key, seq): + i = 0 + for k, w in zip(key, seq): + if k != w: + break + i += 1 + return i + + +class RadixCache: + def __init__(self, disable=False): + self.root_node = TreeNode() + self.root_node.value = [] + self.root_node.ref_counter = 1 + self.evictable_size_ = 0 + + self.disable = disable + + ##### Public API ##### + def match_prefix(self, key): + if self.disable: + return [], self.root_node + + value = [] + last_node = [self.root_node] + self._match_prefix_helper(self.root_node, key, value, last_node) + if value: + value = torch.concat(value) + return value, last_node[0] + + def insert(self, key, value=None): + if self.disable: + return len(key) + + if value is None: + value = [x for x in key] + return self._insert_helper(self.root_node, key, value) + + def pretty_print(self): + self._print_helper(self.root_node, 0) + print(f"#tokens: {self.total_size()}") + + def total_size(self): + return self._total_size_helper(self.root_node) + + def evict(self, num_tokens, evict_callback): + if self.disable: + raise RuntimeError() + + leaves = self._collect_leaves() + heapq.heapify(leaves) + + num_evicted = 0 + while num_evicted < num_tokens and len(leaves): + x = heapq.heappop(leaves) + + if x == self.root_node: + break + if x.ref_counter > 0: + continue + + num_evicted += evict_callback(x.value) + self._delete_leaf(x) + + if len(x.parent.children) == 0: + heapq.heappush(leaves, x.parent) + + def inc_ref_counter(self, node): + delta = 0 + while node != self.root_node: + if node.ref_counter == 0: + self.evictable_size_ -= len(node.value) + delta -= len(node.value) + node.ref_counter += 1 + node = node.parent + return delta + + def dec_ref_counter(self, node): + delta = 0 + while node != self.root_node: + if node.ref_counter == 1: + self.evictable_size_ += len(node.value) + delta += len(node.value) + node.ref_counter -= 1 + node = node.parent + return delta + + def evictable_size(self): + return self.evictable_size_ + + ##### Internal Helper Functions ##### + def _match_prefix_helper(self, node, key, value, last_node): + node.last_access_time = time.time() + + for c_key, child in node.children.items(): + prefix_len = match(c_key, key) + if prefix_len != 0: + if prefix_len == len(key) and prefix_len != len(c_key): + new_node = self._split_node(c_key, child, prefix_len) + value.append(new_node.value) + last_node[0] = new_node + else: + value.append(child.value[:prefix_len]) + last_node[0] = child + self._match_prefix_helper(child, key[prefix_len:], value, last_node) + break + + def _split_node(self, key, child, split_len): + # new_node -> child + new_node = TreeNode() + new_node.children = {key[split_len:]: child} + new_node.parent = child.parent + new_node.ref_counter = child.ref_counter + new_node.value = child.value[:split_len] + child.parent = new_node + child.value = child.value[split_len:] + new_node.parent.children[key[:split_len]] = new_node + del new_node.parent.children[key] + return new_node + + def _insert_helper(self, node, key, value): + node.last_access_time = time.time() + + for c_key, child in node.children.items(): + prefix_len = match(c_key, key) + + if prefix_len == len(c_key): + if prefix_len == len(key): + return prefix_len + else: + key = key[prefix_len:] + value = value[prefix_len:] + return prefix_len + self._insert_helper(child, key, value) + + if prefix_len: + new_node = self._split_node(c_key, child, prefix_len) + return prefix_len + self._insert_helper( + new_node, key[prefix_len:], value[prefix_len:] + ) + + if len(key): + new_node = TreeNode() + new_node.parent = node + new_node.value = value + node.children[key] = new_node + self.evictable_size_ += len(value) + return 0 + + def _print_helper(self, node, indent): + for key, child in node.children.items(): + print(" " * indent, len(key), key[:10], f"r={child.ref_counter}") + self._print_helper(child, indent=indent + 2) + + def _delete_leaf(self, node): + for k, v in node.parent.children.items(): + if v == node: + break + del node.parent.children[k] + self.evictable_size_ -= len(k) + + def _total_size_helper(self, node): + x = len(node.value) + for child in node.children.values(): + x += self._total_size_helper(child) + return x + + def _collect_leaves(self): + ret_list = [] + + def dfs_(cur_node): + if len(cur_node.children) == 0: + ret_list.append(cur_node) + + for x in cur_node.children.values(): + dfs_(x) + + dfs_(self.root_node) + return ret_list + + +if __name__ == "__main__": + tree = RadixCache(disable=False) + + tree.insert("Hello") + tree.insert("Hello") + tree.insert("Hello_L.A.!") + # tree.insert("Hello_world! Happy") + # tree.insert("I love you!") + tree.pretty_print() + + # print(tree.match_prefix("I love you! aha")) + + # def evict_callback(x): + # print("evict", x) + # return len(x) + + # tree.evict(5, evict_callback) + # tree.evict(10, evict_callback) + # tree.pretty_print() diff --git a/python/sglang/srt/managers/router/scheduler.py b/python/sglang/srt/managers/router/scheduler.py new file mode 100644 index 000000000..1376b329a --- /dev/null +++ b/python/sglang/srt/managers/router/scheduler.py @@ -0,0 +1,73 @@ +import random +from collections import defaultdict + + +class Scheduler: + def __init__( + self, + schedule_heuristic, + max_running_seq, + max_prefill_num_token, + max_total_num_token, + tree_cache, + ): + self.schedule_heuristic = schedule_heuristic + self.max_running_seq = max_running_seq + self.max_prefill_num_token = max_prefill_num_token + self.max_total_num_token = max_total_num_token + self.tree_cache = tree_cache + + def new_token_estimation_ratio(self): + return 0.4 if self.schedule_heuristic != "fcfs" else 0.5 + + def get_priority_queue(self, forward_queue): + if self.schedule_heuristic == "lpm": + # longest prefix match + forward_queue.sort(key=lambda x: -len(x.prefix_indices)) + return forward_queue + elif self.schedule_heuristic == "random": + random.shuffle(forward_queue) + return forward_queue + elif self.schedule_heuristic == "fcfs": + return forward_queue + elif self.schedule_heuristic == "weight": + last_node_to_reqs = defaultdict(list) + for req in forward_queue: + last_node_to_reqs[req.last_node].append(req) + for node in last_node_to_reqs: + last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices)) + + node_to_weight = defaultdict(int) + self._calc_weight_recursive( + self.tree_cache.root_node, last_node_to_reqs, node_to_weight + ) + + tmp_queue = [] + self._get_weight_priority_recursive( + self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue + ) + assert len(tmp_queue) == len(forward_queue) + return tmp_queue + else: + raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}") + + def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight): + node_to_weight[cur_node] = 1 + if cur_node in last_node_to_reqs: + node_to_weight[cur_node] += len(last_node_to_reqs[cur_node]) + for child in cur_node.children.values(): + self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight) + node_to_weight[cur_node] += node_to_weight[child] + + def _get_weight_priority_recursive( + self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue + ): + visit_list = [child for child in cur_node.children.values()] + visit_list.sort(key=lambda x: -node_to_wight[x]) + # for node in visit_list: + # print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}") + for child in visit_list: + self._get_weight_priority_recursive( + child, node_to_wight, last_node_to_reqs, tmp_queue + ) + tmp_queue.extend(last_node_to_reqs[cur_node]) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py new file mode 100644 index 000000000..313b778c6 --- /dev/null +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -0,0 +1,219 @@ +import asyncio +import concurrent.futures +import dataclasses +import os +from typing import List + +import numpy as np +import transformers +import uvloop +import zmq +import zmq.asyncio +from sglang.srt.hf_transformers_utils import ( + get_config, + get_context_length, + get_processor, + get_tokenizer, +) +from sglang.srt.managers.io_struct import ( + BatchStrOut, + GenerateReqInput, + TokenizedGenerateReqInput, +) +from sglang.srt.sampling_params import SamplingParams +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +@dataclasses.dataclass +class ReqState: + out_list: List + finished: bool + event: asyncio.Event + lock: asyncio.Lock + + +global global_processor + + +def init_global_processor(server_args: ServerArgs): + global global_processor + transformers.logging.set_verbosity_error() + global_processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + + +def get_pixel_values(image_data, processor=None): + try: + processor = processor or global_processor + image = load_image(image_data) + image_hash = hash(image_data) + pixel_values = processor.image_processor(image)["pixel_values"][0] + pixel_values = pixel_values.astype(np.float16) + return pixel_values, image_hash + except Exception: + print("Exception in TokenizerManager:\n" + get_exception_traceback()) + + +class TokenizerManager: + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + ): + context = zmq.asyncio.Context(2) + self.recv_from_detokenizer = context.socket(zmq.PULL) + self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") + + self.send_to_router = context.socket(zmq.PUSH) + self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}") + + self.model_path = server_args.model_path + self.hf_config = get_config( + self.model_path, trust_remote_code=server_args.trust_remote_code + ) + self.context_len = get_context_length(self.hf_config) + + if is_multimodal_model(self.model_path): + self.processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + self.tokenizer = self.processor.tokenizer + os.environ["TOKENIZERS_PARALLELISM"] = "false" + self.executor = concurrent.futures.ProcessPoolExecutor( + initializer=init_global_processor, initargs=(server_args,) + ) + else: + self.tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + + self.to_create_loop = True + self.rid_to_state = {} # Dict[str -> ReqState] + + async def get_pixel_values(self, image_data): + if self.executor is not None: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.executor, get_pixel_values, image_data + ) + else: + return get_pixel_values(image_data, self.processor) + + async def generate_request(self, obj: GenerateReqInput): + if self.to_create_loop: + await self.create_handle_loop() + + is_single = isinstance(obj.text, str) + + if is_single: + rid = obj.rid + input_ids = self.tokenizer.encode(obj.text) + sampling_params = SamplingParams(**obj.sampling_params) + if sampling_params.max_new_tokens != 0: + sampling_params.normalize(self.tokenizer) + sampling_params.verify() + if obj.image_data is None: + pixel_values, image_hash = None, None + else: + pixel_values, image_hash = await self.get_pixel_values(obj.image_data) + tokenized_obj = TokenizedGenerateReqInput( + rid=rid, + input_ids=input_ids, + pixel_values=pixel_values, + image_hash=image_hash, + sampling_params=sampling_params, + return_normalized_logprob=obj.return_normalized_logprob, + normalized_logprob_start_len=obj.normalized_logprob_start_len, + stream=obj.stream, + ) + self.send_to_router.send_pyobj(tokenized_obj) + + lock = asyncio.Lock() + event = asyncio.Event() + state = ReqState([], False, event, lock) + self.rid_to_state[rid] = state + + while True: + await event.wait() + yield state.out_list[-1] + state.out_list = [] + if state.finished: + del self.rid_to_state[rid] + break + event.clear() + else: + assert obj.stream is False + bs = len(obj.text) + for i in range(bs): + rid = obj.rid[i] + input_ids = self.tokenizer.encode(obj.text[i]) + sampling_params = SamplingParams(**obj.sampling_params[i]) + if sampling_params.max_new_tokens != 0: + sampling_params.normalize(self.tokenizer) + sampling_params.verify() + if obj.image_data[i] is None: + pixel_values, image_hash = None, None + else: + pixel_values, image_hash = await self.get_pixel_values( + obj.image_data[i] + ) + tokenized_obj = TokenizedGenerateReqInput( + rid=rid, + input_ids=input_ids, + pixel_values=pixel_values, + image_hash=image_hash, + sampling_params=sampling_params, + return_normalized_logprob=obj.return_normalized_logprob[i], + normalized_logprob_start_len=obj.normalized_logprob_start_len[i], + stream=obj.stream, + ) + self.send_to_router.send_pyobj(tokenized_obj) + + lock = asyncio.Lock() + event = asyncio.Event() + state = ReqState([], False, event, lock) + self.rid_to_state[rid] = state + + output_list = [] + for i in range(bs): + rid = obj.rid[i] + state = self.rid_to_state[rid] + await state.event.wait() + output_list.append(state.out_list[-1]) + assert state.finished + del self.rid_to_state[rid] + + yield output_list + + async def create_handle_loop(self): + self.to_create_loop = False + loop = asyncio.get_event_loop() + loop.create_task(self.handle_loop()) + + async def handle_loop(self): + while True: + recv_obj = await self.recv_from_detokenizer.recv_pyobj() + + if isinstance(recv_obj, BatchStrOut): + for i, rid in enumerate(recv_obj.rids): + recv_obj.meta_info[i]["id"] = rid + out_dict = { + "text": recv_obj.output_str[i], + "meta_info": recv_obj.meta_info[i], + } + state = self.rid_to_state[rid] + state.out_list.append(out_dict) + state.finished = recv_obj.finished[i] + state.event.set() + else: + raise ValueError(f"Invalid object: {recv_obj}") diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py new file mode 100644 index 000000000..99f86d7b0 --- /dev/null +++ b/python/sglang/srt/memory_pool.py @@ -0,0 +1,111 @@ +"""Memory pool.""" +import logging + +import torch + +logger = logging.getLogger(__name__) + + +class ReqToTokenPool: + def __init__(self, size, max_context_len): + self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda") + self.can_use_mem_size = size + self.req_to_token = torch.empty( + (size, max_context_len), dtype=torch.int32, device="cuda" + ) + + def alloc(self, need_size): + if need_size > self.can_use_mem_size: + return None + + select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size] + self.mem_state[select_index] = 0 + self.can_use_mem_size -= need_size + return select_index.to(torch.int32) + + def free(self, free_index): + if isinstance(free_index, (int,)): + self.can_use_mem_size += 1 + else: + self.can_use_mem_size += free_index.shape[0] + self.mem_state[free_index] = 1 + + # if self.can_use_mem_size == len(self.mem_state): + # print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.") + + def clear(self): + self.mem_state.fill_(1) + self.can_use_mem_size = len(self.mem_state) + + +class TokenToKVPool: + def __init__(self, size, dtype, head_num, head_dim, layer_num): + self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda") + self.alloc_ct = 0 + + # [size, key/value, head_num, head_dim] for each layer + self.kv_data = [ + torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda") + for _ in range(layer_num) + ] + + def get_key_buffer(self, layer_id): + return self.kv_data[layer_id][:, 0] + + def get_value_buffer(self, layer_id): + return self.kv_data[layer_id][:, 1] + + def alloc(self, need_size): + select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size] + if select_index.shape[0] < need_size: + return None + + self.add_refs(select_index) + return select_index.to(torch.int32) + + def alloc_contiguous(self, need_size): + empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size] + if empty_index.shape[0] < need_size: + return None + empty_size = len(empty_index) + loc_sum = ( + empty_index[need_size - 1 :] - empty_index[: empty_size - (need_size - 1)] + ) + can_used_loc = empty_index[: empty_size - (need_size - 1)][ + loc_sum == need_size - 1 + ] + if can_used_loc.shape[0] == 0: + return None + + start_loc = can_used_loc[0].item() + select_index = torch.arange(start_loc, start_loc + need_size, device="cuda") + self.add_refs(select_index) + return select_index.to(torch.int32), start_loc, start_loc + need_size + + def free(self, free_index): + return self.decrease_refs(free_index) + + def used_size(self): + return len(torch.nonzero(self.mem_state).squeeze(1)) + + def available_size(self): + return torch.sum(self.mem_state == 0).item() + + def add_refs(self, token_index: torch.Tensor): + self.alloc_ct += len(token_index) + self.mem_state[token_index] += 1 + + def decrease_refs(self, token_index: torch.Tensor): + self.alloc_ct -= len(token_index) + self.mem_state[token_index] -= 1 + + num_freed = torch.sum(self.mem_state[token_index] == 0) + + # if self.alloc_ct == 0: + # print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.") + + return num_freed + + def clear(self): + self.mem_state.fill_(0) + self.alloc_ct = 0 diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py new file mode 100644 index 000000000..9c7b41e8c --- /dev/null +++ b/python/sglang/srt/model_config.py @@ -0,0 +1,27 @@ +import os +from typing import Optional, Union + +import torch +from sglang.srt.hf_transformers_utils import get_config, get_context_length + + +class ModelConfig: + def __init__( + self, + path: str, + trust_remote_code: bool = True, + revision: Optional[str] = None, + ) -> None: + self.path = path + self.trust_remote_code = trust_remote_code + self.revision = revision + self.hf_config = get_config(self.path, trust_remote_code, revision) + + # Unify the config keys for hf_config + self.context_len = get_context_length(self.hf_config) + self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads + self.num_key_value_heads = self.hf_config.num_key_value_heads + self.num_attention_heads = self.hf_config.num_attention_heads + self.hidden_size = self.hf_config.hidden_size + self.num_hidden_layers = self.hf_config.num_hidden_layers + self.vocab_size = self.hf_config.vocab_size diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py new file mode 100644 index 000000000..5f728c5a8 --- /dev/null +++ b/python/sglang/srt/models/llama2.py @@ -0,0 +1,316 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1 +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, List, Optional, Tuple + +import torch +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.router.model_runner import InputMetadata +from torch import nn +from transformers import LlamaConfig +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) + + +class LlamaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, linear_method=linear_method + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + def __init__( + self, + config: LlamaConfig, + layer_id: int = 0, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = LlamaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class LlamaModel(nn.Module): + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config, i, linear_method) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + skip_embed: bool = False, + ) -> torch.Tensor: + if not skip_embed: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_ids + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class LlamaForCausalLM(nn.Module): + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = LlamaModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + skip_embed: bool = False, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + if "rotary_emb.inv_freq" in name or "projector" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py new file mode 100644 index 000000000..3fbf04adf --- /dev/null +++ b/python/sglang/srt/models/llava.py @@ -0,0 +1,213 @@ +"""Inference-only LLaVa model compatible with HuggingFace weights.""" +import json +import os +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from sglang.srt.managers.router.infer_batch import ForwardMode +from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.models.llama2 import LlamaForCausalLM +from torch import nn +from transformers import CLIPImageProcessor, CLIPVisionModel, LlavaConfig +from transformers.models.llava.modeling_llava import LlavaMultiModalProjector +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) + + +class LlavaLlamaForCausalLM(nn.Module): + def __init__( + self, + config: LlavaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.vision_tower = None + self.config.vision_config.hidden_size = config.mm_hidden_size + self.config.text_config.hidden_size = config.hidden_size + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.language_model = LlamaForCausalLM(config, linear_method) + + def pad_input_ids(self, input_ids, pad_value): + pad_ids = pad_value * ( + (self.image_feature_len + len(pad_value)) // len(pad_value) + ) + offset = input_ids.index(self.config.image_token_index) + # old_len + pad_len - 1, because we need to remove image_token_id + new_input_ids = ( + input_ids[:offset] + + pad_ids[: self.image_feature_len] + + input_ids[offset + 1 :] + ) + return new_input_ids, offset + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + pixel_values: Optional[List[Optional[np.array]]] = None, + image_offsets: Optional[List[int]] = None, + ) -> torch.Tensor: + if input_metadata.forward_mode == ForwardMode.EXTEND: + bs = input_metadata.batch_size + + # Embed text input + input_embeds = self.language_model.model.embed_tokens(input_ids) + + # Embed vision input + need_vision = ( + (positions[input_metadata.extend_start_loc] < self.image_feature_len) + .cpu() + .numpy() + ) + # FIXME: We need to substract the length of the system prompt + has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) + need_vision = need_vision & has_pixel + + if need_vision.any(): + pixel_values = torch.tensor( + np.array([pixel_values[i] for i in range(bs) if need_vision[i]]), + device=self.vision_tower.device, + ) + + image_outputs = self.vision_tower( + pixel_values, output_hidden_states=True + ) + # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated. + + selected_image_feature = image_outputs.hidden_states[ + self.vision_feature_layer + ] + if self.vision_feature_select_strategy in ["default", "patch"]: + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + image_features = self.multi_modal_projector(selected_image_feature) + + extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() + pt = 0 + for i in range(bs): + if not need_vision[i]: + continue + + start_idx = extend_start_loc_cpu[i] + pad_len, pad_dim = image_features[pt].shape + dim = input_embeds.shape[1] + assert ( + pad_dim == dim + ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) + # Fill in the placeholder for the image + try: + input_embeds[ + start_idx + + image_offsets[i] : start_idx + + image_offsets[i] + + pad_len + ] = image_features[pt] + except RuntimeError as e: + print(f"RuntimeError in llava image encoding: {e}") + print(input_embeds.shape) + print(start_idx, image_offsets[i]) + pt += 1 + + return self.language_model( + input_embeds, positions, input_metadata, skip_embed=True + ) + elif input_metadata.forward_mode == ForwardMode.DECODE: + return self.language_model( + input_ids, positions, input_metadata, skip_embed=False + ) + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + # load clip vision model by cfg['mm_vision_tower']: + # huggingface_name or path_of_clip_relative_to_llava_model_dir + vision_path = self.config.mm_vision_tower + self.vision_tower = CLIPVisionModel.from_pretrained( + vision_path, torch_dtype=torch.float16 + ).cuda() + self.vision_tower.eval() + + self.vision_feature_layer = self.config.mm_vision_select_layer + self.vision_feature_select_strategy = self.config.mm_vision_select_feature + self.image_size = self.vision_tower.config.image_size + self.patch_size = self.vision_tower.config.patch_size + self.image_feature_len = int((self.image_size / self.patch_size) ** 2) + if self.vision_feature_select_strategy == "patch": + pass + elif self.vision_feature_select_strategy == "cls_patch": + self.image_feature_len += 1 + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + + # load mm_projector + # TODO: support TP? + projector_weights = { + "model.mm_projector.0": "multi_modal_projector.linear_1", + "model.mm_projector.2": "multi_modal_projector.linear_2", + } + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + # FIXME: why projector weights read two times? + if "projector" in name: + for weight_name, param_name in projector_weights.items(): + if weight_name in name: + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + # load language model + self.language_model.load_weights( + model_name_or_path, cache_dir, load_format, revision + ) + + monkey_path_clip_vision_embed_forward() + + +first_call = True + + +def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + + # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. + global first_call + if first_call: + self.patch_embedding.cpu().float() + first_call = False + pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") + patch_embeds = self.patch_embedding(pixel_values).cuda().half() + + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +def monkey_path_clip_vision_embed_forward(): + import transformers + + setattr( + transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, + "forward", + clip_vision_embed_forward, + ) diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py new file mode 100644 index 000000000..83844cd67 --- /dev/null +++ b/python/sglang/srt/models/mixtral.py @@ -0,0 +1,378 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1 +"""Inference-only Mixtral model.""" +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.router.model_runner import InputMetadata +from torch import nn +from transformers import MixtralConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) + + +class MixtralMLP(nn.Module): + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.ffn_dim = intermediate_size + self.hidden_dim = hidden_size + + self.w1 = ReplicatedLinear( + self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method + ) + self.w2 = ReplicatedLinear( + self.ffn_dim, self.hidden_dim, bias=False, linear_method=linear_method + ) + self.w3 = ReplicatedLinear( + self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method + ) + + # TODO: Use vllm's SiluAndMul + self.act_fn = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + w1_out, _ = self.w1(hidden_states) + w1_out = self.act_fn(w1_out) + w3_out, _ = self.w3(hidden_states) + current_hidden_states = w1_out * w3_out + current_hidden_states, _ = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralMoE(nn.Module): + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.num_total_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.num_total_experts}." + ) + # Split experts equally between ranks + self.expert_indicies = np.array_split( + range(self.num_total_experts), self.tp_size + )[self.rank].tolist() + if not self.expert_indicies: + raise ValueError(f"Rank {self.rank} has no experts assigned to it.") + + self.experts = nn.ModuleList( + [ + MixtralMLP( + self.num_total_experts, + config.hidden_size, + config.intermediate_size, + linear_method=linear_method, + ) + if idx in self.expert_indicies + else None + for idx in range(self.num_total_experts) + ] + ) + self.gate = ReplicatedLinear( + config.hidden_size, self.num_total_experts, bias=False, linear_method=None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + router_logits, _ = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = selected_experts == expert_idx + expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_(expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states) + + +class MixtralAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + linear_method: Optional[LinearMethodBase] = None, + sliding_window: Optional[int] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.sliding_window = sliding_window + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MixtralDecoderLayer(nn.Module): + def __init__( + self, + config: MixtralConfig, + layer_id: int = 0, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = MixtralAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + sliding_window=config.sliding_window, + linear_method=linear_method, + ) + self.block_sparse_moe = MixtralMoE(config=config, linear_method=linear_method) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) + return hidden_states, residual + + +class MixtralModel(nn.Module): + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + # config.num_hidden_layers=16 + self.layers = nn.ModuleList( + [ + MixtralDecoderLayer(config, i, linear_method=linear_method) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + skip_embed: bool = False, + ) -> torch.Tensor: + if not skip_embed: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_ids + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, input_metadata, residual + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class MixtralForCausalLM(nn.Module): + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = MixtralModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + skip_embed: bool = False, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision, fall_back_to_pt=False + ): + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if "block_sparse_moe.experts." in name and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py new file mode 100644 index 000000000..e80766984 --- /dev/null +++ b/python/sglang/srt/sampling_params.py @@ -0,0 +1,81 @@ +"""Sampling parameters for text generation.""" +from typing import List, Optional, Union + +_SAMPLING_EPS = 1e-6 + + +class SamplingParams: + def __init__( + self, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + stop: Optional[Union[str, List[str]]] = None, + max_new_tokens: int = 16, + ignore_eos: bool = False, + skip_special_tokens: bool = True, + dtype: Optional[str] = None, + regex: Optional[str] = None, + ) -> None: + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.frequency_penalty = frequency_penalty + self.presence_penalty = presence_penalty + self.stop_strs = stop + self.max_new_tokens = max_new_tokens + self.ignore_eos = ignore_eos + self.skip_special_tokens = skip_special_tokens + self.dtype = dtype + self.regex = regex + + # Process some special cases + if self.temperature < _SAMPLING_EPS: + self.temperature = 1.0 + self.top_k = 1 + if self.top_k == -1: + self.top_k = 1 << 30 # whole vocabulary + if self.dtype == "int": + self.stop_strs = [" ", "\n"] + + def verify(self): + if self.temperature < 0.0: + raise ValueError( + f"temperature must be non-negative, got {self.temperature}." + ) + if not 0.0 < self.top_p <= 1.0: + raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") + if self.top_k < -1 or self.top_k == 0: + raise ValueError( + f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." + ) + if not -2.0 <= self.frequency_penalty <= 2.0: + raise ValueError( + "frequency_penalty must be in [-2, 2], got " + f"{self.frequency_penalty}." + ) + if not -2.0 <= self.presence_penalty <= 2.0: + raise ValueError( + "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}." + ) + if self.max_new_tokens < 0: + raise ValueError( + f"max_new_tokens must be at least 0, got {self.max_new_tokens}." + ) + + def normalize(self, tokenizer): + # Process stop strings + if self.stop_strs is None: + self.stop_strs = [] + self.stop_str_max_len = 0 + else: + if isinstance(self.stop_strs, str): + self.stop_strs = [self.stop_strs] + + stop_str_max_len = 0 + for stop_str in self.stop_strs: + stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False) + stop_str_max_len = max(stop_str_max_len, len(stop_str_ids)) + self.stop_str_max_len = stop_str_max_len diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py new file mode 100644 index 000000000..5f2c2f289 --- /dev/null +++ b/python/sglang/srt/server.py @@ -0,0 +1,222 @@ +"""SRT: SGLang Runtime""" +import argparse +import asyncio +import dataclasses +import json +import multiprocessing as mp +import sys +import threading +import time +from typing import List, Optional + +# Fix a Python bug +setattr(threading, "_register_atexit", lambda *args, **kwargs: None) + +import psutil +import requests +import uvicorn +import uvloop +from fastapi import FastAPI +from fastapi.responses import StreamingResponse +from sglang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.srt.managers.detokenizer_manager import start_detokenizer_process +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.openai_protocol import CompletionRequest +from sglang.srt.managers.router.manager import start_router_process +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import alloc_usable_network_port + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +app = FastAPI() +tokenizer_manager = None + + +@app.get("/get_model_info") +async def get_model_info(): + result = { + "model_path": tokenizer_manager.model_path, + } + return result + + +@app.post("/generate") +async def generate_request(obj: GenerateReqInput): + obj.post_init() + result_generator = tokenizer_manager.generate_request(obj) + + if obj.stream: + + async def stream_results(): + async for out in result_generator: + yield (json.dumps(out) + "\0").encode("utf-8") + + return StreamingResponse(stream_results(), media_type="text/event-stream") + else: + ret = await result_generator.__anext__() + return ret + + +@app.post("/v1/completions") +async def v1_completions(obj: CompletionRequest): + assert obj.n == 1 + obj = GenerateReqInput( + text=obj.prompt, + sampling_params={ + "temperature": obj.temperature, + "max_new_tokens": obj.max_tokens, + "stop": obj.stop, + }, + ) + ret = await generate_request(obj) + return { + "choices": [{"text": ret["text"]}], + } + + +def launch_server(server_args, pipe_finish_writer): + global tokenizer_manager + + # Allocate ports + can_use_ports = alloc_usable_network_port( + num=4 + server_args.tp_size, used_list=(server_args.port,) + ) + port_args = PortArgs( + tokenizer_port=can_use_ports[0], + router_port=can_use_ports[1], + detokenizer_port=can_use_ports[2], + nccl_port=can_use_ports[3], + model_rpc_ports=can_use_ports[4:], + ) + + # Launch processes + tokenizer_manager = TokenizerManager(server_args, port_args) + pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) + pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) + + proc_router = mp.Process( + target=start_router_process, + args=( + server_args, + port_args, + pipe_router_writer, + ), + ) + proc_router.start() + proc_detoken = mp.Process( + target=start_detokenizer_process, + args=( + server_args, + port_args, + pipe_detoken_writer, + ), + ) + proc_detoken.start() + + # Wait for the model to finish loading + router_init_state = pipe_router_reader.recv() + detoken_init_state = pipe_detoken_reader.recv() + + if router_init_state != "init ok" or detoken_init_state != "init ok": + proc_router.kill() + proc_detoken.kill() + print("router init state:", router_init_state) + print("detoken init state:", detoken_init_state) + sys.exit(1) + + assert proc_router.is_alive() and proc_detoken.is_alive() + + def launch_server(): + # Launch api server + uvicorn.run( + app, + host=server_args.host, + port=server_args.port, + log_level=server_args.log_level, + timeout_keep_alive=5, + loop="uvloop", + ) + + t = threading.Thread(target=launch_server) + t.start() + + if pipe_finish_writer: + url = server_args.url() + + success = False + for i in range(60): + try: + res = requests.get(url + "/get_model_info", timeout=5) + success = True + break + except requests.exceptions.RequestException as e: + time.sleep(1) + + if success: + pipe_finish_writer.send("init ok") + else: + pipe_finish_writer.send(str(e)) + + +class Runtime: + def __init__( + self, + model_path: str, + tokenizer_path: Optional[str] = None, + load_format: str = "auto", + tokenizer_mode: str = "auto", + trust_remote_code: bool = True, + mem_fraction_static: float = 0.9, + tp_size: int = 1, + model_mode: List[str] = (), + schedule_heuristic: str = "lpm", + random_seed: int = 42, + log_level: str = "warning", + ): + host = "127.0.0.1" + port = alloc_usable_network_port(1)[0] + server_args = ServerArgs( + model_path=model_path, + tokenizer_path=tokenizer_path, + host=host, + port=port, + load_format=load_format, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + mem_fraction_static=mem_fraction_static, + tp_size=tp_size, + model_mode=model_mode, + schedule_heuristic=schedule_heuristic, + random_seed=random_seed, + log_level=log_level, + ) + self.url = server_args.url() + + self.pid = None + pipe_reader, pipe_writer = mp.Pipe(duplex=False) + proc = mp.Process(target=launch_server, args=(server_args, pipe_writer)) + proc.start() + self.pid = proc.pid + + init_state = pipe_reader.recv() + if init_state != "init ok": + self.shutdown() + raise RuntimeError("Launch failed") + + self.endpoint = RuntimeEndpoint(self.url) + + def shutdown(self): + if self.pid is not None: + parent = psutil.Process(self.pid) + children = parent.children(recursive=True) + for child in children: + child.kill() + psutil.wait_procs(children, timeout=5) + parent.kill() + parent.wait(timeout=5) + self.pid = None + + def __del__(self): + self.shutdown() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py new file mode 100644 index 000000000..6f35a8c82 --- /dev/null +++ b/python/sglang/srt/server_args.py @@ -0,0 +1,138 @@ +import argparse +import dataclasses +from typing import List, Optional + + +@dataclasses.dataclass +class ServerArgs: + model_path: str + tokenizer_path: Optional[str] = None + host: str = "127.0.0.1" + port: int = 30000 + load_format: str = "auto" + tokenizer_mode: str = "auto" + trust_remote_code: bool = True + mem_fraction_static: float = 0.91 + tp_size: int = 1 + model_mode: List[str] = () + schedule_heuristic: str = "lpm" + random_seed: int = 42 + disable_log_stats: bool = False + log_stats_interval: int = 10 + log_level: str = "info" + + def __post_init__(self): + if self.tokenizer_path is None: + self.tokenizer_path = self.model_path + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--model-path", + type=str, + help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.", + required=True, + ) + parser.add_argument( + "--tokenizer-path", + type=str, + default=ServerArgs.tokenizer_path, + help="The path of the tokenizer.", + ) + parser.add_argument("--host", type=str, default=ServerArgs.host) + parser.add_argument("--port", type=int, default=ServerArgs.port) + parser.add_argument( + "--load-format", + type=str, + default=ServerArgs.load_format, + choices=["auto", "pt", "safetensors", "npcache", "dummy"], + help="The format of the model weights to load. " + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling.", + ) + parser.add_argument( + "--tokenizer-mode", + type=str, + default=ServerArgs.tokenizer_mode, + choices=["auto", "slow"], + help="Tokenizer mode. 'auto' will use the fast " + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", + ) + parser.add_argument( + "--mem-fraction-static", + type=float, + default=ServerArgs.mem_fraction_static, + help="The fraction of the memory used for static allocation (model weights and KV cache memory pool)", + ) + parser.add_argument( + "--tp-size", + type=int, + default=ServerArgs.tp_size, + help="Tensor parallelism degree.", + ) + parser.add_argument( + "--model-mode", + type=str, + default=[], + nargs="+", + help="Model mode: [flashinfer, no-cache, aggressive-new-fill]", + ) + parser.add_argument( + "--schedule-heuristic", + type=str, + default=ServerArgs.schedule_heuristic, + help="Schudule mode: [lpm, weight, random, fcfs]", + ) + parser.add_argument( + "--random-seed", + type=int, + default=ServerArgs.random_seed, + help="Random seed.", + ) + parser.add_argument( + "--log-level", + type=str, + default=ServerArgs.log_level, + help="Log level", + ) + parser.add_argument( + "--disable-log-stats", + action="store_true", + help="Disable logging throughput stats.", + ) + parser.add_argument( + "--log-stats-interval", + type=int, + default=ServerArgs.log_stats_interval, + help="Log stats interval in second.", + ) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + attrs = [attr.name for attr in dataclasses.fields(cls)] + return cls(**{attr: getattr(args, attr) for attr in attrs}) + + def url(self): + return f"http://{self.host}:{self.port}" + + +@dataclasses.dataclass +class PortArgs: + tokenizer_port: int + router_port: int + detokenizer_port: int + nccl_port: int + model_rpc_ports: List[int] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py new file mode 100644 index 000000000..212533f1c --- /dev/null +++ b/python/sglang/srt/utils.py @@ -0,0 +1,217 @@ +import base64 +import os +import random +import socket +import sys +import time +import traceback +from io import BytesIO + +import numpy as np +import requests +import torch +import torch.distributed as dist + +is_show_cost_time = False + + +def mark_cost_time(func_name): + def inner_func(func): + def time_func(*args, **kwargs): + if dist.get_rank() in [0, 1] and is_show_cost_time: + torch.cuda.synchronize() + start_time = time.time() + ans = func(*args, **kwargs) + torch.cuda.synchronize() + print(func_name, "cost time:", (time.time() - start_time) * 1000) + return ans + else: + torch.cuda.synchronize() + ans = func(*args, **kwargs) + torch.cuda.synchronize() + return ans + + return time_func + + return inner_func + + +time_mark = {} + + +def mark_start(key): + torch.cuda.synchronize() + global time_mark + time_mark[key] = time.time() + return + + +def mark_end(key, print_min_cost=0.0): + torch.cuda.synchronize() + global time_mark + cost_time = (time.time() - time_mark[key]) * 1000 + if cost_time > print_min_cost: + print(f"cost {key}:", cost_time) + + +def calculate_time(show=False, min_cost_ms=0.0): + def wrapper(func): + def inner_func(*args, **kwargs): + torch.cuda.synchronize() + if show: + start_time = time.time() + result = func(*args, **kwargs) + torch.cuda.synchronize() + if show: + cost_time = (time.time() - start_time) * 1000 + if cost_time > min_cost_ms: + print(f"Function {func.__name__} took {cost_time} ms to run.") + return result + + return inner_func + + return wrapper + + +def set_random_seed(seed: int) -> None: + random.seed(seed) + + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def alloc_usable_network_port(num, used_list=()): + port_list = [] + for port in range(10000, 65536): + if port in used_list: + continue + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("", port)) + port_list.append(port) + except socket.error: + pass + + if len(port_list) == num: + return port_list + return None + + +def get_exception_traceback(): + etype, value, tb = sys.exc_info() + err_str = "".join(traceback.format_exception(etype, value, tb)) + return err_str + + +def get_int_token_logit_bias(tokenizer, vocab_size): + from transformers import LlamaTokenizer, LlamaTokenizerFast + + logit_bias = np.zeros(vocab_size, dtype=np.float32) + for t_id in range(vocab_size): + ss = tokenizer.decode(t_id).strip() + if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id): + logit_bias[t_id] = -1e5 + # else: + # print(ss, t_id) + + return logit_bias + + +def wrap_kernel_launcher(kernel): + """A faster launcher for triton kernels.""" + import torch.distributed as dist + + if dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + kernels = kernel.cache[rank].values() + kernel = next(iter(kernels)) + + # Different trition versions use different low-level names + if hasattr(kernel, "cu_function"): + kfunction = kernel.cu_function + else: + kfunction = kernel.function + + if hasattr(kernel, "c_wrapper"): + run = kernel.c_wrapper + else: + run = kernel.run + + add_cluster_dim = True + + def ret_func(grid, num_warps, *args): + nonlocal add_cluster_dim + + try: + if add_cluster_dim: + run( + grid[0], + grid[1], + grid[2], + num_warps, + 1, + 1, + 1, + 1, + kernel.shared, + 0, + kfunction, + None, + None, + kernel, + *args, + ) + else: + run( + grid[0], + grid[1], + grid[2], + num_warps, + kernel.shared, + 0, + kfunction, + None, + None, + kernel, + *args, + ) + except TypeError: + add_cluster_dim = not add_cluster_dim + ret_func(grid, num_warps, *args) + + return ret_func + + +def is_multimodal_model(model): + if isinstance(model, str): + return "llava" in model + from sglang.srt.model_config import ModelConfig + + if isinstance(model, ModelConfig): + return "llava" in model.path.lower() + raise Exception("unrecognized type") + + +def load_image(image_file): + from PIL import Image + + image = None + + if image_file.startswith("http://") or image_file.startswith("https://"): + timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) + response = requests.get(image_file, timeout=timeout) + image = Image.open(BytesIO(response.content)) + elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): + image = Image.open(image_file) + elif image_file.startswith("data:"): + image_file = image_url.split(",")[1] + image = Image.open(BytesIO(base64.b64decode(image_file))) + else: + image = Image.open(BytesIO(base64.b64decode(image_file))) + + return image diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py new file mode 100644 index 000000000..805e50c1e --- /dev/null +++ b/python/sglang/test/test_programs.py @@ -0,0 +1,324 @@ +""" +This file contains the SGL programs used for unit testing. +""" + +import json +import re + +import sglang as sgl + + +def test_few_shot_qa(): + @sgl.function + def few_shot_qa(s, question): + s += "The following are questions with answers.\n\n" + s += "Q: What is the capital of France?\n" + s += "A: Paris\n" + s += "Q: What is the capital of Germany?\n" + s += "A: Berlin\n" + s += "Q: What is the capital of Italy?\n" + s += "A: Rome\n" + s += "Q: " + question + "\n" + s += "A:" + sgl.gen("answer", stop="\n", temperature=0) + + ret = few_shot_qa.run(question="What is the capital of the United States?") + assert "washington" in ret["answer"].strip().lower(), f"answer: {ret['answer']}" + + rets = few_shot_qa.run_batch( + [ + {"question": "What is the capital of Japan?"}, + {"question": "What is the capital of the United Kingdom?"}, + {"question": "What is the capital city of China?"}, + ], + temperature=0.1, + ) + answers = [x["answer"].strip().lower() for x in rets] + assert answers == ["tokyo", "london", "beijing"], f"answers: {answers}" + + +def test_mt_bench(): + @sgl.function + def answer_mt_bench(s, question_1, question_2): + s += sgl.system("You are a helpful assistant.") + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1")) + with s.user(): + s += question_2 + with s.assistant(): + s += sgl.gen("answer_2") + + question_1 = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions." + question_2 = ( + "Rewrite your previous response. Start every sentence with the letter A." + ) + ret = answer_mt_bench.run( + question_1=question_1, question_2=question_2, temperature=0.7, max_new_tokens=64 + ) + assert len(ret.messages()) in [4, 5] + + +def test_select(check_answer): + @sgl.function + def true_or_false(s, statement): + s += "Determine whether the statement below is True, False, or Unknown.\n" + s += "Statement: The capital of France is Pairs.\n" + s += "Answer: True\n" + s += "Statement: " + statement + "\n" + s += "Answer:" + sgl.select("answer", ["True", "False", "Unknown"]) + + ret = true_or_false.run( + statement="The capital of Germany is Berlin.", + ) + if check_answer: + assert ret["answer"] == "True", ret.text + else: + assert ret["answer"] in ["True", "False", "Unknown"] + + ret = true_or_false.run( + statement="The capital of Canada is Tokyo.", + ) + if check_answer: + assert ret["answer"] == "False", ret.text + else: + assert ret["answer"] in ["True", "False", "Unknown"] + + ret = true_or_false.run( + statement="Purple is a better color than green.", + ) + if check_answer: + assert ret["answer"] == "Unknown", ret.text + else: + assert ret["answer"] in ["True", "False", "Unknown"] + + +def test_decode_int(): + @sgl.function + def decode_int(s): + s += "The number of hours in a day is " + sgl.gen_int("hours") + "\n" + s += "The number of days in a year is " + sgl.gen_int("days") + "\n" + + ret = decode_int.run(temperature=0.1) + assert int(ret["hours"]) == 24, ret.text + assert int(ret["days"]) == 365, ret.text + + +def test_decode_json(): + @sgl.function + def decode_json(s): + s += "Generate a JSON object to describe the basic information of a city.\n" + + with s.var_scope("json_output"): + s += "{\n" + s += ' "name": ' + sgl.gen_string() + ",\n" + s += ' "population": ' + sgl.gen_int() + ",\n" + s += ' "area": ' + sgl.gen(dtype=int) + ",\n" + s += ' "country": ' + sgl.gen_string() + ",\n" + s += ' "timezone": ' + sgl.gen(dtype=str) + "\n" + s += "}" + + ret = decode_json.run() + js_obj = json.loads(ret["json_output"]) + assert isinstance(js_obj["name"], str) + assert isinstance(js_obj["population"], int) + + +def test_expert_answer(): + @sgl.function + def expert_answer(s, question): + s += "Question: " + question + "\n" + s += ( + "A good person to answer this question is" + + sgl.gen("expert", stop=[".", "\n"]) + + ".\n" + ) + s += ( + "For example," + + s["expert"] + + " would answer that " + + sgl.gen("answer", stop=".") + + "." + ) + + ret = expert_answer.run(question="What is the capital of France?", temperature=0.1) + assert "paris" in ret.text().lower() + + +def test_tool_use(): + def calculate(expression): + return f"{eval(expression)}" + + @sgl.function + def tool_use(s, lhs, rhs): + s += "Please perform computations using a calculator. You can use calculate(expression) to get the results.\n" + s += "For example,\ncalculate(1+2)=3\ncalculate(3*4)=12\n" + s += "Question: What is the product of " + lhs + " and " + rhs + "?\n" + s += ( + "Answer: The answer is calculate(" + + sgl.gen("expression", stop=")") + + ") = " + ) + with s.var_scope("answer"): + s += calculate(s["expression"]) + + lhs, rhs = 257, 983 + ret = tool_use(lhs=lhs, rhs=rhs, temperature=0) + assert int(ret["answer"]) == lhs * rhs + + +def test_react(): + @sgl.function + def react(s, question): + s += """ +Question: Which country does the founder of Microsoft live in? +Thought 1: I need to search for the founder of Microsoft. +Action 1: Search [Founder of Microsoft]. +Observation 1: The founder of Microsoft is Bill Gates. +Thought 2: I need to search for the country where Bill Gates lives in. +Action 2: Search [Where does Bill Gates live]. +Observation 2: Bill Gates lives in the United States. +Thought 3: The answer is the United States. +Action 3: Finish [United States].\n +""" + + s += "Question: " + question + "\n" + + for i in range(1, 5): + s += f"Thought {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n" + s += f"Action {i}: " + sgl.select(f"action_{i}", ["Search", "Finish"]) + + if s[f"action_{i}"] == "Search": + s += " [" + sgl.gen(stop="]") + "].\n" + s += f"Observation {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n" + else: + s += " [" + sgl.gen("answer", stop="]") + "].\n" + break + + ret = react.run( + question="What country does the creator of Linux live in?", + temperature=0.1, + ) + answer = ret["answer"].lower() + assert "finland" in answer or "states" in answer + + +def test_parallel_decoding(): + max_tokens = 64 + number = 5 + + @sgl.function + def parallel_decoding(s, topic): + s += "Act as a helpful assistant.\n" + s += "USER: Give some tips for " + topic + ".\n" + s += ( + "ASSISTANT: Okay. Here are " + + str(number) + + " concise tips, each under 8 words:\n" + ) + + # Generate skeleton + for i in range(1, 1 + number): + s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n" + + # Generate detailed tips + forks = s.fork(number) + for i in range(number): + forks[ + i + ] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:" + forks[i] += sgl.gen("detailed_tip", max_tokens, stop=["\n\n"]) + forks.join() + + # Concatenate tips and summarize + s += "Here are these tips with detailed explanation:\n" + for i in range(number): + s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n" + + s += "\nIn summary," + sgl.gen("summary", max_tokens=512) + + ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3) + + +def test_parallel_encoding(check_answer=True): + max_tokens = 64 + + @sgl.function + def parallel_encoding(s, question, context_0, context_1, context_2): + s += "USER: I will ask a question based on some statements.\n" + s += "ASSISTANT: Sure. I will give the answer.\n" + s += "USER: Please memorize these statements.\n" + + contexts = [context_0, context_1, context_2] + + forks = s.fork(len(contexts)) + forks += lambda i: f"Statement {i}: " + contexts[i] + "\n" + forks.join(mode="concate_and_append") + + s += "Now, please answer the following question. " "Do not list options." + s += "\nQuestion: " + question + "\n" + s += "ASSISTANT:" + sgl.gen("answer", max_tokens=max_tokens) + + ret = parallel_encoding.run( + question="Who is the father of Julian?", + context_0="Ethan is the father of Liam.", + context_1="Noah is the father of Julian.", + context_2="Oliver is the father of Carlos.", + temperature=0, + ) + answer = ret["answer"] + + if check_answer: + assert "Noah" in answer + + +def test_image_qa(): + @sgl.function + def image_qa(s, question): + s += sgl.user(sgl.image("image.png") + question) + s += sgl.assistant(sgl.gen("answer")) + + state = image_qa.run( + question="Please describe this image in simple words.", + temperature=0, + max_new_tokens=64, + ) + assert "taxi" in state.messages()[-1]["content"] + + +def test_stream(): + @sgl.function + def qa(s, question): + s += sgl.user(question) + s += sgl.assistant(sgl.gen("answer")) + + ret = qa( + question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.", + stream=True, + ) + out = "" + for chunk in ret.text_iter(): + out += chunk + + ret = qa( + question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.", + stream=True, + ) + out = "" + for chunk in ret.text_iter("answer"): + out += chunk + + +def test_regex(): + regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + + @sgl.function + def regex_gen(s): + s += "Q: What is the IP address of the Google DNS servers?\n" + s += "A: " + sgl.gen( + "answer", + temperature=0, + regex=regex, + ) + + state = regex_gen.run() + answer = state["answer"] + assert re.match(regex, answer) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py new file mode 100644 index 000000000..161953159 --- /dev/null +++ b/python/sglang/test/test_utils.py @@ -0,0 +1,141 @@ +"""Common utilities for testing and benchmarking""" +import numpy as np +import requests +from sglang.backend.openai import OpenAI +from sglang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.global_config import global_config + + +def call_generate_lightllm(prompt, temperature, max_tokens, stop, url): + data = { + "inputs": prompt, + "parameters": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "stop_sequences": stop, + }, + } + res = requests.post(url, json=data) + assert res.status_code == 200 + pred = res.json()["generated_text"][0] + return pred + + +def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1): + data = { + "prompt": prompt, + "temperature": temperature, + "max_tokens": max_tokens, + "stop": stop, + "n": n, + } + res = requests.post(url, json=data) + assert res.status_code == 200 + if n == 1: + pred = res.json()["text"][0][len(prompt) :] + else: + pred = [x[len(prompt) :] for x in res.json()["text"]] + return pred + + +def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url): + data = { + "text": prompt, + "sampling_params": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "stop": stop, + }, + } + res = requests.post(url, json=data) + assert res.status_code == 200 + obj = res.json() + pred = obj["text"] + return pred + + +def call_select_lightllm(context, choices, url): + scores = [] + for i in range(len(choices)): + data = { + "inputs": context + choices[i], + "parameters": { + "max_new_tokens": 1, + }, + } + res = requests.post(url, json=data) + assert res.status_code == 200 + scores.append(0) + return np.argmax(scores) + + +def call_select_vllm(context, choices, url): + scores = [] + for i in range(len(choices)): + data = { + "prompt": context + choices[i], + "max_tokens": 1, + "prompt_logprobs": 1, + } + res = requests.post(url, json=data) + assert res.status_code == 200 + scores.append(res.json()["prompt_score"]) + return np.argmax(scores) + + """ + Modify vllm/entrypoints/api_server.py + + if final_output.prompt_logprobs is not None: + score = np.mean([prob[t_id] for t_id, prob in zip(final_output.prompt_token_ids[1:], final_output.prompt_logprobs[1:])]) + ret["prompt_score"] = score + """ + + +def add_common_other_args_and_parse(parser): + parser.add_argument("--parallel", type=int, default=96) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=None) + parser.add_argument( + "--backend", + type=str, + required=True, + choices=["vllm", "lightllm", "guidance", "lmql", "srt-raw", "llama.cpp"], + ) + parser.add_argument( + "--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" + ) + parser.add_argument("--result-file", type=str, default="result.jsonl") + args = parser.parse_args() + + if args.port is None: + default_port = { + "vllm": 21000, + "lightllm": 22000, + "lmql": 23000, + "srt-raw": 30000, + } + args.port = default_port.get(args.backend, None) + return args + + +def add_common_sglang_args_and_parse(parser): + parser.add_argument("--parallel", type=int, default=64) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + parser.add_argument("--backend", type=str, default="srt") + parser.add_argument("--result-file", type=str, default="result.jsonl") + args = parser.parse_args() + return args + + +def select_sglang_backend(args): + if args.backend.startswith("srt"): + if args.backend == "srt-no-parallel": + global_config.enable_parallel_decoding = False + global_config.enable_parallel_encoding = False + backend = RuntimeEndpoint(f"{args.host}:{args.port}") + elif args.backend.startswith("gpt"): + backend = OpenAI(args.backend) + else: + raise ValueError(f"Invalid backend: {args.backend}") + return backend diff --git a/python/sglang/utils.py b/python/sglang/utils.py new file mode 100644 index 000000000..c3a40b0ce --- /dev/null +++ b/python/sglang/utils.py @@ -0,0 +1,179 @@ +"""Common utilities.""" + +import base64 +import json +import threading +import urllib.request +from io import BytesIO +from json import dumps + +import requests + + +def get_available_gpu_memory(gpu_id, distributed=True): + """ + Get available memory for cuda:gpu_id device. + When distributed is True, the available memory is the minimum available memory of all GPUs. + """ + import torch + + num_gpus = torch.cuda.device_count() + assert gpu_id < num_gpus + + if torch.cuda.current_device() != gpu_id: + print( + f"WARN: current device is not {gpu_id}, but {torch.cuda.current_device()}, ", + "which may cause useless memory allocation for torch CUDA context.", + ) + + free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id) + + if distributed: + tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( + torch.device("cuda", gpu_id) + ) + torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN) + free_gpu_memory = tensor.item() + + return free_gpu_memory / (1 << 30) + + +def is_same_type(values): + """Return whether the elements in values are of the same type.""" + if len(values) <= 1: + return True + else: + t = type(values[0]) + return all(isinstance(v, t) for v in values[1:]) + + +def read_jsonl(filename: str): + """Read a JSONL file.""" + rets = [] + with open(filename) as fin: + for line in fin: + if line.startswith("#"): + continue + rets.append(json.loads(line)) + return rets + + +def dump_state_text(filename, states, mode="w"): + """Dump program state in a text file.""" + from sglang.lang.interpreter import ProgramState + + with open(filename, mode) as fout: + for i, s in enumerate(states): + if isinstance(s, str): + pass + elif isinstance(s, ProgramState): + s = s.text().strip() + else: + s = str(s) + + fout.write( + "=" * 40 + f" {i} " + "=" * 40 + "\n" + s + "\n" + "=" * 80 + "\n\n" + ) + + +class HttpResponse: + def __init__(self, resp): + self.resp = resp + + def json(self): + return json.loads(self.resp.read()) + + @property + def status_code(self): + return self.resp.status + + +def http_request(url, json=None, stream=False): + """A faster version of requests.post with low-level urllib API.""" + if stream: + return requests.post(url, json=json, stream=True) + else: + req = urllib.request.Request(url) + req.add_header("Content-Type", "application/json; charset=utf-8") + if json is None: + data = None + else: + data = bytes(dumps(json), encoding="utf-8") + resp = urllib.request.urlopen(req, data=data) + return HttpResponse(resp) + + +def encode_image_base64(image_path): + """Encode an image in base64.""" + if isinstance(image_path, str): + with open(image_path, "rb") as image_file: + data = image_file.read() + return base64.b64encode(data).decode("utf-8") + elif isinstance(image_path, bytes): + return base64.b64encode(image_path).decode("utf-8") + else: + # image_path is PIL.WebPImagePlugin.WebPImageFile + image = image_path + buffered = BytesIO() + image.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + +def _is_chinese_char(cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + +def find_printable_text(text): + """Returns the longest printable substring of text that contains only entire words.""" + # Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99 + + # After the symbol for a new line, we flush the cache. + if text.endswith("\n"): + return text + # If the last token is a CJK character, we print the characters. + elif len(text) > 0 and _is_chinese_char(ord(text[-1])): + return text + # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, + # which may change with the subsequent token -- there are probably smarter ways to do this!) + else: + return text[: text.rfind(" ") + 1] + + +def run_with_timeout(func, args=(), kwargs=None, timeout=None): + """Run a function with timeout.""" + ret_value = [] + + def _target_func(): + ret_value.append(func(*args, **(kwargs or {}))) + + t = threading.Thread(target=_target_func) + t.start() + t.join(timeout=timeout) + if t.is_alive(): + raise TimeoutError() + + if not ret_value: + raise RuntimeError() + + return ret_value[0] diff --git a/test/killall_python.sh b/test/killall_python.sh new file mode 100644 index 000000000..ae9de8701 --- /dev/null +++ b/test/killall_python.sh @@ -0,0 +1 @@ +kill -9 $(ps aux | grep 'python' | grep -v 'grep' | awk '{print $2}') diff --git a/test/lang/run_all.py b/test/lang/run_all.py new file mode 100644 index 000000000..cb5da1585 --- /dev/null +++ b/test/lang/run_all.py @@ -0,0 +1,60 @@ +import argparse +import glob +import multiprocessing +import os +import time +import unittest + +from sglang.utils import run_with_timeout + + +def run_unittest_files(files, args): + for filename in files: + + def func(): + print(filename) + ret = unittest.main(module=None, argv=["", "-vb"] + [filename]) + + p = multiprocessing.Process(target=func) + + def run_one_file(): + p.start() + p.join() + + try: + run_with_timeout(run_one_file, timeout=args.time_limit_per_file) + if p.exitcode != 0: + return False + except TimeoutError: + p.terminate() + time.sleep(5) + print( + f"\nTimeout after {args.time_limit_per_file} seconds " + f"when running {filename}" + ) + return False + + return True + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument( + "--time-limit-per-file", + type=int, + default=1000, + help="The time limit for running one file in seconds.", + ) + args = arg_parser.parse_args() + + files = glob.glob("**/test_*.py", recursive=True) + + tic = time.time() + success = run_unittest_files(files, args) + + if success: + print(f"Success. Time elapsed: {time.time() - tic:.2f}s") + else: + print(f"Fail. Time elapsed: {time.time() - tic:.2f}s") + + exit(0 if success else -1) diff --git a/test/lang/test_anthropic_backend.py b/test/lang/test_anthropic_backend.py new file mode 100644 index 000000000..988007289 --- /dev/null +++ b/test/lang/test_anthropic_backend.py @@ -0,0 +1,35 @@ +import json +import unittest + +from sglang.test.test_programs import test_mt_bench, test_stream + +from sglang import Anthropic, set_default_backend + + +class TestAnthropicBackend(unittest.TestCase): + backend = None + chat_backend = None + + def setUp(self): + cls = type(self) + + if cls.backend is None: + cls.backend = Anthropic("claude-2") + set_default_backend(cls.backend) + + def test_mt_bench(self): + test_mt_bench() + + def test_stream(self): + test_stream() + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # from sglang.global_config import global_config + + # global_config.verbosity = 2 + # t = TestAnthropicBackend() + # t.setUp() + # t.test_mt_bench() diff --git a/test/lang/test_bind_pin.py b/test/lang/test_bind_pin.py new file mode 100644 index 000000000..38e5daa41 --- /dev/null +++ b/test/lang/test_bind_pin.py @@ -0,0 +1,54 @@ +import unittest + +from sglang.backend.runtime_endpoint import RuntimeEndpoint + +import sglang as sgl + + +class TestBind(unittest.TestCase): + backend = None + + def setUp(self): + cls = type(self) + + if cls.backend is None: + cls.backend = RuntimeEndpoint(base_url="http://localhost:30000") + + def test_bind(self): + @sgl.function + def few_shot_qa(s, prompt, question): + s += prompt + s += "Q: What is the capital of France?\n" + s += "A: Paris\n" + s += "Q: " + question + "\n" + s += "A:" + sgl.gen("answer", stop="\n") + + few_shot_qa_2 = few_shot_qa.bind( + prompt="The following are questions with answers.\n\n" + ) + + tracer = few_shot_qa_2.trace() + print(tracer.last_node.print_graph_dfs() + "\n") + + def test_pin(self): + @sgl.function + def few_shot_qa(s, prompt, question): + s += prompt + s += "Q: What is the capital of France?\n" + s += "A: Paris\n" + s += "Q: " + question + "\n" + s += "A:" + sgl.gen("answer", stop="\n") + + few_shot_qa_2 = few_shot_qa.bind( + prompt="Answer the following questions as if you were a 5-year-old kid.\n\n" + ) + few_shot_qa_2.pin(self.backend) + few_shot_qa_2.unpin(self.backend) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # t = TestBind() + # t.setUp() + # t.test_pin() diff --git a/test/lang/test_openai_backend.py b/test/lang/test_openai_backend.py new file mode 100644 index 000000000..f5590e13d --- /dev/null +++ b/test/lang/test_openai_backend.py @@ -0,0 +1,91 @@ +import unittest + +from sglang.test.test_programs import ( + test_decode_int, + test_decode_json, + test_expert_answer, + test_few_shot_qa, + test_image_qa, + test_mt_bench, + test_parallel_decoding, + test_parallel_encoding, + test_react, + test_select, + test_stream, + test_tool_use, +) + +from sglang import OpenAI, set_default_backend + + +class TestOpenAIBackend(unittest.TestCase): + backend = None + chat_backend = None + chat_vision_backend = None + + def setUp(self): + cls = type(self) + + if cls.backend is None: + cls.backend = OpenAI("gpt-3.5-turbo-instruct") + cls.chat_backend = OpenAI("gpt-3.5-turbo") + cls.chat_vision_backend = OpenAI("gpt-4-vision-preview") + + def test_few_shot_qa(self): + set_default_backend(self.backend) + test_few_shot_qa() + + def test_mt_bench(self): + set_default_backend(self.chat_backend) + test_mt_bench() + + def test_select(self): + set_default_backend(self.backend) + test_select(check_answer=True) + + def test_decode_int(self): + set_default_backend(self.backend) + test_decode_int() + + def test_decode_json(self): + set_default_backend(self.backend) + test_decode_json() + + def test_expert_answer(self): + set_default_backend(self.backend) + test_expert_answer() + + def test_tool_use(self): + set_default_backend(self.backend) + test_tool_use() + + def test_react(self): + set_default_backend(self.backend) + test_react() + + def test_parallel_decoding(self): + set_default_backend(self.backend) + test_parallel_decoding() + + def test_parallel_encoding(self): + set_default_backend(self.backend) + test_parallel_encoding() + + def test_image_qa(self): + set_default_backend(self.chat_vision_backend) + test_image_qa() + + def test_stream(self): + set_default_backend(self.backend) + test_stream() + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # from sglang.global_config import global_config + + # global_config.verbosity = 2 + # t = TestOpenAIBackend() + # t.setUp() + # t.test_decode_json() diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py new file mode 100644 index 000000000..374c53db9 --- /dev/null +++ b/test/lang/test_srt_backend.py @@ -0,0 +1,74 @@ +""" +python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +""" +import json +import unittest + +from sglang.test.test_programs import ( + test_decode_int, + test_decode_json, + test_expert_answer, + test_few_shot_qa, + test_mt_bench, + test_parallel_decoding, + test_parallel_encoding, + test_react, + test_regex, + test_select, + test_stream, + test_tool_use, +) + +import sglang as sgl + + +class TestSRTBackend(unittest.TestCase): + backend = None + + def setUp(self): + cls = type(self) + + if cls.backend is None: + cls.backend = sgl.RuntimeEndpoint(base_url="http://localhost:30000") + sgl.set_default_backend(cls.backend) + + def test_few_shot_qa(self): + test_few_shot_qa() + + def test_mt_bench(self): + test_mt_bench() + + def test_select(self): + test_select(check_answer=False) + + def test_decode_int(self): + test_decode_int() + + def test_expert_answer(self): + test_expert_answer() + + def test_tool_use(self): + test_tool_use() + + def test_parallel_decoding(self): + test_parallel_decoding() + + def test_stream(self): + test_stream() + + def test_regex(self): + test_regex() + + # def test_parallel_encoding(self): + # test_parallel_encoding(check_answer=False) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # from sglang.global_config import global_config + + # global_config.verbosity = 2 + # t = TestSRTBackend() + # t.setUp() + # t.test_regex() diff --git a/test/lang/test_tracing.py b/test/lang/test_tracing.py new file mode 100644 index 000000000..4b162983a --- /dev/null +++ b/test/lang/test_tracing.py @@ -0,0 +1,132 @@ +import unittest + +from sglang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template + +import sglang as sgl + + +class TestTracing(unittest.TestCase): + def test_few_shot_qa(self): + @sgl.function + def few_shot_qa(s, question): + s += "The following are questions with answers.\n\n" + s += "Q: What is the capital of France?\n" + s += "A: Paris\n" + s += "Q: " + question + "\n" + s += "A:" + sgl.gen("answer", stop="\n") + + tracer = few_shot_qa.trace() + print(tracer.last_node.print_graph_dfs() + "\n") + + def test_select(self): + @sgl.function + def capital(s): + s += "The capital of France is" + s += sgl.select("capital", ["Paris. ", "London. "]) + s += "It is a city" + sgl.gen("description", stop=".") + + tracer = capital.trace() + print(tracer.last_node.print_graph_dfs() + "\n") + + def test_raise_warning(self): + @sgl.function + def wrong(s, question): + s += f"I want to ask {question}" + + try: + tracer = wrong.trace() + raised = False + except TypeError: + raised = True + + assert raised + + def test_multi_function(self): + @sgl.function + def expand(s, tip): + s += ( + "Please expand the following tip into a detailed paragraph:" + + tip + + "\n" + ) + s += sgl.gen("detailed_tip") + + @sgl.function + def tip_suggestion(s, topic): + s += "Here are 2 tips for " + topic + ".\n" + + s += "1." + sgl.gen("tip_1", stop=["\n", ":", "."]) + "\n" + s += "2." + sgl.gen("tip_2", stop=["\n", ":", "."]) + "\n" + + branch1 = expand(tip=s["tip_1"]) + branch2 = expand(tip=s["tip_2"]) + + s += "Tip 1: " + branch1["detailed_tip"] + "\n" + s += "Tip 2: " + branch2["detailed_tip"] + "\n" + s += "In summary" + sgl.gen("summary") + + compiled = tip_suggestion.compile() + compiled.print_graph() + + sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct")) + state = compiled.run(topic="staying healthy") + print(state.text() + "\n") + + states = compiled.run_batch( + [ + {"topic": "staying healthy"}, + {"topic": "staying happy"}, + {"topic": "earning money"}, + ], + temperature=0, + ) + for s in states: + print(s.text() + "\n") + + def test_role(self): + @sgl.function + def multi_turn_chat(s): + s += sgl.user("Who are you?") + s += sgl.assistant(sgl.gen("answer_1")) + s += sgl.user("Who created you?") + s += sgl.assistant(sgl.gen("answer_2")) + + backend = BaseBackend() + backend.chat_template = get_chat_template("llama-2-chat") + + compiled = multi_turn_chat.compile(backend=backend) + compiled.print_graph() + + def test_fork(self): + @sgl.function + def tip_suggestion(s): + s += ( + "Here are three tips for staying healthy: " + "1. Balanced Diet; " + "2. Regular Exercise; " + "3. Adequate Sleep\n" + ) + + forks = s.fork(3) + for i in range(3): + forks[i] += f"Now, expand tip {i+1} into a paragraph:\n" + forks[i] += sgl.gen(f"detailed_tip") + + s += "Tip 1:" + forks[0]["detailed_tip"] + "\n" + s += "Tip 2:" + forks[1]["detailed_tip"] + "\n" + s += "Tip 3:" + forks[2]["detailed_tip"] + "\n" + s += "In summary" + sgl.gen("summary") + + tracer = tip_suggestion.trace() + print(tracer.last_node.print_graph_dfs()) + + a = tip_suggestion.run(backend=sgl.OpenAI("gpt-3.5-turbo-instruct")) + print(a.text()) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # t = TestTracing() + # t.test_fork() diff --git a/test/srt/model/bench_llama_low_api.py b/test/srt/model/bench_llama_low_api.py new file mode 100644 index 000000000..3e3534709 --- /dev/null +++ b/test/srt/model/bench_llama_low_api.py @@ -0,0 +1,274 @@ +import multiprocessing as mp +import time +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from sglang.srt.managers.router.model_runner import ModelRunner +from sglang.srt.model_config import ModelConfig + + +@dataclass +class BenchBatch: + req_to_token_pool: torch.Tensor + token_to_kv_pool: torch.Tensor + + input_ids: torch.Tensor = None + position_ids_offsets: torch.Tensor = None + seq_lens: torch.Tensor = None + prefix_lens: torch.Tensor = None + req_pool_indices: torch.Tensor = None + out_cache_loc: torch.Tensor = None + out_cache_cont_start: torch.Tensor = None + out_cache_cont_end: torch.Tensor = None + + def __init__(self, model_runner: ModelRunner): + self.req_to_token_pool = model_runner.req_to_token_pool + self.token_to_kv_pool = model_runner.token_to_kv_pool + + def init_prefill_batch(self, input_ids, batch_size, seq_len): + self.input_ids = input_ids + self.position_ids_offsets = torch.zeros( + batch_size, dtype=torch.int32, device="cuda" + ) + self.seq_lens = torch.full( + (batch_size,), seq_len, dtype=torch.int32, device="cuda" + ) + self.prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + self.req_pool_indices = self.req_to_token_pool.alloc(batch_size) + self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * seq_len) + + for i in range(batch_size): + n_idx = self.req_pool_indices[i].item() + self.req_to_token_pool.req_to_token[n_idx, :seq_len] = self.out_cache_loc[ + i * seq_len : (i + 1) * seq_len + ] + + def update_extend( + self, input_ids, batch_size, prefix_len, extend_len, prefix_req_idx + ): + self.input_ids = input_ids + self.position_ids_offsets = torch.zeros( + batch_size, dtype=torch.int32, device="cuda" + ) + self.seq_lens = torch.full( + (batch_size,), prefix_len + extend_len, dtype=torch.int32, device="cuda" + ) + self.prefix_lens = torch.full( + (batch_size,), prefix_len, dtype=torch.int32, device="cuda" + ) + self.req_pool_indices = self.req_to_token_pool.alloc(batch_size) + self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * extend_len) + + req_to_token = self.req_to_token_pool.req_to_token + fork_num = batch_size // prefix_req_idx.shape[0] + for i in range(batch_size): + p_idx = prefix_req_idx[i // fork_num].item() + n_idx = self.req_pool_indices[i].item() + req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len] + req_to_token[ + n_idx, prefix_len : prefix_len + extend_len + ] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len] + + def update_decode(self, predict_ids, batch_size): + assert predict_ids.shape[0] == batch_size + assert batch_size == self.req_pool_indices.shape[0] + + self.input_ids = predict_ids.reshape(-1) + self.prefix_lens = None + ( + self.out_cache_loc, + self.out_cache_cont_start, + self.out_cache_cont_end, + ) = self.token_to_kv_pool.alloc_contiguous(batch_size) + self.req_to_token_pool.req_to_token[ + self.req_pool_indices, self.seq_lens + ] = self.out_cache_loc + self.seq_lens.add_(1) + + +def prefill(model_runner: ModelRunner, batch: BenchBatch): + logits, _ = model_runner.forward_extend( + batch.input_ids, + batch.req_pool_indices, + batch.seq_lens, + batch.prefix_lens, + batch.position_ids_offsets, + batch.out_cache_loc, + False, + ) + + prob_out = torch.softmax(logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + + return predict_ids + + +def extend(model_runner: ModelRunner, batch: BenchBatch): + logits, _ = model_runner.forward_extend( + batch.input_ids, + batch.req_pool_indices, + batch.seq_lens, + batch.prefix_lens, + batch.position_ids_offsets, + batch.out_cache_loc, + True, + ) + + prob_out = torch.softmax(logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + + return predict_ids + + +def decode(model_runner: ModelRunner, batch: BenchBatch): + logits = model_runner.forward_decode( + batch.input_ids, + batch.req_pool_indices, + batch.seq_lens, + None, + batch.position_ids_offsets, + None, + batch.out_cache_cont_start, + batch.out_cache_cont_end, + ) + + prob_out = torch.softmax(logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + + return predict_ids + + +def bench_generate_worker( + model_path, + tp_rank, + tp_size, + shared_num, + unique_num, + shared_len, + unique_len, + decode_len, + model_mode, +): + assert unique_num % shared_num == 0 + + model_config = ModelConfig(path=model_path) + model_runner = ModelRunner( + model_config=model_config, + mem_fraction_static=0.8, + tp_rank=tp_rank, + tp_size=tp_size, + nccl_port=28888, + model_mode=model_mode, + ) + + batch = BenchBatch(model_runner) + + # warm up + for _ in range(1): + input_ids = torch.randint( + low=5, high=100, size=(shared_num * shared_len,) + ).cuda() + batch.init_prefill_batch(input_ids, shared_num, shared_len) + _ = prefill(model_runner, batch) + + input_ids = torch.randint( + low=5, high=100, size=(unique_num * unique_len,) + ).cuda() + batch.update_extend( + input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices + ) + predict_ids = extend(model_runner, batch) + + for i in range(decode_len): + predict_ids = torch.from_numpy(predict_ids).cuda() + batch.update_decode(predict_ids, unique_num) + predict_ids = decode(model_runner, batch) + + model_runner.req_to_token_pool.clear() + model_runner.token_to_kv_pool.clear() + + if tp_size > 1: + dist.barrier() + + prefill_start = time.time() + input_ids = torch.randint(low=5, high=100, size=(shared_num * shared_len,)).cuda() + batch.init_prefill_batch(input_ids, shared_num, shared_len) + _ = prefill(model_runner, batch) + if tp_rank == 0: + print(f"prefill: {(time.time() - prefill_start) * 1000:.2f} ms") + + extend_start = time.time() + input_ids = torch.randint(low=5, high=100, size=(unique_num * unique_len,)).cuda() + batch.update_extend( + input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices + ) + predict_ids = extend(model_runner, batch) + if tp_rank == 0: + print(f"extend: {(time.time() - extend_start) * 1000:.2f} ms") + + for i in range(decode_len): + decode_start = time.time() + predict_ids = torch.from_numpy(predict_ids).cuda() + batch.update_decode(predict_ids, unique_num) + predict_ids = decode(model_runner, batch) + if tp_rank == 0: + print(f"decode {i}: {(time.time() - decode_start) * 1000:.2f} ms") + + +def bench_generate( + model_path, + tp_size, + shared_num, + unique_num, + shared_len, + unique_len, + decode_len, + model_mode, +): + print( + f"tp_size: {tp_size}, " + f"shared_num: {shared_num}, " + f"unique_num: {unique_num}, " + f"shared_len: {shared_len}, " + f"unique_len: {unique_len}, " + f"decode_len: {decode_len}, " + f"model_mode: {model_mode}" + ) + workers = [] + for tp_rank in range(tp_size): + proc = mp.Process( + target=bench_generate_worker, + args=( + model_path, + tp_rank, + tp_size, + shared_num, + unique_num, + shared_len, + unique_len, + decode_len, + model_mode, + ), + ) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() + + +if __name__ == "__main__": + bench_generate( + model_path="meta-llama/Llama-2-7b-chat-hf", + tp_size=1, + shared_num=1, + unique_num=32, + shared_len=256, + unique_len=256, + decode_len=8, + model_mode=[], + ) diff --git a/test/srt/model/reference_hf.py b/test/srt/model/reference_hf.py new file mode 100644 index 000000000..e63866f02 --- /dev/null +++ b/test/srt/model/reference_hf.py @@ -0,0 +1,80 @@ +import argparse +import os + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +@torch.inference_mode() +def normal_text(args): + t = AutoTokenizer.from_pretrained(args.model_path) + m = AutoModelForCausalLM.from_pretrained( + args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + m.cuda() + + print(m) + + prompts = [ + "The capital of France is", + "The capital of the United Kindom is", + "Today is a sunny day and I like", + ] + max_new_tokens = 32 + + for p in prompts: + if isinstance(p, str): + input_ids = t.encode(p, return_tensors="pt").cuda() + else: + input_ids = torch.tensor([p], device="cuda") + + output_ids = m.generate( + input_ids, do_sample=False, max_new_tokens=max_new_tokens + ) + output_str = t.decode(output_ids[0]) + print(output_str) + + prefill_logits = m.forward(input_ids).logits[0][-1] + print("prefill logits", prefill_logits) + + +@torch.inference_mode() +def synthetic_tokens(args): + t = AutoTokenizer.from_pretrained(args.model_path) + m = AutoModelForCausalLM.from_pretrained( + args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + m.cuda() + print(m) + + input_len = 256 + output_len = 8 + prompts = [list(range(5, 5 + input_len))] + + for p in prompts: + input_ids = p + for i in range(output_len + 1): + prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[ + 0 + ][-1] + + if i == 0: + print("prefill logits", prefill_logits) + else: + print("decode", i - 1, prefill_logits) + + input_ids.append(torch.argmax(prefill_logits).item()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-path", + type=str, + default="TinyLlama/TinyLlama-1.1B-Chat-v0.4", + # default="meta-llama/Llama-2-7b-chat-hf", + ) + args = parser.parse_args() + + normal_text(args) + # synthetic_tokens(args) diff --git a/test/srt/model/test_llama_extend.py b/test/srt/model/test_llama_extend.py new file mode 100644 index 000000000..fdd7bbb13 --- /dev/null +++ b/test/srt/model/test_llama_extend.py @@ -0,0 +1,108 @@ +import multiprocessing +import os +import time + +import numpy as np +import torch +import torch.distributed as dist +import transformers +from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req +from sglang.srt.managers.router.model_runner import ModelRunner +from sglang.srt.model_config import ModelConfig +from sglang.srt.sampling_params import SamplingParams + + +def test_generate_worker(model_path, tp_rank, tp_size): + model_config = ModelConfig(path=model_path) + model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888) + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + + # Input + prompts = [ + "The capital of France is", + "Today is a sunny day and I like", + ] + sampling_params = SamplingParams(temperature=0) + + cut_num = 4 + + reqs = [] + for i in range(len(prompts)): + req = Req(i) + req.input_ids = tokenizer.encode(prompts[i])[:cut_num] + req.sampling_params = sampling_params + reqs.append(req) + + # Prefill + batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None) + batch.init_extend_batch(model.model_config.vocab_size(), None) + logits, _ = model.forward(batch, ForwardMode.EXTEND) + next_token_ids, next_token_probs = batch.sample(logits) + print("extend logits (first)", logits) + + # Extend + for i in range(len(prompts)): + req = reqs[i] + req.input_ids += tokenizer.encode(prompts[i])[cut_num:] + req.prefix_indices = model.req_to_token_pool.req_to_token[ + batch.req_pool_indices[i], :cut_num + ] + batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None) + batch.init_extend_batch(model.model_config.vocab_size(), None) + logits, _ = model.forward(batch, ForwardMode.EXTEND) + next_token_ids, next_token_probs = batch.sample(logits) + + print("extend logits", logits) + print( + "next_token_ids", next_token_ids, [tokenizer.decode(x) for x in next_token_ids] + ) + + # Decode + for i in range(6): + batch.update_for_decode(next_token_ids.cpu().numpy()) + logits = model.forward(batch, ForwardMode.DECODE) + next_token_ids, next_token_probs = batch.sample(logits) + + print( + "next_token_ids", + next_token_ids, + [tokenizer.decode(x) for x in next_token_ids], + ) + + +def test_generate(model_path, tp_size): + workers = [] + for tp_rank in range(tp_size): + proc = multiprocessing.Process( + target=test_generate_worker, + args=( + model_path, + tp_rank, + tp_size, + ), + ) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() + + +if __name__ == "__main__": + os.environ["TOKENIZERS_PARALLELISM"] = "false" + test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1) + + # Reference output for TinyLlama-1.1B-Chat-v0.4 + # extend logits (first) tensor([[-10.0312, -9.5000, 0.8896, ..., -4.9375, -3.2402, -3.3633], + # [ -9.1797, -10.2500, 2.7168, ..., -4.3359, -4.0664, -4.1289]], + # device='cuda:0', dtype=torch.float16) + # extend logits tensor([[-8.3125, -7.1172, 3.3359, ..., -4.9531, -4.1289, -3.4121], + # [-9.6406, -9.0547, 4.0195, ..., -5.3086, -4.7188, -4.4609]], + # device='cuda:0', dtype=torch.float16) + # next_token_ids tensor([3681, 304], device='cuda:0') ['Paris', 'to'] + # next_token_ids tensor([29889, 748], device='cuda:0') ['.', 'go'] + # next_token_ids tensor([ 13, 363], device='cuda:0') ['\n', 'for'] + # next_token_ids tensor([1576, 263], device='cuda:0') ['The', 'a'] + # next_token_ids tensor([7483, 6686], device='cuda:0') ['capital', 'walk'] + # next_token_ids tensor([310, 297], device='cuda:0') ['of', 'in'] + # next_token_ids tensor([278, 278], device='cuda:0') ['the', 'the'] diff --git a/test/srt/model/test_llama_low_api.py b/test/srt/model/test_llama_low_api.py new file mode 100644 index 000000000..e556ec7eb --- /dev/null +++ b/test/srt/model/test_llama_low_api.py @@ -0,0 +1,209 @@ +import multiprocessing +import time + +import numpy as np +import torch +import torch.distributed as dist +from sglang.srt.managers.router.model_runner import ModelRunner +from sglang.srt.model_config import ModelConfig + + +def test_generate_worker( + model_path, tp_rank, tp_size, batch_size, input_len, output_len +): + model_config = ModelConfig(path=model_path) + model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888) + + # Prepare data + input_ids = np.vstack([np.arange(5, input_len + 5) for _ in range(batch_size)]) + input_ids = input_ids.reshape(-1) + input_ids = torch.tensor(input_ids).cuda() + + def init_batch_data(model, batch_size, input_len): + req_pool_indices = model.req_to_token_pool.alloc(batch_size) + seq_lens = torch.full( + (batch_size,), input_len, dtype=torch.int32, device="cuda" + ) + prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + + out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len) + for i in range(batch_size): + req_idx = req_pool_indices[i].item() + model.req_to_token_pool.req_to_token[req_idx, :input_len] = out_cache_loc[ + i * input_len : (i + 1) * input_len + ] + + return ( + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + ) + + def prefill(print_logits): + nonlocal predict_ids + + logits, _ = model.forward_prefill( + input_ids, + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + False, + ) + prob_out = torch.softmax(logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + + if print_logits and tp_rank == 0: + print("prefill logits", logits, logits.shape) + + def decode(print_logits): + nonlocal predict_ids + + ( + out_cache_loc, + out_cache_cont_start, + out_cache_cont_end, + ) = model.token_to_kv_pool.alloc_contiguous(batch_size) + model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc + seq_lens.add_(1) + logits = model.forward_decode( + torch.from_numpy(predict_ids).cuda().reshape(-1), + req_pool_indices, + seq_lens, + None, + position_ids_offsets, + None, + out_cache_cont_start, + out_cache_cont_end, + ) + prob_out = torch.softmax(logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + if print_logits and tp_rank == 0: + print("decode", i, logits) + + # Warm up + ( + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + ) = init_batch_data(model, batch_size, input_len) + predict_ids = None + + prefill(True) + for i in range(output_len): + decode(True) + + for i in range(batch_size): + req_idx = req_pool_indices[i].item() + model.token_to_kv_pool.free( + model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]] + ) + model.req_to_token_pool.free(req_pool_indices) + + # Benchmark + if tp_size > 1: + dist.barrier() + start_time = prefill_start_time = time.time() + + ( + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + ) = init_batch_data(model, batch_size, input_len) + + prefill(False) + + if tp_rank == 0: + print(f"prefill cost: {(time.time() - prefill_start_time) * 1000:.2f} ms") + + for i in range(output_len): + step_start = time.time() + + decode(False) + + step_end = time.time() + + if i % 100 == 0 or i == output_len - 1: + if tp_rank == 0: + print(f"step {i} cost: {(step_end - step_start) * 1000:.2f} ms") + + end_time = time.time() + + if tp_rank == 0: + print(f"total cost: {(end_time - start_time) * 1000:.2f}") + + +def test_generate(model_path, tp_size, batch_size, input_len, output_len): + workers = [] + for tp_rank in range(tp_size): + proc = multiprocessing.Process( + target=test_generate_worker, + args=( + model_path, + tp_rank, + tp_size, + batch_size, + input_len, + output_len, + ), + ) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() + + +if __name__ == "__main__": + test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1, 1, 256, 8) + # test_generate("meta-llama/Llama-2-7b-chat-hf", 1, 16, 256, 8) + + # Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 32, 8) + # prefill logits tensor([[-1.3380e-03, 4.4702e-01, 2.9082e+00, ..., -1.8398e+00, + # 1.8281e+00, 2.1816e+00]], device='cuda:0') + # decode 0 tensor([[-0.3904, 0.8784, 3.6934, ..., -2.4473, 1.5811, 2.0098]], + # device='cuda:0') + # decode 1 tensor([[-0.3552, 0.0635, 2.5781, ..., -2.5820, 1.3047, 1.7607]], + # device='cuda:0') + # decode 2 tensor([[-1.5645, -1.1963, 3.8145, ..., -2.9766, 1.0244, 1.0645]], + # device='cuda:0') + # decode 3 tensor([[-1.3682, -0.6548, 4.2734, ..., -2.8711, 1.1172, 1.1494]], + # device='cuda:0') + # decode 4 tensor([[-1.0205, -0.0060, 4.4844, ..., -2.7090, 1.6143, 1.8135]], + # device='cuda:0') + # decode 5 tensor([[ 0.4260, 1.6006, 4.3633, ..., -2.2480, 2.5547, 2.8379]], + # device='cuda:0') + # decode 6 tensor([[ 0.7095, 2.1816, 5.0078, ..., -2.1309, 3.0293, 3.0840]], + # device='cuda:0') + # decode 7 tensor([[-0.2883, 1.1289, 4.7188, ..., -2.4023, 2.1055, 2.1836]], + # device='cuda:0') + + # Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 256, 8) + # prefill logits tensor([[-2.5840, -2.7227, 6.8047, ..., -2.3613, 0.1224, 0.5952]], + # device='cuda:0') + # decode 0 tensor([[-0.6235, -0.7690, 9.2891, ..., -1.4922, 2.8008, 2.9531]], + # device='cuda:0') + # decode 1 tensor([[-1.3662, -1.4648, 7.1250, ..., -1.7861, 1.7363, 1.8857]], + # device='cuda:0') + # decode 2 tensor([[-0.8540, -0.5947, 9.1328, ..., -2.1211, 2.9707, 2.8945]], + # device='cuda:0') + # decode 3 tensor([[ 0.0652, 1.0312, 8.1250, ..., -2.0586, 3.4727, 3.6172]], + # device='cuda:0') + # decode 4 tensor([[-0.0459, 1.0098, 9.1406, ..., -2.1797, 3.8320, 3.9355]], + # device='cuda:0') + # decode 5 tensor([[ 0.2964, 1.3564, 9.8828, ..., -2.1602, 4.1836, 4.2422]], + # device='cuda:0') + # decode 6 tensor([[ 0.6475, 1.8105, 10.1250, ..., -2.0098, 4.2578, 4.4062]], + # device='cuda:0') + # decode 7 tensor([[ 0.4985, 1.4746, 9.9062, ..., -1.9141, 3.9863, 4.3047]], + # device='cuda:0') diff --git a/test/srt/model/test_llava_low_api.py b/test/srt/model/test_llava_low_api.py new file mode 100644 index 000000000..00cdd622f --- /dev/null +++ b/test/srt/model/test_llava_low_api.py @@ -0,0 +1,161 @@ +import multiprocessing +import time + +import numpy as np +import torch +import torch.distributed as dist +from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.managers.router.infer_batch import ForwardMode +from sglang.srt.managers.router.model_runner import InputMetadata, ModelRunner +from sglang.srt.model_config import ModelConfig +from sglang.srt.utils import load_image + + +def init_batch_data(model, batch_size, input_len): + req_pool_indices = model.req_to_token_pool.alloc(batch_size) + seq_lens = torch.full((batch_size,), input_len, dtype=torch.int32, device="cuda") + prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + + out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len) + for i in range(batch_size): + model.req_to_token_pool.req_to_token[i, :input_len] = out_cache_loc[ + i * input_len : (i + 1) * input_len + ] + + return ( + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + ) + + +def prefill(model, tp_rank, params, print_logits): + logits, _ = model.forward_extend_multi_modal( + *params, + False, + ) + prob_out = torch.softmax(logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + + if print_logits and tp_rank == 0: + print("prefill logits", logits, logits.shape) + + return predict_ids + + +def decode(step, model, tp_rank, batch_size, predict_ids, params, print_logits): + ( + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + ) = params + + ( + out_cache_loc, + out_cache_cont_start, + out_cache_cont_end, + ) = model.token_to_kv_pool.alloc_contiguous(batch_size) + model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc + seq_lens.add_(1) + logits = model.forward_decode( + torch.from_numpy(predict_ids).cuda().reshape(-1), + req_pool_indices, + seq_lens, + None, + position_ids_offsets, + None, + out_cache_cont_start, + out_cache_cont_end, + ) + prob_out = torch.softmax(logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + if print_logits and tp_rank == 0: + print("decode", step, logits) + return predict_ids + + +def test_generate_worker( + model_path, + tp_rank, + tp_size, +): + model_config = ModelConfig(path=model_path) + model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888) + # print(model.model) + + # Prepare data + prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:" + image_path = "/home/ubuntu/sglang/test/lang/image.png" + image = load_image(image_path) + + processor = get_processor("llava-hf/llava-1.5-7b-hf") + input_ids = processor.tokenizer.encode(prompt) + pixel_values = processor.image_processor(image)["pixel_values"] + input_ids, offset = model.model.pad_input_ids( + input_ids, + [ + 0, + ], + ) + + params = init_batch_data(model, 1, len(input_ids)) + + # inference + output_ids = [] + prefill_params = ( + torch.tensor(np.array(input_ids)).cuda(), + np.array(pixel_values), + [offset], + *params, + ) + predict_ids = prefill(model, tp_rank=0, params=prefill_params, print_logits=False) + output_ids.append(predict_ids[0][0]) + for i in range(16): + predict_ids = decode( + i, + model, + tp_rank=0, + batch_size=1, + predict_ids=predict_ids, + params=params, + print_logits=False, + ) + output_ids.append(predict_ids[0][0]) + + # detokenization + output = processor.tokenizer.batch_decode( + [output_ids], skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + assert ( + output + == "The image features a man standing on the back of a yellow taxi cab, holding" + ) + + +def test_generate(model_path, tp_size): + workers = [] + for tp_rank in range(tp_size): + proc = multiprocessing.Process( + target=test_generate_worker, + args=( + model_path, + tp_rank, + tp_size, + ), + ) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() + + +if __name__ == "__main__": + test_generate("liuhaotian/llava-v1.5-7b", 1) diff --git a/test/srt/test_flashinfer.py b/test/srt/test_flashinfer.py new file mode 100644 index 000000000..3fef32e99 --- /dev/null +++ b/test/srt/test_flashinfer.py @@ -0,0 +1,163 @@ +import flashinfer +import pytest +import torch +from sglang.srt.layers.extend_attention import extend_attention_fwd +from sglang.srt.layers.token_attention import token_attention_fwd + + +@pytest.mark.parametrize("batch_size", [12, 37, 67]) +@pytest.mark.parametrize("kv_len", [54, 97]) +@pytest.mark.parametrize("qo_len", [37, 17]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("use_wrapper", [True, False]) +def test_batch_prefill_with_paged_kv_cache( + batch_size, + kv_len, + qo_len, + num_kv_heads, + num_qo_heads, + head_dim, + use_wrapper, +): + q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() + q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len + total_tokens = kv_len * batch_size + kv_data = torch.randn(total_tokens, 2, num_kv_heads, 1, head_dim).to(0).half() + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len + kv_indices = torch.arange(0, total_tokens).to(0).int() + kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0) + + # init args for triton kernel + k_extend = ( + kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 0] + .contiguous() + .view(-1, num_kv_heads, head_dim) + ) + v_extend = ( + kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 1] + .contiguous() + .view(-1, num_kv_heads, head_dim) + ) + o_triton = torch.empty_like(q) + k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous() + v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous() + req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len) + b_req_idx = torch.arange(0, batch_size).to(0).int() + b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0) + b_start_loc_extend = torch.arange(0, batch_size).to(0).int() * qo_len + b_seq_len_extend = torch.full((batch_size,), qo_len, dtype=torch.int32).to(0) + max_len_in_batch = kv_len + max_len_extend = qo_len + + extend_attention_fwd( + q, + k_extend, + v_extend, + o_triton, + k_buffer, + v_buffer, + req_to_token, + b_req_idx, + None, # b_start_loc = None + b_seq_len, + None, # b_seq_len_prefix = None + b_start_loc_extend, + b_seq_len_extend, + max_len_in_batch, + max_len_extend, + ) + + if use_wrapper: + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper() + wrapper.begin_forward(q_indptr, batch_size, num_qo_heads, num_kv_heads) + o = wrapper.forward( + q, q_indptr, kv_data, kv_indptr, kv_indices, kv_last_page_len + ) + else: + o = flashinfer.batch_prefill_with_paged_kv_cache( + q, + q_indptr, + kv_data, + kv_indptr, + kv_indices, + kv_last_page_len, + ) + + print("Mean: ", torch.mean(torch.abs(o - o_triton))) + print("Max: ", torch.max(torch.abs(o - o_triton))) + assert torch.allclose(o, o_triton, rtol=1e-2, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [12, 17, 37]) +@pytest.mark.parametrize("kv_len", [54, 127, 537]) +@pytest.mark.parametrize("num_kv_heads", [32]) +@pytest.mark.parametrize("num_qo_heads", [32]) +@pytest.mark.parametrize("head_dim", [128]) +def test_batch_decode_with_paged_kv_cache( + batch_size, + kv_len, + num_kv_heads, + num_qo_heads, + head_dim, +): + # note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache + # to test different shape of decode, change the parameters in the __main__, and run decode only once + + q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half() + total_tokens = kv_len * batch_size + kv_data = torch.randn(total_tokens, 2, num_kv_heads, 1, head_dim).to(0).half() + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len + kv_indices = torch.arange(0, total_tokens).to(0).int() + kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0) + + # init args for triton kernel + k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous() + v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous() + o_triton = torch.empty_like(q) + req_to_token = ( + torch.arange(0, kv_len * batch_size).to(0).int().view(batch_size, kv_len) + ) + b_req_idx = torch.arange(0, batch_size).to(0).int() + b_start_loc = torch.arange(0, batch_size).to(0).int() * kv_len + b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0) + max_len_in_batch = kv_len + other_kv_index = 0 + token_attention_fwd( + q, + k_buffer, + v_buffer, + o_triton, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + max_len_in_batch, + other_kv_index, + total_tokens, + ) + + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper() + wrapper.begin_forward( + kv_indptr, + kv_last_page_len, + batch_size, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + "NONE", + "float16", + ) + o = wrapper.forward(q, kv_data, kv_indptr, kv_indices, kv_last_page_len) + + print("Mean: ", torch.mean(torch.abs(o - o_triton))) + print("Max: ", torch.max(torch.abs(o - o_triton))) + assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3) + + +if __name__ == "__main__": + test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128, False) + test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128, True) + test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128) diff --git a/test/srt/test_httpserver_concurrent.py b/test/srt/test_httpserver_concurrent.py new file mode 100644 index 000000000..855e51f33 --- /dev/null +++ b/test/srt/test_httpserver_concurrent.py @@ -0,0 +1,56 @@ +""" +python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 + +Output: +The capital of France is Paris.\nThe capital of the United States is Washington, D.C. + +The capital of the United Kindom is London.\nThe capital of the United Kingdom is London.\nThe capital of +""" + +import argparse +import asyncio +import json +import time + +import aiohttp +import requests + + +async def send_request(url, data, delay=0): + await asyncio.sleep(delay) + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data) as resp: + output = await resp.json() + return output + + +async def main(args): + url = f"{args.host}:{args.port}" + task1 = send_request( + url + "/generate", + { + "text": "The capital of France is", + "sampling_params": {"temperature": 0, "max_new_tokens": 128}, + }, + delay=1, + ) + + task2 = send_request( + url + "/generate", + { + "text": "The capital of the United Kindom is", + "sampling_params": {"temperature": 0, "max_new_tokens": 128}, + }, + ) + + rets = await asyncio.gather(task1, task2) + print(rets) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + args = parser.parse_args() + + asyncio.run(main(args)) diff --git a/test/srt/test_httpserver_decode.py b/test/srt/test_httpserver_decode.py new file mode 100644 index 000000000..a79ffb6e4 --- /dev/null +++ b/test/srt/test_httpserver_decode.py @@ -0,0 +1,31 @@ +""" +python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 + +Output: +The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo +""" + +import argparse +import time + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + args = parser.parse_args() + + url = f"{args.host}:{args.port}" + + response = requests.post( + url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + }, + ) + print(response.json()) diff --git a/test/srt/test_httpserver_decode_stream.py b/test/srt/test_httpserver_decode_stream.py new file mode 100644 index 000000000..048ee363f --- /dev/null +++ b/test/srt/test_httpserver_decode_stream.py @@ -0,0 +1,42 @@ +""" +python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 + +Output: +The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo +""" + +import argparse +import json +import time + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + args = parser.parse_args() + + url = f"{args.host}:{args.port}" + + response = requests.post( + url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 1024, + }, + "stream": True, + }, + stream=True, + ) + + prev = 0 + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + output = data["text"].strip() + print(output[prev:], end="", flush=True) + prev = len(output) + print("") diff --git a/test/srt/test_httpserver_llava.py b/test/srt/test_httpserver_llava.py new file mode 100644 index 000000000..042f4229d --- /dev/null +++ b/test/srt/test_httpserver_llava.py @@ -0,0 +1,84 @@ +""" +python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 + +Output: +The image features a man standing on the back of a yellow taxi cab, holding +""" + +import argparse +import asyncio +import json +import time + +import aiohttp +import requests + + +async def send_request(url, data, delay=0): + await asyncio.sleep(delay) + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data) as resp: + output = await resp.json() + return output + + +async def test_concurrent(args): + url = f"{args.host}:{args.port}" + + response = [] + for i in range(8): + response.append( + send_request( + url + "/generate", + { + "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:", + "image_data": "/home/ubuntu/sglang/test/lang/image.png", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + }, + }, + ) + ) + + rets = await asyncio.gather(*response) + for ret in rets: + print(ret["text"]) + + +def test_streaming(args): + url = f"{args.host}:{args.port}" + + response = requests.post( + url + "/generate", + json={ + "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:", + "image_data": "/home/ubuntu/sglang/test/lang/image.png", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 128, + }, + "stream": True, + }, + stream=True, + ) + + prev = 0 + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + output = data["text"].strip() + print(output[prev:], end="", flush=True) + prev = len(output) + print("") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + args = parser.parse_args() + + asyncio.run(test_concurrent(args)) + + test_streaming(args) diff --git a/test/srt/test_httpserver_reuse.py b/test/srt/test_httpserver_reuse.py new file mode 100644 index 000000000..c3f786589 --- /dev/null +++ b/test/srt/test_httpserver_reuse.py @@ -0,0 +1,43 @@ +""" +python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 + +Output: +The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo +""" + +import argparse +import time + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + args = parser.parse_args() + + url = f"{args.host}:{args.port}" + + response = requests.post( + url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + }, + ) + print(response.json()) + + response = requests.post( + url + "/generate", + json={ + "text": "The capital of France is Paris.\nThe capital of the United States is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + }, + ) + print(response.json())