release initial code

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com>
Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Lianmin Zheng
2024-01-08 04:37:50 +00:00
parent f6d40df0ee
commit 22085081bb
145 changed files with 17802 additions and 2 deletions

21
.gitignore vendored
View File

@@ -157,4 +157,23 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
# MacOS
.DS_Store
*.json
# Vim
*.swp
# SGL
benchmark/mmlu/data
benchmark/mmlu/data.tar
benchmark/llava_bench/images
benchmark/llava_bench/mme_pack
*.jsonl
tmp*.txt
# Plots
*.png
*.pdf

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "3rdparty/flashinfer"]
path = 3rdparty/flashinfer
url = git@github.com:flashinfer-ai/flashinfer.git

168
README.md
View File

@@ -1 +1,167 @@
# sglang
# SGLang
SGLang is a structured generation language designed for large language models (LLMs).
It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.
The core features of SGLang include:
- **A Flexible Front-End Language**: This allows for easy programming of LLM applications with multiple chained generation calls, advanced prompting techniques, control flow, multiple modalities, parallelism, and external interaction.
- **A High-Performance Runtime with RadixAttention**: This feature significantly accelerates the execution of complex LLM programs by automatic KV cache reuse across multiple calls. It also supports other common techniques like continuous batching and tensor parallelism.
## Contents
- [Install](#install)
- [Quick Start](#quick-start)
- [Frontend: Structured Generation Langauge (SGLang)](#frontend-structured-generation-langauge-sglang)
- [Backend: SGLang Runtime (SRT)](#backend-sglang-runtime-srt)
- [Benchmark And Performance](#benchmark-and-performance)
- [Roadmap](#roadmap)
- [Citation And Acknowledgment](#citation-and-acknowledgment)
## Install
### Method 1: With Pip
### Method 2: From Source
```
git clone git@github.com:sgl-project/sglang.git
cd sglang
pip install --upgrade pip
pip install -e "python[all]"
```
## Quick Start
The example below shows how to use sglang to answer a mulit-turn question.
### Using OpenAI Models
```python
from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI
@function
def multi_turn_question(s, question_1, question_2):
s += system("You are a helpful assistant.")
s += user(question_1)
s += assistant(gen("answer_1", max_tokens=256))
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))
set_default_backend(OpenAI("gpt-3.5-turbo"))
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)
for m in state.messages():
print(m["role"], ":", m["content"])
```
### Using Local Models
First, launch a server with
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Then, connect to the server and answer a multi-turn question.
```python
from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint
@function
def multi_turn_question(s, question_1, question_2):
s += system("You are a helpful assistant.")
s += user(question_1)
s += assistant(gen("answer_1", max_tokens=256))
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))
set_default_backend(RuntimeEndpoint("http://localhost:30000"))
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)
for m in state.messages():
print(m["role"], ":", m["content"])
```
### More Examples
You can find more examples at [examples/quick_start](examples/quick_start).
## Frontend: Structured Generation Langauge (SGLang)
### Control Flow
### Parallelism
### Multi Modality
```python
@sgl.function
def multi_turn_question(s, image_file, question):
s += sgl.user(sgl.image(image_file) + question)
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
```
### Batching
### Streaming
### Other Backends
## Backend: SGLang Runtime (SRT)
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
However, it can also be used as a standalone API server.
In this case, the RadixAttention can still greatly accelerate many use cases.
### Usage
Launch a server
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Send a request
```
curl http://localhost:30000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"prompt": "Say this is a test",
"max_tokens": 16,
"temperature": 0
}'
```
### Additional Arguments
- Add `--tp 2` to enable tensor parallelism.
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
```
### Supported Models
- Llama
- Mistral
- Mixtral
- LLaVA
## Benchmark And Performance
## Roadmap
- [ ] Function call
- [ ] Constrained decoding
- [ ] Quantization
- [ ] S-LoRA
- [ ] More models
## Citation And Acknowledgment
```
@misc{zheng2023efficiently,
title={Efficiently Programming Large Language Models using SGLang},
author={Lianmin Zheng and Liangsheng Yin and Zhiqiang Xie and Jeff Huang and Chuyue Sun and Cody Hao Yu and Shiyi Cao and Christos Kozyrakis and Ion Stoica and Joseph E. Gonzalez and Clark Barrett and Ying Sheng},
year={2023},
eprint={2312.07104},
archivePrefix={arXiv},
primaryClass={cs.AI}
}
```
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [LMQL](https://github.com/eth-sri/lmql).

45
benchmark/dspy/README.md Normal file
View File

@@ -0,0 +1,45 @@
## Install
```
pip3 install dspy-ai
```
Turn off cache at https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/dsp/modules/cache_utils.py#L10.
```
cache_turn_on = False
```
## Benchmark SGLang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_dspy_intro.py --backend sglang
```
## Benchmark TGI
```
docker run --name tgi --rm -ti --gpus all --network host \
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
ghcr.io/huggingface/text-generation-inference:1.1.0 \
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
--max-input-length 2048 --max-total-tokens 4096 \
--port 24000
```
```
python3 bench_dspy_intro.py --backend tgi
```
## Benchmark vLLM
```
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_dspy_intro.py --backend vllm
```

View File

@@ -0,0 +1,165 @@
"""
Adapted from
https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9
"""
import argparse
import dspy
from dspy.datasets import HotPotQA
class BasicQA(dspy.Signature):
"""Answer questions with short factoid answers."""
question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")
class GenerateAnswer(dspy.Signature):
"""Answer questions with short factoid answers."""
context = dspy.InputField(desc="may contain relevant facts")
question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")
class RAG(dspy.Module):
def __init__(self, num_passages=3):
super().__init__()
self.retrieve = dspy.Retrieve(k=num_passages)
self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
def forward(self, question):
context = self.retrieve(question).passages
prediction = self.generate_answer(context=context, question=question)
return dspy.Prediction(context=context, answer=prediction.answer)
def main(args):
#lm = dspy.OpenAI(model='gpt-3.5-turbo')
if args.backend == "tgi":
lm = dspy.HFClientTGI(model="meta-llama/Llama-2-7b-chat-hf", port=args.port,
url="http://localhost")
elif args.backend == "sglang":
lm = dspy.HFClientSGLang(model="meta-llama/Llama-2-7b-chat-hf", port=args.port,
url="http://localhost")
elif args.backend == "vllm":
lm = dspy.HFClientVLLM(model="meta-llama/Llama-2-7b-chat-hf", port=args.port,
url="http://localhost")
else:
raise ValueError(f"Invalid backend: {args.backend}")
colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)
# Load the dataset.
dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size,
test_size=0)
# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
trainset = [x.with_inputs('question') for x in dataset.train]
devset = [x.with_inputs('question') for x in dataset.dev]
print(len(trainset), len(devset))
train_example = trainset[0]
print(f"Question: {train_example.question}")
print(f"Answer: {train_example.answer}")
dev_example = devset[18]
print(f"Question: {dev_example.question}")
print(f"Answer: {dev_example.answer}")
print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}")
print(f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}")
print(f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}")
# Define the predictor.
generate_answer = dspy.Predict(BasicQA)
# Call the predictor on a particular input.
pred = generate_answer(question=dev_example.question)
# Print the input and the prediction.
print(f"Question: {dev_example.question}")
print(f"Predicted Answer: {pred.answer}")
lm.inspect_history(n=1)
# Define the predictor. Notice we're just changing the class. The signature BasicQA is unchanged.
generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)
# Call the predictor on the same input.
pred = generate_answer_with_chain_of_thought(question=dev_example.question)
# Print the input, the chain of thought, and the prediction.
print(f"Question: {dev_example.question}")
print(f"Thought: {pred.rationale.split('.', 1)[1].strip()}")
print(f"Predicted Answer: {pred.answer}")
retrieve = dspy.Retrieve(k=3)
topK_passages = retrieve(dev_example.question).passages
print(f"Top {retrieve.k} passages for question: {dev_example.question} \n", '-' * 30, '\n')
for idx, passage in enumerate(topK_passages):
print(f'{idx+1}]', passage, '\n')
retrieve("When was the first FIFA World Cup held?").passages[0]
from dspy.teleprompt import BootstrapFewShot
# Validation logic: check that the predicted answer is correct.
# Also check that the retrieved context does actually contain that answer.
def validate_context_and_answer(example, pred, trace=None):
answer_EM = dspy.evaluate.answer_exact_match(example, pred)
answer_PM = dspy.evaluate.answer_passage_match(example, pred)
return answer_EM and answer_PM
# Set up a basic teleprompter, which will compile our RAG program.
teleprompter = BootstrapFewShot(metric=validate_context_and_answer)
# Compile!
compiled_rag = teleprompter.compile(RAG(), trainset=trainset)
# Ask any question you like to this simple RAG program.
my_question = "What castle did David Gregory inherit?"
# Get the prediction. This contains `pred.context` and `pred.answer`.
pred = compiled_rag(my_question)
# Print the contexts and the answer.
print(f"Question: {my_question}")
print(f"Predicted Answer: {pred.answer}")
print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")
from dspy.evaluate.evaluate import Evaluate
# Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=args.num_threads, display_progress=True, display_table=5)
# Evaluate the `compiled_rag` program with the `answer_exact_match` metric.
metric = dspy.evaluate.answer_exact_match
evaluate_on_hotpotqa(compiled_rag, metric=metric)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int)
parser.add_argument("--num-threads", type=int, default=32)
parser.add_argument("--dev-size", type=int, default=150)
parser.add_argument("--backend", type=str, choices=["sglang", "tgi", "vllm"],
default="sglang")
args = parser.parse_args()
if args.port is None:
default_port = {
"vllm": 21000,
"lightllm": 22000,
"tgi": 24000,
"sglang": 30000,
}
args.port = default_port.get(args.backend, None)
main(args)

View File

@@ -0,0 +1,26 @@
## Run benchmark
Ensure that this benchmark is run in a serial manner (using --parallel 1) to preserve any potential dependencies between requests.
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-events 1000 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-events 1000 --backend vllm --parallel 1
```
### Benchmark guidance
```
python3 bench_other.py --num-events 1000 --backend guidance --parallel 1
```

View File

@@ -0,0 +1,231 @@
import sglang as sgl
# here are the top five agent functions contributing ~70% LLM calls
# reference: https://github.com/joonspk-research/generative_agents/
@sgl.function
def poignancy_event(s, persona_name, persona_iss, event):
s += "Here is a brief description of " + persona_name + ".\n"
s += persona_iss + "\n"
s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
s += persona_name + ".\n\n"
s += "Event: " + event
s += "Rate (return a number between 1 to 10):"
s += sgl.gen(name="Rate", max_tokens=2)
def poignancy_event_prompt(persona_name, persona_iss, event):
# return prompt and max_tokens
s = ""
s += "Here is a brief description of " + persona_name + ".\n"
s += persona_iss + "\n"
s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
s += persona_name + ".\n\n"
s += "Event: " + event
s += "Rate (return a number between 1 to 10):"
return {"prompt": s, "max_tokens": 2, "stop": None}
@sgl.function
def generate_event_triple(s, persona_name, action):
s += """Task: Turn the input into (subject, predicate, object).
Input: Sam Johnson is eating breakfast.
Output: (Dolores Murphy, eat, breakfast)
---
Input: Joon Park is brewing coffee.
Output: (Joon Park, brew, coffee)
---
Input: Jane Cook is sleeping.
Output: (Jane Cook, is, sleep)
---
Input: Michael Bernstein is writing email on a computer.
Output: (Michael Bernstein, write, email)
---
Input: Percy Liang is teaching students in a classroom.
Output: (Percy Liang, teach, students)
---
Input: Merrie Morris is running on a treadmill.
Output: (Merrie Morris, run, treadmill)
---"""
s += persona_name + "is" + action + ".\n"
s += "(" + persona_name + ","
s += sgl.gen(name="Triple", max_tokens=20, stop=")")
def generate_event_triple_prompt(persona_name, action):
s = ""
s += """Task: Turn the input into (subject, predicate, object).
Input: Sam Johnson is eating breakfast.
Output: (Dolores Murphy, eat, breakfast)
---
Input: Joon Park is brewing coffee.
Output: (Joon Park, brew, coffee)
---
Input: Jane Cook is sleeping.
Output: (Jane Cook, is, sleep)
---
Input: Michael Bernstein is writing email on a computer.
Output: (Michael Bernstein, write, email)
---
Input: Percy Liang is teaching students in a classroom.
Output: (Percy Liang, teach, students)
---
Input: Merrie Morris is running on a treadmill.
Output: (Merrie Morris, run, treadmill)
---"""
s += persona_name + "is" + action + ".\n"
s += "(" + persona_name + ","
return {"prompt": s, "max_tokens": 20, "stop": ")"}
@sgl.function
def generate_pronunciatio(s, action):
s += "Convert an action description to an emoji (important: use two or less emojis).\n"
s += "Action description: " + action + ".\n"
s += "Emoji:" + sgl.gen(name="Emoji", max_tokens=6)
def generate_pronunciatio_prompt(action):
s = ""
s += "Convert an action description to an emoji (important: use two or less emojis).\n"
s += "Action description: " + action + ".\n"
s += "Emoji:"
return {"prompt": s, "max_tokens": 6, "stop": None}
@sgl.function
def action_location_sector(
s,
persona_name,
living_sector,
living_sector_areas,
current_sector,
current_sector_areas,
daily_plan,
sector_options,
current_action,
next_action,
):
s += """Task -- choose an appropriate area from the area options for a task at hand.
Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.
For taking a walk, Sam Kim should go to the following area: {Johnson Park}
---
Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
Jane Anderson is currently in {Oak Hill College} that has a classroom, library
Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
---"""
s += (persona_name + " lives in " + living_sector + " that has " +
living_sector_areas + ".\n")
s += (persona_name + " is currently in " + current_sector + " that has " +
current_sector_areas + ".\n")
s += daily_plan + ".\n"
s += "Area options: " + sector_options + ".\n"
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.\n"""
s += (persona_name + " is " + current_action + ". For " + next_action +
", " + persona_name + " should go to the following area: {")
s += sgl.gen(name="Location", max_tokens=10, stop="}")
def action_location_sector_prompt(
persona_name,
living_sector,
living_sector_areas,
current_sector,
current_sector_areas,
daily_plan,
sector_options,
current_action,
next_action,
):
s = ""
s += """Task -- choose an appropriate area from the area options for a task at hand.
Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.
For taking a walk, Sam Kim should go to the following area: {Johnson Park}
---
Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
Jane Anderson is currently in {Oak Hill College} that has a classroom, library
Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
---"""
s += (persona_name + " lives in " + living_sector + " that has " +
living_sector_areas + ".\n")
s += (persona_name + " is currently in " + current_sector + " that has " +
current_sector_areas + ".\n")
s += daily_plan + ".\n"
s += "Area options: " + sector_options + ".\n"
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.\n"""
s += (persona_name + " is " + current_action + ". For " + next_action +
", " + persona_name + " should go to the following area: {")
return {"prompt": s, "max_tokens": 10, "stop": "}"}
@sgl.function
def action_location_object(s, persona_name, target_sector, target_sector_areas,
current_action, next_action):
s += """
Jane Anderson is in kitchen in Jane Anderson's house.
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
Answer: {kitchen}
---
Tom Watson is in common room in Tom Watson's apartment.
Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
Answer: {cafe}
---"""
s += (persona_name + " is going to " + target_sector +
" that has the following areas: {" + target_sector_areas + "}\n")
s += """* Stay in the current area if the activity can be done there.
* NEVER go into other people's rooms unless necessary."""
s += (persona_name + " is " + current_action + ". For " + next_action +
", " + persona_name + "should go to the following area in " +
target_sector)
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
s += "Answer: {" + sgl.gen(name="Area", max_tokens=5, stop="}")
def action_location_object_prompt(persona_name, target_sector,
target_sector_areas, current_action,
next_action):
s = ""
s += """
Jane Anderson is in kitchen in Jane Anderson's house.
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
Answer: {kitchen}
---
Tom Watson is in common room in Tom Watson's apartment.
Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
Answer: {cafe}
---"""
s += (persona_name + " is going to " + target_sector +
" that has the following areas: {" + target_sector_areas + "}\n")
s += """* Stay in the current area if the activity can be done there.
* NEVER go into other people's rooms unless necessary."""
s += (persona_name + " is " + current_action + ". For " + next_action +
", " + persona_name + "should go to the following area in " +
target_sector)
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
s += "Answer: {"
return {"prompt": s, "max_tokens": 5, "stop": "}"}

View File

@@ -0,0 +1,104 @@
import argparse
from functools import partial
import json
import time
from pathlib import Path
from tqdm import tqdm
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_vllm,
call_generate_srt_raw,
)
from sglang.utils import read_jsonl, dump_state_text
from agent_functions import (
poignancy_event_prompt,
generate_event_triple_prompt,
generate_pronunciatio_prompt,
action_location_sector_prompt,
action_location_object_prompt,
)
def main(args):
lines = read_jsonl(args.data_path)[:args.num_events]
mapping = {
"poignancy_event": poignancy_event_prompt,
"generate_event_triple": generate_event_triple_prompt,
"generate_pronunciatio": generate_pronunciatio_prompt,
"action_location_sector": action_location_sector_prompt,
"action_location_object": action_location_object_prompt,
}
arguments = [mapping[k](**v) for l in lines for k, v in l.items()]
states = []
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import models, gen
model = models.LlamaCpp(
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop):
out = model + prompt + gen(
name="result",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
return out["result"]
else:
raise ValueError(f"Invalid backend: {args.backend}")
def get_one_answer(arg):
answer = call_generate(**arg, temperature=0)
states.append(answer)
tic = time.time()
# we always sequentially execute agent calls to maintain its dependency
for arg in tqdm(arguments):
get_one_answer(arg)
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "Generative Agents",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
# to pack weighted functions as a single agent
"num_requests": len(arguments) / len(mapping),
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
parser.add_argument("--num-events", type=int, default=10)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,74 @@
import argparse
import json
import time
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import read_jsonl, dump_state_text
from agent_functions import (
poignancy_event,
generate_event_triple,
generate_pronunciatio,
action_location_sector,
action_location_object,
)
def main(args):
lines = read_jsonl(args.data_path)[:args.num_events]
mapping = {
"poignancy_event": poignancy_event,
"generate_event_triple": generate_event_triple,
"generate_pronunciatio": generate_pronunciatio,
"action_location_sector": action_location_sector,
"action_location_object": action_location_object,
}
arguments = [{mapping[k]: v for k, v in l.items()} for l in lines]
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
states = []
# Run requests
tic = time.time()
for a in arguments:
# only a single key in the dict
for func, arg in a.items():
result = func.run(**arg)
result.sync()
states.append(result)
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "Generative Agents",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
# to pack weighted functions as a single agent
"num_requests": len(arguments) / len(mapping),
"other": {
"num_events": args.num_events,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
parser.add_argument("--num-events", type=int, default=10)
args = add_common_sglang_args_and_parse(parser)
main(args)

52
benchmark/gsm8k/README.md Normal file
View File

@@ -0,0 +1,52 @@
## Download data
```
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
```
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 200
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-questions 200 --backend vllm
```
### Benchmark lightllm
```
# A10G
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
```
```
python3 bench_other.py --num-questions 200 --backend lightllm
```
### Benchmark guidance
```
python3 bench_other.py --num-questions 200 --backend guidance --parallel 1
```
### Benchmark lmql
```
CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
```
```
python3 bench_other.py --num-questions 100 --backend lmql --parallel 2
```

View File

@@ -0,0 +1,168 @@
import argparse
import ast
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json
import re
import time
import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
INVALID = -9999999
def get_one_example(lines, i, include_answer):
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
questions = []
labels = []
for i in range(len(lines[:args.num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
states = [None] * len(labels)
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import models, gen
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
def call_generate(prompt, temperature, max_tokens, stop):
out = model + prompt + gen(name="answer",
max_tokens=max_tokens, temperature=temperature, stop=stop)
return out["answer"]
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path,
endpoint=f"{args.host}:{args.port}")
@lmql.query(model=model)
async def program(question):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 257 and STOPS_AT(ANSWER, "Question")
return ANSWER
'''
async def call_generate(prompt, temperature, max_tokens, stop):
return await program(question=prompt, temperature=0)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests
if args.backend != "lmql":
# Use thread pool
def get_one_answer(i):
answer = call_generate(
prompt=few_shot_examples + questions[i],
temperature=0,
max_tokens=256,
stop="Question")
states[i] = answer
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(questions))))
else:
# Use asyncio
async def batched_call(batch_size):
for i in range(0, len(questions), batch_size):
tasks = []
for q in questions[i:i+batch_size]:
tasks.append(call_generate(few_shot_examples + q,
temperature=0, max_tokens=256, stop="Question"))
rets = await asyncio.gather(*tasks)
for j in range(len(rets)):
states[i+j] = rets[j]
tic = time.time()
asyncio.run(batched_call(batch_size=args.parallel))
latency = time.time() - tic
preds = []
for i in range(len(states)):
preds.append(get_answer_value(states[i]))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,115 @@
import argparse
import ast
import json
import re
import time
import numpy as np
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
INVALID = -9999999
def get_one_example(lines, i, include_answer):
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
questions = []
labels = []
for i in range(len(lines[:args.num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q} for q in questions]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_gsm8k(s, question):
s += few_shot_examples + question
s += sgl.gen("answer", max_tokens=256, stop="Question")
#####################################
########## SGL Program End ##########
#####################################
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.time()
states = few_shot_gsm8k.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel)
latency = time.time() - tic
preds = []
for i in range(len(states)):
preds.append(get_answer_value(states[i]["answer"]))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,52 @@
## Download data
```
wget https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl
```
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 200
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-questions 200 --backend vllm
```
### Benchmark lightllm
```
# A10G
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
```
```
python3 bench_other.py --num-questions 200 --backend lightllm
```
### Benchmark guidance
```
CUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1
```
### Benchmark lmql
```
lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
```
```
python3 bench_other.py --num-questions 200 --backend lmql --port 23000 --parallel 1
```

View File

@@ -0,0 +1,140 @@
import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
import json
from functools import partial
import time
import numpy as np
from sglang.test.test_utils import add_common_other_args_and_parse, call_select_lightllm, call_select_vllm
from sglang.utils import read_jsonl
def get_one_example(lines, i, include_answer):
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
if include_answer:
ret += lines[i]["endings"][lines[i]["label"]]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
questions = []
choices = []
labels = []
for i in range(len(lines[:args.num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
preds = [None] * len(labels)
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_select = partial(call_select_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_select = partial(call_select_vllm, url=url)
elif args.backend == "guidance":
from guidance import models, select
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
def call_select(context, choices):
out = model + context + select(choices, name="answer")
return choices.index(out["answer"])
elif args.backend == "lmql":
import lmql
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
endpoint=f"{args.host}:{args.port}")
@lmql.query(model=model)
async def program(ctx, choices):
'''lmql
"""{ctx}[ANSWER]""" where ANSWER in set(choices)
return ANSWER
'''
async def call_select(context, choices):
answer = await program(ctx=context, choices=choices, temperature=0)
return choices.index(answer)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests
if args.backend != "lmql":
# Use thread pool
def get_one_answer(i):
preds[i] = call_select(
context=few_shot_examples + questions[i],
choices=choices[i])
tic = time.time()
if args.parallel == 1:
for i in range(len(questions)):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(questions))))
else:
# Use asyncio
async def batched_call(batch_size):
for i in range(0, len(questions), batch_size):
tasks = []
for q, c in zip(questions[i:i+batch_size], choices[i:i+batch_size]):
tasks.append(call_select(
context=few_shot_examples + q,
choices=c))
rets = await asyncio.gather(*tasks)
for j in range(len(rets)):
preds[i+j] = rets[j]
tic = time.time()
asyncio.run(batched_call(batch_size=args.parallel))
latency = time.time() - tic
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
print(f"Latency: {latency:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
with open(args.result_file, "a") as fout:
value = {
"task": "hellaswag",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=20)
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
parser.add_argument("--num-questions", type=int, default=100)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,96 @@
import argparse
import json
import time
import numpy as np
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl
def get_one_example(lines, i, include_answer):
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
if include_answer:
ret += lines[i]["endings"][lines[i]["label"]]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
questions = []
choices = []
labels = []
for i in range(len(lines[:args.num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
arguments = [
{"question": q, "choices": c}
for q, c in zip(questions, choices)
]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_hellaswag(s, question, choices):
s += few_shot_examples + question
s += sgl.select("answer", choices=choices)
#####################################
########## SGL Program End ##########
#####################################
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.time()
rets = few_shot_hellaswag.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel)
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
latency = time.time() - tic
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
print(f"Latency: {latency:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
with open(args.result_file, "a") as fout:
value = {
"task": "hellaswag",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=20)
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
parser.add_argument("--num-questions", type=int, default=100)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,46 @@
### Download data
```
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```
### Performance
- Model: Llama-2-7b-chat-hf
- `--num-prompts 2000 --request-rate 200`
- On 4 A10 (24G) GPUs
| Backend | Throughput | Latency |
| ----------- | --------------- | -------- |
| srt | 5.82 requests/s | 343.54 s |
| vllm==0.2.6 | 3.93 requests/s | 509.08 s |
| vllm==0.2.7 | 5.02 requests/s | 398.25 s |
### SGLang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 --port 30000
```
### vLLM
```
python3 -m vllm.entrypoints.api_server --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --swap-space 16
```
```
python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10
```
### LightLLM
```
python -m lightllm.server.api_server --model_dir ~/model_weights/Llama-2-7b-chat-hf --max_total_token_num 15600 --tokenizer_mode auto --port 22000
```
```
python3 bench_throughput.py --backend lightllm --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 --port 22000
```

View File

@@ -0,0 +1,254 @@
"""Benchmark online serving throughput.
On the server side, run one of the following commands:
(vLLM backend)
python -m vllm.entrypoints.api_server \
--model <your_model> --swap-space 16 \
--disable-log-requests
(TGI backend)
./launch_hf_server.sh <your_model>
On the client side, run:
python benchmarks/benchmark_serving.py \
--backend <backend> \
--tokenizer <your_model> --dataset <target_dataset> \
--request-rate <request_rate>
"""
import argparse
import asyncio
import json
import random
import time
from typing import AsyncGenerator, List, Tuple
from tqdm.asyncio import tqdm_asyncio
import aiohttp
import numpy as np
from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer
# (prompt len, output len, latency)
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [
data for data in dataset
if len(data["conversations"]) >= 2
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
prompt_len = len(prompt_token_ids)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
async def get_request(
input_requests: List[Tuple[str, int, int]],
request_rate: float,
) -> AsyncGenerator[Tuple[str, int, int], None]:
input_requests = iter(input_requests)
for request in input_requests:
yield request
if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue
# Sample the request interval from the exponential distribution.
interval = np.random.exponential(1.0 / request_rate)
# The next request will be sent after the interval.
await asyncio.sleep(interval)
async def send_request(
backend: str,
api_url: str,
prompt: str,
prompt_len: int,
output_len: int,
best_of: int,
use_beam_search: bool,
) -> None:
request_start_time = time.perf_counter()
headers = {"User-Agent": "Benchmark Client"}
if backend == "vllm":
pload = {
"prompt": prompt,
"n": 1,
"best_of": best_of,
"use_beam_search": use_beam_search,
"temperature": 0.0 if use_beam_search else 1.0,
"top_p": 1.0,
"max_tokens": output_len,
"ignore_eos": True,
"stream": False,
}
elif backend == "tgi":
assert not use_beam_search
params = {
"best_of": best_of,
"max_new_tokens": output_len,
"do_sample": True,
}
pload = {
"inputs": prompt,
"parameters": params,
}
elif backend == "srt":
assert not use_beam_search
params = {
"ignore_eos": True,
"max_new_tokens": output_len,
}
pload = {
"text": prompt,
"sampling_params": params,
}
elif backend == "lightllm":
assert not use_beam_search
params = {
"ignore_eos": True,
"max_new_tokens": output_len,
}
pload = {
"inputs": prompt,
"parameters": params,
}
else:
raise ValueError(f"Unknown backend: {backend}")
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session:
while True:
async with session.post(api_url, headers=headers, json=pload) as response:
chunks = []
async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk)
output = b"".join(chunks).decode("utf-8")
output = json.loads(output)
# Re-send the request if it failed.
if "error" not in output:
break
request_end_time = time.perf_counter()
request_latency = request_end_time - request_start_time
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
async def benchmark(
backend: str,
api_url: str,
input_requests: List[Tuple[str, int, int]],
best_of: int,
use_beam_search: bool,
request_rate: float,
) -> None:
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
task = asyncio.create_task(send_request(backend, api_url, prompt,
prompt_len, output_len,
best_of, use_beam_search))
tasks.append(task)
await tqdm_asyncio.gather(*tasks)
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
api_url = f"http://{args.host}:{args.port}/generate"
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
benchmark_start_time = time.perf_counter()
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
args.use_beam_search, args.request_rate))
benchmark_end_time = time.perf_counter()
benchmark_time = benchmark_end_time - benchmark_start_time
print(f"Total time: {benchmark_time:.2f} s")
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
# Compute the latency statistics.
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
print(f"Average latency: {avg_latency:.2f} s")
avg_per_token_latency = np.mean([
latency / (prompt_len + output_len)
for prompt_len, output_len, latency in REQUEST_LATENCY
])
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
avg_per_output_token_latency = np.mean([
latency / output_len
for _, output_len, latency in REQUEST_LATENCY
])
print("Average latency per output token: "
f"{avg_per_output_token_latency:.2f} s")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark the online serving throughput.")
parser.add_argument("--backend", type=str, default="vllm",
choices=["vllm", "tgi", "srt", "lightllm"])
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--dataset", type=str, required=True,
help="Path to the dataset.")
parser.add_argument("--tokenizer", type=str, required=True,
help="Name or path of the tokenizer.")
parser.add_argument("--best-of", type=int, default=1,
help="Generates `best_of` sequences per prompt and "
"returns the best one.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000,
help="Number of prompts to process.")
parser.add_argument("--request-rate", type=float, default=float("inf"),
help="Number of requests per second. If this is inf, "
"then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize "
"the request arrival times.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code', action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,66 @@
import argparse
import random
import time
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--backend", type=str, default="srt")
args = parser.parse_args()
if args.port is None:
if args.backend == "srt":
args.port = 30000
elif args.backend == "vllm":
args.port = 21000
elif args.backend == "lightllm":
args.port = 22000
else:
raise ValueError(f"Invalid backend: {args.backend}")
url = f"{args.host}:{args.port}"
a = random.randint(0, 1 << 20)
max_new_tokens = 256
tic = time.time()
if args.backend == "srt":
response = requests.post(
url + "/generate",
json={
"text": f"{a}, ",
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
},
)
elif args.backend == "lightllm":
response = requests.post(
url + "/generate",
json={
"inputs": f"{a}, ",
"parameters": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
},
)
elif args.backend == "vllm":
response = requests.post(
url + "/generate",
json={
"prompt": f"{a}, ",
"temperature": 0,
"max_tokens": max_new_tokens,
},
)
latency = time.time() - tic
ret = response.json()
print(ret)
speed = max_new_tokens / latency
print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s")

View File

@@ -0,0 +1,37 @@
## Download data
```
wget https://raw.githubusercontent.com/merrymercy/merrymercy.github.io/master/files/random_words.json
python3 gen_data.py --number 1000
```
## Run benchmark
### Benchmark sglang
```
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-hf --port 30000
```
```
python3 bench_sglang.py --src-index 600 --num-q 50 --parallel 1
```
###
```
# original
Accuracy: 0.940, latency: 332.83 s
# parallel encoding (no_adjust, offset = 1000)
Accuracy: 0.760, latency: 238.46 s
# parallel encoding (no_adjust, offset = 3000)
Accuracy: 0.760, latency: 238.46 s
# parallel encoding (no_adjust, offset = 0)
Accuracy: 0.520, latency: 238.46 s
# parallel encoding (adjust_cache)
Accuracy: 0.460, latency: 257.66 s
```

View File

@@ -0,0 +1,133 @@
import argparse
import json
import time
import re
import numpy as np
import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import dump_state_text
@sgl.function
def line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3):
s += prefix + "\n"
contexts = [body_0, body_1, body_2, body_3]
position_ids_offset = [i * 1000 for i in range(len(contexts))]
forks = s.fork(len(contexts), position_ids_offset)
forks += lambda i: contexts[i] + "\n"
forks.join(mode="concate_and_append")
s += "\n" + suffix
s += sgl.gen("answer", max_tokens=16)
def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
arguments = []
labels = []
sum_src_indices = []
sum_dst_indices = []
for i in range(len(src_indices)):
for j in range(len(dst_percents)):
src_index = src_indices[i]
dst_percent = dst_percents[j]
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
query_indices = [q for q in query_indices if
all(l <= src_index for l in line_obj["links"][q]) and q < src_index]
dst_index = query_indices[min(int(len(query_indices) * dst_percent), len(query_indices)-1)]
label = line_obj["values"][dst_index]
body = line_obj["lines"][:src_index+1]
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
body_part_len = len(body) // 4
arguments.append({
"prefix": line_obj["prefix"],
"body_0": "\n".join(body[:body_part_len]),
"body_1": "\n".join(body[body_part_len: 2 * body_part_len]),
"body_2": "\n".join(body[2 * body_part_len: 3 * body_part_len]),
"body_3": "\n".join(body[3 * body_part_len:]),
"suffix": suffix,
})
labels.append(label)
sum_src_indices.append(src_index)
sum_dst_indices.append(dst_index)
# Select backend
backend = select_sglang_backend(args)
tic = time.time()
states = line_retrieval.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel)
latency = time.time() - tic
corrects = []
for i in range(len(arguments)):
output = states[i]["answer"]
prompt_len = states[i].get_meta_info("answer").get("prompt_length", -1)
label = labels[i]
# Try all numbers
findall = re.findall("\d+", output)
if not findall:
response_number = output
else:
for response_number in findall:
if response_number == label:
break
correct = (response_number == label)
corrects.append(correct)
# Log results
summary = (
f"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, "
f"Prompt len: {prompt_len}, "
f"Correct: {correct}, "
f"Label: {label}, Predicted: {response_number}, "
)
print(summary)
accuracy = np.mean(corrects)
print(f"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "line_retrieval",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(arguments),
"other": {
"num_questions": len(arguments),
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
def main(args):
line_obj = json.load(open(args.data_path, "r"))
num_hoops = args.num_hoops
for src_index in args.src_index:
src_indices = [src_index]
num_queries = args.num_queries_per_src
dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)]
eval_model(args, line_obj, num_hoops, src_indices, dst_percents)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="lines_1000_0.0.json")
parser.add_argument("--src-index", type=int, nargs="+", default=[100])
parser.add_argument("--num-queries-per-src", type=int, default=10)
parser.add_argument("--num-hoops", type=int, default=1)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,135 @@
"""
Generate line data for line retrieval task.
Usage:
python3 gen_data.py --number 1000
"""
import argparse
from collections import defaultdict
import json
from tqdm import tqdm
import numpy as np
def generate_lines(random_words, num_lines, redirect_ratio):
prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask."
suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resovling the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is"
# Raw lines
visited_indices = set([None])
visited_values = set([None])
lines = []
redirects = []
indices = []
values = []
for i in tqdm(range(num_lines)):
line_index = None
while line_index in visited_indices:
line_index = "-".join(np.random.choice(random_words, size=(2,)))
visited_indices.add(line_index)
line_value = np.random.randint(low=0, high=999999)
line_value = f"{line_value:06}"
line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}."
lines.append(line)
redirects.append(None)
indices.append(line_index)
values.append(line_value)
# Add redirect
if redirect_ratio > 0:
num_redirect_lines = int(len(lines) * redirect_ratio)
redirect_indices = np.random.choice(np.arange(len(lines)),
size=(num_redirect_lines,), replace=False)
for i in redirect_indices:
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
lines[i] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
redirects[i] = target_idx
# Build links and find sources
links = [[] for _ in range(num_lines)]
contains_ring = set()
for i in range(num_lines):
if redirects[i] is None:
continue
tmp_link = []
cur = i
visited = set()
while redirects[cur] is not None:
visited.add(cur)
tmp_link.append(redirects[cur])
cur = redirects[cur]
if cur in visited:
contains_ring.add(i)
tmp_link = None
break
values[i] = values[cur]
links[i] = tmp_link
# Group by num_links
group_by_num_hoops = defaultdict(list)
for i in range(num_lines):
if i in contains_ring:
continue
group_by_num_hoops[len(links[i]) + 1].append(i)
keys = sorted(list(group_by_num_hoops.keys()))
for num_links in keys:
print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}")
# Append few-shot examples
hoop1_candidates = list(group_by_num_hoops[1])
hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates}
hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c])
hoop2_candidates = list(group_by_num_hoops[2])
hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates}
hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c])
i = hoop1_candidates[5]
suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i])
if len(hoop2_candidates):
i = hoop2_candidates[0]
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
i = hoop2_candidates[1]
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
else:
i = hoop1_candidates[1]
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
i = hoop1_candidates[10]
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
obj = {
"prefix": prefix,
"suffix": suffix,
"lines": lines,
"indices": indices,
"values": values,
"links": links,
"group_by_num_hoops": group_by_num_hoops,
"contains_ring": sorted(list(contains_ring)),
}
return obj
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--number", type=int)
parser.add_argument("--redirect-ratio", type=float, default=0.0)
args = parser.parse_args()
num_lines = args.number
random_words_filename = "random_words.json"
random_words = json.load(open(random_words_filename, "r"))
np.random.seed(42)
obj = generate_lines(random_words, num_lines, args.redirect_ratio)
fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json"
with open(fout, "w") as fout:
json.dump(obj, fout, indent=2)

View File

@@ -0,0 +1,60 @@
## Download benchmark images
```
python3 download_images.py
```
image benchmark source: https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild
### Other Dependency
```
pip3 install "torch>=2.1.2" "transformers>=4.36" pillow
```
## Run benchmark
### Benchmark sglang
Launch a server
```
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
```
Run benchmark
```
# Run with local models
python3 bench_sglang.py --num-questions 60
# Run with OpenAI models
python3 bench_sglang.py --num-questions 60 --backend gpt-4-vision-preview
```
### Bench LLaVA original code
```
git clone git@github.com:haotian-liu/LLaVA.git
cd LLaVA
git reset --hard 9a26bd1435b4ac42c282757f2c16d34226575e96
pip3 install -e .
cd ~/sglang/benchmark/llava_bench
CUDA_VISIBLE_DEVICES=0 bash bench_hf_llava_bench.sh
```
### Benchmark llama.cpp
```
# Install
CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python
pip install sse_starlette starlette_context pydantic_settings
# Download weights
mkdir -p ~/model_weights/llava-v1.5-7b/
wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/ggml-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf
wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/mmproj-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf
```
```
python3 -m llama_cpp.server --model ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf --clip_model_path ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf --chat_format llava-1-5 --port 23000
OPENAI_BASE_URL=http://localhost:23000/v1 python3 bench_sglang.py --backend gpt-4-vision-preview --num-q 1
```

View File

@@ -0,0 +1,9 @@
#!/bin/bash
python -m llava.eval.model_vqa \
--model-path liuhaotian/llava-v1.5-7b \
--question-file ./questions.jsonl \
--image-folder ./images \
--answers-file ./answers_hf.jsonl \
--temperature 0 \
--conv-mode vicuna_v1

View File

@@ -0,0 +1,9 @@
#!/bin/bash
python -m llava.eval.model_vqa_loader \
--model-path liuhaotian/llava-v1.5-7b \
--question-file ./mme_pack/llava_mme_bench_replace.jsonl \
--image-folder ./mme_pack/MME_Benchmark_release_version \
--answers-file ./answers_hf_mme.jsonl \
--temperature 0 \
--conv-mode vicuna_v1

View File

@@ -0,0 +1,96 @@
import argparse
import json
import time
import os
import sglang as sgl
import tqdm
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
from PIL import Image
@sgl.function
def image_qa(s, image_file, question):
s += sgl.user(sgl.image(image_file) + question)
s += sgl.assistant(sgl.gen("answer", max_tokens=args.max_tokens))
def main(args):
lines = read_jsonl(args.question_file)[:args.num_questions]
arguments = [
{"image_file":
os.path.abspath(args.image_folder + "/" + l["image"]),
"question": l["text"]} for l in lines
]
#arguments = [
# {"image_file":
# Image.open(os.path.abspath(args.image_folder + "/" + l["image"])),
# "question": l["text"]} for l in lines
#]
states = [None] * len(lines)
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Run requests
tic = time.time()
if args.parallel == 1:
for i in tqdm.tqdm(range(len(lines))):
image_file = arguments[i]["image_file"]
question = arguments[i]["question"]
ret = image_qa.run(
image_file=image_file,
question=question,
temperature=0)
states[i] = ret
else:
states = image_qa.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True)
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
print(f"Write output to {args.answer_file}")
with open(args.answer_file, "w") as fout:
for i in range(len(lines)):
value = {
"question_id": lines[i]["question_id"],
"prompt": lines[i]["text"],
"text": states[i]["answer"].strip(),
"model_id": backend.model_info["model_path"],
"answer_id": i,
"metadata": {},
}
fout.write(json.dumps(value) + "\n")
with open(args.result_file, "a") as fout:
value = {
"task": "llava_bench",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(lines),
"parallel": args.parallel,
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--question-file", type=str, default="questions.jsonl")
parser.add_argument("--answer-file", type=str, default="answers.jsonl")
parser.add_argument("--image-folder", type=str, default="./images")
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--num-questions", type=int, default=None)
parser.add_argument("--max-tokens", type=int, default=768)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,2 @@
MME_FOLDER=./mme_pack
python3 bench_sglang.py --num-questions 5000 --question-file $MME_FOLDER/llava_mme_bench_replace.jsonl --answer-file answer_mme.jsonl --image-folder $MME_FOLDER/MME_Benchmark_release_version --max-tokens 4

View File

@@ -0,0 +1,20 @@
import os
# Create the 'images' directory if it doesn't exist
if not os.path.exists('images'):
os.makedirs('images')
# Base URL
base_url = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/"
# Loop through image numbers
for i in range(1, 25):
# Format the image number with leading zeros
image_number = str(i).zfill(3)
image_url = base_url + image_number + ".jpg"
image_path = "images/" + image_number + ".jpg"
# Download the image using wget
os.system(f"wget -O {image_path} {image_url}")
print("Download complete.")

View File

@@ -0,0 +1,27 @@
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 25 --parallel 8
python3 bench_sglang.py --num-questions 16 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --backend vllm --num-questions 25
```
### Benchmark guidance
```
python3 bench_other.py --backend guidance --num-questions 25 --parallel 1
```

View File

@@ -0,0 +1,120 @@
import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json
import time
import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
system_prompt = (
"Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
)
dimension_prompts = [
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
]
def multi_dimension_judge(article, generate):
s = system_prompt
s += "\n```\n" + article + "\n```\n\n"
judges = []
for i in range(len(dimension_prompts)):
comp = generate(s +
"USER: Please judge the quality based on the following metric. " +
dimension_prompts[i] + " Please provide a single-paragraph judgement. " +
"Focus on the provided metric and do not say other things. "
'End your judgement paragraph with the word "END"\nJUDGE:',
max_tokens=256, stop="END")
judges.append(comp)
s += "I will judge the quality based on the following metrics.\n"
for i in range(len(dimension_prompts)):
s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n"
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
s += generate(s, max_tokens=2, stop=None)
return s
def main(args):
lines = read_jsonl(args.data_path)[:args.num_questions]
states = [None] * len(lines)
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_lightllm, url=url, temperature=0)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_vllm, url=url, temperature=0)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import models, gen
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
def generate(prompt, max_tokens, stop):
out = model + prompt + gen(name="answer",
max_tokens=max_tokens, temperature=0, stop=stop)
return out["answer"]
# warmup
generate("Hello!", max_tokens=8, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests
def get_one_answer(i):
states[i] = multi_dimension_judge(lines[i], generate)
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(lines))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(lines))))
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "llm_judge",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="articles.jsonl")
parser.add_argument("--num-questions", type=int, default=20)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,85 @@
import argparse
import json
import time
import numpy as np
import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
system_prompt = (
"Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
)
dimension_prompts = [
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
]
@sgl.function
def multi_dimension_judge(s, article):
s += system_prompt
s += "\n```\n" + article + "\n```\n\n"
forks = s.fork(len(dimension_prompts))
for i in range(len(dimension_prompts)):
forks[i] += ("USER: Please judge the quality based on the following metric. " +
dimension_prompts[i] + " Please provide a single-paragraph judgement. " +
"Focus on the provided metric and do not say other things. "
'End your judgement paragraph with the word "END"\nJUDGE:')
forks[i] += sgl.gen("judgement", max_tokens=256, stop="END")
forks.join()
s += "I will judge the quality based on the following metrics.\n"
for i in range(len(dimension_prompts)):
s += dimension_prompts[i].split(":")[0] + ": " + forks[i]["judgement"].strip() + "\n"
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
s += sgl.gen("score", max_tokens=2)
def main(args):
lines = read_jsonl(args.data_path)[:args.num_questions]
arguments = [{"article": l} for l in lines]
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.time()
states = multi_dimension_judge.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel)
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "llm_judge",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="articles.jsonl")
parser.add_argument("--num-questions", type=int, default=20)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,33 @@
## Run benchmark
### Benchmark sglang
```
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 5 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97
```
```
python3 bench_other.py --backend vllm --num-questions 5
```
### Benchmark guidance
```
python3 bench_other.py --backend guidance --num-questions 5 --parallel 1
```
### Build dataset
```
pip install wikipedia
python3 build_dataset.py
```

View File

@@ -0,0 +1,104 @@
import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json
import time
from tqdm import tqdm
import numpy as np
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
def json_decode(document, generate):
s = "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += '{\n'
s += ' "name": "'
s += generate(s, max_tokens=8, stop='"') + '",\n'
s += ' "country": "'
s += generate(s, max_tokens=8, stop='"') + '",\n'
s += ' "air port code": "'
s += generate(s, max_tokens=8, stop='"') + '",\n'
s += ' "top 3 landmarks": "'
s += generate(s, max_tokens=24, stop='"') + '",\n'
s += '}\n'
return s
def main(args):
lines = read_jsonl(args.data_path)
arguments = []
for i in range(len(lines[:args.num_questions])):
arguments.append({
"document": lines[i]["document"],
})
states = [None] * len(arguments)
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_lightllm, url=url, temperature=0)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_vllm, url=url, temperature=0)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import models, gen
model = models.LlamaCpp("/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf", n_gpu_layers=-1, n_ctx=11000)
def generate(prompt, max_tokens, stop):
out = model + prompt + gen(name="answer",
max_tokens=max_tokens, temperature=0, stop=stop)
return out["answer"]
# warmup
generate("Hello!", max_tokens=8, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests
def get_one_answer(i):
states[i] = json_decode(generate=generate, **arguments[i])
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(arguments))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(arguments))))
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "long_json_decode",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=100)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,68 @@
import argparse
import json
import time
import numpy as np
import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
@sgl.function
def json_decode(s, document):
s += "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += '{\n'
s += ' "name": "' + sgl.gen("name", max_tokens=8, stop='"') + '",\n'
s += ' "country": "' + sgl.gen("country", max_tokens=8, stop='"') + '",\n'
s += ' "air port code": "' + sgl.gen("air port code", max_tokens=8, stop='"') + '",\n'
s += ' "top 3 landmarks": "' + sgl.gen("landmarks", max_tokens=24, stop='"') + '",\n'
s += '}\n'
def main(args):
lines = read_jsonl(args.data_path)
arguments = []
for i in range(len(lines[:args.num_questions])):
arguments.append({
"document": lines[i]["document"],
})
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Run requests
tic = time.time()
states = json_decode.run_batch(
arguments, temperature=0, num_threads=args.parallel)
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "long_json_decode",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=10)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,26 @@
import json
import transformers
import wikipedia
name = "meta-llama/Llama-2-7b-chat-hf"
t = transformers.AutoTokenizer.from_pretrained(name)
city_names = ["los angles", "london", "tokyo", "beijing", "singapore"]
for city_name in city_names:
content = str(wikipedia.page(city_name).content)
content = content.replace("\n\n", "\n")
tokens = t.encode(content)
truncate_len = int((10000 / len(tokens)) * len(content))
truncate_content = content[:truncate_len]
truncate_tokens = t.encode(truncate_content)
# Count token
print(f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}")
with open("questions.jsonl", "a") as fout:
fout.write(json.dumps({"document": truncate_content}) + "\n")

56
benchmark/mmlu/README.md Normal file
View File

@@ -0,0 +1,56 @@
## Download data
```
wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
tar xf data.tar
```
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --nsub 10
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --nsub 10 --backend vllm
```
### Benchmark lightllm
```
# A10G
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
# V100
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 4500 --port 22000
```
```
python3 bench_other.py --nsub 10 --backend lightllm
```
### Benchmark guidance
```
python3 bench_other.py --nsub 10 --backend guidance --parallel 1
```
### Benchmark lmql
```
CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
```
```
python3 bench_other.py --nsub 10 --backend lmql --parallel 2
```

View File

@@ -0,0 +1,202 @@
import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
import json
from functools import partial
import os
import time
import numpy as np
import pandas as pd
import tiktoken
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
choices = ["A", "B", "C", "D"]
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
def format_subject(subject):
l = subject.split("_")
s = ""
for entry in l:
s += " " + entry
return s
def format_example(df, idx, include_answer=True):
prompt = df.iloc[idx, 0]
k = df.shape[1] - 2
for j in range(k):
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1])
prompt += "\nAnswer:"
if include_answer:
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
return prompt
def gen_prompt(train_df, subject, k=-1):
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(format_subject(subject))
if k == -1:
k = train_df.shape[0]
for i in range(k):
prompt += format_example(train_df, i)
return prompt
model_initialized = None
def evaluate(args, subject, dev_df, test_df):
prompts = []
labels = []
# Construct prompts
k = args.ntrain
train_prompt = gen_prompt(dev_df, subject, k)
while len(tokenizer.encode(train_prompt)) > 1536:
k -= 1
train_prompt = gen_prompt(dev_df, subject, k)
for i in range(test_df.shape[0]):
prompt_end = format_example(test_df, i, include_answer=False)
prompt = train_prompt + prompt_end
prompts.append(prompt)
label = test_df.iloc[i, test_df.shape[1]-1]
labels.append(label)
preds = [None] * len(prompts)
max_tokens = 1
# Select backend
global model_initialized
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url, stop=None)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url, stop=None)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url, stop=None)
elif args.backend == "guidance":
from guidance import models, gen
if model_initialized is None:
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
model_initialized = model
else:
model = model_initialized
def call_generate(prompt, temperature, max_tokens):
out = model + prompt + gen(name="answer",
max_tokens=max_tokens, temperature=0)
return out["answer"]
elif args.backend == "lmql":
import lmql
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
endpoint=f"{args.host}:{args.port}")
@lmql.query(model=model)
async def program(question):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 2
return ANSWER
'''
async def call_generate(prompt, temperature, max_tokens):
return await program(question=prompt, temperature=temperature)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests
if args.backend != "lmql":
# Use thread pool
def get_one_answer(i):
pred = call_generate(prompts[i], temperature=0,
max_tokens=max_tokens)
preds[i] = pred.strip()[0]
tic = time.time()
if args.parallel == 1:
for i in range(len(prompts)):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(prompts))))
else:
# Use asyncio
async def batched_call(batch_size):
for i in range(0, len(prompts), batch_size):
tasks = []
for p in prompts[i:i+batch_size]:
tasks.append(call_generate(p,
temperature=0, max_tokens=max_tokens))
rets = await asyncio.gather(*tasks)
for j in range(len(rets)):
preds[i+j] = rets[j].strip()[0]
tic = time.time()
asyncio.run(batched_call(batch_size=args.parallel))
latency = time.time() - tic
# Compute accuracy
cors = [pred == label for pred, label in zip(preds, labels)]
acc = np.mean(cors)
cors = np.array(cors)
print("Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
acc, latency, len(prompts), subject))
return cors, acc, latency
def main(args):
subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f])
all_cors = []
all_latencies = []
num_requests = 0
for subject in tqdm(subjects[:args.nsub]):
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain]
test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None)
cors, acc, latency = evaluate(args, subject, dev_df, test_df)
all_cors.append(cors)
all_latencies.append(latency)
num_requests += len(test_df)
total_latency = np.sum(all_latencies)
print("Total latency: {:.3f}".format(total_latency))
weighted_acc = np.mean(np.concatenate(all_cors))
print("Average accuracy: {:.3f}".format(weighted_acc))
# Write results
with open(args.result_file, "a") as fout:
value = {
"task": "mmlu",
"backend": args.backend,
"num_gpus": 1,
"latency": round(total_latency, 3),
"accuracy": round(weighted_acc, 3),
"num_requests": num_requests,
"other": {
"nsub": args.nsub,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ntrain", type=int, default=5)
parser.add_argument("--data_dir", type=str, default="data")
parser.add_argument("--nsub", type=int, default=60)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,143 @@
import argparse
import json
import os
import time
import numpy as np
import pandas as pd
import tiktoken
from tqdm import tqdm
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
choices = ["A", "B", "C", "D"]
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
def format_subject(subject):
l = subject.split("_")
s = ""
for entry in l:
s += " " + entry
return s
def format_example(df, idx, include_answer=True):
prompt = df.iloc[idx, 0]
k = df.shape[1] - 2
for j in range(k):
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1])
prompt += "\nAnswer:"
if include_answer:
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
return prompt
def gen_prompt(train_df, subject, k=-1):
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(format_subject(subject))
if k == -1:
k = train_df.shape[0]
for i in range(k):
prompt += format_example(train_df, i)
return prompt
def evaluate(args, subject, dev_df, test_df):
prompts = []
labels = []
k = args.ntrain
few_shot_examples = gen_prompt(dev_df, subject, k)
while len(tokenizer.encode(few_shot_examples)) > 1536:
k -= 1
few_shot_examples = gen_prompt(dev_df, subject, k)
for i in range(test_df.shape[0]):
prompt_end = format_example(test_df, i, include_answer=False)
prompts.append(prompt_end)
label = test_df.iloc[i, test_df.shape[1]-1]
labels.append(label)
arguments = [{"question": p} for p in prompts]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_mmlu(s, examples, question):
s += examples + question + sgl.gen("answer")
#####################################
########## SGL Program End ##########
#####################################
# Select backend
backend = select_sglang_backend(args)
tic = time.time()
states = few_shot_mmlu.bind(examples=few_shot_examples).run_batch(
arguments, temperature=0, max_new_tokens=1,
backend=backend, num_threads=args.parallel)
preds = [s["answer"].strip()[0] if len(s["answer"].strip()) > 0 else ""
for s in states]
latency = time.time() - tic
cors = [pred == label for pred, label in zip(preds, labels)]
acc = np.mean(cors)
cors = np.array(cors)
print("Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
acc, latency, len(prompts), subject))
return cors, acc, latency
def main(args):
subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f])
all_cors = []
all_latencies = []
num_requests = 0
for subject in tqdm(subjects[:args.nsub]):
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain]
test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None)
cors, acc, latency = evaluate(args, subject, dev_df, test_df)
all_cors.append(cors)
all_latencies.append(latency)
num_requests += len(test_df)
total_latency = np.sum(all_latencies)
print("Total latency: {:.3f}".format(total_latency))
weighted_acc = np.mean(np.concatenate(all_cors))
print("Average accuracy: {:.3f}".format(weighted_acc))
# Write results
with open(args.result_file, "a") as fout:
value = {
"task": "mmlu",
"backend": args.backend,
"num_gpus": 1,
"latency": round(total_latency, 3),
"accuracy": round(weighted_acc, 3),
"num_requests": num_requests,
"other": {
"nsub": args.nsub,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ntrain", "-k", type=int, default=5)
parser.add_argument("--data_dir", "-d", type=str, default="data")
parser.add_argument("--save_dir", "-s", type=str, default="results")
parser.add_argument("--nsub", type=int, default=60)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,31 @@
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 80
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-questions 80 --backend vllm
```
### Benchmark lightllm
```
# A10G
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
```
```
python3 bench_other.py --num-questions 80 --backend lightllm
```

View File

@@ -0,0 +1,116 @@
import argparse
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json
import os
import time
import uuid
from fastchat.model import get_conversation_template
import requests
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt
def load_questions(filename):
questions = []
with open(filename, "r") as fin:
for line in fin:
obj = json.loads(line)
questions.append(obj)
return questions
def write_answers(filename, model_id, questions, answers):
with open(os.path.expanduser(filename), "w") as fout:
for i in range(len(answers)):
ans_json = {
"question_id": questions[i]["question_id"],
"answer_id": uuid.uuid4().hex,
"model_id": model_id,
"choices": {
"index": 0,
"turns": [answers[i][0], answers[i][1]],
},
"tstamp": time.time(),
}
fout.write(json.dumps(ans_json) + "\n")
def main(args):
questions = load_questions(args.question_file)
questions = (questions * 10)[:args.num_questions]
max_tokens = 256
model_id = "llama-2-chat"
conv_main = get_conversation_template(model_id)
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url, stop=None)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url, stop=None)
elif args.backend == "srt":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt, url=url, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
answers = [None] * len(questions)
def get_answer(i):
conv = conv_main.copy()
cur_answers = []
for j in range(2):
q = questions[i]["turns"][j]
conv.append_message(conv.roles[0], q)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
output = call_generate(prompt,
temperature=0, max_tokens=max_tokens).strip()
cur_answers.append(output)
conv.update_last_message(output)
answers[i] = cur_answers
# Run requests
tic = time.time()
if args.parallel == 1:
for i in range(len(questions)):
get_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_answer, list(range(len(questions))))
latency = time.time() - tic
print(f"#questions: {len(questions)}, Latency: {latency:.2f}")
# Write results
answer_file = args.answer_file or f"tmp_output_{args.backend}.txt"
write_answers(answer_file, model_id, questions, answers)
with open(args.result_file, "a") as fout:
value = {
"task": "mtbench",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--question-file", type=str, default="question.jsonl")
parser.add_argument("--answer-file", type=str, default=None)
parser.add_argument("--num-questions", type=int, default=80)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,95 @@
import argparse
import json
import os
import time
import uuid
import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
def load_questions(filename):
questions = []
with open(filename, "r") as fin:
for line in fin:
obj = json.loads(line)
questions.append(obj)
return questions
def write_answers(filename, model_id, questions, answers):
with open(os.path.expanduser(filename), "w") as fout:
for i in range(len(answers)):
ans_json = {
"question_id": questions[i]["question_id"],
"answer_id": uuid.uuid4().hex,
"model_id": model_id,
"choices": {
"index": 0,
"turns": [answers[i][0], answers[i][1]],
},
"tstamp": time.time(),
}
fout.write(json.dumps(ans_json) + "\n")
@sgl.function
def answer_mt_bench(s, question_1, question_2):
s += sgl.system()
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1"))
s += sgl.user(question_2)
s += sgl.assistant(sgl.gen("answer_2"))
def main(args):
# Construct prompts
questions = load_questions(args.question_file)[:args.num_questions]
arguments = [
{"question_1": q["turns"][0], "question_2": q["turns"][1]}
for q in questions
]
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Run requests
tic = time.time()
rets = answer_mt_bench.run_batch(
arguments,
temperature=0,
max_new_tokens=256,
num_threads=args.parallel)
answers = [[s["answer_1"], s["answer_2"]] for s in rets]
latency = time.time() - tic
print(f"#questions: {len(questions)}, Latency: {latency:.2f}")
# Write results
model_id = backend.model_info["model_path"]
answer_file = args.answer_file or f"tmp_output_{args.backend}.txt"
write_answers(answer_file, model_id, questions, answers)
with open(args.result_file, "a") as fout:
value = {
"task": "mtbench",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--question-file", type=str, default="question.jsonl")
parser.add_argument("--answer-file", type=str, default=None)
parser.add_argument("--num-questions", type=int, default=80)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,43 @@
## Download data
```
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
```
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 64
python3 bench_sglang.py --num-questions 32 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-questions 64 --backend vllm
```
### Benchmark lightllm
```
# A10G
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
```
```
python3 bench_other.py --num-questions 64 --backend lightllm
```
### Benchmark guidance
```
python3 bench_other.py --num-questions 8 --backend guidance --parallel 1
```

View File

@@ -0,0 +1,195 @@
import argparse
import ast
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json
import re
import time
import numpy as np
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
INVALID = -9999999
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
prompt_lib = [
"Let us think step by step.",
"Approach this methodically. Let's dissect the problem into smaller, more manageable parts.",
"It's important to proceed step by step, ensuring accuracy at each stage.",
"Take a deep breath and break this down.",
"A little bit of arithmetic and a logical approach will help us quickly arrive at the solution to this problem.",
"I am extremely good at math.",
]
def multi_chain_gsm8k(question, num_chains, call_generate):
s = "Question: " + question + "\n"
# s += call_generate(s + "Answer: " + prompt_lib[0], max_tokens=256,
# stop="Question", temperature=0)
# return s
comps = []
for i in range(num_chains):
comps.append(call_generate(s + "Answer: " + prompt_lib[i % num_chains],
max_tokens=256, temperature=0.3, stop="Question"))
s += "Answer: To answer this question, here are some possible solutions. "
s += "After considering all of them, I will do a majority vote.\n\n"
for i in range(num_chains):
s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
s += call_generate(s, max_tokens=16, temperature=0, stop=None)
return s
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
k = args.num_shot
questions = []
labels = []
for i in range(len(lines[:args.num_questions])):
questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
states = [None] * len(labels)
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import models, gen
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
def call_generate(prompt, temperature, max_tokens, stop):
out = model + prompt + gen(name="answer",
max_tokens=max_tokens, temperature=temperature, stop=stop)
return out["answer"]
#def multi_chain_gsm8k(question, num_chains, call_generate):
# s = model + "Question: " + question + "\n"
# comps = []
# for i in range(num_chains):
# comps.append(call_generate(s + "Answer: " + prompt_lib[i % num_chains],
# max_tokens=256, temperature=0.3, stop="Question"))
# s += "Answer: To answer this question, here are some possible solutions. "
# s += "After considering all of them, I will do a majority vote.\n\n"
# for i in range(num_chains):
# s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
# s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
# return call_generate(s, max_tokens=16, temperature=0, stop=None)
elif args.backend == "lmql":
import lmql
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
endpoint=f"{args.host}:{args.port}")
@lmql.query(model=model)
async def program(question):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 257 and STOPS_AT(ANSWER, "Question")
return ANSWER
'''
async def call_generate(prompt, temperature, max_tokens, stop):
return await program(question=prompt, temperature=0)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests
if args.backend != "lmql":
# Use thread pool
def get_one_answer(i):
answer = multi_chain_gsm8k(questions[i], args.num_chains,
call_generate)
states[i] = answer
tic = time.time()
if args.parallel == 1:
for i in range(len(questions)):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(questions))))
else:
# Use asyncio
async def batched_call(batch_size):
for i in range(0, len(questions), batch_size):
tasks = []
for q in questions[i:i+batch_size]:
tasks.append(call_generate(few_shot_examples + q,
temperature=0, max_tokens=256, stop="Question"))
rets = await asyncio.gather(*tasks)
for j in range(len(rets)):
states[i+j] = get_answer_value(rets[j])
tic = time.time()
asyncio.run(batched_call(batch_size=args.parallel))
latency = time.time() - tic
preds = []
for i in range(len(states)):
preds.append(get_answer_value(states[i]))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_chain_gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=0)
parser.add_argument("--num-chains", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=50)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,129 @@
import argparse
import ast
import json
import re
import time
import numpy as np
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
INVALID = -9999999
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
prompt_lib = [
"Let us think step by step.",
"Approach this methodically. Let's dissect the problem into smaller, more manageable parts.",
"It's important to proceed step by step, ensuring accuracy at each stage.",
"Take a deep breath and break this down.",
"A little bit of arithmetic and a logical approach will help us quickly arrive at the solution to this problem.",
"I am extremely good at math.",
]
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
#k = args.num_shot
#few_shot_examples = get_few_shot_examples(lines, k)
questions = []
labels = []
for i in range(len(lines[:args.num_questions])):
questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q} for q in questions]
num_chains = args.num_chains
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def multi_chain_gsm8k(s, question):
s += "Question: " + question + "\n"
#s += "Answer: " + prompt_lib[0] + sgl.gen("answer", max_tokens=256, stop="Question",
# temperature=0)
#return
forks = s.fork(num_chains)
for i in range(num_chains):
forks[i] += ("Answer: " + prompt_lib[i % num_chains] +
sgl.gen(f"chain", max_tokens=256, temperature=0.3, stop="Question"))
forks.join()
s += "Answer: To answer this question, here are some possible solutions. "
s += "After considering all of them, I will do a majority vote.\n\n"
for i in range(num_chains):
s += f"Solution {i+1}: " + forks[i]["chain"].strip() + "\n\n"
s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
s += sgl.gen("answer", max_tokens=16)
#####################################
########## SGL Program End ##########
#####################################
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.time()
states = multi_chain_gsm8k.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel)
latency = time.time() - tic
preds = []
for i in range(len(states)):
preds.append(get_answer_value(states[i]["answer"]))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_chain_gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=0)
parser.add_argument("--num-chains", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=50)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,47 @@
## Run benchmark
### Benchmark sglang
```
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 10 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97
```
```
python3 bench_other.py --backend vllm --num-questions 64
```
### Benchmark guidance
```
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1
```
### Build dataset
```
pip install PyPDF2
python3 build_dataset.py
```
```python
import PyPDF2
with open('llama2.pdf', 'rb') as file:
reader = PyPDF2.PdfReader(file)
text = ''
for page_num in range(len(reader.pages)):
text += reader.pages[page_num].extract_text()
with open('output.txt', 'w') as text_file:
text_file.write(text)
```

View File

@@ -0,0 +1,126 @@
import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json
import time
from tqdm import tqdm
import numpy as np
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
USER_PREFIX = "[INST] "
USER_SUFFIX = " [/INST]"
ASSISTANT_PREFIX = ""
ASSISTANT_SUFFIX = " </s><s>"
def multi_document_qa(docs, question, generate):
s = USER_PREFIX
s += "Pleaes answer a question according to given documents.\n"
s += "Question:" + question + "Documents begin.\n"
s += "".join(docs)
s += "\nDocuments end."
s += ("\n\nBased on the above documents, please answer this question:\n" + question + "\nAnswer in three words or fewer.")
s += USER_SUFFIX
s += ASSISTANT_PREFIX
answer = generate(s, max_tokens=16, stop=None)
return answer
def main(args):
lines = read_jsonl(args.data_path)
l = lines[0]
arguments = []
labels = []
num_docs = 10
if args.backend == "guidance":
num_docs = 7 # due to OOM
for i in range(len(l["questions"][:args.num_questions])):
arguments.append({
"docs": l["documents"][:num_docs],
"question": l["questions"][i],
})
labels.append(l["answers"][i])
states = [None] * len(arguments)
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_lightllm, url=url, temperature=0)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_vllm, url=url, temperature=0)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import models, gen
model = models.LlamaCpp("/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf", n_gpu_layers=-1, n_ctx=11000)
def generate(prompt, max_tokens, stop):
out = model + prompt + gen(name="answer",
max_tokens=max_tokens, temperature=0, stop=stop)
return out["answer"]
# warmup
generate("Hello!", max_tokens=8, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests
def get_one_answer(i):
states[i] = multi_document_qa(generate=generate, **arguments[i])
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(labels))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(labels))))
latency = time.time() - tic
# Compute accuracy
print(states)
correct = 0
for s, label in zip(states, labels):
answer = s.lower()
if all(x in answer for x in label.lower().split(" ")):
correct += 1
accuracy = correct / len(labels)
print(f"Accuracy: {accuracy:.3f}")
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_document_qa",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"accuracy": accuracy,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=100)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,84 @@
import argparse
import json
import time
import numpy as np
import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
@sgl.function
def multi_document_qa(s, docs, question):
s += sgl.user_begin()
s += "Pleaes answer a question according to given documents.\n"
s += "Question:" + question + "Documents begin.\n"
forks = s.fork(len(docs))
forks += lambda i: docs[i]
forks.join("concate_and_append")
s += "\nDocuments end."
s += ("\n\nBased on the above documents, please answer this question:\n" + question + "\nAnswer in three words or fewer.")
s += sgl.user_end()
s += sgl.assistant(sgl.gen("answer", max_tokens=16))
def main(args):
lines = read_jsonl(args.data_path)
l = lines[0]
arguments = []
labels = []
for i in range(len(l["questions"][:args.num_questions])):
arguments.append({
"docs": l["documents"][:10],
"question": l["questions"][i],
})
labels.append(l["answers"][i])
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Run requests
tic = time.time()
states = multi_document_qa.run_batch(
arguments, temperature=0, num_threads=args.parallel)
latency = time.time() - tic
# Compute accuracy
print([s["answer"] for s in states])
correct = 0
for s, label in zip(states, labels):
answer = s["answer"].lower()
if all(x in answer for x in label.lower().split(" ")):
correct += 1
accuracy = correct / len(labels)
print(f"Accuracy: {accuracy:.3f}")
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_document_qa",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"accuracy": accuracy,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=100)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,64 @@
import json
import transformers
content = "\n".join(
open("llama2.txt", 'r', encoding='utf-8', errors='ignore').readlines())
content = content.replace("\n\n", "\n")
# Count token
name = "meta-llama/Llama-2-7b-chat-hf"
t = transformers.AutoTokenizer.from_pretrained(name)
print(f"num tokens: {len(t.encode(content))}")
# Segment
SEP = "\n\n"
parts = content.split(SEP)
print(f"num segments: {len(parts)}")
segment_len = 1100
segments = []
tmp = []
tmp_len = 0
for i in range(len(parts)):
tmp.append(parts[i])
tmp_len += len(t.encode(parts[i]))
if tmp_len > segment_len:
segments.append(SEP.join(tmp))
tmp = []
tmp_len = 0
for i, s in enumerate(segments):
print(i, len(t.encode(segments[i])))
# Dump
with open("questions.jsonl", "w") as fout:
fout.write(json.dumps({
"documents": segments[:30],
"questions": [
"What is the name of the fine-tuned LLMs?",
"Which figure shows the helpfulness human evaluation results for Llama 2-Chat?",
"What is the number of parameters in the largest Llama 2 model?",
"What is the batch size of fine-tuning?",
"Where can we find the details of potential data contamination?",
"What is the full name of MPT?",
"What is the power consumption of RSC in Watt?",
"How many tokens of data do they train on?",
"Which model's release is delayed due to a lack of time to sufficiently red team?",
"Which activation function is used in Llama?"
],
"answers": [
"Llama 2 Chat",
"1",
"70 B",
"64",
"A 6",
"MosaicML",
"400",
"2 trillion",
"34 B",
"SwiGLU",
],
}) + "\n")

26
benchmark/react/README.md Normal file
View File

@@ -0,0 +1,26 @@
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 100
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-questions 100 --backend vllm
```
### Benchmark guidance
```
python3 bench_other.py --num-questions 100 --backend guidance --parallel 1
```

View File

@@ -0,0 +1,182 @@
import argparse
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json
import time
from pathlib import Path
from tqdm import tqdm
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_vllm,
call_generate_srt_raw,
)
from sglang.utils import read_jsonl, dump_state_text
def get_prompt(question):
prompt = (
"""Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:
(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.
(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage.
(3) Finish[answer], which returns the answer and finishes the task.
Here are some examples.
Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?
Thought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area.
Action 1: Search[Colorado orogeny]
Observation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas.
Thought 2: It does not mention the eastern sector. So I need to look up eastern sector.
Action 2: Lookup[eastern sector]
Observation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny.
Thought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range.
Action 3: Search[High Plains]
Observation 3: High Plains refers to one of two distinct land regions:
Thought 4: I need to instead search High Plains (United States).
Action 4: Search[High Plains (United States)]
Observation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3]
Thought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft.
Action 5: Finish[1,800 to 7,000 ft]
Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who?
Thought 1: The question simplifies to "The Simpsons" character Milhouse is named after who. I only need to search Milhouse and find who it is named after.
Action 1: Search[Milhouse]
Observation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening.
Thought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up "named after".
Action 2: Lookup[named after]
Observation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous.
Thought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon.
Action 3: Finish[Richard Nixon]
Question: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture?
Thought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups.
Action 1: Search[Adam Clayton Powell]
Observation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito'].
Thought 2: To find the documentary, I can search Adam Clayton Powell (film).
Action 2: Search[Adam Clayton Powell (film)]
Observation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg.
The film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience.
Thought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture.
Action 3: Finish[The Saimaa Gesture]
Question: What profession does Nicholas Ray and Elia Kazan have in common?
Thought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common.
Action 1: Search[Nicholas Ray]
Observation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause.
Thought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions.
Action 2: Search[Elia Kazan]
Observation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor.
Thought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor.
Action 3: Finish[director, screenwriter, actor]
Question: Which magazine was started first Arthur's Magazine or First for Women?
Thought 1: I need to search Arthur's Magazine and First for Women, and find which was started first.
Action 1: Search[Arthur's Magazine]
Observation 1: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century.
Thought 2: Arthur's Magazine was started in 1844. I need to search First for Women next.
Action 2: Search[First for Women]
Observation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989.
Thought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first.
Action 3: Finish[Arthur's Magazine]
Question: Were Pavel Urysohn and Leonid Levin known for the same type of work?
Thought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same.
Action 1: Search[Pavel Urysohn]
Observation 1: Pavel Samuilovich Urysohn (February 3, 1898 â August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory.
Thought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work.
Action 2: Search[Leonid Levin]
Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
Action 3: Finish[yes]
""" + question)
return prompt
def main(args):
lines = read_jsonl(args.data_path)[:args.num_questions]
arguments = [{
"question": k,
"triplets": v
} for l in lines for k, v in l.items()]
states = []
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import models, gen
model = models.LlamaCpp(
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop):
out = (model + prompt + gen(
name="result",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
))
return out["result"]
else:
raise ValueError(f"Invalid backend: {args.backend}")
def run_single_agent(argument):
question = argument["question"]
triplets = argument["triplets"]
prompt = get_prompt(question)
for i in range(1, len(triplets) + 2):
prompt += "Thought " + str(i) + ":"
states.append(prompt)
answer = call_generate(prompt,
max_tokens=200,
temperature=0,
stop="Observation")
if i > len(triplets):
break
prompt += (triplets[i - 1]["thought"] + "\nAction " + str(i) +
":" + triplets[i - 1]["action"] + "\nObservation " +
str(i) + ":" + triplets[i - 1]["observation"] + "\n")
states.append(answer)
tic = time.time()
if args.parallel == 1:
for arg in tqdm(arguments):
run_single_agent(arg)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(run_single_agent, arguments)
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "ReAct Agents",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(arguments),
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="hotpotqa_100.jsonl")
parser.add_argument("--num-questions", type=int, default=10)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,141 @@
import argparse
import json
import time
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import read_jsonl, dump_state_text
@sgl.function
def webthink(s, question, triplets):
s += (
"""Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:
(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.
(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage.
(3) Finish[answer], which returns the answer and finishes the task.
Here are some examples.
Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?
Thought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area.
Action 1: Search[Colorado orogeny]
Observation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas.
Thought 2: It does not mention the eastern sector. So I need to look up eastern sector.
Action 2: Lookup[eastern sector]
Observation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny.
Thought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range.
Action 3: Search[High Plains]
Observation 3: High Plains refers to one of two distinct land regions:
Thought 4: I need to instead search High Plains (United States).
Action 4: Search[High Plains (United States)]
Observation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3]
Thought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft.
Action 5: Finish[1,800 to 7,000 ft]
Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who?
Thought 1: The question simplifies to "The Simpsons" character Milhouse is named after who. I only need to search Milhouse and find who it is named after.
Action 1: Search[Milhouse]
Observation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening.
Thought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up "named after".
Action 2: Lookup[named after]
Observation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous.
Thought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon.
Action 3: Finish[Richard Nixon]
Question: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture?
Thought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups.
Action 1: Search[Adam Clayton Powell]
Observation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito'].
Thought 2: To find the documentary, I can search Adam Clayton Powell (film).
Action 2: Search[Adam Clayton Powell (film)]
Observation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg.
The film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience.
Thought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture.
Action 3: Finish[The Saimaa Gesture]
Question: What profession does Nicholas Ray and Elia Kazan have in common?
Thought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common.
Action 1: Search[Nicholas Ray]
Observation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause.
Thought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions.
Action 2: Search[Elia Kazan]
Observation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor.
Thought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor.
Action 3: Finish[director, screenwriter, actor]
Question: Which magazine was started first Arthur's Magazine or First for Women?
Thought 1: I need to search Arthur's Magazine and First for Women, and find which was started first.
Action 1: Search[Arthur's Magazine]
Observation 1: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century.
Thought 2: Arthur's Magazine was started in 1844. I need to search First for Women next.
Action 2: Search[First for Women]
Observation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989.
Thought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first.
Action 3: Finish[Arthur's Magazine]
Question: Were Pavel Urysohn and Leonid Levin known for the same type of work?
Thought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same.
Action 1: Search[Pavel Urysohn]
Observation 1: Pavel Samuilovich Urysohn (February 3, 1898 â August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory.
Thought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work.
Action 2: Search[Leonid Levin]
Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
Action 3: Finish[yes]
""" + question)
for i in range(1, len(triplets) + 2):
s += "Thought " + str(i) + ":"
ss = s.fork(1)
ss[0] += sgl.gen(name="thought_action", max_tokens=200, stop="Observation")
# ss.join()
# to verify the correctness of output, this should be collected
# print(ss[0]["thought_action"])
if i > len(triplets):
break
s += (triplets[i - 1]["thought"] + "\nAction " + str(i) + ":" +
triplets[i - 1]["action"] + "\nObservation " + str(i) + ":" +
triplets[i - 1]["observation"] + "\n")
def main(args):
lines = read_jsonl(args.data_path)[:args.num_questions]
arguments = [{
"question": k,
"triplets": v
} for l in lines for k, v in l.items()]
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
states = []
tic = time.time()
states = webthink.run_batch(arguments,
temperature=0,
num_threads=args.parallel)
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "ReAct Agents",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(arguments),
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="hotpotqa_100.jsonl")
parser.add_argument("--num-questions", type=int, default=10)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,27 @@
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 64
python3 bench_sglang.py --num-questions 32 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --backend vllm --num-questions 64
```
### Benchmark guidance
```
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1
```

View File

@@ -0,0 +1,124 @@
import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json
import time
from tqdm import tqdm
import numpy as np
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
number = 5
def expand_tip(topic, tip, generate):
s = (
"""Please expand a tip for a topic into a detailed paragraph.
Topic: staying healthy
Tip: Regular Exercise
Paragraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement.
Topic: building a campfire
Tip: Choose the Right Location
Paragraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches.
Topic: writing a blog post
Tip: structure your content effectively
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
Topic: """ + topic + "\nTip: " + tip + "\nParagraph:")
return generate(s, max_tokens=128, stop=["\n\n"])
def suggest_tips(topic, generate):
s = "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
s += "USER: Give some tips for " + topic + ".\n"
s += ("ASSISTANT: Okay. Here are " + str(number) + " concise tips, each under 8 words:\n")
tips = []
for i in range(1, 1 + number):
s += f"{i}."
tip = generate(s, max_tokens=24, stop=[".", "\n"])
s += tip + ".\n"
tips.append(tip)
paragraphs = [expand_tip(topic, tip, generate=generate) for tip in tips]
for i in range(1, 1 + number):
s += f"Tip {i}:" + paragraphs[i-1] + "\n"
return s
def main(args):
lines = read_jsonl(args.data_path)[:args.num_questions]
states = [None] * len(lines)
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_lightllm, url=url, temperature=0)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_vllm, url=url, temperature=0)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import models, gen
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
def generate(prompt, max_tokens, stop):
out = model + prompt + gen(name="answer",
max_tokens=max_tokens, temperature=0, stop=stop)
return out["answer"]
# warmup
generate("Hello!", max_tokens=8, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests
def get_one_answer(i):
states[i] = suggest_tips(lines[i]["topic"], generate)
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(lines))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(lines))))
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "tip_suggestion",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="topic.jsonl")
parser.add_argument("--num-questions", type=int, default=100)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,91 @@
import argparse
import json
import time
import numpy as np
import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
number = 5
@sgl.function
def expand_tip(s, topic, tip):
s += (
"""Please expand a tip for a topic into a detailed paragraph.
Topic: staying healthy
Tip: Regular Exercise
Paragraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement.
Topic: building a campfire
Tip: Choose the Right Location
Paragraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches.
Topic: writing a blog post
Tip: structure your content effectively
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
Topic: """ + topic + "\nTip: " + tip + "\nParagraph:")
s += sgl.gen("paragraph", max_tokens=128, stop=["\n\n"], temperature=0)
@sgl.function
def suggest_tips(s, topic):
s += "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
s += "USER: Give some tips for " + topic + ".\n"
s += ("ASSISTANT: Okay. Here are " + str(number) + " concise tips, each under 8 words:\n")
paragraphs = []
for i in range(1, 1 + number):
s += f"{i}." + sgl.gen(f"tip_{i}", max_tokens=24, stop=[".", "\n"]) + ".\n"
paragraphs.append(expand_tip(topic=topic, tip=s[f"tip_{i}"]))
for i in range(1, 1 + number):
s += f"Tip {i}:" + paragraphs[i-1]["paragraph"] + "\n"
def main(args):
lines = read_jsonl(args.data_path)[:args.num_questions]
arguments = [
{"topic": l["topic"]} for l in lines
]
# Select backend
sgl.set_default_backend(select_sglang_backend(args))
# Run requests
tic = time.time()
states = suggest_tips.run_batch(
arguments, temperature=0, num_threads=args.parallel)
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "tip_suggestion",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="topic.jsonl")
parser.add_argument("--num-questions", type=int, default=100)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,43 @@
## Download data
```
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
```
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 32 --parallel 16
python3 bench_sglang.py --num-questions 10 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-questions 32 --backend vllm
```
### Benchmark lightllm
```
# A10G
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
```
```
python3 bench_other.py --num-questions 32 --backend lightllm
```
### Benchmark guidance
```
python3 bench_other.py --num-questions 32 --backend guidance --parallel 1
```

View File

@@ -0,0 +1,183 @@
import argparse
import ast
import asyncio
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json
import re
import time
import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
INVALID = -9999999
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def most_frequent_number(numbers):
if not numbers:
return None
frequency = Counter(numbers)
most_frequent = max(frequency, key=frequency.get)
return most_frequent
USER_PREFIX = "[INST] "
USER_SUFFIX = " [/INST]"
ASSISTANT_PREFIX = ""
ASSISTANT_SUFFIX = " </s><s>"
# Use a low temp to make the results more deterministic and the comparison more fair.
temp = 0.3
def propose_plan(s, question, num_branches, call_generate):
s += (USER_PREFIX +
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question + USER_SUFFIX)
s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def execute_plan(s, num_branches, call_generate):
s += (USER_PREFIX +
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + USER_SUFFIX)
s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def reflect_solution(s, num_branches, call_generate):
s += (USER_PREFIX +
"""Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + USER_SUFFIX)
s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def tree_search(question, num_branches, call_generate):
s = ""
solutions = []
plan_forks = propose_plan(s, question, num_branches, call_generate)
for plan in plan_forks:
sol_forks = execute_plan(plan, num_branches, call_generate)
for sol in sol_forks:
score_forks = reflect_solution(sol, num_branches, call_generate)
solutions.append(sol_forks)
return solutions
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
num_branches = 3
questions = []
labels = []
for i in range(len(lines[:args.num_questions])):
questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import models, gen
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
def call_generate(prompt, temperature, max_tokens, stop, n):
if n == 1:
out = model + prompt + gen(name="answer",
max_tokens=max_tokens, temperature=temperature, stop=stop)
return out["answer"]
else:
rets = []
for i in range(n):
out = model + prompt + gen(name="answer",
max_tokens=max_tokens, temperature=temperature, stop=stop)
rets.append(out["answer"])
return rets
# Run requests
states = [None] * len(questions)
def get_one_answer(i):
states[i] = tree_search(**arguments[i], call_generate=call_generate)
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(questions))))
latency = time.time() - tic
answers_text = []
for s in states:
answers_text.append([x for xs in s for x in xs])
preds = []
for i in range(len(states)):
answers = [get_answer_value(v) for v in answers_text[i]]
preds.append(most_frequent_number(answers))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
with open(args.result_file, "a") as fout:
value = {
"task": "tree_of_thought_gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,147 @@
import argparse
import ast
from collections import Counter
import json
import re
import time
import numpy as np
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
import sglang as sgl
INVALID = -9999999
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def most_frequent_number(numbers):
if not numbers:
return None
frequency = Counter(numbers)
most_frequent = max(frequency, key=frequency.get)
return most_frequent
# Use a low temp to make the results more deterministic and the comparison more fair.
temp = 0.3
def propose_plan(s, question, num_branches):
s += sgl.user(
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question)
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
return forks
def execute_plan(s, num_branches):
s += sgl.user(
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""")
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
return forks
def reflect_solution(s, num_branches):
s += sgl.user(
"""Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""")
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
return forks
@sgl.function
def tree_search(s, question, num_branches):
forks_to_join = []
plan_forks = propose_plan(s, question, num_branches)
forks_to_join.append(plan_forks)
sol_states = []
for plan in plan_forks:
forks = execute_plan(plan, num_branches)
forks_to_join.append(forks)
sol_states.extend(forks)
for sol in sol_states:
forks = reflect_solution(sol, num_branches)
forks_to_join.append(forks)
for f in reversed(forks_to_join):
f.join()
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
num_branches = 3
questions = []
labels = []
for i in range(len(lines[:args.num_questions])):
questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.time()
states = tree_search.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel)
latency = time.time() - tic
answers_text = []
for s in states:
answers_text.append([x for xs in s["answer"] for x in xs])
preds = []
for i in range(len(states)):
answers = [get_answer_value(v) for v in answers_text[i]]
preds.append(most_frequent_number(answers))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
with open(args.result_file, "a") as fout:
value = {
"task": "tree_of_thought_gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,43 @@
## Download data
```
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
```
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 32 --parallel 8
python3 bench_sglang.py --num-questions 16 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-questions 32 --backend vllm
```
### Benchmark lightllm
```
# A10G
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
```
```
python3 bench_other.py --num-questions 32 --backend lightllm
```
### Benchmark guidance
```
python3 bench_other.py --num-questions 8 --backend guidance --parallel 1
```

View File

@@ -0,0 +1,198 @@
import argparse
import ast
import asyncio
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import json
import re
import time
import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
from sglang.utils import read_jsonl, dump_state_text
INVALID = -9999999
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def most_frequent_number(numbers):
if not numbers:
return None
frequency = Counter(numbers)
most_frequent = max(frequency, key=frequency.get)
return most_frequent
USER_PREFIX = "[INST] "
USER_SUFFIX = " [/INST]"
ASSISTANT_PREFIX = ""
ASSISTANT_SUFFIX = " </s><s>"
# Use a low temp to make the results more deterministic and the comparison more fair.
temp = 0.001
def propose_plan(s, question, num_branches, call_generate):
s += (USER_PREFIX +
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question + USER_SUFFIX)
s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def execute_plan(s, num_branches, call_generate):
s += (USER_PREFIX +
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + USER_SUFFIX)
s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def reflect_solution(s, num_branches, call_generate):
s += (USER_PREFIX +
"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + USER_SUFFIX)
s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def get_final_answer(s, num_branches, call_generate):
s += (USER_PREFIX +
"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""" + USER_SUFFIX)
s += ASSISTANT_PREFIX
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
def tree_search(question, num_branches, call_generate):
plan_forks = propose_plan("", question, num_branches, call_generate)
sol_states = []
for plan in plan_forks:
forks = execute_plan(plan, num_branches, call_generate)
sol_states.extend(forks)
ref_states = []
for sol in sol_states:
forks = reflect_solution(sol, num_branches, call_generate)
ref_states.extend(forks)
solutions = []
for sol in ref_states:
ans = get_final_answer(sol, num_branches, call_generate)
solutions.append(ans)
return solutions
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
num_branches = 2
questions = []
labels = []
for i in range(len(lines[:args.num_questions])):
questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
# Select backend
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import models, gen
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
def call_generate(prompt, temperature, max_tokens, stop, n):
if n == 1:
out = model + prompt + gen(name="answer",
max_tokens=max_tokens, temperature=temperature, stop=stop)
return out["answer"]
else:
rets = []
for i in range(n):
out = model + prompt + gen(name="answer",
max_tokens=max_tokens, temperature=temperature, stop=stop)
rets.append(out["answer"])
return rets
# Run requests
states = [None] * len(questions)
def get_one_answer(i):
states[i] = tree_search(**arguments[i], call_generate=call_generate)
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(questions))))
latency = time.time() - tic
answers_text = []
for s in states:
answers_text.append([x for xs in s for x in xs])
preds = []
for i in range(len(states)):
answers = [get_answer_value(v) for v in answers_text[i]]
preds.append(most_frequent_number(answers))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
with open(args.result_file, "a") as fout:
value = {
"task": "tree_of_thought_gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,157 @@
import argparse
import ast
from collections import Counter
import json
import re
import time
import numpy as np
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.utils import read_jsonl, dump_state_text
import sglang as sgl
INVALID = -9999999
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r'\d+', answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def most_frequent_number(numbers):
if not numbers:
return None
frequency = Counter(numbers)
most_frequent = max(frequency, key=frequency.get)
return most_frequent
# Use a low temp to make the results more deterministic and the comparison more fair.
temp = 0.001
def propose_plan(s, question, num_branches):
s += sgl.user(
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question)
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
return forks
def execute_plan(s, num_branches):
s += sgl.user(
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""")
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
return forks
def reflect_solution(s, num_branches):
s += sgl.user(
"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""")
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
return forks
def get_final_answer(s, num_branches):
s += sgl.user(
"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""")
forks = s.fork(num_branches)
forks += sgl.assistant(sgl.gen("final_answer", max_tokens=256, temperature=temp))
return forks
@sgl.function
def tree_search(s, question, num_branches):
plan_forks = propose_plan(s, question, num_branches)
sol_states = []
for plan in plan_forks:
forks = execute_plan(plan, num_branches)
sol_states.extend(forks)
ref_states = []
for sol in sol_states:
forks = reflect_solution(sol, num_branches)
ref_states.extend(forks)
solutions = []
for sol in ref_states:
forks = get_final_answer(sol, num_branches)
solutions.append(forks)
solutions = [[s.text() for s in forks] for forks in solutions]
return solutions
def main(args):
lines = read_jsonl(args.data_path)
# Construct prompts
num_branches = 2
questions = []
labels = []
for i in range(len(lines[:args.num_questions])):
questions.append(lines[i]["question"])
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.time()
states = tree_search.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel)
latency = time.time() - tic
answers_text = []
for s in states:
answers_text.append([x for xs in s.ret_value for x in xs])
preds = []
for i in range(len(states)):
answers = [get_answer_value(v) for v in answers_text[i]]
preds.append(most_frequent_number(answers))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
with open(args.result_file, "a") as fout:
value = {
"task": "tree_of_thought_gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)
main(args)

20
docs/flashinfer.md Normal file
View File

@@ -0,0 +1,20 @@
## Flashinfer Mode
[`flashinfer`](https://github.com/flashinfer-ai/flashinfer) is a kernel library for LLM serving; we use it here to support our attention computation.
### Install flashinfer
```bash
git submodule update --init --recursive
pip install 3rdparty/flashinfer/python
```
### Run Sever With Flashinfer Mode
Add through `--model_mode` argument from the command line.
Example:
```bash
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --model-mode flashinfer
```

63
docs/test_process.md Normal file
View File

@@ -0,0 +1,63 @@
## SRT Unit Tests
### Low-level API
```
cd sglang/test/srt/model
python3 test_llama_low_api.py
python3 test_llama_extend.py
python3 test_llava_low_api.py
python3 bench_llama_low_api.py
```
### High-level API
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
cd test/lang
python3 test_srt_backend.py
```
### Performance
#### MMLU
```
cd benchmark/mmlu
```
Follow README.md to download the data.
```
python3 bench_sglang.py --nsub 3
# Expected performance on A10G
# Total latency: 8.200
# Average accuracy: 0.413
```
### More Models
#### LLaVA
```
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
```
```
cd benchmark/llava_bench
python3 bench_sglang.py
```
## SGLang Unit Tests
```
export ANTHROPIC_API_KEY=
export OPENAI_API_KEY=
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
cd test/lang
python3 run_all.py
```

View File

@@ -0,0 +1,19 @@
from sglang import function, system, user, assistant, gen, set_default_backend, Anthropic
@function
def multi_turn_question(s, question_1, question_2):
s += user(question_1)
s += assistant(gen("answer_1", max_tokens=256))
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))
set_default_backend(Anthropic("claude-2"))
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)
for m in state.messages():
print(m["role"], ":", m["content"])

View File

@@ -0,0 +1,26 @@
from sglang import function, gen, set_default_backend, Anthropic
@function
def few_shot_qa(s, question):
s += (
"""
\n\nHuman: What is the capital of France?
\n\nAssistant: Paris
\n\nHuman: What is the capital of Germany?
\n\nAssistant: Berlin
\n\nHuman: What is the capital of Italy?
\n\nAssistant: Rome
""")
s += "\n\nHuman: " + question + "\n"
s += "\n\nAssistant:" + gen("answer", stop="\n", temperature=0)
set_default_backend(Anthropic("claude-2"))
state = few_shot_qa.run(question="What is the capital of the United States?")
answer = state["answer"].strip().lower()
assert "washington" in answer, f"answer: {state['answer']}"
print(state.text())

View File

@@ -0,0 +1,20 @@
from sglang import function, system, user, assistant, gen, set_default_backend, Anthropic
@function
def multi_turn_question(s, question_1, question_2):
s += user(question_1)
s += assistant(gen("answer_1", max_tokens=256))
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))
set_default_backend(Anthropic("claude-2"))
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
stream=True
)
for out in state.text_iter():
print(out, end="", flush=True)

View File

@@ -0,0 +1,44 @@
import asyncio
import sglang as sgl
@sgl.function
def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
s += sgl.user(question_2)
s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo"))
#sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
def stream_a_variable():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
stream=True
)
for out in state.text_iter(var_name="answer_2"):
print(out, end="", flush=True)
print()
async def async_stream():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
stream=True
)
async for out in state.text_async_iter(var_name="answer_2"):
print(out, end="", flush=True)
print()
if __name__ == "__main__":
#stream_a_variable()
asyncio.run(async_stream())

View File

@@ -0,0 +1,20 @@
from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI
@function
def multi_turn_question(s, question_1, question_2):
s += system("You are a helpful assistant.")
s += user(question_1)
s += assistant(gen("answer_1", max_tokens=256))
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))
set_default_backend(OpenAI("gpt-3.5-turbo"))
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)
for m in state.messages():
print(m["role"], ":", m["content"])

View File

@@ -0,0 +1,26 @@
from sglang import function, gen, set_default_backend, OpenAI
@function
def few_shot_qa(s, question):
s += (
"""The following are questions with answers.
Q: What is the capital of France?
A: Paris
Q: What is the capital of Germany?
A: Berlin
Q: What is the capital of Italy?
A: Rome
""")
s += "Q: " + question + "\n"
s += "A:" + gen("answer", stop="\n", temperature=0)
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
state = few_shot_qa.run(question="What is the capital of the United States?")
answer = state["answer"].strip().lower()
assert "washington" in answer, f"answer: {state['answer']}"
print(state.text())

View File

@@ -0,0 +1,21 @@
from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI
@function
def multi_turn_question(s, question_1, question_2):
s += system("You are a helpful assistant.")
s += user(question_1)
s += assistant(gen("answer_1", max_tokens=256))
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))
set_default_backend(OpenAI("gpt-3.5-turbo"))
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
stream=True
)
for out in state.text_iter():
print(out, end="", flush=True)

View File

@@ -0,0 +1,26 @@
from sglang import function, system, user, assistant, gen, set_default_backend, Runtime
@function
def multi_turn_question(s, question_1, question_2):
s += system("You are a helpful assistant.")
s += user(question_1)
s += assistant(gen("answer_1", max_tokens=256))
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
#runtime = Runtime(model_path="mistralai/Mixtral-8x7B-Instruct-v0.1")
set_default_backend(runtime)
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)
for m in state.messages():
print(m["role"], ":", m["content"])
runtime.shutdown()

View File

@@ -0,0 +1,28 @@
from sglang import function, gen, set_default_backend, Runtime
@function
def few_shot_qa(s, question):
s += (
"""The following are questions with answers.
Q: What is the capital of France?
A: Paris
Q: What is the capital of Germany?
A: Berlin
Q: What is the capital of Italy?
A: Rome
""")
s += "Q: " + question + "\n"
s += "A:" + gen("answer", stop="\n", temperature=0)
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
set_default_backend(runtime)
state = few_shot_qa.run(question="What is the capital of the United States?")
answer = state["answer"].strip().lower()
assert "washington" in answer, f"answer: {state['answer']}"
print(state.text())
runtime.shutdown()

View File

@@ -0,0 +1,21 @@
from sglang import function, gen, set_default_backend, Runtime
@function
def regex_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + gen(
"answer",
temperature=0,
regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
)
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
set_default_backend(runtime)
state = regex_gen.run()
print(state.text())
runtime.shutdown()

View File

@@ -0,0 +1,26 @@
from sglang import function, system, user, assistant, gen, set_default_backend, Runtime
@function
def multi_turn_question(s, question_1, question_2):
s += system("You are a helpful assistant.")
s += user(question_1)
s += assistant(gen("answer_1", max_tokens=256))
s += user(question_2)
s += assistant(gen("answer_2", max_tokens=256))
runtime = Runtime("meta-llama/Llama-2-7b-chat-hf")
set_default_backend(runtime)
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
temperature=0,
stream=True,
)
for out in state.text_iter():
print(out, end="", flush=True)
print()
runtime.shutdown()

5
format.sh Normal file
View File

@@ -0,0 +1,5 @@
isort python
black python
isort test
black test

7
playground/launch_tgi.sh Normal file
View File

@@ -0,0 +1,7 @@
# Assuming the model is downdloaded at /home/ubuntu/model_weights/Llama-2-7b-chat-hf
docker run --name tgi --rm -ti --gpus all --network host \
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
ghcr.io/huggingface/text-generation-inference:1.1.0 \
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
--max-input-length 2048 --max-total-tokens 4096 \
--port 24000

View File

@@ -0,0 +1,7 @@
import transformers
import code
name = "meta-llama/Llama-2-7b-chat-hf"
t = transformers.AutoTokenizer.from_pretrained(name)
code.interact(local=locals())

31
python/pyproject.toml Normal file
View File

@@ -0,0 +1,31 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "sglang"
version = "0.1.0"
description = "A structured generation langauge for LLMs."
readme = "README.md"
requires-python = ">=3.8"
license = {file = "LICENSE"}
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"requests",
]
[project.optional-dependencies]
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
"interegular", "lark"]
openai = ["openai>=1.0"]
anthropic = ["anthropic"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
[tool.setuptools.packages.find]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
[tool.wheel]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]

View File

@@ -0,0 +1,2 @@
from sglang.api import *
from sglang.global_config import global_config

161
python/sglang/api.py Normal file
View File

@@ -0,0 +1,161 @@
"""Public API"""
import re
from typing import Callable, List, Optional, Union
from sglang.backend.anthropic import Anthropic
from sglang.backend.base_backend import BaseBackend
from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.global_config import global_config
from sglang.lang.ir import (
SglExpr,
SglExprList,
SglFunction,
SglGen,
SglImage,
SglRoleBegin,
SglRoleEnd,
SglSelect,
)
from sglang.srt.server import Runtime
def function(func: Callable):
return SglFunction(func)
def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend
def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
dtype: Optional[type] = None,
choices: Optional[List[str]] = None,
regex: Optional[str] = None,
):
if choices:
return SglSelect(name, choices, temperature)
# check regex is valid
if regex is not None:
try:
re.compile(regex)
except re.error as e:
raise e
return SglGen(
name,
max_tokens,
stop,
temperature,
top_p,
top_k,
frequency_penalty,
presence_penalty,
dtype,
regex,
)
def gen_int(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
):
return SglGen(
name,
max_tokens,
stop,
temperature,
top_p,
top_k,
frequency_penalty,
presence_penalty,
int,
None,
)
def gen_string(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
):
return SglGen(
name,
max_tokens,
stop,
temperature,
top_p,
top_k,
frequency_penalty,
presence_penalty,
str,
None,
)
def image(expr: SglExpr):
return SglImage(expr)
def select(
name: Optional[str] = None,
choices: List[str] = None,
temperature: float = 0.0,
):
assert choices is not None
return SglSelect(name, choices, temperature)
def _role_common(name: str, expr: Optional[SglExpr] = None):
if expr is None:
return SglExprList([SglRoleBegin(name), SglRoleEnd(name)])
else:
return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
def system(expr: Optional[SglExpr] = None):
return _role_common("system", expr)
def user(expr: Optional[SglExpr] = None):
return _role_common("user", expr)
def assistant(expr: Optional[SglExpr] = None):
return _role_common("assistant", expr)
def user_begin():
return SglRoleBegin("user")
def user_end():
return SglRoleEnd("user")
def assistant_begin():
return SglRoleBegin("assistant")
def assistant_end():
return SglRoleEnd("assistant")

View File

View File

@@ -0,0 +1,57 @@
from typing import List, Optional, Union
import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
try:
import anthropic
except ImportError as e:
anthropic = e
class Anthropic(BaseBackend):
def __init__(self, model_name):
super().__init__()
if isinstance(anthropic, Exception):
raise anthropic
self.model_name = model_name
self.chat_template = get_chat_template("claude")
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
prompt = s.text_
ret = anthropic.Anthropic().completions.create(
model=self.model_name,
prompt=prompt,
**sampling_params.to_anthropic_kwargs(),
)
comp = ret.completion
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
prompt = s.text_
generator = anthropic.Anthropic().completions.create(
model=self.model_name,
prompt=prompt,
stream=True,
**sampling_params.to_anthropic_kwargs(),
)
for ret in generator:
yield ret.completion, {}

View File

@@ -0,0 +1,74 @@
from typing import Callable, List, Optional, Union
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
class BaseBackend:
def __init__(self) -> None:
self.support_concate_and_append = False
self.chat_template = get_chat_template("default")
def get_model_name(self):
raise NotImplementedError()
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
pass
def uncache_prefix(self, rid: str):
pass
def end_request(self, rid: Union[str, List[str]]):
pass
def begin_program(self, s: StreamExecutor):
pass
def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
pass
def commit_lazy_operations(self, s: StreamExecutor):
pass
def fork_program(
self,
src: StreamExecutor,
dst: List[StreamExecutor],
position_ids_offset: Optional[List[int]] = None,
):
pass
def fill_image(self, s: StreamExecutor):
pass
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
raise NotImplementedError()
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
raise NotImplementedError()
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
):
raise NotImplementedError()
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
raise NotImplementedError()
def shutdown(self):
pass

View File

@@ -0,0 +1,349 @@
import functools
from enum import Enum, auto
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import transformers
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import ProgramState
from sglang.utils import get_available_gpu_memory
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
from transformersgl.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
class StopReason(Enum):
EOS_TOKEN = auto()
STOP_STR = auto()
LENGTH = auto()
def load_model(
model_name: str,
device,
num_gpus,
max_gpu_memory,
model_kwargs=None,
tokenizer_kwargs=None,
):
model_kwargs = model_kwargs or {}
tokenizer_kwargs = tokenizer_kwargs or {}
if device == "cuda":
model_kwargs["torch_dtype"] = torch.float16
if num_gpus != 1:
model_kwargs["device_map"] = "auto"
if max_gpu_memory is None:
model_kwargs[
"device_map"
] = "sequential" # This is important for not the same VRAM sizes
available_gpu_memory = [
get_available_gpu_memory(i, False) for i in range(num_gpus)
]
model_kwargs["max_memory"] = {
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
for i in range(num_gpus)
}
else:
model_kwargs["max_memory"] = {
i: max_gpu_memory for i in range(num_gpus)
}
elif device == "cpu":
model_kwargs["torch_dtype"] = torch.float32
else:
raise ValueError(f"Invalid device: {device}")
model = AutoModelForCausalLM.from_pretrained(
model_name, low_cpu_mem_usage=True, **model_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
if num_gpus == 1:
model.to(device).eval()
return model, tokenizer
def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list
@functools.lru_cache
def get_token_healing_mask(tokenizer, prompt_last_token):
last_str = tokenizer.convert_ids_to_tokens(prompt_last_token)
disallowed = torch.zeros(len(tokenizer), dtype=bool)
for s, t_id in tokenizer.get_vocab().items():
if not s.startswith(last_str):
disallowed[t_id] = 1
return disallowed
@functools.lru_cache
def get_int_token_mask(tokenizer):
disallowed = torch.zeros(len(tokenizer), dtype=bool)
for s, t_id in tokenizer.get_vocab().items():
s = s.replace("", "").strip()
if not (s.isdigit() or len(s) == 0 or s == ","):
disallowed[t_id] = 1
disallowed[tokenizer.eos_token_id] = 0
return disallowed
@torch.inference_mode()
def generate_stream(
model,
tokenizer,
prompt,
max_new_tokens,
stop: List[str],
temperature,
top_p,
token_healing,
logit_mask=None,
):
logits_processor = prepare_logits_processor(
temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0
)
device = model.device
input_ids = tokenizer.encode(prompt)
output_ids = list(input_ids)
prompt_len = len(prompt)
# Resolve stop
stop_token_ids = [tokenizer.eos_token_id]
# Token healing
token_healing = token_healing and len(input_ids) > 0
if token_healing:
token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1])
del output_ids[-1]
# Generate
past_key_values = None
stop_reason = None
for i in range(max_new_tokens):
# Forward
if i == 0: # prefill
out = model(torch.as_tensor([output_ids], device=device), use_cache=True)
else: # decoding
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
# Logit mask
if token_healing and i == 0:
logits[0, -1, token_healing_mask] = -1e4
if logit_mask is not None:
logits[0, -1, logit_mask] = -1e4
# Sample next token
last_token_logits = logits_processor(None, logits[:, -1, :])[0]
if temperature < 1e-5 or top_p < 1e-8: # greedy
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
# Stop condition
if token in stop_token_ids:
stop_reason = StopReason.EOS_TOKEN
break
output_str = tokenizer.decode(output_ids, skip_special_tokens=True)
for stop_str in stop:
pos = output_str[prompt_len:].find(stop_str)
if pos != -1:
stop_reason = StopReason.STOP_STR
output_str = output_str[: prompt_len + pos]
break
if stop_reason:
break
return output_str[prompt_len:]
class HuggingFaceTransformers(BaseBackend):
def __init__(
self,
model_name,
device="cuda",
num_gpus=1,
max_gpu_memory=None,
model_kwargs=None,
tokenizer_kwargs=None,
):
self.model_name = model_name
self.device = device
self.model, self.tokenizer = load_model(
model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs
)
self.chat_template = get_chat_template_by_model_path(model_name)
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
pass
def uncache_prefix(self, rid: str):
pass
def end_request(self, rid: str):
pass
def begin_program(self, s: ProgramState):
pass
def end_program(self, s: ProgramState):
pass
def fill(self, s: ProgramState, text: str):
return False
def generate_internal(
self,
prompt: str,
max_tokens: int,
stop: Union[str, List[str]],
temperature: float,
top_p: float,
dtype: Optional[str] = None,
):
if dtype is None:
comp = generate_stream(
self.model,
self.tokenizer,
prompt,
max_new_tokens=max_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
token_healing=True,
)
elif dtype in [str, "str", "string"]:
comp = generate_stream(
self.model,
self.tokenizer,
prompt + '"',
max_new_tokens=max_tokens,
stop=['"'],
temperature=temperature,
top_p=top_p,
token_healing=False,
)
comp = '"' + comp + '"'
elif dtype in [int, "int"]:
logit_mask = get_int_token_mask(self.tokenizer)
comp = generate_stream(
self.model,
self.tokenizer,
prompt,
max_new_tokens=max_tokens,
stop=stop + [" ", ","],
temperature=temperature,
top_p=top_p,
token_healing=False,
logit_mask=logit_mask,
)
return comp
def generate(
self,
s: ProgramState,
max_tokens: int,
stop: Union[str, List[str]],
temperature: float,
top_p: float,
dtype: Optional[str] = None,
):
prompt = s.text
comp = self.generate_internal(
prompt, max_tokens, stop, temperature, top_p, dtype
)
return comp
def parallel_generate(
self,
s: ProgramState,
prefixes: List[str],
join_func: Callable,
max_tokens: int,
stop: Union[str, List[str]],
temperature: float,
top_p: float,
dtype: Optional[str] = None,
):
prompt = s.text
parallel_prompts = [prompt + prefix for prefix in prefixes]
comps = []
for i in range(len(parallel_prompts)):
comps.append(
self.generate_internal(
parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype
)
)
joined = join_func([p + c for p, c in zip(prefixes, comps)])
return joined, comps
@torch.inference_mode()
def select(
self, s: ProgramState, choices: List[str], temperature: float, top_p: float
):
loss_fct = torch.nn.CrossEntropyLoss()
prompt = s.text
prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1]
prompt_choices = [prompt + choice for choice in choices]
scores = []
for i in range(len(choices)):
choice_ids = self.tokenizer.encode(
prompt_choices[i], return_tensors="pt"
).to(self.model.device)
logits = self.model(choice_ids).logits
# score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item()
logprobs = torch.log(torch.softmax(logits, dim=-1))
idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device)
idx2 = choice_ids[0, 1:]
selected_logprobs = logprobs[0, idx1, idx2]
score = selected_logprobs.mean().item()
scores.append(score)
decision = choices[np.argmax(scores)]
return decision, scores

View File

@@ -0,0 +1,241 @@
from typing import Callable, List, Optional, Union
import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
try:
import openai
import tiktoken
except ImportError as e:
openai = tiktoken = e
def create_logit_bias_int(tokenizer):
"""Get logit bias for integer numbers."""
int_token_ids = []
tokens = tokenizer._mergeable_ranks
for token, token_id in tokens.items():
s = tokenizer.decode([token_id])
if all([c.isdigit() for c in s]) or s in [" "]:
int_token_ids.append(token_id)
if len(int_token_ids) >= 300: # OpenAI API limit
break
special_tokens = tokenizer._special_tokens
mask = {t: 100 for t in int_token_ids[:299]}
mask[special_tokens["<|endoftext|>"]] = 100
return mask
CHAT_MODEL_NAMES = [
# GPT-4
"gpt-4",
"gpt-4-32k",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4-0613",
"gpt-4-0314",
# GPT-3.5
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0301",
]
class OpenAI(BaseBackend):
def __init__(self, model_name, *args, **kwargs):
super().__init__()
self.client = openai.OpenAI(*args, **kwargs)
if isinstance(openai, Exception):
raise e
self.model_name = model_name
self.tokenizer = tiktoken.encoding_for_model(model_name)
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
if model_name in CHAT_MODEL_NAMES:
self.is_chat_model = True
else:
self.is_chat_model = False
self.chat_template = get_chat_template("default")
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
if sampling_params.dtype is None:
if self.is_chat_model:
assert s.text_.endswith("ASSISTANT:")
prompt = s.messages_
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
comp = openai_completion(
client=self.client,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
**kwargs,
)
elif sampling_params.dtype in [str, "str", "string"]:
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_ + '"',
stop='"',
**kwargs,
)
comp = '"' + comp + '"'
elif sampling_params.dtype in [int, "int"]:
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_,
logit_bias=self.logit_bias_int,
stop=[" "],
**kwargs,
)
else:
raise ValueError(f"Unknown dtype: {dtype}")
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
if sampling_params.dtype is None:
if self.is_chat_model:
assert s.text_.endswith("ASSISTANT:")
prompt = s.messages_
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
generator = openai_completion_stream(
client=self.client,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
**kwargs,
)
return generator
else:
raise ValueError(f"Unknown dtype: {dtype}")
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
):
n_choices = len(choices)
token_ids = [self.tokenizer.encode(x) for x in choices]
scores = [0] * n_choices
valid = [len(x) > 0 for x in token_ids]
prompt_tokens = self.tokenizer.encode(s.text_)
max_len = max([len(x) for x in token_ids])
for step in range(max_len):
# Build logit bias
logit_bias = {}
for i in range(n_choices):
if valid[i]:
logit_bias[token_ids[i][step]] = 100
# Call API
ret = self.client.completions.create(
model=self.model_name,
prompt=prompt_tokens,
logit_bias=logit_bias,
max_tokens=1,
temperature=temperature,
)
ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0]
# TODO:
# 1. return logits as the scores
# 2. compute logits of the full choice
# 3. consider chunk-based decoding
# Update valid
hit = False
for i in range(n_choices):
if valid[i]:
if step == len(token_ids[i]) - 1:
valid[i] = False
if ret_token == token_ids[i][step]:
scores[i] += 1
hit = True
else:
valid[i] = False
assert hit
if np.sum(valid) <= 1:
break
prompt_tokens.append(ret_token)
decision = choices[np.argmax(scores)]
return decision, scores
def openai_completion(client, is_chat=None, prompt=None, **kwargs):
try:
if is_chat:
if kwargs["stop"] is None:
kwargs.pop("stop")
ret = client.chat.completions.create(messages=prompt, **kwargs)
comp = ret.choices[0].message.content
else:
ret = client.completions.create(prompt=prompt, **kwargs)
if isinstance(prompt, (list, tuple)):
comp = [c.text for c in ret.choices]
else:
comp = ret.choices[0].text
except openai.OpenAIError as e:
print(f"OpenAI Error: {e}")
raise e
return comp
def openai_completion_stream(client, is_chat=None, prompt=None, **kwargs):
try:
if is_chat:
generator = client.chat.completions.create(
messages=prompt, stream=True, **kwargs
)
for ret in generator:
content = ret.choices[0].delta.content
yield content or "", {}
else:
generator = client.completions.create(prompt=prompt, stream=True, **kwargs)
for ret in generator:
content = ret.choices[0].text
yield content or "", {}
except openai.OpenAIError as e:
print(f"OpenAI Error: {e}")
raise e

View File

@@ -0,0 +1,171 @@
import json
from typing import Callable, List, Optional, Union
import numpy as np
import requests
from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams, SglArgument
from sglang.utils import encode_image_base64, find_printable_text, http_request
class RuntimeEndpoint(BaseBackend):
def __init__(self, base_url):
super().__init__()
self.support_concate_and_append = True
self.base_url = base_url
res = http_request(self.base_url + "/get_model_info")
assert res.status_code == 200
self.model_info = res.json()
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_path"]
)
def get_model_name(self):
return self.model_info["model_path"]
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
res = http_request(
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
)
assert res.status_code == 200
def commit_lazy_operations(self, s: StreamExecutor):
res = http_request(
self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
)
assert res.status_code == 200
def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
assert res.status_code == 200
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
if sampling_params.dtype is None:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
**sampling_params.to_srt_kwargs(),
},
}
elif sampling_params.dtype in [int, "int"]:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
}
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
obj = res.json()
comp = obj["text"]
return comp, obj["meta_info"]
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
if sampling_params.dtype is None:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
**sampling_params.to_srt_kwargs(),
},
}
elif sampling_params.dtype in [int, "int"]:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
}
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
data["stream"] = True
self._add_images(s, data)
response = http_request(self.base_url + "/generate", json=data, stream=True)
pos = 0
incomplete_text = ""
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
text = find_printable_text(data["text"][pos:])
meta_info = data["meta_info"]
pos += len(text)
incomplete_text = data["text"][pos:]
yield text, meta_info
if len(incomplete_text) > 0:
yield incomplete_text, meta_info
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
):
assert temperature <= 1e-5
# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
assert res.status_code == 200
prompt_len = res.json()["meta_info"]["prompt_tokens"]
# Compute logprob
data = {
"text": [s.text_ + c for c in choices],
"sampling_params": {"max_new_tokens": 0},
"return_normalized_logprob": True,
"normalized_logprob_start_len": prompt_len,
}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
assert res.status_code == 200
logps = [r["meta_info"]["normalized_logprob"] for r in res.json()]
decision = choices[np.argmax(logps)]
return decision, logps
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
res = http_request(
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
)
assert res.status_code == 200
def _add_images(self, s: StreamExecutor, data):
if s.images_:
assert len(s.images_) == 1, "Only support one image."
data["image_data"] = s.images_[0][1]

View File

@@ -0,0 +1,190 @@
import re
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from itertools import repeat
from typing import List, Optional, Union
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
from sglang.utils import http_request
class TGI(BaseBackend):
def __init__(self, base_url):
super().__init__()
self.base_url = base_url
res = http_request(self.base_url + "/info")
assert res.status_code == 200
self.model_info = res.json()
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_id"]
)
def get_model_name(self):
return self.model_info["model_id"]
def get_chat_template(self):
return self.chat_template
@staticmethod
def adapt_params(max_tokens, stop, sampling_params, **override_params):
temperature = sampling_params.temperature
do_sample = True
if temperature == 0:
do_sample = False
temperature = None
if stop is None:
stop = []
elif isinstance(stop, str):
stop = [stop]
top_p = sampling_params.top_p
if top_p == 0:
top_p = 0.001
if top_p == 1:
top_p = 0.999
top_k = sampling_params.top_k
if top_k == -1:
top_k = None
params = {
"decoder_input_details": False,
"details": False,
"do_sample": do_sample,
"max_new_tokens": max_tokens,
"stop": stop,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"return_full_text": False,
}
params.update(override_params)
return params
@staticmethod
def _extract_int(text):
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
for word in words:
try:
int(word)
return word
except ValueError:
continue
raise ValueError
@staticmethod
def _extract_choice(choices, text):
# FIXME: Current only support the case where the choices are single words.
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
for word in words:
if word in choices:
return word
raise ValueError
@staticmethod
def _truncate_to_stop(text, stop):
# The stop sequence may not be a single token. In this case TGI will generate
# too many tokens so we need to truncate the output.
if stop:
stop = [stop] if isinstance(stop, str) else stop
for stop_seq in stop:
pos = text.find(stop_seq)
if pos != -1:
return text[:pos]
return text
def _make_request(self, params):
res = http_request(self.base_url + "/generate", json=params)
if res.status_code != 200:
raise ValueError(f"Error from TGI backend: {res.text}")
return res.json()
def retry_for_expected(self, prompt, params, extract_fn, retry=5):
# TGI does not support logis_bias (yet), so we have to use an inefficient hack.
failed = []
while retry > 0:
res_json = self._make_request(
{
"inputs": prompt,
"parameters": params,
}
)
text = res_json["generated_text"]
try:
return extract_fn(text)
except ValueError:
retry -= 1
failed.append(text)
msg = "=" * 20 + "\n"
msg += f"Prompt:\n{prompt}\n"
msg += "=" * 20 + "\n"
for i, text in enumerate(failed):
msg += f"====== Try {i+1}:\n{text}\n"
raise ValueError(
f"Model {self.model_info['model_id']} served by TGI backend does not generate"
"expected output. Please improve the prompt, increase the temperature, or "
f"use different models.\n{msg}"
)
def select(
self,
s: StreamExecutor,
choices: List[str],
sampling_params: SamplingParams,
):
decision = self.retry_for_expected(
s.text_,
self.adapt_params(16, [], sampling_params),
partial(self._extract_choice, choices),
)
return decision, [1 if choice == decision else 0 for choice in choices]
def generate(
self,
s: StreamExecutor,
max_tokens: int,
stop: Union[str, List[str]],
sampling_params: SamplingParams,
dtype: Optional[str] = None,
):
if dtype is None:
res_json = self._make_request(
{
"inputs": s.text_,
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
}
)
return self._truncate_to_stop(res_json["generated_text"], stop), {}
if dtype in [str, "str", "string"]:
stop = ['"']
res_json = self._make_request(
{
"inputs": f'{s.text_}"',
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
}
)
return (
'"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"',
{},
)
if dtype in [int, "int"]:
return (
self.retry_for_expected(
s.text_,
self.adapt_params(max_tokens, stop, sampling_params),
self._extract_int,
),
{},
)
raise ValueError(f"Unknown dtype: {dtype}")

View File

@@ -0,0 +1,60 @@
"""Flush cache in the backend by sending random requests."""
import argparse
import random
import string
import time
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
import sglang as sgl
@sgl.function
def flush_radix_cache(s, prompt):
s += prompt + sgl.gen("flush", max_tokens=1, stop="END")
def main(args, max_total_tokens, context_length, print_flag):
backend = select_sglang_backend(args)
flush_length = int(context_length * 0.8)
batch_size = int(max_total_tokens / flush_length)
prompt_length = flush_length * 2
prompts = [
" ".join(random.choices(string.ascii_letters, k=int(prompt_length)))
for _ in range(batch_size)
]
arguments = [{"prompt": prompts[i]} for i in range(batch_size)]
start_time = time.time()
flush_radix_cache.run_batch(
arguments, temperature=0, backend=backend, num_threads=1
)
end_time = time.time()
if print_flag:
print(
f"Flush length: {flush_length}\n",
f"Prompt length: {prompt_length}\n",
f"Total Prompt letters: {batch_size * prompt_length}\n",
f"Flush radix cache latency: {end_time - start_time:.3f}",
sep="",
)
# to prevent the backend still running
time.sleep(1)
def run_flush(args, max_total_tokens=20000, context_length=1024, print_flag=False):
main(args, max_total_tokens, context_length, print_flag=print_flag)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max-total-tokens", type=int, default=20000)
parser.add_argument("--context-length", type=int, default=1024)
args = add_common_sglang_args_and_parse(parser)
random.seed(0)
main(args, args.max_total_tokens, args.context_length, print_flag=True)

View File

@@ -0,0 +1,28 @@
"""Global configurations"""
class GlobalConfig:
def __init__(self):
# Verbosity level
# 0: do not output anything
# 2: output final text after every run
self.verbosity = 0
self.default_backend = None
# Output configs
self.skip_special_tokens_in_output = True
# Optimization configs
self.eager_fill_image = False
self.enable_prefix_sharing = True
self.enable_parallel_encoding = True
self.enable_parallel_decoding = True
# Choices: ["no_adjust", "adjust_cache"]
# no_adjust: Do not adjust the position embedding of KV cache.
# adjust_cache: Adjust the position embedding of KV cache.
self.concate_and_append_mode = "no_adjust"
global_config = GlobalConfig()

View File

View File

@@ -0,0 +1,186 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Dict, List, Tuple
class ChatTemplateStyle(Enum):
PLAIN = auto()
LLAMA2 = auto()
@dataclass
class ChatTemplate:
name: str
default_system_prompt: str
role_prefix_and_suffix: Dict[str, Tuple[str]]
image_token: str = "<image>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(self, role, hist_messages):
if self.style == ChatTemplateStyle.PLAIN:
return self.role_prefix_and_suffix[role]
elif self.style == ChatTemplateStyle.LLAMA2:
if len(hist_messages) == 0 and role == "system":
return (
self.role_prefix_and_suffix["user"][0]
+ self.role_prefix_and_suffix["system"][0],
self.role_prefix_and_suffix["system"][1],
)
elif (
len(hist_messages) == 1
and role == "user"
and hist_messages[0]["content"] is not None
):
return ("", self.role_prefix_and_suffix["user"][1])
return self.role_prefix_and_suffix[role]
else:
raise ValueError(f"Invalid style: {self.style}")
def get_prompt(self, messages):
prompt = ""
for i in range(len(messages)):
role, content = messages[i]["role"], messages[i]["content"]
if role == "system" and content is None:
content = self.default_system_prompt
if content is None:
continue
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
prompt += prefix + content + suffix
return prompt
chat_template_registry: Dict[str, ChatTemplate] = {}
matching_function_registry: List[Callable] = []
def register_chat_template(template):
chat_template_registry[template.name] = template
def register_chat_template_matching_function(func):
matching_function_registry.append(func)
def get_chat_template(name):
return chat_template_registry[name]
def get_chat_template_by_model_path(model_path):
for matching_func in matching_function_registry:
template = matching_func(model_path)
if template is not None:
return template
return get_chat_template("default")
register_chat_template(
ChatTemplate(
name="default",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("SYSTEM:", "\n"),
"user": ("USER:", "\n"),
"assistant": ("ASSISTANT:", "\n"),
},
)
)
register_chat_template(
ChatTemplate(
name="claude",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("\n\nHuman: ", ""),
"assistant": ("\n\nAssistant:", ""),
},
)
)
register_chat_template(
ChatTemplate(
name="chatml",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
)
)
register_chat_template(
ChatTemplate(
name="vicuna_v1.1",
default_system_prompt=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
role_prefix_and_suffix={
"system": ("", " "),
"user": ("USER:", " "),
"assistant": ("ASSISTANT:", "</s>"),
},
image_token=" <image>\n",
)
)
register_chat_template(
ChatTemplate(
name="llama-2-chat",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
"user": ("[INST] ", " [/INST]"),
"assistant": ("", " </s><s>"),
},
style=ChatTemplateStyle.LLAMA2,
)
)
@register_chat_template_matching_function
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if "llava" in model_path.lower():
return get_chat_template("vicuna_v1.1")
@register_chat_template_matching_function
def match_llama2_chat(model_path: str):
model_path = model_path.lower()
if "llama-2" in model_path and "chat" in model_path:
return get_chat_template("llama-2-chat")
if (
"mistral" in model_path or "mixtral" in model_path
) and "instruct" in model_path:
return get_chat_template("llama-2-chat")
if "codellama" in model_path and "instruct" in model_path:
return get_chat_template("llama-2-chat")
@register_chat_template_matching_function
def match_chat_ml(model_path: str):
if "tinyllama" in model_path.lower():
return get_chat_template("chatml")
if __name__ == "__main__":
messages = [
{"role": "system", "content": None}, # None means default
# {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "What can you do?"},
{"role": "assistant", "content": "I can chat with you."},
]
template = get_chat_template("llama-2-chat")
print(template.get_prompt(messages))

View File

@@ -0,0 +1,237 @@
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
from sglang.lang.ir import (
SamplingParams,
SglArgument,
SglConstantText,
SglExpr,
SglVariable,
)
def compile_func(function, backend):
tracer = function.trace(backend=backend)
compiler = CompiledFunction(tracer, function)
return compiler
class CompiledFunction:
def __init__(self, tracer, function):
self.function = function
self.last_node = CompGraphNode(tracer.last_node)
self.expr_to_node = {}
self.build_graph(tracer)
self.topological_sort()
def build_graph(self, tracer):
self.nodes = [self.last_node]
self.expr_to_node[tracer.last_node] = self.nodes[-1]
rename_pid = {}
visited = set([tracer.last_node])
head = 0
while head < len(self.nodes):
cur_node = self.nodes[head]
# add prev node
prev_node = cur_node.expr.prev_node
if prev_node is not None:
if prev_node not in visited:
visited.add(prev_node)
self.nodes.append(CompGraphNode(prev_node))
self.expr_to_node[prev_node] = self.nodes[-1]
cur_node.prev_node = self.expr_to_node[prev_node]
self.expr_to_node[prev_node].add_next_node(cur_node)
# add source node
if isinstance(cur_node.expr, SglVariable):
if cur_node.expr.name in tracer.variables:
source = tracer.variables[cur_node.expr.name].source
else:
source = cur_node.expr.source
if source not in visited:
visited.add(source)
self.nodes.append(CompGraphNode(source))
self.expr_to_node[source] = self.nodes[-1]
cur_node.source_node = self.expr_to_node[source]
self.expr_to_node[source].add_next_node(cur_node)
head += 1
# rename pid
if cur_node.expr.pid not in rename_pid:
rename_pid[cur_node.expr.pid] = len(rename_pid)
cur_node.expr.pid = rename_pid[cur_node.expr.pid]
def topological_sort(self):
prevd = {}
cand = Queue()
for x in self.nodes:
prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
if prevd[x] == 0:
cand.put(x)
new_list = []
while cand.qsize() > 0:
head = cand.get()
new_list.append(head)
for x in head.next_nodes:
prevd[x] -= 1
if prevd[x] == 0:
cand.put(x)
self.nodes = new_list
def print_graph(
self,
):
for node in self.nodes:
print(node)
def run_internal(
self,
backend,
kwargs,
default_sampling_para,
):
stream_executor_ids = set([x.expr.pid for x in self.nodes])
stream_executors = {}
for x in stream_executor_ids:
arguments = kwargs if x == self.last_node.expr.pid else {}
stream_executors[x] = StreamExecutor(
backend, arguments, default_sampling_para, None, False
)
for node in self.nodes:
se_id = node.expr.pid
expr = node.expr
if isinstance(expr, SglVariable):
# Make a copy for SglVariable
expr = SglVariable(expr.name, expr.source)
expr.source_stream_executor = stream_executors[
node.source_node.expr.pid
]
elif isinstance(expr, SglArgument):
# Substitute SglArgument
expr = kwargs[expr.name]
stream_executors[se_id].submit(expr)
for stream_executor in stream_executors.values():
stream_executor.end()
return ProgramState(stream_executors[self.last_node.expr.pid])
def run(
self,
*,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
**kwargs,
):
backend = backend or global_config.default_backend
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
kwargs.update(self.function.bind_arguments)
default_sampling_para = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
return self.run_internal(backend, kwargs, default_sampling_para)
def run_batch(
self,
batch_kwargs,
*,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
num_threads: Union[str, int] = "auto",
):
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
assert isinstance(batch_kwargs[0], dict)
backend = backend or global_config.default_backend
default_sampling_para = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
batch_kwargs = [
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
]
# Extract prefix by tracing and cache it
if len(batch_kwargs) > 1:
pin_program(self.function, backend)
# Run all programs
if num_threads == "auto":
num_threads = multiprocessing.cpu_count()
num_threads = min(num_threads, len(batch_kwargs))
if num_threads == 1:
rets = []
for arguments in batch_kwargs:
rets.append(
self.run_internal(backend, arguments, default_sampling_para)
)
else:
with ThreadPoolExecutor(num_threads) as executor:
futures = []
for arguments in batch_kwargs:
futures.append(
executor.submit(
self.run_internal, backend, arguments, default_sampling_para
)
)
rets = [f.result() for f in futures]
rets[-1].sync()
return rets
class CompGraphNode:
def __init__(
self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
):
self.expr = expr
self.next_nodes = next_nodes or []
self.prev_node = prev_node
self.source_node = source_node
def add_next_node(self, other):
self.next_nodes.append(other)
def __repr__(self):
re = f"stream {self.expr.pid:2d}: "
re += f"%{self.expr.node_id} = "
if self.prev_node is not None:
re += f"%{self.prev_node.expr.node_id} + "
re += repr(self.expr)
return re

View File

@@ -0,0 +1,697 @@
"""The interpreter that executes SGL programs"""
import asyncio
import multiprocessing
import queue
import threading
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
import tqdm
from sglang.global_config import global_config
from sglang.lang.ir import (
SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
SglExpr,
SglExprList,
SglFunction,
SglGen,
SglImage,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
)
from sglang.utils import encode_image_base64
def run_internal(state, program, func_args, func_kwargs, sync):
try:
state.ret_value = program.func(state, *func_args, **func_kwargs)
except Exception as e:
raise e
finally:
state.stream_executor.end()
if sync:
state.stream_executor.sync()
if global_config.verbosity >= 2:
print(state.text())
def run_program(
program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False
):
assert backend is not None, "Please specify a backend"
func_kwargs.update(program.bind_arguments)
stream_executor = StreamExecutor(
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream
)
state = ProgramState(stream_executor)
if stream:
t = threading.Thread(
target=run_internal, args=(state, program, func_args, func_kwargs, sync)
)
t.start()
return state
else:
run_internal(state, program, func_args, func_kwargs, sync)
return state
def run_program_batch(
program,
backend,
batch_arguments,
default_sampling_para,
num_threads,
progress_bar,
):
# Extract prefix by tracing and cache it
if len(batch_arguments) > 1:
pin_program(program, backend)
# Run all programs
if num_threads == "auto":
num_threads = multiprocessing.cpu_count()
num_threads = min(num_threads, len(batch_arguments))
if num_threads == 1:
rets = []
for arguments in batch_arguments:
rets.append(
run_program(
program, backend, (), arguments, default_sampling_para, False, False
)
)
else:
if progress_bar:
pbar = tqdm.tqdm(total=len(batch_arguments))
with ThreadPoolExecutor(num_threads) as executor:
futures = []
for arguments in batch_arguments:
futures.append(
executor.submit(
run_program,
program,
backend,
(),
arguments,
default_sampling_para,
False,
False,
)
)
if progress_bar:
futures[-1].add_done_callback(lambda _: pbar.update())
rets = [f.result() for f in futures]
rets[-1].sync()
if progress_bar:
pbar.close()
return rets
def pin_program(program, backend):
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
# TODO: handle multiple backends
from sglang.lang.tracer import extract_prefix_by_tracing
prefix = extract_prefix_by_tracing(program, backend)
if prefix and len(prefix) > 64:
prefix_rid = backend.cache_prefix(prefix)
program.pin_prefix_rid = prefix_rid
return prefix_rid
return None
def unpin_program(program, backend):
pass
class StreamExecutor:
"""A stream executor that executes SGL expressions in a background thread."""
def __init__(
self,
backend,
arguments,
default_sampling_para,
chat_template,
stream,
use_thread=True,
):
self.sid = uuid.uuid4().hex
self.backend = backend
self.arguments: Dict[str, Any] = arguments
self.default_sampling_para = default_sampling_para
self.stream = stream
if hasattr(backend, "endpoint"):
self.backend = backend.endpoint
self.variables = {} # Dict[name: str -> value: str]
self.variable_event = {} # Dict[name: str -> event: threading.Event]
self.meta_info = {} # Dict[name: str -> info: str]
self.is_finished = False
# For completion
self.text_ = "" # The full text
# For chat
self.messages_ = [] # The messages in the OpenAI API format
self.chat_template = chat_template or self.backend.get_chat_template()
self.cur_role = None
self.cur_role_begin_pos = None
# For vision
self.images_ = []
self.cur_images = []
# For fork/join
self.fork_start_text_pos = None
# Worker thread
self.use_thread = use_thread
if self.use_thread:
self.queue = queue.Queue()
self.worker = threading.Thread(target=self._thread_worker_func)
self.worker.start()
# For streaming
if stream:
self.stream_text_event = threading.Event()
self.stream_var_event = {}
else:
self.stream_text_event = None
self.stream_var_event = None
def submit(self, expr: SglExpr):
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
self.variable_event[expr.name] = threading.Event()
if self.stream:
self.stream_var_event[expr.name] = threading.Event()
elif isinstance(expr, SglExprList):
for e in expr.expr_list:
if isinstance(e, (SglGen, SglSelect, SglVarScopeBegin)):
self.variable_event[e.name] = threading.Event()
if self.stream:
self.stream_var_event[e.name] = threading.Event()
if self.use_thread:
self.queue.put(expr)
else:
self._execute(expr)
def sync(self):
if self.use_thread:
self.queue.join()
def get_var(self, name):
if name in self.variable_event:
self.variable_event[name].wait()
return self.variables[name]
def get_meta_info(self, name):
if name in self.variable_event:
self.variable_event[name].wait()
ret = self.meta_info.get(name, None)
return ret
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
if number > 1:
self.submit(SglCommitLazy())
self.sync()
number = int(number)
exes = [
StreamExecutor(
self.backend,
self.arguments,
self.default_sampling_para,
self.chat_template,
self.stream,
)
for _ in range(number)
]
for i in range(number):
exes[i].variables = dict(self.variables)
exes[i].text_ = str(self.text_)
exes[i].messages_ = list(self.messages_)
exes[i].cur_role = self.cur_role
exes[i].fork_start_text_pos = len(self.text_)
return exes
def text(self):
self.sync()
return self.text_
def messages(self):
self.sync()
return self.messages_
def end(self):
if self.use_thread:
if self.worker.is_alive():
self.queue.put(None)
self.backend.end_program(self)
def _thread_worker_func(self):
while True:
expr = self.queue.get()
if expr is None:
self.queue.task_done()
break
self._execute(expr)
self.queue.task_done()
if self.stream_text_event:
self.stream_text_event.set()
if self.stream_text_event:
self.stream_text_event.set()
self.is_finished = True
def _execute(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr), f"{other}"
if isinstance(other, (SglConstantText, SglArgument)):
self._execute_fill(other.value)
elif isinstance(other, SglGen):
self._execute_gen(other)
elif isinstance(other, SglSelect):
self._execute_select(other)
elif isinstance(other, SglExprList):
for x in other.expr_list:
self._execute(x)
elif isinstance(other, SglRoleBegin):
self._execute_role_begin(other)
elif isinstance(other, SglRoleEnd):
self._execute_role_end(other)
elif isinstance(other, SglImage):
self._execute_image(other)
elif isinstance(other, SglVariable):
self._execute_variable(other)
elif isinstance(other, SglVarScopeBegin):
self._execute_var_scope_begin(other)
elif isinstance(other, SglVarScopeEnd):
self._execute_var_scope_end(other)
elif isinstance(other, SglCommitLazy):
self._execute_commit_lazy_operations(other)
elif isinstance(other, SglConcateAndAppend):
if (
global_config.enable_parallel_encoding
and self.backend.support_concate_and_append
):
self._execute_concatenate_and_append_kv_cache(other)
else:
self._execute_concatenate_and_append_text(other)
else:
raise ValueError(f"Unknown type: {type(other)}")
def _execute_fill(self, value: str):
value = str(value)
self.text_ += value
def _execute_image(self, expr: SglImage):
path = expr.path
if isinstance(path, SglArgument):
path = path.value
base64_data = encode_image_base64(path)
self.images_.append((path, base64_data))
self.cur_images.append((path, base64_data))
self.text_ += self.chat_template.image_token
# if global_config.eager_fill_image:
# self.backend.fill_image(self)
def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name
if not self.stream:
comp, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
self.text_ += comp
self.variables[name] = comp
self.meta_info[name] = meta_info
self.variable_event[name].set()
else:
generator = self.backend.generate_stream(
self, sampling_params=sampling_params
)
self.stream_var_event[name].set()
self.variables[name] = ""
for comp, meta_info in generator:
self.text_ += comp
self.variables[name] += comp
self.stream_var_event[name].set()
self.stream_text_event.set()
self.meta_info[name] = meta_info
self.variable_event[name].set()
self.stream_var_event[name].set()
def _execute_select(self, expr: SglSelect):
decision, scores = self.backend.select(self, expr.choices, expr.temperature)
if expr.name is not None:
name = expr.name
self.variables[name] = decision
self.variable_event[name].set()
self.text_ += decision
def _execute_variable(self, expr: SglVariable):
src_executor = expr.source_stream_executor
value = src_executor.get_var(expr.name)
self._execute_fill(value)
def _execute_role_begin(self, expr: SglRoleBegin):
assert self.cur_role is None, "Nested roles are not allowed."
if len(self.messages_) == 0 and expr.role != "system":
# Insert the default system message
default_system = self.chat_template.default_system_prompt
if default_system:
self._execute_role_begin(SglRoleBegin("system"))
self._execute_fill(default_system)
self._execute_role_end(SglRoleEnd("system"))
self.cur_role = expr.role
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
self._execute_fill(prefix)
self.cur_role_begin_pos = len(self.text_)
def _execute_role_end(self, expr: SglRoleEnd):
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
self._execute_fill(suffix)
if self.cur_images:
# OpenAI vision API format
last_msg = {
"role": expr.role,
"content": [{"type": "text", "text": new_text}],
}
for (image_path, image_base64_data) in self.cur_images:
last_msg["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64_data}"
},
}
)
self.messages_.append(last_msg)
self.cur_images = []
else:
self.messages_.append({"role": expr.role, "content": new_text})
self.cur_role = None
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
self.variables[expr.name] = int(len(self.text_))
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
self.variables[expr.name] = self.text_[self.variables[expr.name] :]
self.variable_event[expr.name].set()
def _execute_commit_lazy_operations(self, expr: SglCommitLazy):
self.backend.commit_lazy_operations(self)
def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend):
new_text = ""
for s in expr.states:
exe = s.stream_executor
exe.sync()
new_text += exe.text_[exe.fork_start_text_pos :]
self._execute_fill(new_text)
def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend):
self_len = len(self.text_)
for i, s in enumerate(expr.states):
exe = s.stream_executor
exe.submit(SglCommitLazy())
for i, s in enumerate(expr.states):
exe = s.stream_executor
exe.sync()
assert exe.fork_start_text_pos == self_len
self.text_ += exe.text_[exe.fork_start_text_pos :]
src_rids = [state.stream_executor.sid for state in expr.states]
self.backend.concatenate_and_append(src_rids, self.sid)
def _resolve_sampling_params(self, sampling_params):
clone = None
for item in [
"max_new_tokens",
"stop",
"temperature",
"top_p",
"top_k",
"frequency_penalty",
"presence_penalty",
"dtype",
"regex",
]:
value = getattr(sampling_params, item, None)
if value is not None:
if clone is None:
clone = self.default_sampling_para.clone()
setattr(clone, item, value)
return clone or self.default_sampling_para
def __del__(self):
self.end()
class ProgramState:
"""The state of an SGL program."""
def __init__(self, stream_executor: StreamExecutor):
self.stream_executor = stream_executor
def _role_common(self, name: str, expr: Optional[SglExpr] = None):
if expr is not None:
self.stream_executor.submit(
SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
)
else:
@contextmanager
def role_scope():
self.stream_executor.submit(SglRoleBegin(name))
yield
self.stream_executor.submit(SglRoleEnd(name))
return role_scope()
def system(self, expr: Optional[SglExpr] = None):
return self._role_common("system", expr)
def user(self, expr: Optional[SglExpr] = None):
return self._role_common("user", expr)
def assistant(self, expr: Optional[SglExpr] = None):
return self._role_common("assistant", expr)
@contextmanager
def var_scope(self, name: str):
self.stream_executor.submit(SglVarScopeBegin(name))
yield
self.stream_executor.submit(SglVarScopeEnd(name))
def fork(self, number: int = 1, position_ids_offset: Optional[List[int]] = None):
stream_executors = self.stream_executor.fork(number, position_ids_offset)
states = [ProgramState(x) for x in stream_executors]
state_group = ProgramStateGroup(states, self)
return state_group
@contextmanager
def copy(self, position_ids_offset: Optional[List[int]] = None):
state_group = self.fork(1, position_ids_offset)
try:
yield state_group[0]
finally:
state_group.join()
def text(self):
return self.stream_executor.text()
def messages(self):
return self.stream_executor.messages()
def sync(self):
return self.stream_executor.sync()
def text_iter(self, var_name=None):
if self.stream_executor.stream:
prev = 0
if var_name is None:
event = self.stream_executor.stream_text_event
while True:
event.wait()
event.clear()
out = str(self.stream_executor.text_[prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.is_finished:
break
else:
event = self.stream_executor.stream_var_event[var_name]
while True:
event.wait()
event.clear()
out = str(self.stream_executor.variables[var_name][prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.variable_event[var_name].is_set():
break
else:
if var_name is None:
yield self.text()
else:
yield self.get_var(name)
async def text_async_iter(self, var_name=None):
loop = asyncio.get_running_loop()
if self.stream_executor.stream:
prev = 0
if var_name is None:
event = self.stream_executor.stream_text_event
while True:
await loop.run_in_executor(None, event.wait)
event.clear()
out = str(self.stream_executor.text_[prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.is_finished:
break
else:
event = self.stream_executor.stream_var_event[var_name]
while True:
await loop.run_in_executor(None, event.wait)
event.clear()
out = str(self.stream_executor.variables[var_name][prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.variable_event[var_name].is_set():
break
else:
if var_name is None:
yield self.text()
else:
yield self.get_var(name)
def get_var(self, name):
return self.stream_executor.get_var(name)
def get_meta_info(self, name):
return self.stream_executor.get_meta_info(name)
def __iadd__(self, other):
self.stream_executor.submit(other)
return self
def __getitem__(self, name):
return self.get_var(name)
def __del__(self):
self.stream_executor.end()
def __repr__(self) -> str:
msgs = self.messages()
ret = ""
for msg in msgs:
ret += msg["role"] + ":\n" + msg["content"] + "\n"
return ret
class ProgramStateGroup:
def __init__(
self, states: List[ProgramState], src_state: Optional[ProgramState] = None
):
self.states = states
self.src_state = src_state
def join(self, mode: str = "gather_variable"):
if mode == "gather_variable":
# Copy variables back
src_vars = self.src_state.stream_executor.variables
src_var_set = set(src_vars.keys())
for child_state in self.states:
child_state.stream_executor.sync()
child_vars = child_state.stream_executor.variables
new_vars = set(child_vars.keys()) - src_var_set
for k in new_vars:
if k in src_vars:
src_vars[k].append(child_vars[k])
else:
src_vars[k] = [child_vars[k]]
elif mode == "concate_and_append":
# Concatenate and append KV cache
self.src_state += SglConcateAndAppend(self.states)
# Need a sync here. Otherwise, `states` can be deleted.
self.src_state.stream_executor.sync()
else:
raise ValueError(f"Invalid join mode: {mode}")
for s in self.states:
s.stream_executor.end()
def __getitem__(self, i: int):
return self.states[i]
def __setitem__(self, i: int, value):
assert self.states[i] == value
def __iadd__(self, other):
if isinstance(other, Callable):
# lambda function
for i in range(len(self.states)):
self.states[i] += other(i)
elif isinstance(other, SglExpr):
for i in range(len(self.states)):
self.states[i] += other
elif isinstance(other, (list, tuple)):
for i in range(len(self.states)):
self.states[i] += other[i]
else:
raise ValueError(f"Invalid value: {other}")
return self

442
python/sglang/lang/ir.py Normal file
View File

@@ -0,0 +1,442 @@
"""The intermediate representation."""
import dataclasses
import inspect
from typing import List, Optional, Union
from sglang.global_config import global_config
@dataclasses.dataclass
class SamplingParams:
max_new_tokens: int = 16
stop: Union[str, List[str]] = ()
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
# for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None
regex: Optional[str] = None
def clone(self):
return SamplingParams(
self.max_new_tokens,
self.stop,
self.temperature,
self.top_p,
self.top_k,
self.frequency_penalty,
self.presence_penalty,
)
def to_openai_kwargs(self):
# OpenAI does not support top_k, so we drop it here
return {
"max_tokens": self.max_new_tokens,
"stop": self.stop or None,
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
}
def to_anthropic_kwargs(self):
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
return {
"max_tokens_to_sample": self.max_new_tokens,
"stop_sequences": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
}
def to_srt_kwargs(self):
return {
"max_new_tokens": self.max_new_tokens,
"stop": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"regex": self.regex,
}
class SglFunction:
def __init__(self, func, bind_arguments=None):
self.func = func
self.bind_arguments = bind_arguments or {}
self.pin_prefix_rid = None
# Parse arguments
argspec = inspect.getfullargspec(func)
assert argspec.args[0] == "s", 'The first argument must be "s"'
self.arg_names = argspec.args[1:]
def bind(self, **kwargs):
assert all(key in self.arg_names for key in kwargs)
new_bind_dict = {**self.bind_arguments, **kwargs}
return SglFunction(self.func, bind_arguments=new_bind_dict)
def run(
self,
*args,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
stream: bool = False,
backend=None,
**kwargs,
):
from sglang.lang.interpreter import run_program
default_sampling_para = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
backend = backend or global_config.default_backend
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
def run_batch(
self,
batch_kwargs,
*,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
num_threads: Union[str, int] = "auto",
progress_bar: bool = False,
):
from sglang.lang.interpreter import run_program_batch
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
assert isinstance(batch_kwargs[0], dict)
default_sampling_para = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
backend = backend or global_config.default_backend
batch_kwargs = [
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
]
return run_program_batch(
self,
backend,
batch_kwargs,
default_sampling_para,
num_threads,
progress_bar,
)
def trace(self, *, backend=None, **kwargs):
from sglang.lang.tracer import trace_program
backend = backend or global_config.default_backend
return trace_program(self, kwargs, backend)
def pin(self, backend=None):
from sglang.lang.interpreter import pin_program
backend = backend or global_config.default_backend
return pin_program(self, backend)
def unpin(self, backend=None):
from sglang.lang.interpreter import unpin_program
backend = backend or global_config.default_backend
return unpin_program(self, backend)
def compile(self, *, backend=None):
from sglang.lang.compiler import compile_func
return compile_func(self, backend)
def __call__(self, *args, **kwargs):
from sglang.lang.tracer import TracingScope
tracing_scope = TracingScope.get_current_scope()
if tracing_scope is None:
return self.run(*args, **kwargs)
else:
kwargs["backend"] = tracing_scope.tracer_state.backend
return self.trace(*args, **kwargs)
class SglExpr:
node_ct = 0
def __init__(self):
self.node_id = SglExpr.node_ct
self.prev_node = None
self.pid = None
SglExpr.node_ct += 1
def __add__(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr)
return self.concatenate_ir(self, other)
def __radd__(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr), f"{other}"
return self.concatenate_ir(other, self)
def concatenate_ir(self, a, b):
if isinstance(a, SglExprList):
if isinstance(b, SglExprList):
return SglExprList(a.expr_list + b.expr_list)
else:
return SglExprList(a.expr_list + [b])
elif isinstance(b, SglExprList):
return SglExprList([a] + b.expr_list)
return SglExprList([a, b])
def print_graph_dfs(self):
ret = [""]
visited = set()
def dfs_print(x):
if x is None or x in visited:
return
visited.add(x)
# Print dependency
if x.prev_node is not None:
dfs_print(x.prev_node)
if isinstance(x, SglExprList):
for y in x.expr_list:
dfs_print(y)
# elif isinstance(x, SglRole):
# dfs_print(x.expr)
elif isinstance(x, SglVariable):
dfs_print(x.source)
# Print the node itself
if isinstance(x, (SglFork, SglGetForkItem)):
ret[0] += f"%{x.node_id} = {x}\n"
else:
if x.prev_node is not None:
ret[0] += (
f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
)
else:
ret[0] += f"%{x.node_id} = " + str(x) + "\n"
dfs_print(self)
return ret[0]
class SglExprList(SglExpr):
def __init__(self, expr_list: List[SglExpr]):
super().__init__()
self.expr_list = expr_list
def __repr__(self):
return f"ExprList({self.expr_list})"
class SglArgument(SglExpr):
def __init__(self, name: str, value: str):
super().__init__()
self.name = name
self.value = value
def __repr__(self):
return f"Argument(name={self.name}, value={repr(self.value)})"
def __len__(self):
return len(self.value)
def __getitem__(self, i):
return self.value[i]
def __int__(self):
return self.value
def __bool__(self):
return self.value
def __format__(self, *args):
raise TypeError(
"Cannot put argument inside a f-string. "
"This is not compatible with the tracer. "
)
class SglImage(SglExpr):
def __init__(self, path):
self.path = path
def __repr__(self) -> str:
return f"SglImage({self.path})"
class SglGen(SglExpr):
def __init__(
self,
name,
max_new_tokens,
stop,
temperature,
top_p,
top_k,
frequency_penalty,
presence_penalty,
dtype,
regex,
):
super().__init__()
self.name = name
self.sampling_params = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
dtype=dtype,
regex=regex,
)
def __repr__(self):
return f"Gen('{self.name}')"
class SglConstantText(SglExpr):
def __init__(self, value):
super().__init__()
self.value = value
def __repr__(self):
return f"Constant({repr(self.value)})"
class SglRoleBegin(SglExpr):
def __init__(self, role):
super().__init__()
self.role = role
def __repr__(self):
return f"RoleBegin({self.role})"
class SglRoleEnd(SglExpr):
def __init__(self, role):
super().__init__()
self.role = role
def __repr__(self):
return f"RoleEnd({self.role})"
class SglSelect(SglExpr):
def __init__(self, name, choices, temperature):
super().__init__()
self.name = name
self.choices = choices
self.temperature = temperature
def __repr__(self):
return f"Select({self.name}, choices={self.choices})"
class SglFork(SglExpr):
def __init__(self, number, position_ids_offset=None):
super().__init__()
self.number = number
self.position_ids_offset = position_ids_offset
def __repr__(self):
return (
f"Fork(%{self.prev_node.node_id}, number={self.number}, "
f"position_ids_offset={self.position_ids_offset})"
)
class SglGetForkItem(SglExpr):
def __init__(self, index):
super().__init__()
self.index = index
def __repr__(self):
return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"
class SglVariable(SglExpr):
def __init__(self, name, source):
super().__init__()
self.name = name
self.source = source
def __repr__(self):
return f"Variable('{self.name}', source=%{self.source.node_id})"
class SglVarScopeBegin(SglExpr):
def __init__(self, name):
super().__init__()
self.name = name
def __repr__(self):
return f"VarScopeBegin('{self.name}')"
class SglVarScopeEnd(SglExpr):
def __init__(self, name):
super().__init__()
self.name = name
def __repr__(self):
return f"VarScopeEnd('{self.name}')"
class SglConcateAndAppend(SglExpr):
def __init__(self, states):
super().__init__()
self.states = states
def __repr__(self):
return f"ConcatenateAndAppend('{self.states}')"
class SglCommitLazy(SglExpr):
def __init__(self):
super().__init__()
def __repr__(self):
return f"CommitLazy()"

View File

@@ -0,0 +1,279 @@
"""Tracing a program."""
import uuid
from typing import Any, Callable, Dict, List, Optional, Union
from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import (
SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
SglExpr,
SglExprList,
SglFork,
SglFunction,
SglGen,
SglGetForkItem,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
)
class StopTracing(Exception):
pass
def extract_prefix_by_tracing(program, backend):
# Create dummy arguments
dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
arguments = dummy_arguments
arguments.update(program.bind_arguments)
# Trace
tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
try:
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
except StopTracing:
pass
# Run and cache prefix
prefix = ""
for expr in tracer.flatten_nodes():
if isinstance(expr, SglConstantText):
prefix += expr.value
else:
break
return prefix
def trace_program(program, arguments, backend):
# Create dummy backend
if backend is None:
backend = BaseBackend()
# Create dummy arguments
dummy_arguments = {
name: SglArgument(name, None)
for name in program.arg_names
if name not in arguments
}
arguments.update(dummy_arguments)
arguments.update(program.bind_arguments)
# Trace
tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
return tracer
class TracerProgramState(ProgramState):
def __init__(self, backend, arguments, only_trace_prefix):
self.pid = uuid.uuid4().hex
self.backend = backend
self.arguments: Dict[str, Any] = arguments
self.only_trace_prefix = only_trace_prefix
if hasattr(backend, "endpoint"):
self.backend = backend.endpoint
self.nodes = []
self.last_node = None
self.variables = {}
self.ret_value = None
# For completion
# For chat
self.messages_ = []
self.cur_role = None
self.chat_template = self.backend.get_chat_template()
# For multi states
self.child_states = []
cur_scope = TracingScope.get_current_scope()
if cur_scope is not None:
cur_scope.add_child_state(self)
##################################
########### Public API ###########
##################################
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
if self.only_trace_prefix:
raise StopTracing()
fork_node = SglFork(number)
fork_node.prev_node = self.last_node
states = [
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
for _ in range(number)
]
for i in range(number):
node = SglGetForkItem(i)
node.prev_node = fork_node
states[i].last_node = node
states[i].variables = dict(self.variables)
states[i].messages_ = list(self.messages_)
states[i].cur_role = self.cur_role
states[i].chat_template = self.chat_template
state_group = ProgramStateGroup(states, self)
return state_group
##################################
########## Internal API ##########
##################################
def _append_node(self, other: SglExpr):
self.nodes.append(other)
other.prev_node = self.last_node
self.last_node = other
def _execute(self, other: SglExpr):
if isinstance(other, str):
other = SglConstantText(other)
other.pid = self.pid
if isinstance(other, SglConstantText):
self._execute_fill(other)
elif isinstance(other, SglGen):
self._execute_gen(other)
elif isinstance(other, SglSelect):
self._execute_select(other)
elif isinstance(other, SglExprList):
for x in other.expr_list:
self._execute(x)
elif isinstance(other, SglRoleBegin):
self._execute_role_begin(other)
elif isinstance(other, SglRoleEnd):
self._execute_role_end(other)
elif isinstance(other, SglVarScopeBegin):
self._execute_var_scope_begin(other)
elif isinstance(other, SglVarScopeEnd):
self._execute_var_scope_end(other)
else:
if self.only_trace_prefix:
raise StopTracing()
else:
self._append_node(other)
return self
def __iadd__(self, other):
self._execute(other)
return self
def _execute_fill(self, expr: SglConstantText):
if isinstance(expr, str):
expr = SglConstantText(expr)
self._append_node(expr)
def _execute_gen(self, expr: SglGen):
name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
new_node = SglVariable(name, source=expr)
self.variables[name] = new_node
self._append_node(expr)
def _execute_select(self, expr: SglSelect):
name = (
expr.name if expr.name is not None else "select_" + str(len(self.variables))
)
new_node = SglVariable(name, source=expr)
self.variables[name] = new_node
self._append_node(expr)
def _execute_role_begin(self, expr: SglRoleBegin):
assert self.cur_role is None, "Nested roles are not allowed."
if len(self.messages_) == 0 and expr.role != "system":
# Insert default system message
default_system = self.chat_template.default_system_prompt
if default_system:
self._execute_role_begin(SglRoleBegin("system"))
self._execute_fill(default_system)
self._execute_role_end(SglRoleEnd("system"))
self.cur_role = expr.role
prefix, suffix = self.chat_template.get_prefix_and_suffix(
expr.role, self.messages_
)
self._execute_fill(prefix)
def _execute_role_end(self, expr: SglRoleEnd):
prefix, suffix = self.chat_template.get_prefix_and_suffix(
expr.role, self.messages_
)
self._execute_fill(suffix)
self.messages_.append({"role": expr.role, "content": ""})
self.cur_role = None
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
new_node = SglVariable(name, source=self.last_node)
self.variables[name] = new_node
def get_var(self, name):
ret = self.arguments.get(name, None)
if ret is not None:
return ret
v = self.variables[name]
return SglVariable(v.name, v.source)
def flatten_nodes(self):
def traverse(cur):
if isinstance(cur, SglExprList):
for child in cur.expr_list:
traverse(child)
else:
ret.append(cur)
ret = []
for x in self.nodes:
traverse(x)
return ret
def __del__(self):
pass
class TracingScope:
cur_scope = None
def __init__(self, tracer_state: TracerProgramState):
self.tracer_state = tracer_state
self.last_scope = TracingScope.cur_scope
def __enter__(self):
TracingScope.cur_scope = self
return self
def __exit__(self, exc_type, exc_value, traceback):
TracingScope.cur_scope = self.last_scope
@staticmethod
def get_current_scope():
return TracingScope.cur_scope
def add_child_state(self, state: TracerProgramState):
cur_scope = self
while cur_scope != None:
cur_scope.tracer_state.child_states.append(state)
cur_scope = cur_scope.last_scope

View File

@@ -0,0 +1,11 @@
import argparse
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
launch_server(server_args, None)

View File

@@ -0,0 +1,385 @@
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py
from typing import List, NewType, Protocol
import interegular
from lark import Lark
# from outlines.fsm.parsing import PartialLark
from sglang.srt.constrained.regex import (
create_fsm_index_tokenizer,
make_deterministic_fsm,
)
from sglang.srt.constrained.tokenizer import Tokenizer
FSMState = NewType("FSMState", int)
class FSM(Protocol):
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
...
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
...
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
...
def reset(self) -> None:
...
class StopAtTokenFSM(FSM):
"""FSM to generate text until a specified token id is generated or
a specified number of tokens has been generated.
Text is usually produced until the EOS token is generated by the
model.
"""
def __init__(
self,
tokenizer: "Tokenizer",
stop_token_id: int,
):
self.stop_token_id = stop_token_id
self.num_tokens_generated = 0
self.vocabulary = tokenizer.vocabulary.values()
self.final_states = {1}
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
"""Generate a list of allowed tokens for the next step.
When in the initial state we allow every token to be generated.
In the final state the only allowed token is `stop_token_id`.
Parameters
----------
state
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
-------
A list that contains the tokens to mask.
"""
if state == 0:
return list(self.vocabulary)
else:
return [self.stop_token_id]
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
"""Update the state of the FSM.
The FSM stays in the initial state `0` unless the specified stop token
has been generated or the maximum number of tokens has been reached. In
which case the FSM moves to the final state `1`.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
-------
The new state of the FSM.
"""
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.stop_token_id:
return FSMState(1)
return FSMState(0)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state in self.final_states
def reset(self) -> None:
"""Reset the FSM to its initial state. Here this only resets the token counter."""
self.num_tokens_generated = 0
class RegexFSM(FSM):
"""FSM to generate text that is in the language of a regular expression."""
def __init__(
self,
regex_string: str,
tokenizer: "Tokenizer",
):
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
(
self.states_to_token_maps,
self.empty_token_ids,
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values())
for v in self.states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)
self.final_states = regex_fsm.finals | {
-1
} # Include the EOS token in final states
self.num_tokens_generated = 0
self.vocabulary = tokenizer.vocabulary.values()
self.end_token_id = tokenizer.eos_token_id
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
"""Generate a list of allowed tokens for the next step.
The initialization of the FSM builds an index which maps FSM states to a
map from authorized tokens to the state in which the FSM needs to move
if said token is generated. Therefore the authorized tokens at the
current state are the keys of the map returned by the value of the index
for current state.
If the current state is not contained in the end this means that we are
in a final state of the FSM. We only authorize EOS tokens in the final
state.
Parameters
----------
state
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
-------
A list that contains the tokens to mask.
"""
next_tokens_to_end_states = self.states_to_token_maps.get(state)
if next_tokens_to_end_states is None:
return [self.end_token_id]
else:
return list(next_tokens_to_end_states.keys())
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
"""Update the state of the FSM.
We use the index to determine to which state the FSM should transition
given the token that was just generated.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
-------
The new state of the FSM.
"""
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.end_token_id:
return FSMState(-1)
last_token_to_end_state = self.states_to_token_maps[state]
next_state = last_token_to_end_state.get(token_id)
if next_state is None:
next_state = -1
return FSMState(next_state)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state in self.final_states
def reset(self) -> None:
"""Reset the FSM to its initial state. Here this only resets the token counter."""
self.num_tokens_generated = 0
class CFGFSM(FSM):
"""FSM to generate text that is in the language of a context-free grammar."""
def __init__(
self,
cfg_string: str,
tokenizer: "Tokenizer",
):
# self.parser = PartialLark(cfg_string, parser="lalr")
self.parser = Lark(
cfg_string,
parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
)
self.terminal_regexps = dict()
for terminal in self.parser.terminals:
if terminal.pattern is not None:
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
self.terminal_regexps["$END"] = tokenizer.eos_token
self.tokenizer = tokenizer
self.num_tokens_generated = 0
self.generations: List[str] = []
self.regex_fsms: List[RegexFSM] = []
self.reset_state: List[bool] = []
self.allow_eos: List[bool] = []
self.done: List[bool] = []
def _set_next_regex_fsm(self, idx: int = 0) -> None:
"""Use the CFG incremental parser to set the next regex FSM.
Check what the CFG incremental parser proposes next.
If the only proposal is the EOS token,
we set the state to done and return.
If there are other proposals,
we set a new regex FSM and return.
"""
interactive = self.parser.parse_interactive(self.generations[idx])
interactive.exhaust_lexer()
options = {self.terminal_regexps[x] for x in interactive.accepts()}
if self.terminal_regexps["$END"] in options:
options.remove(self.terminal_regexps["$END"])
if len(options) == 0:
self.done[idx] = True
return
self.allow_eos[idx] = True
options.add("")
assert len(options) > 1
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
args = (
regex_string,
self.tokenizer,
)
if len(self.regex_fsms) <= idx:
self.regex_fsms.append(RegexFSM(*args))
else:
self.regex_fsms[idx] = RegexFSM(*args)
self.reset_state[idx] = True
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Upon initialization, the CFG incremental parser is used to determine the first regex.
This regex is used for proposals until either:
- the regex is exhausted, and its only remaining option is the EOS token,
in which case we always transition to the next regex
- the regex can be exhausted, but the EOS token is not the only remaining option,
in which case we transition to the next regex with probability P (TODO)
or remove the possibility of generating the EOS token and continue with the current regex
The CFG incremental parser is allowed to propose the EOS token from any final state,
and once it is generated, the FSM will continue to always generate the EOS token.
Parameters
----------
state
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
-------
A list that contains the tokens to mask.
"""
if len(self.generations) <= idx:
self.generations.append("")
self.reset_state.append(False)
self.allow_eos.append(False)
self.done.append(False)
if len(self.regex_fsms) > idx:
proposal = self.regex_fsms[idx].allowed_token_ids(state)
if self.tokenizer.eos_token_id not in proposal:
return proposal
if set(proposal) != {self.tokenizer.eos_token_id}:
if False: # TODO: THIS NEEDS TO BE SAMPLED
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
return proposal
self._set_next_regex_fsm(idx)
if self.done[idx]:
return [self.tokenizer.eos_token_id]
if self.reset_state[idx]:
state = FSMState(0)
proposal = self.regex_fsms[idx].allowed_token_ids(state)
if self.allow_eos[idx]:
self.allow_eos[idx] = False
else:
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
assert len(proposal) > 0
return proposal
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
"""Update the state of the FSM.
Transitions the underlying regex FSM to its next state.
If at max tokens or EOS token, transition permanently to the final state.
Update stored partial generations for subsequent incremental parsing.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
-------
The new state of the FSM.
"""
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.tokenizer.eos_token_id:
self.done[idx] = True
return FSMState(-1)
if self.reset_state[idx]:
self.reset_state[idx] = False
state = FSMState(0)
self.generations[idx] += self.tokenizer.decode([token_id])[0]
return self.regex_fsms[idx].next_state(state, token_id, idx)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
"""Return whether the current state of the FSM is a final state."""
return self.done[idx]
def reset(self) -> None:
"""Reset the FSM to its initial state, so it can be called on a fresh batch on inputs."""
self.num_tokens_generated = 0
self.generations = []
self.regex_fsms = []
self.reset_state = []
self.done = []

View File

@@ -0,0 +1,41 @@
import threading
from sglang.srt.constrained.fsm import RegexFSM
from sglang.srt.constrained.tokenizer import TransformerTokenizer
def get_fsm(regex, tokenizer, fsm_cache_entry):
outlines_tokenizer = TransformerTokenizer(tokenizer)
fsm = RegexFSM(regex, outlines_tokenizer)
fsm_cache_entry.fsm = fsm
fsm_cache_entry.event.set()
class FSMCacheEntry:
def __init__(self):
self.fsm = None
self.event = threading.Event()
class FSMCache:
def __init__(self, tokenizer):
self.cache = {}
self.tokenizer = tokenizer
def init_fsm_in_background(self, regex):
if regex not in self.cache:
self.cache[regex] = FSMCacheEntry()
threading.Thread(
target=get_fsm,
args=(
regex,
self.tokenizer,
self.cache[regex],
),
).start()
def get_fsm(self, regex):
self.init_fsm_in_background(regex)
entry = self.cache[regex]
entry.event.wait()
return entry.fsm

View File

@@ -0,0 +1,586 @@
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/regex.py
from collections import namedtuple
from functools import lru_cache
from typing import Dict, Generator, List, Sequence, Set, Tuple
import numba
import numpy as np
from interegular.fsm import FSM, Alphabet, OblivionError, anything_else
from numba.typed.typedobjectutils import _nonoptional
from sglang.srt.constrained.tokenizer import Tokenizer
class BetterAlphabet(Alphabet):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert anything_else in self._symbol_mapping
self.anything_value = self._symbol_mapping[anything_else]
def __getitem__(self, item):
return self._symbol_mapping.get(item, self.anything_value)
def copy(self):
return BetterAlphabet(self._symbol_mapping.copy())
class BetterFSM(FSM):
flat_transition_map: Dict[Tuple[int, int], int]
trans_key_to_states: Dict[int, List[int]]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not isinstance(self.alphabet, BetterAlphabet):
self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping)
flat_transition_map = {}
trans_key_to_states = {}
for from_state, trans_map in self.map.items():
for trans_key, to_state in trans_map.items():
flat_transition_map[(from_state, trans_key)] = to_state
trans_key_to_states.setdefault(trans_key, set()).add(from_state)
self.__dict__["trans_key_to_states"] = trans_key_to_states
self.__dict__["flat_transition_map"] = flat_transition_map
self.__dict__["_fsm_info"] = None
def copy(self):
return BetterFSM(
alphabet=self.alphabet.copy(),
states=self.states.copy(),
initial=self.initial,
finals=self.finals.copy(),
map=self.map.copy(),
__no_validation__=True,
)
@property
def fsm_info(self):
if self._fsm_info is None:
flat_transition_map_items = np.fromiter(
((a[0], a[1], b) for a, b in self.flat_transition_map.items()),
dtype=np.dtype("i8, i8, i8"),
)
trans_key_to_states_items = np.fromiter(
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
dtype=np.dtype("i8, i8"),
)
alphabet_symbol_mapping_items = np.fromiter(
(
it
for it in self.alphabet._symbol_mapping.items()
if it[0] != anything_else
),
dtype=np.dtype("U1, i8"),
)
nb_finals = np.fromiter(self.finals, dtype=np.dtype("i8"))
self.__dict__["_fsm_info"] = create_fsm_info(
self.initial,
nb_finals,
flat_transition_map_items,
trans_key_to_states_items,
self.alphabet.anything_value,
alphabet_symbol_mapping_items,
)
return self._fsm_info
nb_int_list_type = numba.types.ListType(numba.int64)
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
nb_unichar_1_type = numba.types.UnicodeCharSeq(1)
@numba.njit(cache=True)
def create_fsm_info(
py_initial,
py_finals,
flat_transition_map_items,
trans_key_to_states_items,
py_anything_value,
alphabet_symbol_mapping_items,
):
trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type)
for trans_key_and_state in trans_key_to_states_items:
trans_key_to_states.setdefault(
trans_key_and_state[0], numba.typed.List.empty_list(numba.int64)
).append(trans_key_and_state[1])
flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64)
for trans_key_and_state in flat_transition_map_items:
flat_transition_map[
(trans_key_and_state[0], trans_key_and_state[1])
] = trans_key_and_state[2]
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_1_type, numba.int64)
for symbol_and_trans_key in alphabet_symbol_mapping_items:
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]
initial = numba.int64(py_initial)
finals = set()
for final in py_finals:
finals.add(final)
anything_value = numba.int64(py_anything_value)
return FSMInfo(
initial,
finals,
flat_transition_map,
trans_key_to_states,
anything_value,
alphabet_symbol_map,
)
FSMInfo = namedtuple(
"FSMInfo",
[
"initial",
"finals",
"transitions",
"trans_key_to_states",
"alphabet_anything_value",
"alphabet_symbol_mapping",
],
)
def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
"""Construct an equivalent FSM with deterministic state labels."""
old_to_new_trans_keys = {
trans_key: i
for i, (trans_key, _) in enumerate(
sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1]))
)
}
new_symbol_mapping = {
symbol: old_to_new_trans_keys[trans_key]
for symbol, trans_key in fsm.alphabet._symbol_mapping.items()
}
new_alphabet = BetterAlphabet(new_symbol_mapping)
new_map = {
from_state: {
old_to_new_trans_keys[trans_key]: to_state
for trans_key, to_state in trans_map.items()
}
for from_state, trans_map in fsm.map.items()
}
old_to_new_states = {}
old_to_new_states[fsm.initial] = 0
i = 0
seen = {fsm.initial}
old_state_queue = [fsm.initial]
while old_state_queue:
old_state = old_state_queue.pop(-1)
transitions = new_map[old_state]
sorted_transitions = sorted(transitions.items(), key=lambda v: v[0])
for _, old_state in sorted_transitions:
if old_state not in seen:
old_state_queue.append(old_state)
seen.add(old_state)
if old_state not in old_to_new_states:
i += 1
old_to_new_states[old_state] = i
new_map = dict(
sorted(
(
(
old_to_new_states[from_state],
dict(
sorted(
(
(trans_key, old_to_new_states[to_state])
for trans_key, to_state in trans_map.items()
),
key=lambda v: v[0],
)
),
)
for from_state, trans_map in new_map.items()
),
key=lambda v: v[0],
)
)
new_initial = 0
new_finals = frozenset(
sorted(old_to_new_states[old_state] for old_state in fsm.finals)
)
new_states = frozenset(sorted(new_map.keys()))
new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map)
return new_fsm, old_to_new_states
@numba.njit(nogil=True, cache=True)
def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
input_string: str,
start_state: int,
full_match: bool = True,
) -> List[int]:
state = start_state
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
last_final_idx: int = numba.uint64(0)
for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
new_state = fsm_transitions.get((state, trans_key))
if new_state is None:
if not full_match and last_final_idx > 0:
return accepted_states[:last_final_idx]
return numba.typed.List.empty_list(numba.int64)
state = new_state
if state in fsm_finals:
last_final_idx = numba.uint64(i + 1)
accepted_states.append(_nonoptional(state))
if full_match and last_final_idx - 1 != i:
return numba.typed.List.empty_list(numba.int64)
return accepted_states
def walk_fsm(
fsm: BetterFSM,
input_string: str,
start_state: int,
full_match: bool = True,
) -> List[int]:
fsm_finals = fsm.finals
state = start_state
accepted_states: List[int] = []
last_final_idx: int = 0
alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
alphabet_anything_value = fsm.alphabet.anything_value
fsm_transitions = fsm.flat_transition_map
for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
new_state = fsm_transitions.get((state, trans_key))
if new_state is None:
if not full_match and last_final_idx > 0:
return accepted_states[:last_final_idx]
return []
state = new_state
if state in fsm_finals:
last_final_idx = i + 1
accepted_states.append(state)
if full_match and last_final_idx - 1 != i:
return []
return accepted_states
def fsm_union(
fsms: Sequence[FSM],
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
"""Construct an FSM representing the union of the FSMs in `fsms`.
This is an updated version of `interegular.fsm.FSM.union` made to return an
extra map of component FSMs to the sets of state transitions that
correspond to them in the new FSM.
"""
alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])
indexed_fsms = tuple(enumerate(fsms))
initial = {i: fsm.initial for (i, fsm) in indexed_fsms}
# Dedicated function accepting a "superset" and returning the next
# "superset" obtained by following this transition in the new FSM
def follow(current_state, new_transition: int):
next = {}
for i, f in indexed_fsms:
old_transition = new_to_old[i][new_transition]
if (
i in current_state
and current_state[i] in f.map
and old_transition in f.map[current_state[i]]
):
next[i] = f.map[current_state[i]][old_transition]
if not next:
raise OblivionError
return next
states = [initial]
finals: Set[int] = set()
map: Dict[int, Dict[int, int]] = {}
# Map component FSMs to their new state-to-state transitions, finals, and a
# map translating component FSM states to aggregate FSM states
fsms_to_trans_finals: Dict[
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
] = {}
i = 0
while i < len(states):
state = states[i]
# Add to the finals of the aggregate FSM whenever we hit a final in a
# component FSM
if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
finals.add(i)
# Compute the map for this state
map[i] = {}
for transition in alphabet.by_transition:
try:
next = follow(state, transition)
except OblivionError:
# Reached an oblivion state; don't list it
continue
else:
try:
# TODO: Seems like this could--and should--be avoided
j = states.index(next)
except ValueError:
j = len(states)
states.append(next)
map[i][transition] = j
for fsm_id, fsm_state in next.items():
(
fsm_transitions,
fsm_finals,
fsm_old_to_new,
) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
old_from = state[fsm_id]
old_to = fsm_state
fsm_old_to_new.setdefault(old_from, set()).add(i)
fsm_old_to_new.setdefault(old_to, set()).add(j)
fsm_transitions.add((i, j))
if fsm_state in fsms[fsm_id].finals:
fsm_finals.add(j)
i += 1
fsm = FSM(
alphabet=alphabet,
states=range(len(states)),
initial=0,
finals=finals,
map=map,
__no_validation__=True,
)
fsm, old_to_new_states = make_deterministic_fsm(fsm)
_fsms_to_trans_finals = {
fsm_id: (
{(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
{old_to_new_states[s] for s in finals},
{
old_state: {old_to_new_states[new_state] for new_state in new_states}
for old_state, new_states in old_to_new.items()
},
)
for fsm_id, (transitions, finals, old_to_new) in sorted(
fsms_to_trans_finals.items(), key=lambda x: x[0]
)
}
return (
fsm,
_fsms_to_trans_finals,
)
def get_sub_fsms_from_seq(
state_seq: Sequence[int],
fsms_to_trans_finals: Dict[
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
],
) -> Generator[Tuple[int, bool, bool], None, None]:
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
Parameters
----------
state_seq
A state sequence.
fsms_to_trans_finals
A map from FSM indices to tuples containing sets of their state transitions
and sets of the final/accept states.
Returns
-------
A generator returning tuples containing each sub-FSM index (in the order
they were union-ed to construct `fsm`) and booleans indicating whether or
not there is another valid transition from the last state in the sequence
for the associated sub-FSM (i.e. if the FSM can continue
accepting/matching) and whether or not the sequence ends in a final state
of the sub-FSM.
"""
state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
last_fsm_state = state_seq[-1]
yield from (
(
# The sub-FMS index
fsm_idx,
# Is there another possible transition in this sub-FSM?
any(last_fsm_state == from_s for (from_s, to_s) in transitions),
# Is this sub-FSM in a final state?
state_seq[-1] in finals,
)
for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
if state_seq_transitions.issubset(transitions)
)
@numba.njit(cache=True, nogil=True)
def state_scan_tokens(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: Dict[str, List[int]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()
for token, token_ids in vocabulary.items():
state_seq = _walk_fsm(
fsm_transitions,
alphabet_symbol_mapping,
alphabet_anything_value,
fsm_initial,
fsm_finals,
token,
start_state,
False,
)
if state_seq is not None and len(state_seq) < len(token):
continue
for token_id in token_ids:
res.add((token_id, state_seq[-1]))
return res
def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: Dict[str, List[int]],
) -> Dict[int, Set[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""
# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
# code, too.
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {}
seen: Set[int] = set()
next_states = {fsm_info.initial}
while next_states:
start_state = next_states.pop()
token_ids_end_states = state_scan_tokens(
fsm_info.transitions,
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
fsm_info.initial,
fsm_info.finals,
vocabulary,
start_state,
)
for token_id_and_end_state in token_ids_end_states:
states_to_token_subsets.setdefault(start_state, set()).add(
token_id_and_end_state
)
end_state = token_id_and_end_state[1]
if end_state not in seen:
next_states.add(end_state)
seen.add(start_state)
return states_to_token_subsets
# TODO: Cannot cache typed collections to disk, yet. See
# https://github.com/numba/numba/issues/4698
@lru_cache
def reduced_vocabulary(tokenizer: "Tokenizer"):
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
vocabulary = numba.typed.Dict.empty(
numba.types.string, numba.types.ListType(numba.int64)
)
empty_token_ids = set()
for token, token_idx in tokenizer.vocabulary.items():
if token in tokenizer.special_tokens:
continue
token_str = tokenizer.convert_token_to_string(token)
if token_str:
vocabulary.setdefault(
token_str,
numba.typed.List.empty_list(numba.int64),
).append(numba.int64(token_idx))
else:
empty_token_ids.add(numba.int64(token_idx))
return vocabulary, empty_token_ids
def create_fsm_index_tokenizer(
fsm: BetterFSM,
tokenizer: "Tokenizer",
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
"""Construct an FMS index from a tokenizer.
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
.. warning::
`fsm` needs to be deterministically ordered so that future caching makes sense.
"""
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)
states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)
# Allow transitions to EOS from all terminals FSM states that are
# reachable
# TODO: Do we really need this anymore?
for state in fsm.fsm_info.finals:
subset = states_to_token_subsets.get(state)
if subset is not None:
subset.add((tokenizer.eos_token_id, state))
# Convert to token-to-end-state maps
states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()}
return states_to_token_subsets, empty_token_ids

View File

@@ -0,0 +1,266 @@
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
from abc import abstractmethod
from typing import (
TYPE_CHECKING,
Dict,
Hashable,
List,
Optional,
Protocol,
Set,
Tuple,
Union,
)
import numpy as np
import torch
from numpy.typing import NDArray
class Tokenizer(Protocol, Hashable):
eos_token: str
eos_token_id: int
pad_token_id: int
vocabulary: Dict[str, int]
special_tokens: Set[int]
@abstractmethod
def encode(
self, prompt: Union[str, List[str]]
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
"""Translate the input prompts into NumPy arrays of token ids and attention mask."""
...
@abstractmethod
def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
"""Translate an array of token ids to a string or list of strings."""
...
@abstractmethod
def convert_token_to_string(self, token: str) -> str:
"""Convert a token to its equivalent string.
This is for instance useful for BPE tokenizers where whitespaces are
represented by the special characted `Ġ`. This prevents matching a raw
token that includes `Ġ` with a string.
"""
...
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
__all__ = ["transformers"]
KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
def get_llama_tokenizer_types():
"""Get all the Llama tokenizer types/classes that need work-arounds.
When they can't be imported, a dummy class is created.
"""
try:
from transformers.models.llama import LlamaTokenizer
except ImportError:
class LlamaTokenizer: # type: ignore
pass
try:
from transformers.models.llama import LlamaTokenizerFast
except ImportError:
class LlamaTokenizerFast: # type: ignore
pass
try:
from transformers.models.code_llama import CodeLlamaTokenizer
except ImportError:
class CodeLlamaTokenizer: # type: ignore
pass
try:
from transformers.models.code_llama import CodeLlamaTokenizerFast
except ImportError:
class CodeLlamaTokenizerFast: # type: ignore
pass
return (
LlamaTokenizer,
LlamaTokenizerFast,
CodeLlamaTokenizer,
CodeLlamaTokenizerFast,
)
class Transformer:
"""Represents a `transformers` model."""
def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
):
self.device = model.device
self.model = model
self.tokenizer = tokenizer
@torch.inference_mode
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
"""Compute a forward pass through the transformer model.
Parameters
----------
input_ids
The input token ids. Must be one or two dimensional.
attention_mask
The attention mask. Must be one or two dimensional.
past_key_values
A tuple of tuples containing the cached key and value tensors for each
attention head.
Returns
-------
The computed logits and the new cached key and value tensors.
"""
assert 0 < input_ids.ndim < 3
if past_key_values:
input_ids = input_ids[..., -1].unsqueeze(-1)
output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
past_key_values=past_key_values,
)
return output.logits, output.past_key_values
def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
) -> torch.FloatTensor:
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
next_token_logits = logits[..., -1, :]
return next_token_logits, kv_cache
class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""
def __init__(self, tokenizer):
# TODO: Do something to make this hashable?
self.tokenizer = tokenizer
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token
if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.pad_token_id = self.eos_token_id
else:
self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self.tokenizer.pad_token
self.special_tokens = set(self.tokenizer.all_special_tokens)
self.vocabulary = self.tokenizer.get_vocab()
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())
def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[torch.LongTensor, torch.LongTensor]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]
def decode(self, token_ids: torch.LongTensor) -> List[str]:
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return text
def convert_token_to_string(self, token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = self.tokenizer.convert_tokens_to_string([token])
if self.is_llama:
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
def __eq__(self, other):
if isinstance(other, type(self)):
return False
# TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ?
# return other.model_name == self.model_name and other.kwargs == self.kwargs
return NotImplemented
def __hash__(self):
from datasets.fingerprint import Hasher
return hash(Hasher.hash(self.tokenizer))
def transformers(
model_name: str,
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Parameters
----------
model_name
The name of the model as listed on Hugging Face's model page.
device
The device(s) on which the model should be loaded. This overrides
the `device_map` entry in `model_kwargs` when provided.
model_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the model.
tokenizer_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the tokenizer.
Returns
-------
A `TransformersModel` model instance.
"""
try:
from transformers import AutoModelForCausalLM
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if device is not None:
model_kwargs["device_map"] = device
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs)
return Transformer(model, tokenizer)

View File

@@ -0,0 +1,164 @@
"""Utilities for Huggingface Transformers."""
import json
import os
import warnings
from typing import List, Optional, Tuple, Union
from huggingface_hub import snapshot_download
from sglang.srt.utils import is_multimodal_model
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
def download_from_hf(model_path: str):
if os.path.exists(model_path):
return model_path
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
def get_config_json(model_path: str):
with open(os.path.join(model_path, "config.json")) as f:
config = json.load(f)
return config
def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None):
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision
)
return config
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# have a preference for which value gets used.
CONTEXT_LENGTH_KEYS = [
"max_sequence_length",
"seq_length",
"max_position_embeddings",
"max_seq_len",
"model_max_length",
]
def get_context_length(config):
"""Get the context length of a model from a huggingface model config."""
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling:
rope_scaling_factor = config.rope_scaling["factor"]
else:
rope_scaling_factor = 1
for key in CONTEXT_LENGTH_KEYS:
val = getattr(config, key, None)
if val is not None:
return int(rope_scaling_factor * val)
return 2048
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def get_tokenizer(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
if is_multimodal_model(tokenizer_name):
processor = get_processor(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
**kwargs,
)
tokenizer = processor.tokenizer
return tokenizer
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if (
"llama" in tokenizer_name.lower()
and kwargs.get("use_fast", True)
and tokenizer_name != _FAST_LLAMA_TOKENIZER
):
pass
# warnings.warn(
# "For some LLaMA V1 models, initializing the fast tokenizer may "
# "take a long time. To reduce the initialization time, consider "
# f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
# "tokenizer."
# )
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
**kwargs,
)
except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments.
err_msg = (
"Failed to load the tokenizer. If you are using a LLaMA V1 model "
f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
"original tokenizer."
)
raise RuntimeError(err_msg) from e
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if not trust_remote_code and (
"does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e)
):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI."
)
raise RuntimeError(err_msg) from e
else:
raise e
if not isinstance(tokenizer, PreTrainedTokenizerFast):
warnings.warn(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead."
)
return tokenizer
def get_processor(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
**kwargs,
):
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
**kwargs,
)
return processor

View File

@@ -0,0 +1,181 @@
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import torch
import triton
import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
B_Start_Loc,
B_Seqlen,
Out,
stride_qbs,
stride_qh,
stride_kbs,
stride_kh,
stride_vbs,
stride_vh,
stride_obs,
stride_oh,
kv_group_num: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :]
)
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :]
)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
cached_kernel = None
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK = 128
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
o.stride(0),
o.stride(1),
)
return
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
o.stride(0),
o.stride(1),
kv_group_num=kv_group_num,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
cached_kernel = wrap_kernel_launcher(_fwd_kernel)

Some files were not shown because too many files have changed in this diff Show More