release initial code
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com> Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
21
.gitignore
vendored
21
.gitignore
vendored
@@ -157,4 +157,23 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
# MacOS
|
||||
.DS_Store
|
||||
*.json
|
||||
|
||||
# Vim
|
||||
*.swp
|
||||
|
||||
# SGL
|
||||
benchmark/mmlu/data
|
||||
benchmark/mmlu/data.tar
|
||||
benchmark/llava_bench/images
|
||||
benchmark/llava_bench/mme_pack
|
||||
*.jsonl
|
||||
tmp*.txt
|
||||
|
||||
# Plots
|
||||
*.png
|
||||
*.pdf
|
||||
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "3rdparty/flashinfer"]
|
||||
path = 3rdparty/flashinfer
|
||||
url = git@github.com:flashinfer-ai/flashinfer.git
|
||||
168
README.md
168
README.md
@@ -1 +1,167 @@
|
||||
# sglang
|
||||
# SGLang
|
||||
|
||||
SGLang is a structured generation language designed for large language models (LLMs).
|
||||
It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.
|
||||
|
||||
The core features of SGLang include:
|
||||
- **A Flexible Front-End Language**: This allows for easy programming of LLM applications with multiple chained generation calls, advanced prompting techniques, control flow, multiple modalities, parallelism, and external interaction.
|
||||
- **A High-Performance Runtime with RadixAttention**: This feature significantly accelerates the execution of complex LLM programs by automatic KV cache reuse across multiple calls. It also supports other common techniques like continuous batching and tensor parallelism.
|
||||
|
||||
## Contents
|
||||
- [Install](#install)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Frontend: Structured Generation Langauge (SGLang)](#frontend-structured-generation-langauge-sglang)
|
||||
- [Backend: SGLang Runtime (SRT)](#backend-sglang-runtime-srt)
|
||||
- [Benchmark And Performance](#benchmark-and-performance)
|
||||
- [Roadmap](#roadmap)
|
||||
- [Citation And Acknowledgment](#citation-and-acknowledgment)
|
||||
|
||||
## Install
|
||||
|
||||
### Method 1: With Pip
|
||||
|
||||
### Method 2: From Source
|
||||
```
|
||||
git clone git@github.com:sgl-project/sglang.git
|
||||
cd sglang
|
||||
|
||||
pip install --upgrade pip
|
||||
pip install -e "python[all]"
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
The example below shows how to use sglang to answer a mulit-turn question.
|
||||
|
||||
### Using OpenAI Models
|
||||
```python
|
||||
from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI
|
||||
|
||||
@function
|
||||
def multi_turn_question(s, question_1, question_2):
|
||||
s += system("You are a helpful assistant.")
|
||||
s += user(question_1)
|
||||
s += assistant(gen("answer_1", max_tokens=256))
|
||||
s += user(question_2)
|
||||
s += assistant(gen("answer_2", max_tokens=256))
|
||||
|
||||
set_default_backend(OpenAI("gpt-3.5-turbo"))
|
||||
|
||||
state = multi_turn_question.run(
|
||||
question_1="What is the capital of the United States?",
|
||||
question_2="List two local attractions.",
|
||||
)
|
||||
|
||||
for m in state.messages():
|
||||
print(m["role"], ":", m["content"])
|
||||
```
|
||||
|
||||
### Using Local Models
|
||||
First, launch a server with
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
Then, connect to the server and answer a multi-turn question.
|
||||
|
||||
```python
|
||||
from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint
|
||||
|
||||
@function
|
||||
def multi_turn_question(s, question_1, question_2):
|
||||
s += system("You are a helpful assistant.")
|
||||
s += user(question_1)
|
||||
s += assistant(gen("answer_1", max_tokens=256))
|
||||
s += user(question_2)
|
||||
s += assistant(gen("answer_2", max_tokens=256))
|
||||
|
||||
set_default_backend(RuntimeEndpoint("http://localhost:30000"))
|
||||
|
||||
state = multi_turn_question.run(
|
||||
question_1="What is the capital of the United States?",
|
||||
question_2="List two local attractions.",
|
||||
)
|
||||
|
||||
for m in state.messages():
|
||||
print(m["role"], ":", m["content"])
|
||||
```
|
||||
|
||||
### More Examples
|
||||
|
||||
You can find more examples at [examples/quick_start](examples/quick_start).
|
||||
|
||||
## Frontend: Structured Generation Langauge (SGLang)
|
||||
|
||||
### Control Flow
|
||||
|
||||
### Parallelism
|
||||
|
||||
### Multi Modality
|
||||
```python
|
||||
@sgl.function
|
||||
def multi_turn_question(s, image_file, question):
|
||||
s += sgl.user(sgl.image(image_file) + question)
|
||||
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
|
||||
```
|
||||
|
||||
### Batching
|
||||
|
||||
### Streaming
|
||||
|
||||
### Other Backends
|
||||
|
||||
## Backend: SGLang Runtime (SRT)
|
||||
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
|
||||
However, it can also be used as a standalone API server.
|
||||
In this case, the RadixAttention can still greatly accelerate many use cases.
|
||||
|
||||
### Usage
|
||||
Launch a server
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
Send a request
|
||||
```
|
||||
curl http://localhost:30000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"prompt": "Say this is a test",
|
||||
"max_tokens": 16,
|
||||
"temperature": 0
|
||||
}'
|
||||
```
|
||||
|
||||
### Additional Arguments
|
||||
- Add `--tp 2` to enable tensor parallelism.
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
|
||||
```
|
||||
|
||||
### Supported Models
|
||||
- Llama
|
||||
- Mistral
|
||||
- Mixtral
|
||||
- LLaVA
|
||||
|
||||
## Benchmark And Performance
|
||||
|
||||
## Roadmap
|
||||
- [ ] Function call
|
||||
- [ ] Constrained decoding
|
||||
- [ ] Quantization
|
||||
- [ ] S-LoRA
|
||||
- [ ] More models
|
||||
|
||||
## Citation And Acknowledgment
|
||||
```
|
||||
@misc{zheng2023efficiently,
|
||||
title={Efficiently Programming Large Language Models using SGLang},
|
||||
author={Lianmin Zheng and Liangsheng Yin and Zhiqiang Xie and Jeff Huang and Chuyue Sun and Cody Hao Yu and Shiyi Cao and Christos Kozyrakis and Ion Stoica and Joseph E. Gonzalez and Clark Barrett and Ying Sheng},
|
||||
year={2023},
|
||||
eprint={2312.07104},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.AI}
|
||||
}
|
||||
```
|
||||
|
||||
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [LMQL](https://github.com/eth-sri/lmql).
|
||||
|
||||
45
benchmark/dspy/README.md
Normal file
45
benchmark/dspy/README.md
Normal file
@@ -0,0 +1,45 @@
|
||||
## Install
|
||||
|
||||
```
|
||||
pip3 install dspy-ai
|
||||
```
|
||||
|
||||
Turn off cache at https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/dsp/modules/cache_utils.py#L10.
|
||||
```
|
||||
cache_turn_on = False
|
||||
```
|
||||
|
||||
## Benchmark SGLang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_dspy_intro.py --backend sglang
|
||||
```
|
||||
|
||||
|
||||
## Benchmark TGI
|
||||
```
|
||||
docker run --name tgi --rm -ti --gpus all --network host \
|
||||
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
|
||||
ghcr.io/huggingface/text-generation-inference:1.1.0 \
|
||||
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
|
||||
--max-input-length 2048 --max-total-tokens 4096 \
|
||||
--port 24000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_dspy_intro.py --backend tgi
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Benchmark vLLM
|
||||
```
|
||||
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_dspy_intro.py --backend vllm
|
||||
```
|
||||
165
benchmark/dspy/bench_dspy_intro.py
Normal file
165
benchmark/dspy/bench_dspy_intro.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import dspy
|
||||
from dspy.datasets import HotPotQA
|
||||
|
||||
|
||||
class BasicQA(dspy.Signature):
|
||||
"""Answer questions with short factoid answers."""
|
||||
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="often between 1 and 5 words")
|
||||
|
||||
|
||||
class GenerateAnswer(dspy.Signature):
|
||||
"""Answer questions with short factoid answers."""
|
||||
|
||||
context = dspy.InputField(desc="may contain relevant facts")
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="often between 1 and 5 words")
|
||||
|
||||
|
||||
class RAG(dspy.Module):
|
||||
def __init__(self, num_passages=3):
|
||||
super().__init__()
|
||||
|
||||
self.retrieve = dspy.Retrieve(k=num_passages)
|
||||
self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
|
||||
|
||||
def forward(self, question):
|
||||
context = self.retrieve(question).passages
|
||||
prediction = self.generate_answer(context=context, question=question)
|
||||
return dspy.Prediction(context=context, answer=prediction.answer)
|
||||
|
||||
|
||||
def main(args):
|
||||
#lm = dspy.OpenAI(model='gpt-3.5-turbo')
|
||||
if args.backend == "tgi":
|
||||
lm = dspy.HFClientTGI(model="meta-llama/Llama-2-7b-chat-hf", port=args.port,
|
||||
url="http://localhost")
|
||||
elif args.backend == "sglang":
|
||||
lm = dspy.HFClientSGLang(model="meta-llama/Llama-2-7b-chat-hf", port=args.port,
|
||||
url="http://localhost")
|
||||
elif args.backend == "vllm":
|
||||
lm = dspy.HFClientVLLM(model="meta-llama/Llama-2-7b-chat-hf", port=args.port,
|
||||
url="http://localhost")
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
|
||||
dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)
|
||||
|
||||
# Load the dataset.
|
||||
dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size,
|
||||
test_size=0)
|
||||
|
||||
# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
|
||||
trainset = [x.with_inputs('question') for x in dataset.train]
|
||||
devset = [x.with_inputs('question') for x in dataset.dev]
|
||||
|
||||
print(len(trainset), len(devset))
|
||||
|
||||
train_example = trainset[0]
|
||||
print(f"Question: {train_example.question}")
|
||||
print(f"Answer: {train_example.answer}")
|
||||
|
||||
dev_example = devset[18]
|
||||
print(f"Question: {dev_example.question}")
|
||||
print(f"Answer: {dev_example.answer}")
|
||||
print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}")
|
||||
|
||||
print(f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}")
|
||||
print(f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}")
|
||||
|
||||
# Define the predictor.
|
||||
generate_answer = dspy.Predict(BasicQA)
|
||||
|
||||
# Call the predictor on a particular input.
|
||||
pred = generate_answer(question=dev_example.question)
|
||||
|
||||
# Print the input and the prediction.
|
||||
print(f"Question: {dev_example.question}")
|
||||
print(f"Predicted Answer: {pred.answer}")
|
||||
|
||||
lm.inspect_history(n=1)
|
||||
|
||||
# Define the predictor. Notice we're just changing the class. The signature BasicQA is unchanged.
|
||||
generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)
|
||||
|
||||
# Call the predictor on the same input.
|
||||
pred = generate_answer_with_chain_of_thought(question=dev_example.question)
|
||||
|
||||
# Print the input, the chain of thought, and the prediction.
|
||||
print(f"Question: {dev_example.question}")
|
||||
print(f"Thought: {pred.rationale.split('.', 1)[1].strip()}")
|
||||
print(f"Predicted Answer: {pred.answer}")
|
||||
|
||||
retrieve = dspy.Retrieve(k=3)
|
||||
topK_passages = retrieve(dev_example.question).passages
|
||||
|
||||
print(f"Top {retrieve.k} passages for question: {dev_example.question} \n", '-' * 30, '\n')
|
||||
|
||||
for idx, passage in enumerate(topK_passages):
|
||||
print(f'{idx+1}]', passage, '\n')
|
||||
|
||||
retrieve("When was the first FIFA World Cup held?").passages[0]
|
||||
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
# Validation logic: check that the predicted answer is correct.
|
||||
# Also check that the retrieved context does actually contain that answer.
|
||||
def validate_context_and_answer(example, pred, trace=None):
|
||||
answer_EM = dspy.evaluate.answer_exact_match(example, pred)
|
||||
answer_PM = dspy.evaluate.answer_passage_match(example, pred)
|
||||
return answer_EM and answer_PM
|
||||
|
||||
# Set up a basic teleprompter, which will compile our RAG program.
|
||||
teleprompter = BootstrapFewShot(metric=validate_context_and_answer)
|
||||
|
||||
# Compile!
|
||||
compiled_rag = teleprompter.compile(RAG(), trainset=trainset)
|
||||
|
||||
# Ask any question you like to this simple RAG program.
|
||||
my_question = "What castle did David Gregory inherit?"
|
||||
|
||||
# Get the prediction. This contains `pred.context` and `pred.answer`.
|
||||
pred = compiled_rag(my_question)
|
||||
|
||||
# Print the contexts and the answer.
|
||||
print(f"Question: {my_question}")
|
||||
print(f"Predicted Answer: {pred.answer}")
|
||||
print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")
|
||||
|
||||
from dspy.evaluate.evaluate import Evaluate
|
||||
|
||||
# Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
|
||||
evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=args.num_threads, display_progress=True, display_table=5)
|
||||
|
||||
# Evaluate the `compiled_rag` program with the `answer_exact_match` metric.
|
||||
metric = dspy.evaluate.answer_exact_match
|
||||
evaluate_on_hotpotqa(compiled_rag, metric=metric)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int)
|
||||
parser.add_argument("--num-threads", type=int, default=32)
|
||||
parser.add_argument("--dev-size", type=int, default=150)
|
||||
parser.add_argument("--backend", type=str, choices=["sglang", "tgi", "vllm"],
|
||||
default="sglang")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.port is None:
|
||||
default_port = {
|
||||
"vllm": 21000,
|
||||
"lightllm": 22000,
|
||||
"tgi": 24000,
|
||||
"sglang": 30000,
|
||||
}
|
||||
args.port = default_port.get(args.backend, None)
|
||||
|
||||
main(args)
|
||||
26
benchmark/generative_agents/README.md
Normal file
26
benchmark/generative_agents/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
## Run benchmark
|
||||
|
||||
Ensure that this benchmark is run in a serial manner (using --parallel 1) to preserve any potential dependencies between requests.
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-events 1000 --parallel 1
|
||||
```
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-events 1000 --backend vllm --parallel 1
|
||||
```
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --num-events 1000 --backend guidance --parallel 1
|
||||
```
|
||||
231
benchmark/generative_agents/agent_functions.py
Normal file
231
benchmark/generative_agents/agent_functions.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import sglang as sgl
|
||||
|
||||
# here are the top five agent functions contributing ~70% LLM calls
|
||||
# reference: https://github.com/joonspk-research/generative_agents/
|
||||
|
||||
|
||||
@sgl.function
|
||||
def poignancy_event(s, persona_name, persona_iss, event):
|
||||
s += "Here is a brief description of " + persona_name + ".\n"
|
||||
s += persona_iss + "\n"
|
||||
s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
|
||||
s += persona_name + ".\n\n"
|
||||
s += "Event: " + event
|
||||
s += "Rate (return a number between 1 to 10):"
|
||||
s += sgl.gen(name="Rate", max_tokens=2)
|
||||
|
||||
|
||||
def poignancy_event_prompt(persona_name, persona_iss, event):
|
||||
# return prompt and max_tokens
|
||||
s = ""
|
||||
s += "Here is a brief description of " + persona_name + ".\n"
|
||||
s += persona_iss + "\n"
|
||||
s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
|
||||
s += persona_name + ".\n\n"
|
||||
s += "Event: " + event
|
||||
s += "Rate (return a number between 1 to 10):"
|
||||
return {"prompt": s, "max_tokens": 2, "stop": None}
|
||||
|
||||
|
||||
@sgl.function
|
||||
def generate_event_triple(s, persona_name, action):
|
||||
s += """Task: Turn the input into (subject, predicate, object).
|
||||
Input: Sam Johnson is eating breakfast.
|
||||
Output: (Dolores Murphy, eat, breakfast)
|
||||
---
|
||||
Input: Joon Park is brewing coffee.
|
||||
Output: (Joon Park, brew, coffee)
|
||||
---
|
||||
Input: Jane Cook is sleeping.
|
||||
Output: (Jane Cook, is, sleep)
|
||||
---
|
||||
Input: Michael Bernstein is writing email on a computer.
|
||||
Output: (Michael Bernstein, write, email)
|
||||
---
|
||||
Input: Percy Liang is teaching students in a classroom.
|
||||
Output: (Percy Liang, teach, students)
|
||||
---
|
||||
Input: Merrie Morris is running on a treadmill.
|
||||
Output: (Merrie Morris, run, treadmill)
|
||||
---"""
|
||||
s += persona_name + "is" + action + ".\n"
|
||||
s += "(" + persona_name + ","
|
||||
s += sgl.gen(name="Triple", max_tokens=20, stop=")")
|
||||
|
||||
|
||||
def generate_event_triple_prompt(persona_name, action):
|
||||
s = ""
|
||||
s += """Task: Turn the input into (subject, predicate, object).
|
||||
Input: Sam Johnson is eating breakfast.
|
||||
Output: (Dolores Murphy, eat, breakfast)
|
||||
---
|
||||
Input: Joon Park is brewing coffee.
|
||||
Output: (Joon Park, brew, coffee)
|
||||
---
|
||||
Input: Jane Cook is sleeping.
|
||||
Output: (Jane Cook, is, sleep)
|
||||
---
|
||||
Input: Michael Bernstein is writing email on a computer.
|
||||
Output: (Michael Bernstein, write, email)
|
||||
---
|
||||
Input: Percy Liang is teaching students in a classroom.
|
||||
Output: (Percy Liang, teach, students)
|
||||
---
|
||||
Input: Merrie Morris is running on a treadmill.
|
||||
Output: (Merrie Morris, run, treadmill)
|
||||
---"""
|
||||
s += persona_name + "is" + action + ".\n"
|
||||
s += "(" + persona_name + ","
|
||||
return {"prompt": s, "max_tokens": 20, "stop": ")"}
|
||||
|
||||
|
||||
@sgl.function
|
||||
def generate_pronunciatio(s, action):
|
||||
s += "Convert an action description to an emoji (important: use two or less emojis).\n"
|
||||
s += "Action description: " + action + ".\n"
|
||||
s += "Emoji:" + sgl.gen(name="Emoji", max_tokens=6)
|
||||
|
||||
|
||||
def generate_pronunciatio_prompt(action):
|
||||
s = ""
|
||||
s += "Convert an action description to an emoji (important: use two or less emojis).\n"
|
||||
s += "Action description: " + action + ".\n"
|
||||
s += "Emoji:"
|
||||
return {"prompt": s, "max_tokens": 6, "stop": None}
|
||||
|
||||
|
||||
@sgl.function
|
||||
def action_location_sector(
|
||||
s,
|
||||
persona_name,
|
||||
living_sector,
|
||||
living_sector_areas,
|
||||
current_sector,
|
||||
current_sector_areas,
|
||||
daily_plan,
|
||||
sector_options,
|
||||
current_action,
|
||||
next_action,
|
||||
):
|
||||
s += """Task -- choose an appropriate area from the area options for a task at hand.
|
||||
Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
||||
Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
||||
Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
||||
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||
* Must be one of the "Area options," verbatim.
|
||||
For taking a walk, Sam Kim should go to the following area: {Johnson Park}
|
||||
---
|
||||
Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
|
||||
Jane Anderson is currently in {Oak Hill College} that has a classroom, library
|
||||
Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
||||
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||
* Must be one of the "Area options," verbatim.
|
||||
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
|
||||
---"""
|
||||
s += (persona_name + " lives in " + living_sector + " that has " +
|
||||
living_sector_areas + ".\n")
|
||||
s += (persona_name + " is currently in " + current_sector + " that has " +
|
||||
current_sector_areas + ".\n")
|
||||
s += daily_plan + ".\n"
|
||||
s += "Area options: " + sector_options + ".\n"
|
||||
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||
* Must be one of the "Area options," verbatim.\n"""
|
||||
s += (persona_name + " is " + current_action + ". For " + next_action +
|
||||
", " + persona_name + " should go to the following area: {")
|
||||
s += sgl.gen(name="Location", max_tokens=10, stop="}")
|
||||
|
||||
|
||||
def action_location_sector_prompt(
|
||||
persona_name,
|
||||
living_sector,
|
||||
living_sector_areas,
|
||||
current_sector,
|
||||
current_sector_areas,
|
||||
daily_plan,
|
||||
sector_options,
|
||||
current_action,
|
||||
next_action,
|
||||
):
|
||||
s = ""
|
||||
s += """Task -- choose an appropriate area from the area options for a task at hand.
|
||||
Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
||||
Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
||||
Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
||||
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||
* Must be one of the "Area options," verbatim.
|
||||
For taking a walk, Sam Kim should go to the following area: {Johnson Park}
|
||||
---
|
||||
Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
|
||||
Jane Anderson is currently in {Oak Hill College} that has a classroom, library
|
||||
Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
||||
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||
* Must be one of the "Area options," verbatim.
|
||||
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
|
||||
---"""
|
||||
s += (persona_name + " lives in " + living_sector + " that has " +
|
||||
living_sector_areas + ".\n")
|
||||
s += (persona_name + " is currently in " + current_sector + " that has " +
|
||||
current_sector_areas + ".\n")
|
||||
s += daily_plan + ".\n"
|
||||
s += "Area options: " + sector_options + ".\n"
|
||||
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||
* Must be one of the "Area options," verbatim.\n"""
|
||||
s += (persona_name + " is " + current_action + ". For " + next_action +
|
||||
", " + persona_name + " should go to the following area: {")
|
||||
return {"prompt": s, "max_tokens": 10, "stop": "}"}
|
||||
|
||||
|
||||
@sgl.function
|
||||
def action_location_object(s, persona_name, target_sector, target_sector_areas,
|
||||
current_action, next_action):
|
||||
s += """
|
||||
Jane Anderson is in kitchen in Jane Anderson's house.
|
||||
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
|
||||
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
||||
For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
|
||||
Answer: {kitchen}
|
||||
---
|
||||
Tom Watson is in common room in Tom Watson's apartment.
|
||||
Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
|
||||
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
||||
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
|
||||
Answer: {cafe}
|
||||
---"""
|
||||
s += (persona_name + " is going to " + target_sector +
|
||||
" that has the following areas: {" + target_sector_areas + "}\n")
|
||||
s += """* Stay in the current area if the activity can be done there.
|
||||
* NEVER go into other people's rooms unless necessary."""
|
||||
s += (persona_name + " is " + current_action + ". For " + next_action +
|
||||
", " + persona_name + "should go to the following area in " +
|
||||
target_sector)
|
||||
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
|
||||
s += "Answer: {" + sgl.gen(name="Area", max_tokens=5, stop="}")
|
||||
|
||||
|
||||
def action_location_object_prompt(persona_name, target_sector,
|
||||
target_sector_areas, current_action,
|
||||
next_action):
|
||||
s = ""
|
||||
s += """
|
||||
Jane Anderson is in kitchen in Jane Anderson's house.
|
||||
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
|
||||
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
||||
For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
|
||||
Answer: {kitchen}
|
||||
---
|
||||
Tom Watson is in common room in Tom Watson's apartment.
|
||||
Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
|
||||
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
||||
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
|
||||
Answer: {cafe}
|
||||
---"""
|
||||
s += (persona_name + " is going to " + target_sector +
|
||||
" that has the following areas: {" + target_sector_areas + "}\n")
|
||||
s += """* Stay in the current area if the activity can be done there.
|
||||
* NEVER go into other people's rooms unless necessary."""
|
||||
s += (persona_name + " is " + current_action + ". For " + next_action +
|
||||
", " + persona_name + "should go to the following area in " +
|
||||
target_sector)
|
||||
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
|
||||
s += "Answer: {"
|
||||
return {"prompt": s, "max_tokens": 5, "stop": "}"}
|
||||
104
benchmark/generative_agents/bench_other.py
Normal file
104
benchmark/generative_agents/bench_other.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import argparse
|
||||
from functools import partial
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
from sglang.test.test_utils import (
|
||||
add_common_other_args_and_parse,
|
||||
call_generate_lightllm,
|
||||
call_generate_vllm,
|
||||
call_generate_srt_raw,
|
||||
)
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
from agent_functions import (
|
||||
poignancy_event_prompt,
|
||||
generate_event_triple_prompt,
|
||||
generate_pronunciatio_prompt,
|
||||
action_location_sector_prompt,
|
||||
action_location_object_prompt,
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)[:args.num_events]
|
||||
mapping = {
|
||||
"poignancy_event": poignancy_event_prompt,
|
||||
"generate_event_triple": generate_event_triple_prompt,
|
||||
"generate_pronunciatio": generate_pronunciatio_prompt,
|
||||
"action_location_sector": action_location_sector_prompt,
|
||||
"action_location_object": action_location_object_prompt,
|
||||
}
|
||||
|
||||
arguments = [mapping[k](**v) for l in lines for k, v in l.items()]
|
||||
states = []
|
||||
|
||||
# Select backend
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_lightllm, url=url)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_vllm, url=url)
|
||||
elif args.backend == "srt-raw":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_srt_raw, url=url)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models, gen
|
||||
|
||||
model = models.LlamaCpp(
|
||||
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
|
||||
n_gpu_layers=-1,
|
||||
n_ctx=4096,
|
||||
)
|
||||
|
||||
def call_generate(prompt, temperature, max_tokens, stop):
|
||||
out = model + prompt + gen(
|
||||
name="result",
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop,
|
||||
)
|
||||
return out["result"]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
def get_one_answer(arg):
|
||||
answer = call_generate(**arg, temperature=0)
|
||||
states.append(answer)
|
||||
|
||||
tic = time.time()
|
||||
# we always sequentially execute agent calls to maintain its dependency
|
||||
for arg in tqdm(arguments):
|
||||
get_one_answer(arg)
|
||||
latency = time.time() - tic
|
||||
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "Generative Agents",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
# to pack weighted functions as a single agent
|
||||
"num_requests": len(arguments) / len(mapping),
|
||||
"other": {
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
|
||||
parser.add_argument("--num-events", type=int, default=10)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
74
benchmark/generative_agents/bench_sglang.py
Normal file
74
benchmark/generative_agents/bench_sglang.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
from agent_functions import (
|
||||
poignancy_event,
|
||||
generate_event_triple,
|
||||
generate_pronunciatio,
|
||||
action_location_sector,
|
||||
action_location_object,
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)[:args.num_events]
|
||||
mapping = {
|
||||
"poignancy_event": poignancy_event,
|
||||
"generate_event_triple": generate_event_triple,
|
||||
"generate_pronunciatio": generate_pronunciatio,
|
||||
"action_location_sector": action_location_sector,
|
||||
"action_location_object": action_location_object,
|
||||
}
|
||||
arguments = [{mapping[k]: v for k, v in l.items()} for l in lines]
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
states = []
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
for a in arguments:
|
||||
# only a single key in the dict
|
||||
for func, arg in a.items():
|
||||
result = func.run(**arg)
|
||||
result.sync()
|
||||
states.append(result)
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "Generative Agents",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
# to pack weighted functions as a single agent
|
||||
"num_requests": len(arguments) / len(mapping),
|
||||
"other": {
|
||||
"num_events": args.num_events,
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
|
||||
parser.add_argument("--num-events", type=int, default=10)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
52
benchmark/gsm8k/README.md
Normal file
52
benchmark/gsm8k/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
## Download data
|
||||
```
|
||||
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 200
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 200 --backend vllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lightllm
|
||||
```
|
||||
# A10G
|
||||
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 200 --backend lightllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --num-questions 200 --backend guidance --parallel 1
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lmql
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 100 --backend lmql --parallel 2
|
||||
```
|
||||
168
benchmark/gsm8k/bench_other.py
Normal file
168
benchmark/gsm8k/bench_other.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def get_one_example(lines, i, include_answer):
|
||||
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
||||
if include_answer:
|
||||
ret += " " + lines[i]["answer"]
|
||||
return ret
|
||||
|
||||
|
||||
def get_few_shot_examples(lines, k):
|
||||
ret = ""
|
||||
for i in range(k):
|
||||
ret += get_one_example(lines, i, True) + "\n\n"
|
||||
return ret
|
||||
|
||||
|
||||
def get_answer_value(answer_str):
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r'\d+', answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
|
||||
# Construct prompts
|
||||
k = args.num_shot
|
||||
few_shot_examples = get_few_shot_examples(lines, k)
|
||||
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
questions.append(get_one_example(lines, i, False))
|
||||
labels.append(get_answer_value(lines[i]["answer"]))
|
||||
assert all(l != INVALID for l in labels)
|
||||
|
||||
states = [None] * len(labels)
|
||||
|
||||
# Select backend
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_lightllm, url=url)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_vllm, url=url)
|
||||
elif args.backend == "srt-raw":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_srt_raw, url=url)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models, gen
|
||||
|
||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
||||
|
||||
def call_generate(prompt, temperature, max_tokens, stop):
|
||||
out = model + prompt + gen(name="answer",
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
return out["answer"]
|
||||
|
||||
elif args.backend == "lmql":
|
||||
import lmql
|
||||
model = lmql.model(args.model_path,
|
||||
endpoint=f"{args.host}:{args.port}")
|
||||
|
||||
@lmql.query(model=model)
|
||||
async def program(question):
|
||||
'''lmql
|
||||
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 257 and STOPS_AT(ANSWER, "Question")
|
||||
return ANSWER
|
||||
'''
|
||||
|
||||
async def call_generate(prompt, temperature, max_tokens, stop):
|
||||
return await program(question=prompt, temperature=0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
# Run requests
|
||||
if args.backend != "lmql":
|
||||
# Use thread pool
|
||||
def get_one_answer(i):
|
||||
answer = call_generate(
|
||||
prompt=few_shot_examples + questions[i],
|
||||
temperature=0,
|
||||
max_tokens=256,
|
||||
stop="Question")
|
||||
states[i] = answer
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(questions))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(get_one_answer, list(range(len(questions))))
|
||||
else:
|
||||
# Use asyncio
|
||||
async def batched_call(batch_size):
|
||||
for i in range(0, len(questions), batch_size):
|
||||
tasks = []
|
||||
for q in questions[i:i+batch_size]:
|
||||
tasks.append(call_generate(few_shot_examples + q,
|
||||
temperature=0, max_tokens=256, stop="Question"))
|
||||
rets = await asyncio.gather(*tasks)
|
||||
for j in range(len(rets)):
|
||||
states[i+j] = rets[j]
|
||||
|
||||
tic = time.time()
|
||||
asyncio.run(batched_call(batch_size=args.parallel))
|
||||
latency = time.time() - tic
|
||||
|
||||
preds = []
|
||||
for i in range(len(states)):
|
||||
preds.append(get_answer_value(states[i]))
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid = np.mean(np.array(preds) == INVALID)
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Invalid: {invalid:.3f}")
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "gsm8k",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-shot", type=int, default=5)
|
||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=200)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
115
benchmark/gsm8k/bench_sglang.py
Normal file
115
benchmark/gsm8k/bench_sglang.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def get_one_example(lines, i, include_answer):
|
||||
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
||||
if include_answer:
|
||||
ret += " " + lines[i]["answer"]
|
||||
return ret
|
||||
|
||||
|
||||
def get_few_shot_examples(lines, k):
|
||||
ret = ""
|
||||
for i in range(k):
|
||||
ret += get_one_example(lines, i, True) + "\n\n"
|
||||
return ret
|
||||
|
||||
|
||||
def get_answer_value(answer_str):
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r'\d+', answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
|
||||
# Construct prompts
|
||||
k = args.num_shot
|
||||
few_shot_examples = get_few_shot_examples(lines, k)
|
||||
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
questions.append(get_one_example(lines, i, False))
|
||||
labels.append(get_answer_value(lines[i]["answer"]))
|
||||
assert all(l != INVALID for l in labels)
|
||||
arguments = [{"question": q} for q in questions]
|
||||
|
||||
#####################################
|
||||
######### SGL Program Begin #########
|
||||
#####################################
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
@sgl.function
|
||||
def few_shot_gsm8k(s, question):
|
||||
s += few_shot_examples + question
|
||||
s += sgl.gen("answer", max_tokens=256, stop="Question")
|
||||
|
||||
#####################################
|
||||
########## SGL Program End ##########
|
||||
#####################################
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
states = few_shot_gsm8k.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel)
|
||||
latency = time.time() - tic
|
||||
|
||||
preds = []
|
||||
for i in range(len(states)):
|
||||
preds.append(get_answer_value(states[i]["answer"]))
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid = np.mean(np.array(preds) == INVALID)
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Invalid: {invalid:.3f}")
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "gsm8k",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-shot", type=int, default=5)
|
||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=200)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
52
benchmark/hellaswag/README.md
Normal file
52
benchmark/hellaswag/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
## Download data
|
||||
```
|
||||
wget https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 200
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 200 --backend vllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lightllm
|
||||
```
|
||||
# A10G
|
||||
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 200 --backend lightllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lmql
|
||||
```
|
||||
lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 200 --backend lmql --port 23000 --parallel 1
|
||||
```
|
||||
140
benchmark/hellaswag/bench_other.py
Normal file
140
benchmark/hellaswag/bench_other.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
from functools import partial
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_select_lightllm, call_select_vllm
|
||||
from sglang.utils import read_jsonl
|
||||
|
||||
|
||||
def get_one_example(lines, i, include_answer):
|
||||
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||
if include_answer:
|
||||
ret += lines[i]["endings"][lines[i]["label"]]
|
||||
return ret
|
||||
|
||||
|
||||
def get_few_shot_examples(lines, k):
|
||||
ret = ""
|
||||
for i in range(k):
|
||||
ret += get_one_example(lines, i, True) + "\n\n"
|
||||
return ret
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
|
||||
# Construct prompts
|
||||
k = args.num_shot
|
||||
few_shot_examples = get_few_shot_examples(lines, k)
|
||||
|
||||
questions = []
|
||||
choices = []
|
||||
labels = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
questions.append(get_one_example(lines, i, False))
|
||||
choices.append(lines[i]["endings"])
|
||||
labels.append(lines[i]["label"])
|
||||
|
||||
preds = [None] * len(labels)
|
||||
|
||||
# Select backend
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_select = partial(call_select_lightllm, url=url)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_select = partial(call_select_vllm, url=url)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models, select
|
||||
|
||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
||||
|
||||
def call_select(context, choices):
|
||||
out = model + context + select(choices, name="answer")
|
||||
return choices.index(out["answer"])
|
||||
|
||||
elif args.backend == "lmql":
|
||||
import lmql
|
||||
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
|
||||
endpoint=f"{args.host}:{args.port}")
|
||||
|
||||
@lmql.query(model=model)
|
||||
async def program(ctx, choices):
|
||||
'''lmql
|
||||
"""{ctx}[ANSWER]""" where ANSWER in set(choices)
|
||||
return ANSWER
|
||||
'''
|
||||
|
||||
async def call_select(context, choices):
|
||||
answer = await program(ctx=context, choices=choices, temperature=0)
|
||||
return choices.index(answer)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
# Run requests
|
||||
if args.backend != "lmql":
|
||||
# Use thread pool
|
||||
def get_one_answer(i):
|
||||
preds[i] = call_select(
|
||||
context=few_shot_examples + questions[i],
|
||||
choices=choices[i])
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in range(len(questions)):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(get_one_answer, list(range(len(questions))))
|
||||
else:
|
||||
# Use asyncio
|
||||
async def batched_call(batch_size):
|
||||
for i in range(0, len(questions), batch_size):
|
||||
tasks = []
|
||||
for q, c in zip(questions[i:i+batch_size], choices[i:i+batch_size]):
|
||||
tasks.append(call_select(
|
||||
context=few_shot_examples + q,
|
||||
choices=c))
|
||||
rets = await asyncio.gather(*tasks)
|
||||
for j in range(len(rets)):
|
||||
preds[i+j] = rets[j]
|
||||
|
||||
tic = time.time()
|
||||
asyncio.run(batched_call(batch_size=args.parallel))
|
||||
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
|
||||
# Write results
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "hellaswag",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-shot", type=int, default=20)
|
||||
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=100)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
96
benchmark/hellaswag/bench_sglang.py
Normal file
96
benchmark/hellaswag/bench_sglang.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.utils import read_jsonl
|
||||
|
||||
|
||||
def get_one_example(lines, i, include_answer):
|
||||
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||
if include_answer:
|
||||
ret += lines[i]["endings"][lines[i]["label"]]
|
||||
return ret
|
||||
|
||||
|
||||
def get_few_shot_examples(lines, k):
|
||||
ret = ""
|
||||
for i in range(k):
|
||||
ret += get_one_example(lines, i, True) + "\n\n"
|
||||
return ret
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
|
||||
# Construct prompts
|
||||
k = args.num_shot
|
||||
few_shot_examples = get_few_shot_examples(lines, k)
|
||||
|
||||
questions = []
|
||||
choices = []
|
||||
labels = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
questions.append(get_one_example(lines, i, False))
|
||||
choices.append(lines[i]["endings"])
|
||||
labels.append(lines[i]["label"])
|
||||
arguments = [
|
||||
{"question": q, "choices": c}
|
||||
for q, c in zip(questions, choices)
|
||||
]
|
||||
|
||||
#####################################
|
||||
######### SGL Program Begin #########
|
||||
#####################################
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
@sgl.function
|
||||
def few_shot_hellaswag(s, question, choices):
|
||||
s += few_shot_examples + question
|
||||
s += sgl.select("answer", choices=choices)
|
||||
|
||||
#####################################
|
||||
########## SGL Program End ##########
|
||||
#####################################
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
rets = few_shot_hellaswag.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel)
|
||||
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
|
||||
# Write results
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "hellaswag",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-shot", type=int, default=20)
|
||||
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=100)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
46
benchmark/latency_throughput/README.md
Normal file
46
benchmark/latency_throughput/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
### Download data
|
||||
```
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
```
|
||||
|
||||
### Performance
|
||||
|
||||
- Model: Llama-2-7b-chat-hf
|
||||
- `--num-prompts 2000 --request-rate 200`
|
||||
- On 4 A10 (24G) GPUs
|
||||
|
||||
| Backend | Throughput | Latency |
|
||||
| ----------- | --------------- | -------- |
|
||||
| srt | 5.82 requests/s | 343.54 s |
|
||||
| vllm==0.2.6 | 3.93 requests/s | 509.08 s |
|
||||
| vllm==0.2.7 | 5.02 requests/s | 398.25 s |
|
||||
|
||||
|
||||
### SGLang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 --port 30000
|
||||
```
|
||||
|
||||
|
||||
### vLLM
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --swap-space 16
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10
|
||||
```
|
||||
|
||||
|
||||
### LightLLM
|
||||
```
|
||||
python -m lightllm.server.api_server --model_dir ~/model_weights/Llama-2-7b-chat-hf --max_total_token_num 15600 --tokenizer_mode auto --port 22000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_throughput.py --backend lightllm --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 --port 22000
|
||||
```
|
||||
254
benchmark/latency_throughput/bench_throughput.py
Normal file
254
benchmark/latency_throughput/bench_throughput.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Benchmark online serving throughput.
|
||||
|
||||
On the server side, run one of the following commands:
|
||||
(vLLM backend)
|
||||
python -m vllm.entrypoints.api_server \
|
||||
--model <your_model> --swap-space 16 \
|
||||
--disable-log-requests
|
||||
|
||||
(TGI backend)
|
||||
./launch_hf_server.sh <your_model>
|
||||
|
||||
On the client side, run:
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend <backend> \
|
||||
--tokenizer <your_model> --dataset <target_dataset> \
|
||||
--request-rate <request_rate>
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import AsyncGenerator, List, Tuple
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
# (prompt len, output len, latency)
|
||||
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
|
||||
|
||||
|
||||
def sample_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [
|
||||
data for data in dataset
|
||||
if len(data["conversations"]) >= 2
|
||||
]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompts = [prompt for prompt, _ in dataset]
|
||||
prompt_token_ids = tokenizer(prompts).input_ids
|
||||
completions = [completion for _, completion in dataset]
|
||||
completion_token_ids = tokenizer(completions).input_ids
|
||||
tokenized_dataset = []
|
||||
for i in range(len(dataset)):
|
||||
output_len = len(completion_token_ids[i])
|
||||
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
||||
|
||||
# Filter out too long sequences.
|
||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
||||
prompt_len = len(prompt_token_ids)
|
||||
if prompt_len < 4 or output_len < 4:
|
||||
# Prune too short sequences.
|
||||
# This is because TGI causes errors when the input or output length
|
||||
# is too short.
|
||||
continue
|
||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
|
||||
# Sample the requests.
|
||||
sampled_requests = random.sample(filtered_dataset, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
request_rate: float,
|
||||
) -> AsyncGenerator[Tuple[str, int, int], None]:
|
||||
input_requests = iter(input_requests)
|
||||
for request in input_requests:
|
||||
yield request
|
||||
|
||||
if request_rate == float("inf"):
|
||||
# If the request rate is infinity, then we don't need to wait.
|
||||
continue
|
||||
# Sample the request interval from the exponential distribution.
|
||||
interval = np.random.exponential(1.0 / request_rate)
|
||||
# The next request will be sent after the interval.
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
async def send_request(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
prompt: str,
|
||||
prompt_len: int,
|
||||
output_len: int,
|
||||
best_of: int,
|
||||
use_beam_search: bool,
|
||||
) -> None:
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
headers = {"User-Agent": "Benchmark Client"}
|
||||
if backend == "vllm":
|
||||
pload = {
|
||||
"prompt": prompt,
|
||||
"n": 1,
|
||||
"best_of": best_of,
|
||||
"use_beam_search": use_beam_search,
|
||||
"temperature": 0.0 if use_beam_search else 1.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
"stream": False,
|
||||
}
|
||||
elif backend == "tgi":
|
||||
assert not use_beam_search
|
||||
params = {
|
||||
"best_of": best_of,
|
||||
"max_new_tokens": output_len,
|
||||
"do_sample": True,
|
||||
}
|
||||
pload = {
|
||||
"inputs": prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
elif backend == "srt":
|
||||
assert not use_beam_search
|
||||
params = {
|
||||
"ignore_eos": True,
|
||||
"max_new_tokens": output_len,
|
||||
}
|
||||
pload = {
|
||||
"text": prompt,
|
||||
"sampling_params": params,
|
||||
}
|
||||
elif backend == "lightllm":
|
||||
assert not use_beam_search
|
||||
params = {
|
||||
"ignore_eos": True,
|
||||
"max_new_tokens": output_len,
|
||||
}
|
||||
pload = {
|
||||
"inputs": prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
while True:
|
||||
async with session.post(api_url, headers=headers, json=pload) as response:
|
||||
chunks = []
|
||||
async for chunk, _ in response.content.iter_chunks():
|
||||
chunks.append(chunk)
|
||||
output = b"".join(chunks).decode("utf-8")
|
||||
output = json.loads(output)
|
||||
|
||||
# Re-send the request if it failed.
|
||||
if "error" not in output:
|
||||
break
|
||||
|
||||
request_end_time = time.perf_counter()
|
||||
request_latency = request_end_time - request_start_time
|
||||
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
||||
|
||||
|
||||
async def benchmark(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
best_of: int,
|
||||
use_beam_search: bool,
|
||||
request_rate: float,
|
||||
) -> None:
|
||||
tasks: List[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate):
|
||||
prompt, prompt_len, output_len = request
|
||||
task = asyncio.create_task(send_request(backend, api_url, prompt,
|
||||
prompt_len, output_len,
|
||||
best_of, use_beam_search))
|
||||
tasks.append(task)
|
||||
await tqdm_asyncio.gather(*tasks)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
api_url = f"http://{args.host}:{args.port}/generate"
|
||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
|
||||
args.use_beam_search, args.request_rate))
|
||||
benchmark_end_time = time.perf_counter()
|
||||
benchmark_time = benchmark_end_time - benchmark_start_time
|
||||
print(f"Total time: {benchmark_time:.2f} s")
|
||||
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
|
||||
|
||||
# Compute the latency statistics.
|
||||
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
|
||||
print(f"Average latency: {avg_latency:.2f} s")
|
||||
avg_per_token_latency = np.mean([
|
||||
latency / (prompt_len + output_len)
|
||||
for prompt_len, output_len, latency in REQUEST_LATENCY
|
||||
])
|
||||
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
||||
avg_per_output_token_latency = np.mean([
|
||||
latency / output_len
|
||||
for _, output_len, latency in REQUEST_LATENCY
|
||||
])
|
||||
print("Average latency per output token: "
|
||||
f"{avg_per_output_token_latency:.2f} s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the online serving throughput.")
|
||||
parser.add_argument("--backend", type=str, default="vllm",
|
||||
choices=["vllm", "tgi", "srt", "lightllm"])
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--tokenizer", type=str, required=True,
|
||||
help="Name or path of the tokenizer.")
|
||||
parser.add_argument("--best-of", type=int, default=1,
|
||||
help="Generates `best_of` sequences per prompt and "
|
||||
"returns the best one.")
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--request-rate", type=float, default=float("inf"),
|
||||
help="Number of requests per second. If this is inf, "
|
||||
"then all the requests are sent at time 0. "
|
||||
"Otherwise, we use Poisson process to synthesize "
|
||||
"the request arrival times.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument('--trust-remote-code', action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
66
benchmark/latency_throughput/test_latency.py
Normal file
66
benchmark/latency_throughput/test_latency.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=None)
|
||||
parser.add_argument("--backend", type=str, default="srt")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.port is None:
|
||||
if args.backend == "srt":
|
||||
args.port = 30000
|
||||
elif args.backend == "vllm":
|
||||
args.port = 21000
|
||||
elif args.backend == "lightllm":
|
||||
args.port = 22000
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
a = random.randint(0, 1 << 20)
|
||||
max_new_tokens = 256
|
||||
|
||||
tic = time.time()
|
||||
if args.backend == "srt":
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": f"{a}, ",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
},
|
||||
)
|
||||
elif args.backend == "lightllm":
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"inputs": f"{a}, ",
|
||||
"parameters": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
},
|
||||
)
|
||||
elif args.backend == "vllm":
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"prompt": f"{a}, ",
|
||||
"temperature": 0,
|
||||
"max_tokens": max_new_tokens,
|
||||
},
|
||||
)
|
||||
latency = time.time() - tic
|
||||
|
||||
ret = response.json()
|
||||
print(ret)
|
||||
|
||||
speed = max_new_tokens / latency
|
||||
print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s")
|
||||
37
benchmark/line_retrieval/README.md
Normal file
37
benchmark/line_retrieval/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
## Download data
|
||||
|
||||
```
|
||||
wget https://raw.githubusercontent.com/merrymercy/merrymercy.github.io/master/files/random_words.json
|
||||
python3 gen_data.py --number 1000
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --src-index 600 --num-q 50 --parallel 1
|
||||
```
|
||||
|
||||
|
||||
###
|
||||
|
||||
```
|
||||
# original
|
||||
Accuracy: 0.940, latency: 332.83 s
|
||||
|
||||
# parallel encoding (no_adjust, offset = 1000)
|
||||
Accuracy: 0.760, latency: 238.46 s
|
||||
|
||||
# parallel encoding (no_adjust, offset = 3000)
|
||||
Accuracy: 0.760, latency: 238.46 s
|
||||
|
||||
# parallel encoding (no_adjust, offset = 0)
|
||||
Accuracy: 0.520, latency: 238.46 s
|
||||
|
||||
# parallel encoding (adjust_cache)
|
||||
Accuracy: 0.460, latency: 257.66 s
|
||||
```
|
||||
133
benchmark/line_retrieval/bench_sglang.py
Normal file
133
benchmark/line_retrieval/bench_sglang.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.utils import dump_state_text
|
||||
|
||||
|
||||
@sgl.function
|
||||
def line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3):
|
||||
s += prefix + "\n"
|
||||
|
||||
contexts = [body_0, body_1, body_2, body_3]
|
||||
position_ids_offset = [i * 1000 for i in range(len(contexts))]
|
||||
forks = s.fork(len(contexts), position_ids_offset)
|
||||
forks += lambda i: contexts[i] + "\n"
|
||||
forks.join(mode="concate_and_append")
|
||||
|
||||
s += "\n" + suffix
|
||||
s += sgl.gen("answer", max_tokens=16)
|
||||
|
||||
|
||||
def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
||||
arguments = []
|
||||
labels = []
|
||||
sum_src_indices = []
|
||||
sum_dst_indices = []
|
||||
|
||||
for i in range(len(src_indices)):
|
||||
for j in range(len(dst_percents)):
|
||||
src_index = src_indices[i]
|
||||
dst_percent = dst_percents[j]
|
||||
|
||||
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
|
||||
query_indices = [q for q in query_indices if
|
||||
all(l <= src_index for l in line_obj["links"][q]) and q < src_index]
|
||||
dst_index = query_indices[min(int(len(query_indices) * dst_percent), len(query_indices)-1)]
|
||||
label = line_obj["values"][dst_index]
|
||||
|
||||
body = line_obj["lines"][:src_index+1]
|
||||
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
|
||||
body_part_len = len(body) // 4
|
||||
|
||||
arguments.append({
|
||||
"prefix": line_obj["prefix"],
|
||||
"body_0": "\n".join(body[:body_part_len]),
|
||||
"body_1": "\n".join(body[body_part_len: 2 * body_part_len]),
|
||||
"body_2": "\n".join(body[2 * body_part_len: 3 * body_part_len]),
|
||||
"body_3": "\n".join(body[3 * body_part_len:]),
|
||||
"suffix": suffix,
|
||||
})
|
||||
labels.append(label)
|
||||
sum_src_indices.append(src_index)
|
||||
sum_dst_indices.append(dst_index)
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
tic = time.time()
|
||||
states = line_retrieval.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel)
|
||||
latency = time.time() - tic
|
||||
|
||||
corrects = []
|
||||
for i in range(len(arguments)):
|
||||
output = states[i]["answer"]
|
||||
prompt_len = states[i].get_meta_info("answer").get("prompt_length", -1)
|
||||
label = labels[i]
|
||||
|
||||
# Try all numbers
|
||||
findall = re.findall("\d+", output)
|
||||
if not findall:
|
||||
response_number = output
|
||||
else:
|
||||
for response_number in findall:
|
||||
if response_number == label:
|
||||
break
|
||||
|
||||
correct = (response_number == label)
|
||||
corrects.append(correct)
|
||||
|
||||
# Log results
|
||||
summary = (
|
||||
f"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, "
|
||||
f"Prompt len: {prompt_len}, "
|
||||
f"Correct: {correct}, "
|
||||
f"Label: {label}, Predicted: {response_number}, "
|
||||
)
|
||||
print(summary)
|
||||
|
||||
accuracy = np.mean(corrects)
|
||||
print(f"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "line_retrieval",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": len(arguments),
|
||||
"other": {
|
||||
"num_questions": len(arguments),
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
def main(args):
|
||||
line_obj = json.load(open(args.data_path, "r"))
|
||||
|
||||
num_hoops = args.num_hoops
|
||||
for src_index in args.src_index:
|
||||
src_indices = [src_index]
|
||||
num_queries = args.num_queries_per_src
|
||||
dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)]
|
||||
eval_model(args, line_obj, num_hoops, src_indices, dst_percents)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="lines_1000_0.0.json")
|
||||
parser.add_argument("--src-index", type=int, nargs="+", default=[100])
|
||||
parser.add_argument("--num-queries-per-src", type=int, default=10)
|
||||
parser.add_argument("--num-hoops", type=int, default=1)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
135
benchmark/line_retrieval/gen_data.py
Normal file
135
benchmark/line_retrieval/gen_data.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Generate line data for line retrieval task.
|
||||
|
||||
Usage:
|
||||
python3 gen_data.py --number 1000
|
||||
"""
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_lines(random_words, num_lines, redirect_ratio):
|
||||
prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask."
|
||||
suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resovling the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is"
|
||||
|
||||
# Raw lines
|
||||
visited_indices = set([None])
|
||||
visited_values = set([None])
|
||||
|
||||
lines = []
|
||||
redirects = []
|
||||
indices = []
|
||||
values = []
|
||||
for i in tqdm(range(num_lines)):
|
||||
line_index = None
|
||||
while line_index in visited_indices:
|
||||
line_index = "-".join(np.random.choice(random_words, size=(2,)))
|
||||
visited_indices.add(line_index)
|
||||
|
||||
line_value = np.random.randint(low=0, high=999999)
|
||||
line_value = f"{line_value:06}"
|
||||
|
||||
line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}."
|
||||
lines.append(line)
|
||||
redirects.append(None)
|
||||
indices.append(line_index)
|
||||
values.append(line_value)
|
||||
|
||||
# Add redirect
|
||||
if redirect_ratio > 0:
|
||||
num_redirect_lines = int(len(lines) * redirect_ratio)
|
||||
redirect_indices = np.random.choice(np.arange(len(lines)),
|
||||
size=(num_redirect_lines,), replace=False)
|
||||
for i in redirect_indices:
|
||||
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
|
||||
lines[i] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
|
||||
redirects[i] = target_idx
|
||||
|
||||
# Build links and find sources
|
||||
links = [[] for _ in range(num_lines)]
|
||||
contains_ring = set()
|
||||
for i in range(num_lines):
|
||||
if redirects[i] is None:
|
||||
continue
|
||||
|
||||
tmp_link = []
|
||||
cur = i
|
||||
visited = set()
|
||||
while redirects[cur] is not None:
|
||||
visited.add(cur)
|
||||
tmp_link.append(redirects[cur])
|
||||
cur = redirects[cur]
|
||||
|
||||
if cur in visited:
|
||||
contains_ring.add(i)
|
||||
tmp_link = None
|
||||
break
|
||||
values[i] = values[cur]
|
||||
links[i] = tmp_link
|
||||
|
||||
# Group by num_links
|
||||
group_by_num_hoops = defaultdict(list)
|
||||
for i in range(num_lines):
|
||||
if i in contains_ring:
|
||||
continue
|
||||
group_by_num_hoops[len(links[i]) + 1].append(i)
|
||||
|
||||
keys = sorted(list(group_by_num_hoops.keys()))
|
||||
for num_links in keys:
|
||||
print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}")
|
||||
|
||||
# Append few-shot examples
|
||||
hoop1_candidates = list(group_by_num_hoops[1])
|
||||
hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates}
|
||||
hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c])
|
||||
hoop2_candidates = list(group_by_num_hoops[2])
|
||||
hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates}
|
||||
hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c])
|
||||
|
||||
i = hoop1_candidates[5]
|
||||
suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i])
|
||||
if len(hoop2_candidates):
|
||||
i = hoop2_candidates[0]
|
||||
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
|
||||
i = hoop2_candidates[1]
|
||||
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
|
||||
else:
|
||||
i = hoop1_candidates[1]
|
||||
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
|
||||
i = hoop1_candidates[10]
|
||||
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
|
||||
|
||||
obj = {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"lines": lines,
|
||||
"indices": indices,
|
||||
"values": values,
|
||||
"links": links,
|
||||
"group_by_num_hoops": group_by_num_hoops,
|
||||
"contains_ring": sorted(list(contains_ring)),
|
||||
}
|
||||
return obj
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--number", type=int)
|
||||
parser.add_argument("--redirect-ratio", type=float, default=0.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
num_lines = args.number
|
||||
|
||||
random_words_filename = "random_words.json"
|
||||
random_words = json.load(open(random_words_filename, "r"))
|
||||
|
||||
np.random.seed(42)
|
||||
obj = generate_lines(random_words, num_lines, args.redirect_ratio)
|
||||
|
||||
fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json"
|
||||
with open(fout, "w") as fout:
|
||||
json.dump(obj, fout, indent=2)
|
||||
60
benchmark/llava_bench/README.md
Normal file
60
benchmark/llava_bench/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
## Download benchmark images
|
||||
|
||||
```
|
||||
python3 download_images.py
|
||||
```
|
||||
|
||||
image benchmark source: https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild
|
||||
|
||||
### Other Dependency
|
||||
```
|
||||
pip3 install "torch>=2.1.2" "transformers>=4.36" pillow
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
Launch a server
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
||||
```
|
||||
|
||||
Run benchmark
|
||||
```
|
||||
# Run with local models
|
||||
python3 bench_sglang.py --num-questions 60
|
||||
|
||||
# Run with OpenAI models
|
||||
python3 bench_sglang.py --num-questions 60 --backend gpt-4-vision-preview
|
||||
```
|
||||
|
||||
### Bench LLaVA original code
|
||||
```
|
||||
git clone git@github.com:haotian-liu/LLaVA.git
|
||||
cd LLaVA
|
||||
git reset --hard 9a26bd1435b4ac42c282757f2c16d34226575e96
|
||||
pip3 install -e .
|
||||
|
||||
cd ~/sglang/benchmark/llava_bench
|
||||
CUDA_VISIBLE_DEVICES=0 bash bench_hf_llava_bench.sh
|
||||
```
|
||||
|
||||
|
||||
### Benchmark llama.cpp
|
||||
|
||||
```
|
||||
# Install
|
||||
CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python
|
||||
pip install sse_starlette starlette_context pydantic_settings
|
||||
|
||||
# Download weights
|
||||
mkdir -p ~/model_weights/llava-v1.5-7b/
|
||||
wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/ggml-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf
|
||||
wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/mmproj-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf
|
||||
```
|
||||
|
||||
```
|
||||
python3 -m llama_cpp.server --model ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf --clip_model_path ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf --chat_format llava-1-5 --port 23000
|
||||
|
||||
OPENAI_BASE_URL=http://localhost:23000/v1 python3 bench_sglang.py --backend gpt-4-vision-preview --num-q 1
|
||||
```
|
||||
9
benchmark/llava_bench/bench_hf_llava_bench.sh
Normal file
9
benchmark/llava_bench/bench_hf_llava_bench.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
python -m llava.eval.model_vqa \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--question-file ./questions.jsonl \
|
||||
--image-folder ./images \
|
||||
--answers-file ./answers_hf.jsonl \
|
||||
--temperature 0 \
|
||||
--conv-mode vicuna_v1
|
||||
9
benchmark/llava_bench/bench_hf_mme.sh
Normal file
9
benchmark/llava_bench/bench_hf_mme.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
python -m llava.eval.model_vqa_loader \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--question-file ./mme_pack/llava_mme_bench_replace.jsonl \
|
||||
--image-folder ./mme_pack/MME_Benchmark_release_version \
|
||||
--answers-file ./answers_hf_mme.jsonl \
|
||||
--temperature 0 \
|
||||
--conv-mode vicuna_v1
|
||||
96
benchmark/llava_bench/bench_sglang.py
Normal file
96
benchmark/llava_bench/bench_sglang.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
|
||||
import sglang as sgl
|
||||
import tqdm
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@sgl.function
|
||||
def image_qa(s, image_file, question):
|
||||
s += sgl.user(sgl.image(image_file) + question)
|
||||
s += sgl.assistant(sgl.gen("answer", max_tokens=args.max_tokens))
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.question_file)[:args.num_questions]
|
||||
arguments = [
|
||||
{"image_file":
|
||||
os.path.abspath(args.image_folder + "/" + l["image"]),
|
||||
"question": l["text"]} for l in lines
|
||||
]
|
||||
#arguments = [
|
||||
# {"image_file":
|
||||
# Image.open(os.path.abspath(args.image_folder + "/" + l["image"])),
|
||||
# "question": l["text"]} for l in lines
|
||||
#]
|
||||
|
||||
states = [None] * len(lines)
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm.tqdm(range(len(lines))):
|
||||
image_file = arguments[i]["image_file"]
|
||||
question = arguments[i]["question"]
|
||||
ret = image_qa.run(
|
||||
image_file=image_file,
|
||||
question=question,
|
||||
temperature=0)
|
||||
states[i] = ret
|
||||
else:
|
||||
states = image_qa.run_batch(
|
||||
arguments,
|
||||
temperature=0,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True)
|
||||
latency = time.time() - tic
|
||||
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
print(f"Write output to {args.answer_file}")
|
||||
with open(args.answer_file, "w") as fout:
|
||||
for i in range(len(lines)):
|
||||
value = {
|
||||
"question_id": lines[i]["question_id"],
|
||||
"prompt": lines[i]["text"],
|
||||
"text": states[i]["answer"].strip(),
|
||||
"model_id": backend.model_info["model_path"],
|
||||
"answer_id": i,
|
||||
"metadata": {},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "llava_bench",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": len(lines),
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--question-file", type=str, default="questions.jsonl")
|
||||
parser.add_argument("--answer-file", type=str, default="answers.jsonl")
|
||||
parser.add_argument("--image-folder", type=str, default="./images")
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
parser.add_argument("--num-questions", type=int, default=None)
|
||||
parser.add_argument("--max-tokens", type=int, default=768)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
2
benchmark/llava_bench/bench_sglang_mme.sh
Normal file
2
benchmark/llava_bench/bench_sglang_mme.sh
Normal file
@@ -0,0 +1,2 @@
|
||||
MME_FOLDER=./mme_pack
|
||||
python3 bench_sglang.py --num-questions 5000 --question-file $MME_FOLDER/llava_mme_bench_replace.jsonl --answer-file answer_mme.jsonl --image-folder $MME_FOLDER/MME_Benchmark_release_version --max-tokens 4
|
||||
20
benchmark/llava_bench/download_images.py
Normal file
20
benchmark/llava_bench/download_images.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
|
||||
# Create the 'images' directory if it doesn't exist
|
||||
if not os.path.exists('images'):
|
||||
os.makedirs('images')
|
||||
|
||||
# Base URL
|
||||
base_url = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/"
|
||||
|
||||
# Loop through image numbers
|
||||
for i in range(1, 25):
|
||||
# Format the image number with leading zeros
|
||||
image_number = str(i).zfill(3)
|
||||
image_url = base_url + image_number + ".jpg"
|
||||
image_path = "images/" + image_number + ".jpg"
|
||||
|
||||
# Download the image using wget
|
||||
os.system(f"wget -O {image_path} {image_url}")
|
||||
|
||||
print("Download complete.")
|
||||
27
benchmark/llm_judge/README.md
Normal file
27
benchmark/llm_judge/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 25 --parallel 8
|
||||
python3 bench_sglang.py --num-questions 16 --parallel 1
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --backend vllm --num-questions 25
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --backend guidance --num-questions 25 --parallel 1
|
||||
```
|
||||
120
benchmark/llm_judge/bench_other.py
Normal file
120
benchmark/llm_judge/bench_other.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import json
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
system_prompt = (
|
||||
"Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
|
||||
)
|
||||
|
||||
dimension_prompts = [
|
||||
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
|
||||
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
|
||||
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
|
||||
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
|
||||
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
|
||||
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
|
||||
]
|
||||
|
||||
|
||||
def multi_dimension_judge(article, generate):
|
||||
s = system_prompt
|
||||
s += "\n```\n" + article + "\n```\n\n"
|
||||
|
||||
judges = []
|
||||
for i in range(len(dimension_prompts)):
|
||||
comp = generate(s +
|
||||
"USER: Please judge the quality based on the following metric. " +
|
||||
dimension_prompts[i] + " Please provide a single-paragraph judgement. " +
|
||||
"Focus on the provided metric and do not say other things. "
|
||||
'End your judgement paragraph with the word "END"\nJUDGE:',
|
||||
max_tokens=256, stop="END")
|
||||
judges.append(comp)
|
||||
|
||||
s += "I will judge the quality based on the following metrics.\n"
|
||||
for i in range(len(dimension_prompts)):
|
||||
s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n"
|
||||
|
||||
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
|
||||
s += generate(s, max_tokens=2, stop=None)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)[:args.num_questions]
|
||||
states = [None] * len(lines)
|
||||
|
||||
# Select backend
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_lightllm, url=url, temperature=0)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_vllm, url=url, temperature=0)
|
||||
elif args.backend == "srt-raw":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_srt_raw, url=url, temperature=0)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models, gen
|
||||
|
||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
||||
|
||||
def generate(prompt, max_tokens, stop):
|
||||
out = model + prompt + gen(name="answer",
|
||||
max_tokens=max_tokens, temperature=0, stop=stop)
|
||||
return out["answer"]
|
||||
|
||||
# warmup
|
||||
generate("Hello!", max_tokens=8, stop=None)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
# Run requests
|
||||
def get_one_answer(i):
|
||||
states[i] = multi_dimension_judge(lines[i], generate)
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(lines))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(get_one_answer, list(range(len(lines))))
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "llm_judge",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="articles.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=20)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
85
benchmark/llm_judge/bench_sglang.py
Normal file
85
benchmark/llm_judge/bench_sglang.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
system_prompt = (
|
||||
"Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
|
||||
)
|
||||
|
||||
dimension_prompts = [
|
||||
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
|
||||
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
|
||||
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
|
||||
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
|
||||
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
|
||||
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
|
||||
]
|
||||
|
||||
|
||||
@sgl.function
|
||||
def multi_dimension_judge(s, article):
|
||||
s += system_prompt
|
||||
s += "\n```\n" + article + "\n```\n\n"
|
||||
|
||||
forks = s.fork(len(dimension_prompts))
|
||||
for i in range(len(dimension_prompts)):
|
||||
forks[i] += ("USER: Please judge the quality based on the following metric. " +
|
||||
dimension_prompts[i] + " Please provide a single-paragraph judgement. " +
|
||||
"Focus on the provided metric and do not say other things. "
|
||||
'End your judgement paragraph with the word "END"\nJUDGE:')
|
||||
forks[i] += sgl.gen("judgement", max_tokens=256, stop="END")
|
||||
forks.join()
|
||||
|
||||
s += "I will judge the quality based on the following metrics.\n"
|
||||
for i in range(len(dimension_prompts)):
|
||||
s += dimension_prompts[i].split(":")[0] + ": " + forks[i]["judgement"].strip() + "\n"
|
||||
|
||||
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
|
||||
s += sgl.gen("score", max_tokens=2)
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)[:args.num_questions]
|
||||
arguments = [{"article": l} for l in lines]
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
states = multi_dimension_judge.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel)
|
||||
latency = time.time() - tic
|
||||
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "llm_judge",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="articles.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=20)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
33
benchmark/long_json_decode/README.md
Normal file
33
benchmark/long_json_decode/README.md
Normal file
@@ -0,0 +1,33 @@
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 5 --parallel 1
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --backend vllm --num-questions 5
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --backend guidance --num-questions 5 --parallel 1
|
||||
```
|
||||
|
||||
|
||||
### Build dataset
|
||||
```
|
||||
pip install wikipedia
|
||||
python3 build_dataset.py
|
||||
```
|
||||
104
benchmark/long_json_decode/bench_other.py
Normal file
104
benchmark/long_json_decode/bench_other.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import json
|
||||
import time
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
def json_decode(document, generate):
|
||||
s = "Please extract the information of a city from the following wikipedia page.\n"
|
||||
s += "Page begin.\n" + document + "Page end.\n"
|
||||
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||
s += '{\n'
|
||||
s += ' "name": "'
|
||||
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
||||
s += ' "country": "'
|
||||
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
||||
s += ' "air port code": "'
|
||||
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
||||
s += ' "top 3 landmarks": "'
|
||||
s += generate(s, max_tokens=24, stop='"') + '",\n'
|
||||
s += '}\n'
|
||||
return s
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
arguments = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
arguments.append({
|
||||
"document": lines[i]["document"],
|
||||
})
|
||||
states = [None] * len(arguments)
|
||||
|
||||
# Select backend
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_lightllm, url=url, temperature=0)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_vllm, url=url, temperature=0)
|
||||
elif args.backend == "srt-raw":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_srt_raw, url=url, temperature=0)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models, gen
|
||||
|
||||
model = models.LlamaCpp("/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf", n_gpu_layers=-1, n_ctx=11000)
|
||||
|
||||
def generate(prompt, max_tokens, stop):
|
||||
out = model + prompt + gen(name="answer",
|
||||
max_tokens=max_tokens, temperature=0, stop=stop)
|
||||
return out["answer"]
|
||||
|
||||
# warmup
|
||||
generate("Hello!", max_tokens=8, stop=None)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
# Run requests
|
||||
def get_one_answer(i):
|
||||
states[i] = json_decode(generate=generate, **arguments[i])
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(arguments))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(get_one_answer, list(range(len(arguments))))
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "long_json_decode",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=100)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
68
benchmark/long_json_decode/bench_sglang.py
Normal file
68
benchmark/long_json_decode/bench_sglang.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
@sgl.function
|
||||
def json_decode(s, document):
|
||||
s += "Please extract the information of a city from the following wikipedia page.\n"
|
||||
s += "Page begin.\n" + document + "Page end.\n"
|
||||
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||
s += '{\n'
|
||||
s += ' "name": "' + sgl.gen("name", max_tokens=8, stop='"') + '",\n'
|
||||
s += ' "country": "' + sgl.gen("country", max_tokens=8, stop='"') + '",\n'
|
||||
s += ' "air port code": "' + sgl.gen("air port code", max_tokens=8, stop='"') + '",\n'
|
||||
s += ' "top 3 landmarks": "' + sgl.gen("landmarks", max_tokens=24, stop='"') + '",\n'
|
||||
s += '}\n'
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
arguments = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
arguments.append({
|
||||
"document": lines[i]["document"],
|
||||
})
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
states = json_decode.run_batch(
|
||||
arguments, temperature=0, num_threads=args.parallel)
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "long_json_decode",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=10)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
26
benchmark/long_json_decode/build_dataset.py
Normal file
26
benchmark/long_json_decode/build_dataset.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import json
|
||||
|
||||
import transformers
|
||||
import wikipedia
|
||||
|
||||
|
||||
name = "meta-llama/Llama-2-7b-chat-hf"
|
||||
t = transformers.AutoTokenizer.from_pretrained(name)
|
||||
city_names = ["los angles", "london", "tokyo", "beijing", "singapore"]
|
||||
|
||||
|
||||
for city_name in city_names:
|
||||
content = str(wikipedia.page(city_name).content)
|
||||
content = content.replace("\n\n", "\n")
|
||||
|
||||
tokens = t.encode(content)
|
||||
|
||||
truncate_len = int((10000 / len(tokens)) * len(content))
|
||||
truncate_content = content[:truncate_len]
|
||||
truncate_tokens = t.encode(truncate_content)
|
||||
|
||||
# Count token
|
||||
print(f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}")
|
||||
|
||||
with open("questions.jsonl", "a") as fout:
|
||||
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
||||
56
benchmark/mmlu/README.md
Normal file
56
benchmark/mmlu/README.md
Normal file
@@ -0,0 +1,56 @@
|
||||
## Download data
|
||||
```
|
||||
wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
|
||||
tar xf data.tar
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --nsub 10
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --nsub 10 --backend vllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lightllm
|
||||
```
|
||||
# A10G
|
||||
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
||||
|
||||
# V100
|
||||
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 4500 --port 22000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --nsub 10 --backend lightllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --nsub 10 --backend guidance --parallel 1
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lmql
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --nsub 10 --backend lmql --parallel 2
|
||||
```
|
||||
202
benchmark/mmlu/bench_other.py
Normal file
202
benchmark/mmlu/bench_other.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
from functools import partial
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from tqdm import tqdm
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
||||
|
||||
|
||||
choices = ["A", "B", "C", "D"]
|
||||
|
||||
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
|
||||
|
||||
def format_subject(subject):
|
||||
l = subject.split("_")
|
||||
s = ""
|
||||
for entry in l:
|
||||
s += " " + entry
|
||||
return s
|
||||
|
||||
def format_example(df, idx, include_answer=True):
|
||||
prompt = df.iloc[idx, 0]
|
||||
k = df.shape[1] - 2
|
||||
for j in range(k):
|
||||
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1])
|
||||
prompt += "\nAnswer:"
|
||||
if include_answer:
|
||||
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
||||
return prompt
|
||||
|
||||
def gen_prompt(train_df, subject, k=-1):
|
||||
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(format_subject(subject))
|
||||
if k == -1:
|
||||
k = train_df.shape[0]
|
||||
for i in range(k):
|
||||
prompt += format_example(train_df, i)
|
||||
return prompt
|
||||
|
||||
|
||||
model_initialized = None
|
||||
|
||||
|
||||
def evaluate(args, subject, dev_df, test_df):
|
||||
prompts = []
|
||||
labels = []
|
||||
|
||||
# Construct prompts
|
||||
k = args.ntrain
|
||||
train_prompt = gen_prompt(dev_df, subject, k)
|
||||
while len(tokenizer.encode(train_prompt)) > 1536:
|
||||
k -= 1
|
||||
train_prompt = gen_prompt(dev_df, subject, k)
|
||||
|
||||
for i in range(test_df.shape[0]):
|
||||
prompt_end = format_example(test_df, i, include_answer=False)
|
||||
prompt = train_prompt + prompt_end
|
||||
prompts.append(prompt)
|
||||
|
||||
label = test_df.iloc[i, test_df.shape[1]-1]
|
||||
labels.append(label)
|
||||
|
||||
preds = [None] * len(prompts)
|
||||
max_tokens = 1
|
||||
|
||||
# Select backend
|
||||
global model_initialized
|
||||
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_lightllm, url=url, stop=None)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_vllm, url=url, stop=None)
|
||||
elif args.backend == "srt-raw":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_srt_raw, url=url, stop=None)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models, gen
|
||||
|
||||
if model_initialized is None:
|
||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
||||
model_initialized = model
|
||||
else:
|
||||
model = model_initialized
|
||||
|
||||
def call_generate(prompt, temperature, max_tokens):
|
||||
out = model + prompt + gen(name="answer",
|
||||
max_tokens=max_tokens, temperature=0)
|
||||
return out["answer"]
|
||||
|
||||
elif args.backend == "lmql":
|
||||
import lmql
|
||||
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
|
||||
endpoint=f"{args.host}:{args.port}")
|
||||
|
||||
@lmql.query(model=model)
|
||||
async def program(question):
|
||||
'''lmql
|
||||
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 2
|
||||
return ANSWER
|
||||
'''
|
||||
|
||||
async def call_generate(prompt, temperature, max_tokens):
|
||||
return await program(question=prompt, temperature=temperature)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
# Run requests
|
||||
if args.backend != "lmql":
|
||||
# Use thread pool
|
||||
def get_one_answer(i):
|
||||
pred = call_generate(prompts[i], temperature=0,
|
||||
max_tokens=max_tokens)
|
||||
preds[i] = pred.strip()[0]
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in range(len(prompts)):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(get_one_answer, list(range(len(prompts))))
|
||||
else:
|
||||
# Use asyncio
|
||||
async def batched_call(batch_size):
|
||||
for i in range(0, len(prompts), batch_size):
|
||||
tasks = []
|
||||
for p in prompts[i:i+batch_size]:
|
||||
tasks.append(call_generate(p,
|
||||
temperature=0, max_tokens=max_tokens))
|
||||
rets = await asyncio.gather(*tasks)
|
||||
for j in range(len(rets)):
|
||||
preds[i+j] = rets[j].strip()[0]
|
||||
|
||||
tic = time.time()
|
||||
asyncio.run(batched_call(batch_size=args.parallel))
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
cors = [pred == label for pred, label in zip(preds, labels)]
|
||||
acc = np.mean(cors)
|
||||
cors = np.array(cors)
|
||||
|
||||
print("Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
|
||||
acc, latency, len(prompts), subject))
|
||||
|
||||
return cors, acc, latency
|
||||
|
||||
|
||||
def main(args):
|
||||
subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f])
|
||||
|
||||
all_cors = []
|
||||
all_latencies = []
|
||||
num_requests = 0
|
||||
|
||||
for subject in tqdm(subjects[:args.nsub]):
|
||||
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain]
|
||||
test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None)
|
||||
|
||||
cors, acc, latency = evaluate(args, subject, dev_df, test_df)
|
||||
all_cors.append(cors)
|
||||
all_latencies.append(latency)
|
||||
num_requests += len(test_df)
|
||||
|
||||
total_latency = np.sum(all_latencies)
|
||||
print("Total latency: {:.3f}".format(total_latency))
|
||||
|
||||
weighted_acc = np.mean(np.concatenate(all_cors))
|
||||
print("Average accuracy: {:.3f}".format(weighted_acc))
|
||||
|
||||
# Write results
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "mmlu",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(total_latency, 3),
|
||||
"accuracy": round(weighted_acc, 3),
|
||||
"num_requests": num_requests,
|
||||
"other": {
|
||||
"nsub": args.nsub,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ntrain", type=int, default=5)
|
||||
parser.add_argument("--data_dir", type=str, default="data")
|
||||
parser.add_argument("--nsub", type=int, default=60)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
143
benchmark/mmlu/bench_sglang.py
Normal file
143
benchmark/mmlu/bench_sglang.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from tqdm import tqdm
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
|
||||
|
||||
choices = ["A", "B", "C", "D"]
|
||||
|
||||
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
|
||||
|
||||
def format_subject(subject):
|
||||
l = subject.split("_")
|
||||
s = ""
|
||||
for entry in l:
|
||||
s += " " + entry
|
||||
return s
|
||||
|
||||
def format_example(df, idx, include_answer=True):
|
||||
prompt = df.iloc[idx, 0]
|
||||
k = df.shape[1] - 2
|
||||
for j in range(k):
|
||||
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1])
|
||||
prompt += "\nAnswer:"
|
||||
if include_answer:
|
||||
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
||||
return prompt
|
||||
|
||||
def gen_prompt(train_df, subject, k=-1):
|
||||
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(format_subject(subject))
|
||||
if k == -1:
|
||||
k = train_df.shape[0]
|
||||
for i in range(k):
|
||||
prompt += format_example(train_df, i)
|
||||
return prompt
|
||||
|
||||
def evaluate(args, subject, dev_df, test_df):
|
||||
prompts = []
|
||||
labels = []
|
||||
|
||||
k = args.ntrain
|
||||
few_shot_examples = gen_prompt(dev_df, subject, k)
|
||||
while len(tokenizer.encode(few_shot_examples)) > 1536:
|
||||
k -= 1
|
||||
few_shot_examples = gen_prompt(dev_df, subject, k)
|
||||
|
||||
for i in range(test_df.shape[0]):
|
||||
prompt_end = format_example(test_df, i, include_answer=False)
|
||||
prompts.append(prompt_end)
|
||||
|
||||
label = test_df.iloc[i, test_df.shape[1]-1]
|
||||
labels.append(label)
|
||||
|
||||
arguments = [{"question": p} for p in prompts]
|
||||
|
||||
#####################################
|
||||
######### SGL Program Begin #########
|
||||
#####################################
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
@sgl.function
|
||||
def few_shot_mmlu(s, examples, question):
|
||||
s += examples + question + sgl.gen("answer")
|
||||
|
||||
#####################################
|
||||
########## SGL Program End ##########
|
||||
#####################################
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
tic = time.time()
|
||||
states = few_shot_mmlu.bind(examples=few_shot_examples).run_batch(
|
||||
arguments, temperature=0, max_new_tokens=1,
|
||||
backend=backend, num_threads=args.parallel)
|
||||
preds = [s["answer"].strip()[0] if len(s["answer"].strip()) > 0 else ""
|
||||
for s in states]
|
||||
latency = time.time() - tic
|
||||
|
||||
cors = [pred == label for pred, label in zip(preds, labels)]
|
||||
acc = np.mean(cors)
|
||||
cors = np.array(cors)
|
||||
|
||||
print("Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
|
||||
acc, latency, len(prompts), subject))
|
||||
|
||||
return cors, acc, latency
|
||||
|
||||
|
||||
def main(args):
|
||||
subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f])
|
||||
|
||||
all_cors = []
|
||||
all_latencies = []
|
||||
num_requests = 0
|
||||
|
||||
for subject in tqdm(subjects[:args.nsub]):
|
||||
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain]
|
||||
test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None)
|
||||
|
||||
cors, acc, latency = evaluate(args, subject, dev_df, test_df)
|
||||
all_cors.append(cors)
|
||||
all_latencies.append(latency)
|
||||
num_requests += len(test_df)
|
||||
|
||||
total_latency = np.sum(all_latencies)
|
||||
print("Total latency: {:.3f}".format(total_latency))
|
||||
|
||||
weighted_acc = np.mean(np.concatenate(all_cors))
|
||||
print("Average accuracy: {:.3f}".format(weighted_acc))
|
||||
|
||||
# Write results
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "mmlu",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(total_latency, 3),
|
||||
"accuracy": round(weighted_acc, 3),
|
||||
"num_requests": num_requests,
|
||||
"other": {
|
||||
"nsub": args.nsub,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ntrain", "-k", type=int, default=5)
|
||||
parser.add_argument("--data_dir", "-d", type=str, default="data")
|
||||
parser.add_argument("--save_dir", "-s", type=str, default="results")
|
||||
parser.add_argument("--nsub", type=int, default=60)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
31
benchmark/mtbench/README.md
Normal file
31
benchmark/mtbench/README.md
Normal file
@@ -0,0 +1,31 @@
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 80
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 80 --backend vllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lightllm
|
||||
```
|
||||
# A10G
|
||||
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 80 --backend lightllm
|
||||
```
|
||||
116
benchmark/mtbench/bench_other.py
Normal file
116
benchmark/mtbench/bench_other.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from fastchat.model import get_conversation_template
|
||||
import requests
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt
|
||||
|
||||
|
||||
def load_questions(filename):
|
||||
questions = []
|
||||
with open(filename, "r") as fin:
|
||||
for line in fin:
|
||||
obj = json.loads(line)
|
||||
questions.append(obj)
|
||||
return questions
|
||||
|
||||
|
||||
def write_answers(filename, model_id, questions, answers):
|
||||
with open(os.path.expanduser(filename), "w") as fout:
|
||||
for i in range(len(answers)):
|
||||
ans_json = {
|
||||
"question_id": questions[i]["question_id"],
|
||||
"answer_id": uuid.uuid4().hex,
|
||||
"model_id": model_id,
|
||||
"choices": {
|
||||
"index": 0,
|
||||
"turns": [answers[i][0], answers[i][1]],
|
||||
},
|
||||
"tstamp": time.time(),
|
||||
}
|
||||
fout.write(json.dumps(ans_json) + "\n")
|
||||
|
||||
|
||||
def main(args):
|
||||
questions = load_questions(args.question_file)
|
||||
questions = (questions * 10)[:args.num_questions]
|
||||
max_tokens = 256
|
||||
model_id = "llama-2-chat"
|
||||
|
||||
conv_main = get_conversation_template(model_id)
|
||||
|
||||
# Select backend
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_lightllm, url=url, stop=None)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_vllm, url=url, stop=None)
|
||||
elif args.backend == "srt":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_srt, url=url, stop=None)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
answers = [None] * len(questions)
|
||||
|
||||
def get_answer(i):
|
||||
conv = conv_main.copy()
|
||||
cur_answers = []
|
||||
for j in range(2):
|
||||
q = questions[i]["turns"][j]
|
||||
conv.append_message(conv.roles[0], q)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
|
||||
prompt = conv.get_prompt()
|
||||
output = call_generate(prompt,
|
||||
temperature=0, max_tokens=max_tokens).strip()
|
||||
|
||||
cur_answers.append(output)
|
||||
conv.update_last_message(output)
|
||||
|
||||
answers[i] = cur_answers
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in range(len(questions)):
|
||||
get_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(get_answer, list(range(len(questions))))
|
||||
latency = time.time() - tic
|
||||
|
||||
print(f"#questions: {len(questions)}, Latency: {latency:.2f}")
|
||||
|
||||
# Write results
|
||||
answer_file = args.answer_file or f"tmp_output_{args.backend}.txt"
|
||||
write_answers(answer_file, model_id, questions, answers)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "mtbench",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--question-file", type=str, default="question.jsonl")
|
||||
parser.add_argument("--answer-file", type=str, default=None)
|
||||
parser.add_argument("--num-questions", type=int, default=80)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
95
benchmark/mtbench/bench_sglang.py
Normal file
95
benchmark/mtbench/bench_sglang.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
|
||||
|
||||
def load_questions(filename):
|
||||
questions = []
|
||||
with open(filename, "r") as fin:
|
||||
for line in fin:
|
||||
obj = json.loads(line)
|
||||
questions.append(obj)
|
||||
return questions
|
||||
|
||||
|
||||
def write_answers(filename, model_id, questions, answers):
|
||||
with open(os.path.expanduser(filename), "w") as fout:
|
||||
for i in range(len(answers)):
|
||||
ans_json = {
|
||||
"question_id": questions[i]["question_id"],
|
||||
"answer_id": uuid.uuid4().hex,
|
||||
"model_id": model_id,
|
||||
"choices": {
|
||||
"index": 0,
|
||||
"turns": [answers[i][0], answers[i][1]],
|
||||
},
|
||||
"tstamp": time.time(),
|
||||
}
|
||||
fout.write(json.dumps(ans_json) + "\n")
|
||||
|
||||
|
||||
@sgl.function
|
||||
def answer_mt_bench(s, question_1, question_2):
|
||||
s += sgl.system()
|
||||
s += sgl.user(question_1)
|
||||
s += sgl.assistant(sgl.gen("answer_1"))
|
||||
s += sgl.user(question_2)
|
||||
s += sgl.assistant(sgl.gen("answer_2"))
|
||||
|
||||
|
||||
def main(args):
|
||||
# Construct prompts
|
||||
questions = load_questions(args.question_file)[:args.num_questions]
|
||||
arguments = [
|
||||
{"question_1": q["turns"][0], "question_2": q["turns"][1]}
|
||||
for q in questions
|
||||
]
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
rets = answer_mt_bench.run_batch(
|
||||
arguments,
|
||||
temperature=0,
|
||||
max_new_tokens=256,
|
||||
num_threads=args.parallel)
|
||||
answers = [[s["answer_1"], s["answer_2"]] for s in rets]
|
||||
latency = time.time() - tic
|
||||
|
||||
print(f"#questions: {len(questions)}, Latency: {latency:.2f}")
|
||||
|
||||
# Write results
|
||||
model_id = backend.model_info["model_path"]
|
||||
answer_file = args.answer_file or f"tmp_output_{args.backend}.txt"
|
||||
write_answers(answer_file, model_id, questions, answers)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "mtbench",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--question-file", type=str, default="question.jsonl")
|
||||
parser.add_argument("--answer-file", type=str, default=None)
|
||||
parser.add_argument("--num-questions", type=int, default=80)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
43
benchmark/multi_chain_reasoning/README.md
Normal file
43
benchmark/multi_chain_reasoning/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
## Download data
|
||||
```
|
||||
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 64
|
||||
python3 bench_sglang.py --num-questions 32 --parallel 1
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 64 --backend vllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lightllm
|
||||
```
|
||||
# A10G
|
||||
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 64 --backend lightllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --num-questions 8 --backend guidance --parallel 1
|
||||
```
|
||||
195
benchmark/multi_chain_reasoning/bench_other.py
Normal file
195
benchmark/multi_chain_reasoning/bench_other.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def get_answer_value(answer_str):
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r'\d+', answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
prompt_lib = [
|
||||
"Let us think step by step.",
|
||||
"Approach this methodically. Let's dissect the problem into smaller, more manageable parts.",
|
||||
"It's important to proceed step by step, ensuring accuracy at each stage.",
|
||||
"Take a deep breath and break this down.",
|
||||
"A little bit of arithmetic and a logical approach will help us quickly arrive at the solution to this problem.",
|
||||
"I am extremely good at math.",
|
||||
]
|
||||
|
||||
|
||||
def multi_chain_gsm8k(question, num_chains, call_generate):
|
||||
s = "Question: " + question + "\n"
|
||||
# s += call_generate(s + "Answer: " + prompt_lib[0], max_tokens=256,
|
||||
# stop="Question", temperature=0)
|
||||
# return s
|
||||
|
||||
comps = []
|
||||
for i in range(num_chains):
|
||||
comps.append(call_generate(s + "Answer: " + prompt_lib[i % num_chains],
|
||||
max_tokens=256, temperature=0.3, stop="Question"))
|
||||
|
||||
s += "Answer: To answer this question, here are some possible solutions. "
|
||||
s += "After considering all of them, I will do a majority vote.\n\n"
|
||||
for i in range(num_chains):
|
||||
s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
|
||||
s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
|
||||
s += call_generate(s, max_tokens=16, temperature=0, stop=None)
|
||||
return s
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
|
||||
# Construct prompts
|
||||
k = args.num_shot
|
||||
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
questions.append(lines[i]["question"])
|
||||
labels.append(get_answer_value(lines[i]["answer"]))
|
||||
assert all(l != INVALID for l in labels)
|
||||
|
||||
states = [None] * len(labels)
|
||||
|
||||
# Select backend
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_lightllm, url=url)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_vllm, url=url)
|
||||
elif args.backend == "srt-raw":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_srt_raw, url=url)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models, gen
|
||||
|
||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
||||
|
||||
def call_generate(prompt, temperature, max_tokens, stop):
|
||||
out = model + prompt + gen(name="answer",
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
return out["answer"]
|
||||
|
||||
#def multi_chain_gsm8k(question, num_chains, call_generate):
|
||||
# s = model + "Question: " + question + "\n"
|
||||
|
||||
# comps = []
|
||||
# for i in range(num_chains):
|
||||
# comps.append(call_generate(s + "Answer: " + prompt_lib[i % num_chains],
|
||||
# max_tokens=256, temperature=0.3, stop="Question"))
|
||||
|
||||
# s += "Answer: To answer this question, here are some possible solutions. "
|
||||
# s += "After considering all of them, I will do a majority vote.\n\n"
|
||||
# for i in range(num_chains):
|
||||
# s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
|
||||
# s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
|
||||
# return call_generate(s, max_tokens=16, temperature=0, stop=None)
|
||||
|
||||
elif args.backend == "lmql":
|
||||
import lmql
|
||||
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
|
||||
endpoint=f"{args.host}:{args.port}")
|
||||
|
||||
@lmql.query(model=model)
|
||||
async def program(question):
|
||||
'''lmql
|
||||
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 257 and STOPS_AT(ANSWER, "Question")
|
||||
return ANSWER
|
||||
'''
|
||||
|
||||
async def call_generate(prompt, temperature, max_tokens, stop):
|
||||
return await program(question=prompt, temperature=0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
# Run requests
|
||||
if args.backend != "lmql":
|
||||
# Use thread pool
|
||||
def get_one_answer(i):
|
||||
answer = multi_chain_gsm8k(questions[i], args.num_chains,
|
||||
call_generate)
|
||||
states[i] = answer
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in range(len(questions)):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(get_one_answer, list(range(len(questions))))
|
||||
else:
|
||||
# Use asyncio
|
||||
async def batched_call(batch_size):
|
||||
for i in range(0, len(questions), batch_size):
|
||||
tasks = []
|
||||
for q in questions[i:i+batch_size]:
|
||||
tasks.append(call_generate(few_shot_examples + q,
|
||||
temperature=0, max_tokens=256, stop="Question"))
|
||||
rets = await asyncio.gather(*tasks)
|
||||
for j in range(len(rets)):
|
||||
states[i+j] = get_answer_value(rets[j])
|
||||
|
||||
tic = time.time()
|
||||
asyncio.run(batched_call(batch_size=args.parallel))
|
||||
latency = time.time() - tic
|
||||
|
||||
preds = []
|
||||
for i in range(len(states)):
|
||||
preds.append(get_answer_value(states[i]))
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid = np.mean(np.array(preds) == INVALID)
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Invalid: {invalid:.3f}")
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "multi_chain_gsm8k",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-shot", type=int, default=0)
|
||||
parser.add_argument("--num-chains", type=int, default=5)
|
||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=50)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
129
benchmark/multi_chain_reasoning/bench_sglang.py
Normal file
129
benchmark/multi_chain_reasoning/bench_sglang.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def get_answer_value(answer_str):
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r'\d+', answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
prompt_lib = [
|
||||
"Let us think step by step.",
|
||||
"Approach this methodically. Let's dissect the problem into smaller, more manageable parts.",
|
||||
"It's important to proceed step by step, ensuring accuracy at each stage.",
|
||||
"Take a deep breath and break this down.",
|
||||
"A little bit of arithmetic and a logical approach will help us quickly arrive at the solution to this problem.",
|
||||
"I am extremely good at math.",
|
||||
]
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
|
||||
# Construct prompts
|
||||
#k = args.num_shot
|
||||
#few_shot_examples = get_few_shot_examples(lines, k)
|
||||
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
questions.append(lines[i]["question"])
|
||||
labels.append(get_answer_value(lines[i]["answer"]))
|
||||
assert all(l != INVALID for l in labels)
|
||||
arguments = [{"question": q} for q in questions]
|
||||
|
||||
num_chains = args.num_chains
|
||||
|
||||
#####################################
|
||||
######### SGL Program Begin #########
|
||||
#####################################
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
@sgl.function
|
||||
def multi_chain_gsm8k(s, question):
|
||||
s += "Question: " + question + "\n"
|
||||
#s += "Answer: " + prompt_lib[0] + sgl.gen("answer", max_tokens=256, stop="Question",
|
||||
# temperature=0)
|
||||
#return
|
||||
|
||||
forks = s.fork(num_chains)
|
||||
for i in range(num_chains):
|
||||
forks[i] += ("Answer: " + prompt_lib[i % num_chains] +
|
||||
sgl.gen(f"chain", max_tokens=256, temperature=0.3, stop="Question"))
|
||||
forks.join()
|
||||
|
||||
s += "Answer: To answer this question, here are some possible solutions. "
|
||||
s += "After considering all of them, I will do a majority vote.\n\n"
|
||||
for i in range(num_chains):
|
||||
s += f"Solution {i+1}: " + forks[i]["chain"].strip() + "\n\n"
|
||||
s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
|
||||
s += sgl.gen("answer", max_tokens=16)
|
||||
|
||||
#####################################
|
||||
########## SGL Program End ##########
|
||||
#####################################
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
states = multi_chain_gsm8k.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel)
|
||||
latency = time.time() - tic
|
||||
|
||||
preds = []
|
||||
for i in range(len(states)):
|
||||
preds.append(get_answer_value(states[i]["answer"]))
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid = np.mean(np.array(preds) == INVALID)
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Invalid: {invalid:.3f}")
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "multi_chain_gsm8k",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-shot", type=int, default=0)
|
||||
parser.add_argument("--num-chains", type=int, default=5)
|
||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=50)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
47
benchmark/multi_document_qa/README.md
Normal file
47
benchmark/multi_document_qa/README.md
Normal file
@@ -0,0 +1,47 @@
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 10 --parallel 1
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --backend vllm --num-questions 64
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Build dataset
|
||||
|
||||
```
|
||||
pip install PyPDF2
|
||||
python3 build_dataset.py
|
||||
```
|
||||
|
||||
```python
|
||||
import PyPDF2
|
||||
|
||||
with open('llama2.pdf', 'rb') as file:
|
||||
reader = PyPDF2.PdfReader(file)
|
||||
text = ''
|
||||
for page_num in range(len(reader.pages)):
|
||||
text += reader.pages[page_num].extract_text()
|
||||
with open('output.txt', 'w') as text_file:
|
||||
text_file.write(text)
|
||||
```
|
||||
126
benchmark/multi_document_qa/bench_other.py
Normal file
126
benchmark/multi_document_qa/bench_other.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import json
|
||||
import time
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
USER_PREFIX = "[INST] "
|
||||
USER_SUFFIX = " [/INST]"
|
||||
ASSISTANT_PREFIX = ""
|
||||
ASSISTANT_SUFFIX = " </s><s>"
|
||||
|
||||
|
||||
def multi_document_qa(docs, question, generate):
|
||||
s = USER_PREFIX
|
||||
s += "Pleaes answer a question according to given documents.\n"
|
||||
s += "Question:" + question + "Documents begin.\n"
|
||||
|
||||
s += "".join(docs)
|
||||
|
||||
s += "\nDocuments end."
|
||||
s += ("\n\nBased on the above documents, please answer this question:\n" + question + "\nAnswer in three words or fewer.")
|
||||
s += USER_SUFFIX
|
||||
s += ASSISTANT_PREFIX
|
||||
answer = generate(s, max_tokens=16, stop=None)
|
||||
return answer
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
l = lines[0]
|
||||
arguments = []
|
||||
labels = []
|
||||
|
||||
num_docs = 10
|
||||
if args.backend == "guidance":
|
||||
num_docs = 7 # due to OOM
|
||||
|
||||
for i in range(len(l["questions"][:args.num_questions])):
|
||||
arguments.append({
|
||||
"docs": l["documents"][:num_docs],
|
||||
"question": l["questions"][i],
|
||||
})
|
||||
labels.append(l["answers"][i])
|
||||
states = [None] * len(arguments)
|
||||
|
||||
# Select backend
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_lightllm, url=url, temperature=0)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_vllm, url=url, temperature=0)
|
||||
elif args.backend == "srt-raw":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_srt_raw, url=url, temperature=0)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models, gen
|
||||
|
||||
model = models.LlamaCpp("/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf", n_gpu_layers=-1, n_ctx=11000)
|
||||
|
||||
def generate(prompt, max_tokens, stop):
|
||||
out = model + prompt + gen(name="answer",
|
||||
max_tokens=max_tokens, temperature=0, stop=stop)
|
||||
return out["answer"]
|
||||
|
||||
# warmup
|
||||
generate("Hello!", max_tokens=8, stop=None)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
# Run requests
|
||||
def get_one_answer(i):
|
||||
states[i] = multi_document_qa(generate=generate, **arguments[i])
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(labels))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(get_one_answer, list(range(len(labels))))
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(states)
|
||||
correct = 0
|
||||
for s, label in zip(states, labels):
|
||||
answer = s.lower()
|
||||
if all(x in answer for x in label.lower().split(" ")):
|
||||
correct += 1
|
||||
accuracy = correct / len(labels)
|
||||
print(f"Accuracy: {accuracy:.3f}")
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "multi_document_qa",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"accuracy": accuracy,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=100)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
84
benchmark/multi_document_qa/bench_sglang.py
Normal file
84
benchmark/multi_document_qa/bench_sglang.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
@sgl.function
|
||||
def multi_document_qa(s, docs, question):
|
||||
s += sgl.user_begin()
|
||||
s += "Pleaes answer a question according to given documents.\n"
|
||||
s += "Question:" + question + "Documents begin.\n"
|
||||
|
||||
forks = s.fork(len(docs))
|
||||
forks += lambda i: docs[i]
|
||||
forks.join("concate_and_append")
|
||||
|
||||
s += "\nDocuments end."
|
||||
s += ("\n\nBased on the above documents, please answer this question:\n" + question + "\nAnswer in three words or fewer.")
|
||||
s += sgl.user_end()
|
||||
s += sgl.assistant(sgl.gen("answer", max_tokens=16))
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
l = lines[0]
|
||||
arguments = []
|
||||
labels = []
|
||||
for i in range(len(l["questions"][:args.num_questions])):
|
||||
arguments.append({
|
||||
"docs": l["documents"][:10],
|
||||
"question": l["questions"][i],
|
||||
})
|
||||
labels.append(l["answers"][i])
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
states = multi_document_qa.run_batch(
|
||||
arguments, temperature=0, num_threads=args.parallel)
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print([s["answer"] for s in states])
|
||||
correct = 0
|
||||
for s, label in zip(states, labels):
|
||||
answer = s["answer"].lower()
|
||||
if all(x in answer for x in label.lower().split(" ")):
|
||||
correct += 1
|
||||
accuracy = correct / len(labels)
|
||||
print(f"Accuracy: {accuracy:.3f}")
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "multi_document_qa",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"accuracy": accuracy,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=100)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
64
benchmark/multi_document_qa/build_dataset.py
Normal file
64
benchmark/multi_document_qa/build_dataset.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import json
|
||||
|
||||
import transformers
|
||||
|
||||
content = "\n".join(
|
||||
open("llama2.txt", 'r', encoding='utf-8', errors='ignore').readlines())
|
||||
content = content.replace("\n\n", "\n")
|
||||
|
||||
# Count token
|
||||
name = "meta-llama/Llama-2-7b-chat-hf"
|
||||
t = transformers.AutoTokenizer.from_pretrained(name)
|
||||
print(f"num tokens: {len(t.encode(content))}")
|
||||
|
||||
# Segment
|
||||
SEP = "\n\n"
|
||||
parts = content.split(SEP)
|
||||
print(f"num segments: {len(parts)}")
|
||||
|
||||
segment_len = 1100
|
||||
|
||||
segments = []
|
||||
tmp = []
|
||||
tmp_len = 0
|
||||
for i in range(len(parts)):
|
||||
tmp.append(parts[i])
|
||||
tmp_len += len(t.encode(parts[i]))
|
||||
|
||||
if tmp_len > segment_len:
|
||||
segments.append(SEP.join(tmp))
|
||||
tmp = []
|
||||
tmp_len = 0
|
||||
|
||||
for i, s in enumerate(segments):
|
||||
print(i, len(t.encode(segments[i])))
|
||||
|
||||
# Dump
|
||||
with open("questions.jsonl", "w") as fout:
|
||||
fout.write(json.dumps({
|
||||
"documents": segments[:30],
|
||||
"questions": [
|
||||
"What is the name of the fine-tuned LLMs?",
|
||||
"Which figure shows the helpfulness human evaluation results for Llama 2-Chat?",
|
||||
"What is the number of parameters in the largest Llama 2 model?",
|
||||
"What is the batch size of fine-tuning?",
|
||||
"Where can we find the details of potential data contamination?",
|
||||
"What is the full name of MPT?",
|
||||
"What is the power consumption of RSC in Watt?",
|
||||
"How many tokens of data do they train on?",
|
||||
"Which model's release is delayed due to a lack of time to sufficiently red team?",
|
||||
"Which activation function is used in Llama?"
|
||||
],
|
||||
"answers": [
|
||||
"Llama 2 Chat",
|
||||
"1",
|
||||
"70 B",
|
||||
"64",
|
||||
"A 6",
|
||||
"MosaicML",
|
||||
"400",
|
||||
"2 trillion",
|
||||
"34 B",
|
||||
"SwiGLU",
|
||||
],
|
||||
}) + "\n")
|
||||
26
benchmark/react/README.md
Normal file
26
benchmark/react/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 100
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 100 --backend vllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --num-questions 100 --backend guidance --parallel 1
|
||||
```
|
||||
182
benchmark/react/bench_other.py
Normal file
182
benchmark/react/bench_other.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
from sglang.test.test_utils import (
|
||||
add_common_other_args_and_parse,
|
||||
call_generate_lightllm,
|
||||
call_generate_vllm,
|
||||
call_generate_srt_raw,
|
||||
)
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
def get_prompt(question):
|
||||
prompt = (
|
||||
"""Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:
|
||||
(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.
|
||||
(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage.
|
||||
(3) Finish[answer], which returns the answer and finishes the task.
|
||||
Here are some examples.
|
||||
Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?
|
||||
Thought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area.
|
||||
Action 1: Search[Colorado orogeny]
|
||||
Observation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas.
|
||||
Thought 2: It does not mention the eastern sector. So I need to look up eastern sector.
|
||||
Action 2: Lookup[eastern sector]
|
||||
Observation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny.
|
||||
Thought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range.
|
||||
Action 3: Search[High Plains]
|
||||
Observation 3: High Plains refers to one of two distinct land regions:
|
||||
Thought 4: I need to instead search High Plains (United States).
|
||||
Action 4: Search[High Plains (United States)]
|
||||
Observation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3]
|
||||
Thought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft.
|
||||
Action 5: Finish[1,800 to 7,000 ft]
|
||||
Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who?
|
||||
Thought 1: The question simplifies to "The Simpsons" character Milhouse is named after who. I only need to search Milhouse and find who it is named after.
|
||||
Action 1: Search[Milhouse]
|
||||
Observation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening.
|
||||
Thought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up "named after".
|
||||
Action 2: Lookup[named after]
|
||||
Observation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous.
|
||||
Thought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon.
|
||||
Action 3: Finish[Richard Nixon]
|
||||
Question: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture?
|
||||
Thought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups.
|
||||
Action 1: Search[Adam Clayton Powell]
|
||||
Observation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito'].
|
||||
Thought 2: To find the documentary, I can search Adam Clayton Powell (film).
|
||||
Action 2: Search[Adam Clayton Powell (film)]
|
||||
Observation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg.
|
||||
The film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience.
|
||||
Thought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture.
|
||||
Action 3: Finish[The Saimaa Gesture]
|
||||
Question: What profession does Nicholas Ray and Elia Kazan have in common?
|
||||
Thought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common.
|
||||
Action 1: Search[Nicholas Ray]
|
||||
Observation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 – June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause.
|
||||
Thought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions.
|
||||
Action 2: Search[Elia Kazan]
|
||||
Observation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor.
|
||||
Thought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor.
|
||||
Action 3: Finish[director, screenwriter, actor]
|
||||
Question: Which magazine was started first Arthur's Magazine or First for Women?
|
||||
Thought 1: I need to search Arthur's Magazine and First for Women, and find which was started first.
|
||||
Action 1: Search[Arthur's Magazine]
|
||||
Observation 1: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century.
|
||||
Thought 2: Arthur's Magazine was started in 1844. I need to search First for Women next.
|
||||
Action 2: Search[First for Women]
|
||||
Observation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989.
|
||||
Thought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first.
|
||||
Action 3: Finish[Arthur's Magazine]
|
||||
Question: Were Pavel Urysohn and Leonid Levin known for the same type of work?
|
||||
Thought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same.
|
||||
Action 1: Search[Pavel Urysohn]
|
||||
Observation 1: Pavel Samuilovich Urysohn (February 3, 1898 â August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory.
|
||||
Thought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work.
|
||||
Action 2: Search[Leonid Levin]
|
||||
Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
|
||||
Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
|
||||
Action 3: Finish[yes]
|
||||
""" + question)
|
||||
return prompt
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)[:args.num_questions]
|
||||
arguments = [{
|
||||
"question": k,
|
||||
"triplets": v
|
||||
} for l in lines for k, v in l.items()]
|
||||
|
||||
states = []
|
||||
|
||||
# Select backend
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_lightllm, url=url)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_vllm, url=url)
|
||||
elif args.backend == "srt-raw":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_srt_raw, url=url)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models, gen
|
||||
|
||||
model = models.LlamaCpp(
|
||||
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
|
||||
n_gpu_layers=-1,
|
||||
n_ctx=4096,
|
||||
)
|
||||
|
||||
def call_generate(prompt, temperature, max_tokens, stop):
|
||||
out = (model + prompt + gen(
|
||||
name="result",
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop,
|
||||
))
|
||||
return out["result"]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
def run_single_agent(argument):
|
||||
question = argument["question"]
|
||||
triplets = argument["triplets"]
|
||||
prompt = get_prompt(question)
|
||||
for i in range(1, len(triplets) + 2):
|
||||
prompt += "Thought " + str(i) + ":"
|
||||
states.append(prompt)
|
||||
answer = call_generate(prompt,
|
||||
max_tokens=200,
|
||||
temperature=0,
|
||||
stop="Observation")
|
||||
if i > len(triplets):
|
||||
break
|
||||
prompt += (triplets[i - 1]["thought"] + "\nAction " + str(i) +
|
||||
":" + triplets[i - 1]["action"] + "\nObservation " +
|
||||
str(i) + ":" + triplets[i - 1]["observation"] + "\n")
|
||||
|
||||
states.append(answer)
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for arg in tqdm(arguments):
|
||||
run_single_agent(arg)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(run_single_agent, arguments)
|
||||
latency = time.time() - tic
|
||||
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "ReAct Agents",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": len(arguments),
|
||||
"other": {
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="hotpotqa_100.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=10)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
141
benchmark/react/bench_sglang.py
Normal file
141
benchmark/react/bench_sglang.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
@sgl.function
|
||||
def webthink(s, question, triplets):
|
||||
s += (
|
||||
"""Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:
|
||||
(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.
|
||||
(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage.
|
||||
(3) Finish[answer], which returns the answer and finishes the task.
|
||||
Here are some examples.
|
||||
Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?
|
||||
Thought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area.
|
||||
Action 1: Search[Colorado orogeny]
|
||||
Observation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas.
|
||||
Thought 2: It does not mention the eastern sector. So I need to look up eastern sector.
|
||||
Action 2: Lookup[eastern sector]
|
||||
Observation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny.
|
||||
Thought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range.
|
||||
Action 3: Search[High Plains]
|
||||
Observation 3: High Plains refers to one of two distinct land regions:
|
||||
Thought 4: I need to instead search High Plains (United States).
|
||||
Action 4: Search[High Plains (United States)]
|
||||
Observation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3]
|
||||
Thought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft.
|
||||
Action 5: Finish[1,800 to 7,000 ft]
|
||||
Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who?
|
||||
Thought 1: The question simplifies to "The Simpsons" character Milhouse is named after who. I only need to search Milhouse and find who it is named after.
|
||||
Action 1: Search[Milhouse]
|
||||
Observation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening.
|
||||
Thought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up "named after".
|
||||
Action 2: Lookup[named after]
|
||||
Observation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous.
|
||||
Thought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon.
|
||||
Action 3: Finish[Richard Nixon]
|
||||
Question: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture?
|
||||
Thought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups.
|
||||
Action 1: Search[Adam Clayton Powell]
|
||||
Observation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito'].
|
||||
Thought 2: To find the documentary, I can search Adam Clayton Powell (film).
|
||||
Action 2: Search[Adam Clayton Powell (film)]
|
||||
Observation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg.
|
||||
The film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience.
|
||||
Thought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture.
|
||||
Action 3: Finish[The Saimaa Gesture]
|
||||
Question: What profession does Nicholas Ray and Elia Kazan have in common?
|
||||
Thought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common.
|
||||
Action 1: Search[Nicholas Ray]
|
||||
Observation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 – June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause.
|
||||
Thought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions.
|
||||
Action 2: Search[Elia Kazan]
|
||||
Observation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor.
|
||||
Thought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor.
|
||||
Action 3: Finish[director, screenwriter, actor]
|
||||
Question: Which magazine was started first Arthur's Magazine or First for Women?
|
||||
Thought 1: I need to search Arthur's Magazine and First for Women, and find which was started first.
|
||||
Action 1: Search[Arthur's Magazine]
|
||||
Observation 1: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century.
|
||||
Thought 2: Arthur's Magazine was started in 1844. I need to search First for Women next.
|
||||
Action 2: Search[First for Women]
|
||||
Observation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989.
|
||||
Thought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first.
|
||||
Action 3: Finish[Arthur's Magazine]
|
||||
Question: Were Pavel Urysohn and Leonid Levin known for the same type of work?
|
||||
Thought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same.
|
||||
Action 1: Search[Pavel Urysohn]
|
||||
Observation 1: Pavel Samuilovich Urysohn (February 3, 1898 â August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory.
|
||||
Thought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work.
|
||||
Action 2: Search[Leonid Levin]
|
||||
Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
|
||||
Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
|
||||
Action 3: Finish[yes]
|
||||
""" + question)
|
||||
for i in range(1, len(triplets) + 2):
|
||||
s += "Thought " + str(i) + ":"
|
||||
ss = s.fork(1)
|
||||
ss[0] += sgl.gen(name="thought_action", max_tokens=200, stop="Observation")
|
||||
# ss.join()
|
||||
# to verify the correctness of output, this should be collected
|
||||
# print(ss[0]["thought_action"])
|
||||
if i > len(triplets):
|
||||
break
|
||||
s += (triplets[i - 1]["thought"] + "\nAction " + str(i) + ":" +
|
||||
triplets[i - 1]["action"] + "\nObservation " + str(i) + ":" +
|
||||
triplets[i - 1]["observation"] + "\n")
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)[:args.num_questions]
|
||||
arguments = [{
|
||||
"question": k,
|
||||
"triplets": v
|
||||
} for l in lines for k, v in l.items()]
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
states = []
|
||||
tic = time.time()
|
||||
states = webthink.run_batch(arguments,
|
||||
temperature=0,
|
||||
num_threads=args.parallel)
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "ReAct Agents",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": len(arguments),
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="hotpotqa_100.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=10)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
27
benchmark/tip_suggestion/README.md
Normal file
27
benchmark/tip_suggestion/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 64
|
||||
python3 bench_sglang.py --num-questions 32 --parallel 1
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --backend vllm --num-questions 64
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1
|
||||
```
|
||||
124
benchmark/tip_suggestion/bench_other.py
Normal file
124
benchmark/tip_suggestion/bench_other.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import json
|
||||
import time
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
number = 5
|
||||
|
||||
|
||||
def expand_tip(topic, tip, generate):
|
||||
s = (
|
||||
"""Please expand a tip for a topic into a detailed paragraph.
|
||||
|
||||
Topic: staying healthy
|
||||
Tip: Regular Exercise
|
||||
Paragraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement.
|
||||
|
||||
Topic: building a campfire
|
||||
Tip: Choose the Right Location
|
||||
Paragraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches.
|
||||
|
||||
Topic: writing a blog post
|
||||
Tip: structure your content effectively
|
||||
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
|
||||
|
||||
Topic: """ + topic + "\nTip: " + tip + "\nParagraph:")
|
||||
return generate(s, max_tokens=128, stop=["\n\n"])
|
||||
|
||||
|
||||
def suggest_tips(topic, generate):
|
||||
s = "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
|
||||
s += "USER: Give some tips for " + topic + ".\n"
|
||||
s += ("ASSISTANT: Okay. Here are " + str(number) + " concise tips, each under 8 words:\n")
|
||||
|
||||
tips = []
|
||||
for i in range(1, 1 + number):
|
||||
s += f"{i}."
|
||||
tip = generate(s, max_tokens=24, stop=[".", "\n"])
|
||||
s += tip + ".\n"
|
||||
tips.append(tip)
|
||||
|
||||
paragraphs = [expand_tip(topic, tip, generate=generate) for tip in tips]
|
||||
|
||||
for i in range(1, 1 + number):
|
||||
s += f"Tip {i}:" + paragraphs[i-1] + "\n"
|
||||
return s
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)[:args.num_questions]
|
||||
states = [None] * len(lines)
|
||||
|
||||
# Select backend
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_lightllm, url=url, temperature=0)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_vllm, url=url, temperature=0)
|
||||
elif args.backend == "srt-raw":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_srt_raw, url=url, temperature=0)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models, gen
|
||||
|
||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
||||
|
||||
def generate(prompt, max_tokens, stop):
|
||||
out = model + prompt + gen(name="answer",
|
||||
max_tokens=max_tokens, temperature=0, stop=stop)
|
||||
return out["answer"]
|
||||
|
||||
# warmup
|
||||
generate("Hello!", max_tokens=8, stop=None)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
# Run requests
|
||||
def get_one_answer(i):
|
||||
states[i] = suggest_tips(lines[i]["topic"], generate)
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(lines))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(get_one_answer, list(range(len(lines))))
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "tip_suggestion",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="topic.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=100)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
91
benchmark/tip_suggestion/bench_sglang.py
Normal file
91
benchmark/tip_suggestion/bench_sglang.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
number = 5
|
||||
|
||||
|
||||
@sgl.function
|
||||
def expand_tip(s, topic, tip):
|
||||
s += (
|
||||
"""Please expand a tip for a topic into a detailed paragraph.
|
||||
|
||||
Topic: staying healthy
|
||||
Tip: Regular Exercise
|
||||
Paragraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement.
|
||||
|
||||
Topic: building a campfire
|
||||
Tip: Choose the Right Location
|
||||
Paragraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches.
|
||||
|
||||
Topic: writing a blog post
|
||||
Tip: structure your content effectively
|
||||
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
|
||||
|
||||
Topic: """ + topic + "\nTip: " + tip + "\nParagraph:")
|
||||
s += sgl.gen("paragraph", max_tokens=128, stop=["\n\n"], temperature=0)
|
||||
|
||||
|
||||
@sgl.function
|
||||
def suggest_tips(s, topic):
|
||||
s += "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
|
||||
s += "USER: Give some tips for " + topic + ".\n"
|
||||
s += ("ASSISTANT: Okay. Here are " + str(number) + " concise tips, each under 8 words:\n")
|
||||
|
||||
paragraphs = []
|
||||
for i in range(1, 1 + number):
|
||||
s += f"{i}." + sgl.gen(f"tip_{i}", max_tokens=24, stop=[".", "\n"]) + ".\n"
|
||||
paragraphs.append(expand_tip(topic=topic, tip=s[f"tip_{i}"]))
|
||||
|
||||
for i in range(1, 1 + number):
|
||||
s += f"Tip {i}:" + paragraphs[i-1]["paragraph"] + "\n"
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)[:args.num_questions]
|
||||
arguments = [
|
||||
{"topic": l["topic"]} for l in lines
|
||||
]
|
||||
|
||||
# Select backend
|
||||
sgl.set_default_backend(select_sglang_backend(args))
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
states = suggest_tips.run_batch(
|
||||
arguments, temperature=0, num_threads=args.parallel)
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "tip_suggestion",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="topic.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=100)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
43
benchmark/tree_of_thought/README.md
Normal file
43
benchmark/tree_of_thought/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
## Download data
|
||||
```
|
||||
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 32 --parallel 16
|
||||
python3 bench_sglang.py --num-questions 10 --parallel 1
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 32 --backend vllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lightllm
|
||||
```
|
||||
# A10G
|
||||
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 32 --backend lightllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --num-questions 32 --backend guidance --parallel 1
|
||||
```
|
||||
183
benchmark/tree_of_thought/bench_other.py
Normal file
183
benchmark/tree_of_thought/bench_other.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def get_answer_value(answer_str):
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r'\d+', answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
def most_frequent_number(numbers):
|
||||
if not numbers:
|
||||
return None
|
||||
|
||||
frequency = Counter(numbers)
|
||||
most_frequent = max(frequency, key=frequency.get)
|
||||
return most_frequent
|
||||
|
||||
|
||||
USER_PREFIX = "[INST] "
|
||||
USER_SUFFIX = " [/INST]"
|
||||
ASSISTANT_PREFIX = ""
|
||||
ASSISTANT_SUFFIX = " </s><s>"
|
||||
|
||||
# Use a low temp to make the results more deterministic and the comparison more fair.
|
||||
temp = 0.3
|
||||
|
||||
|
||||
def propose_plan(s, question, num_branches, call_generate):
|
||||
s += (USER_PREFIX +
|
||||
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question + USER_SUFFIX)
|
||||
|
||||
s += ASSISTANT_PREFIX
|
||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||
|
||||
|
||||
def execute_plan(s, num_branches, call_generate):
|
||||
s += (USER_PREFIX +
|
||||
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + USER_SUFFIX)
|
||||
s += ASSISTANT_PREFIX
|
||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||
|
||||
|
||||
def reflect_solution(s, num_branches, call_generate):
|
||||
s += (USER_PREFIX +
|
||||
"""Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + USER_SUFFIX)
|
||||
s += ASSISTANT_PREFIX
|
||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||
|
||||
|
||||
def tree_search(question, num_branches, call_generate):
|
||||
s = ""
|
||||
solutions = []
|
||||
|
||||
plan_forks = propose_plan(s, question, num_branches, call_generate)
|
||||
for plan in plan_forks:
|
||||
sol_forks = execute_plan(plan, num_branches, call_generate)
|
||||
for sol in sol_forks:
|
||||
score_forks = reflect_solution(sol, num_branches, call_generate)
|
||||
solutions.append(sol_forks)
|
||||
|
||||
return solutions
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
|
||||
# Construct prompts
|
||||
num_branches = 3
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
questions.append(lines[i]["question"])
|
||||
labels.append(get_answer_value(lines[i]["answer"]))
|
||||
assert all(l != INVALID for l in labels)
|
||||
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
|
||||
|
||||
# Select backend
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_lightllm, url=url)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_vllm, url=url)
|
||||
elif args.backend == "srt-raw":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_srt_raw, url=url)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models, gen
|
||||
|
||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
||||
|
||||
def call_generate(prompt, temperature, max_tokens, stop, n):
|
||||
if n == 1:
|
||||
out = model + prompt + gen(name="answer",
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
return out["answer"]
|
||||
else:
|
||||
rets = []
|
||||
for i in range(n):
|
||||
out = model + prompt + gen(name="answer",
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
rets.append(out["answer"])
|
||||
return rets
|
||||
|
||||
# Run requests
|
||||
states = [None] * len(questions)
|
||||
def get_one_answer(i):
|
||||
states[i] = tree_search(**arguments[i], call_generate=call_generate)
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(questions))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(get_one_answer, list(range(len(questions))))
|
||||
latency = time.time() - tic
|
||||
|
||||
answers_text = []
|
||||
for s in states:
|
||||
answers_text.append([x for xs in s for x in xs])
|
||||
|
||||
preds = []
|
||||
for i in range(len(states)):
|
||||
answers = [get_answer_value(v) for v in answers_text[i]]
|
||||
preds.append(most_frequent_number(answers))
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid = np.mean(np.array(preds) == INVALID)
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Invalid: {invalid:.3f}")
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "tree_of_thought_gsm8k",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=200)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
147
benchmark/tree_of_thought/bench_sglang.py
Normal file
147
benchmark/tree_of_thought/bench_sglang.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import argparse
|
||||
import ast
|
||||
from collections import Counter
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def get_answer_value(answer_str):
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r'\d+', answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
def most_frequent_number(numbers):
|
||||
if not numbers:
|
||||
return None
|
||||
|
||||
frequency = Counter(numbers)
|
||||
most_frequent = max(frequency, key=frequency.get)
|
||||
return most_frequent
|
||||
|
||||
|
||||
# Use a low temp to make the results more deterministic and the comparison more fair.
|
||||
temp = 0.3
|
||||
|
||||
|
||||
def propose_plan(s, question, num_branches):
|
||||
s += sgl.user(
|
||||
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question)
|
||||
forks = s.fork(num_branches)
|
||||
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
|
||||
return forks
|
||||
|
||||
|
||||
def execute_plan(s, num_branches):
|
||||
s += sgl.user(
|
||||
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""")
|
||||
forks = s.fork(num_branches)
|
||||
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
|
||||
return forks
|
||||
|
||||
|
||||
def reflect_solution(s, num_branches):
|
||||
s += sgl.user(
|
||||
"""Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""")
|
||||
forks = s.fork(num_branches)
|
||||
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
|
||||
return forks
|
||||
|
||||
|
||||
@sgl.function
|
||||
def tree_search(s, question, num_branches):
|
||||
forks_to_join = []
|
||||
|
||||
plan_forks = propose_plan(s, question, num_branches)
|
||||
forks_to_join.append(plan_forks)
|
||||
|
||||
sol_states = []
|
||||
for plan in plan_forks:
|
||||
forks = execute_plan(plan, num_branches)
|
||||
forks_to_join.append(forks)
|
||||
sol_states.extend(forks)
|
||||
|
||||
for sol in sol_states:
|
||||
forks = reflect_solution(sol, num_branches)
|
||||
forks_to_join.append(forks)
|
||||
|
||||
for f in reversed(forks_to_join):
|
||||
f.join()
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
|
||||
# Construct prompts
|
||||
num_branches = 3
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
questions.append(lines[i]["question"])
|
||||
labels.append(get_answer_value(lines[i]["answer"]))
|
||||
assert all(l != INVALID for l in labels)
|
||||
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
states = tree_search.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel)
|
||||
latency = time.time() - tic
|
||||
answers_text = []
|
||||
for s in states:
|
||||
answers_text.append([x for xs in s["answer"] for x in xs])
|
||||
|
||||
preds = []
|
||||
for i in range(len(states)):
|
||||
answers = [get_answer_value(v) for v in answers_text[i]]
|
||||
preds.append(most_frequent_number(answers))
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid = np.mean(np.array(preds) == INVALID)
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Invalid: {invalid:.3f}")
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "tree_of_thought_gsm8k",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=200)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
43
benchmark/tree_of_thought_deep/README.md
Normal file
43
benchmark/tree_of_thought_deep/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
## Download data
|
||||
```
|
||||
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 32 --parallel 8
|
||||
python3 bench_sglang.py --num-questions 16 --parallel 1
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 32 --backend vllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lightllm
|
||||
```
|
||||
# A10G
|
||||
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 32 --backend lightllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --num-questions 8 --backend guidance --parallel 1
|
||||
```
|
||||
198
benchmark/tree_of_thought_deep/bench_other.py
Normal file
198
benchmark/tree_of_thought_deep/bench_other.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def get_answer_value(answer_str):
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r'\d+', answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
def most_frequent_number(numbers):
|
||||
if not numbers:
|
||||
return None
|
||||
|
||||
frequency = Counter(numbers)
|
||||
most_frequent = max(frequency, key=frequency.get)
|
||||
return most_frequent
|
||||
|
||||
|
||||
USER_PREFIX = "[INST] "
|
||||
USER_SUFFIX = " [/INST]"
|
||||
ASSISTANT_PREFIX = ""
|
||||
ASSISTANT_SUFFIX = " </s><s>"
|
||||
|
||||
# Use a low temp to make the results more deterministic and the comparison more fair.
|
||||
temp = 0.001
|
||||
|
||||
|
||||
def propose_plan(s, question, num_branches, call_generate):
|
||||
s += (USER_PREFIX +
|
||||
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question + USER_SUFFIX)
|
||||
|
||||
s += ASSISTANT_PREFIX
|
||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||
|
||||
|
||||
def execute_plan(s, num_branches, call_generate):
|
||||
s += (USER_PREFIX +
|
||||
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + USER_SUFFIX)
|
||||
s += ASSISTANT_PREFIX
|
||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||
|
||||
|
||||
def reflect_solution(s, num_branches, call_generate):
|
||||
s += (USER_PREFIX +
|
||||
"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + USER_SUFFIX)
|
||||
s += ASSISTANT_PREFIX
|
||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||
|
||||
|
||||
def get_final_answer(s, num_branches, call_generate):
|
||||
s += (USER_PREFIX +
|
||||
"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""" + USER_SUFFIX)
|
||||
s += ASSISTANT_PREFIX
|
||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||
|
||||
|
||||
def tree_search(question, num_branches, call_generate):
|
||||
plan_forks = propose_plan("", question, num_branches, call_generate)
|
||||
|
||||
sol_states = []
|
||||
for plan in plan_forks:
|
||||
forks = execute_plan(plan, num_branches, call_generate)
|
||||
sol_states.extend(forks)
|
||||
|
||||
ref_states = []
|
||||
for sol in sol_states:
|
||||
forks = reflect_solution(sol, num_branches, call_generate)
|
||||
ref_states.extend(forks)
|
||||
|
||||
solutions = []
|
||||
for sol in ref_states:
|
||||
ans = get_final_answer(sol, num_branches, call_generate)
|
||||
solutions.append(ans)
|
||||
|
||||
return solutions
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
|
||||
# Construct prompts
|
||||
num_branches = 2
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
questions.append(lines[i]["question"])
|
||||
labels.append(get_answer_value(lines[i]["answer"]))
|
||||
assert all(l != INVALID for l in labels)
|
||||
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
|
||||
|
||||
# Select backend
|
||||
if args.backend == "lightllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_lightllm, url=url)
|
||||
elif args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_vllm, url=url)
|
||||
elif args.backend == "srt-raw":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
call_generate = partial(call_generate_srt_raw, url=url)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models, gen
|
||||
|
||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
||||
|
||||
def call_generate(prompt, temperature, max_tokens, stop, n):
|
||||
if n == 1:
|
||||
out = model + prompt + gen(name="answer",
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
return out["answer"]
|
||||
else:
|
||||
rets = []
|
||||
for i in range(n):
|
||||
out = model + prompt + gen(name="answer",
|
||||
max_tokens=max_tokens, temperature=temperature, stop=stop)
|
||||
rets.append(out["answer"])
|
||||
return rets
|
||||
|
||||
# Run requests
|
||||
states = [None] * len(questions)
|
||||
def get_one_answer(i):
|
||||
states[i] = tree_search(**arguments[i], call_generate=call_generate)
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(questions))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(get_one_answer, list(range(len(questions))))
|
||||
latency = time.time() - tic
|
||||
|
||||
answers_text = []
|
||||
for s in states:
|
||||
answers_text.append([x for xs in s for x in xs])
|
||||
|
||||
preds = []
|
||||
for i in range(len(states)):
|
||||
answers = [get_answer_value(v) for v in answers_text[i]]
|
||||
preds.append(most_frequent_number(answers))
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid = np.mean(np.array(preds) == INVALID)
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Invalid: {invalid:.3f}")
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "tree_of_thought_gsm8k",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=200)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
157
benchmark/tree_of_thought_deep/bench_sglang.py
Normal file
157
benchmark/tree_of_thought_deep/bench_sglang.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import argparse
|
||||
import ast
|
||||
from collections import Counter
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.utils import read_jsonl, dump_state_text
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def get_answer_value(answer_str):
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r'\d+', answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
def most_frequent_number(numbers):
|
||||
if not numbers:
|
||||
return None
|
||||
|
||||
frequency = Counter(numbers)
|
||||
most_frequent = max(frequency, key=frequency.get)
|
||||
return most_frequent
|
||||
|
||||
|
||||
# Use a low temp to make the results more deterministic and the comparison more fair.
|
||||
temp = 0.001
|
||||
|
||||
|
||||
def propose_plan(s, question, num_branches):
|
||||
s += sgl.user(
|
||||
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question)
|
||||
forks = s.fork(num_branches)
|
||||
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
|
||||
return forks
|
||||
|
||||
|
||||
def execute_plan(s, num_branches):
|
||||
s += sgl.user(
|
||||
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""")
|
||||
forks = s.fork(num_branches)
|
||||
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
|
||||
return forks
|
||||
|
||||
|
||||
def reflect_solution(s, num_branches):
|
||||
s += sgl.user(
|
||||
"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""")
|
||||
forks = s.fork(num_branches)
|
||||
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
|
||||
return forks
|
||||
|
||||
|
||||
def get_final_answer(s, num_branches):
|
||||
s += sgl.user(
|
||||
"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""")
|
||||
forks = s.fork(num_branches)
|
||||
forks += sgl.assistant(sgl.gen("final_answer", max_tokens=256, temperature=temp))
|
||||
return forks
|
||||
|
||||
|
||||
|
||||
@sgl.function
|
||||
def tree_search(s, question, num_branches):
|
||||
plan_forks = propose_plan(s, question, num_branches)
|
||||
|
||||
sol_states = []
|
||||
for plan in plan_forks:
|
||||
forks = execute_plan(plan, num_branches)
|
||||
sol_states.extend(forks)
|
||||
|
||||
ref_states = []
|
||||
for sol in sol_states:
|
||||
forks = reflect_solution(sol, num_branches)
|
||||
ref_states.extend(forks)
|
||||
|
||||
solutions = []
|
||||
for sol in ref_states:
|
||||
forks = get_final_answer(sol, num_branches)
|
||||
solutions.append(forks)
|
||||
solutions = [[s.text() for s in forks] for forks in solutions]
|
||||
|
||||
return solutions
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
|
||||
# Construct prompts
|
||||
num_branches = 2
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
questions.append(lines[i]["question"])
|
||||
labels.append(get_answer_value(lines[i]["answer"]))
|
||||
assert all(l != INVALID for l in labels)
|
||||
arguments = [{"question": q, "num_branches": num_branches} for q in questions]
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
states = tree_search.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel)
|
||||
latency = time.time() - tic
|
||||
answers_text = []
|
||||
for s in states:
|
||||
answers_text.append([x for xs in s.ret_value for x in xs])
|
||||
|
||||
preds = []
|
||||
for i in range(len(states)):
|
||||
answers = [get_answer_value(v) for v in answers_text[i]]
|
||||
preds.append(most_frequent_number(answers))
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid = np.mean(np.array(preds) == INVALID)
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Invalid: {invalid:.3f}")
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", answers_text)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "tree_of_thought_gsm8k",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=200)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
20
docs/flashinfer.md
Normal file
20
docs/flashinfer.md
Normal file
@@ -0,0 +1,20 @@
|
||||
## Flashinfer Mode
|
||||
|
||||
[`flashinfer`](https://github.com/flashinfer-ai/flashinfer) is a kernel library for LLM serving; we use it here to support our attention computation.
|
||||
|
||||
### Install flashinfer
|
||||
|
||||
```bash
|
||||
git submodule update --init --recursive
|
||||
pip install 3rdparty/flashinfer/python
|
||||
```
|
||||
|
||||
### Run Sever With Flashinfer Mode
|
||||
|
||||
Add through `--model_mode` argument from the command line.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --model-mode flashinfer
|
||||
```
|
||||
63
docs/test_process.md
Normal file
63
docs/test_process.md
Normal file
@@ -0,0 +1,63 @@
|
||||
## SRT Unit Tests
|
||||
|
||||
### Low-level API
|
||||
```
|
||||
cd sglang/test/srt/model
|
||||
|
||||
python3 test_llama_low_api.py
|
||||
python3 test_llama_extend.py
|
||||
python3 test_llava_low_api.py
|
||||
python3 bench_llama_low_api.py
|
||||
```
|
||||
|
||||
### High-level API
|
||||
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
cd test/lang
|
||||
python3 test_srt_backend.py
|
||||
```
|
||||
|
||||
### Performance
|
||||
|
||||
#### MMLU
|
||||
```
|
||||
cd benchmark/mmlu
|
||||
```
|
||||
Follow README.md to download the data.
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --nsub 3
|
||||
|
||||
# Expected performance on A10G
|
||||
# Total latency: 8.200
|
||||
# Average accuracy: 0.413
|
||||
```
|
||||
|
||||
### More Models
|
||||
|
||||
#### LLaVA
|
||||
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
cd benchmark/llava_bench
|
||||
python3 bench_sglang.py
|
||||
```
|
||||
|
||||
## SGLang Unit Tests
|
||||
```
|
||||
export ANTHROPIC_API_KEY=
|
||||
export OPENAI_API_KEY=
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
cd test/lang
|
||||
python3 run_all.py
|
||||
```
|
||||
19
examples/quick_start/anthropic_example_chat.py
Normal file
19
examples/quick_start/anthropic_example_chat.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from sglang import function, system, user, assistant, gen, set_default_backend, Anthropic
|
||||
|
||||
|
||||
@function
|
||||
def multi_turn_question(s, question_1, question_2):
|
||||
s += user(question_1)
|
||||
s += assistant(gen("answer_1", max_tokens=256))
|
||||
s += user(question_2)
|
||||
s += assistant(gen("answer_2", max_tokens=256))
|
||||
|
||||
set_default_backend(Anthropic("claude-2"))
|
||||
|
||||
state = multi_turn_question.run(
|
||||
question_1="What is the capital of the United States?",
|
||||
question_2="List two local attractions.",
|
||||
)
|
||||
|
||||
for m in state.messages():
|
||||
print(m["role"], ":", m["content"])
|
||||
26
examples/quick_start/anthropic_example_complete.py
Normal file
26
examples/quick_start/anthropic_example_complete.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sglang import function, gen, set_default_backend, Anthropic
|
||||
|
||||
|
||||
@function
|
||||
def few_shot_qa(s, question):
|
||||
s += (
|
||||
"""
|
||||
\n\nHuman: What is the capital of France?
|
||||
\n\nAssistant: Paris
|
||||
\n\nHuman: What is the capital of Germany?
|
||||
\n\nAssistant: Berlin
|
||||
\n\nHuman: What is the capital of Italy?
|
||||
\n\nAssistant: Rome
|
||||
""")
|
||||
s += "\n\nHuman: " + question + "\n"
|
||||
s += "\n\nAssistant:" + gen("answer", stop="\n", temperature=0)
|
||||
|
||||
|
||||
set_default_backend(Anthropic("claude-2"))
|
||||
|
||||
state = few_shot_qa.run(question="What is the capital of the United States?")
|
||||
answer = state["answer"].strip().lower()
|
||||
|
||||
assert "washington" in answer, f"answer: {state['answer']}"
|
||||
|
||||
print(state.text())
|
||||
20
examples/quick_start/anthropic_example_stream.py
Normal file
20
examples/quick_start/anthropic_example_stream.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from sglang import function, system, user, assistant, gen, set_default_backend, Anthropic
|
||||
|
||||
|
||||
@function
|
||||
def multi_turn_question(s, question_1, question_2):
|
||||
s += user(question_1)
|
||||
s += assistant(gen("answer_1", max_tokens=256))
|
||||
s += user(question_2)
|
||||
s += assistant(gen("answer_2", max_tokens=256))
|
||||
|
||||
set_default_backend(Anthropic("claude-2"))
|
||||
|
||||
state = multi_turn_question.run(
|
||||
question_1="What is the capital of the United States?",
|
||||
question_2="List two local attractions.",
|
||||
stream=True
|
||||
)
|
||||
|
||||
for out in state.text_iter():
|
||||
print(out, end="", flush=True)
|
||||
44
examples/quick_start/more_stream_methods.py
Normal file
44
examples/quick_start/more_stream_methods.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import asyncio
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
@sgl.function
|
||||
def multi_turn_question(s, question_1, question_2):
|
||||
s += sgl.system("You are a helpful assistant.")
|
||||
s += sgl.user(question_1)
|
||||
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
|
||||
s += sgl.user(question_2)
|
||||
s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
|
||||
|
||||
|
||||
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo"))
|
||||
#sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
|
||||
|
||||
|
||||
def stream_a_variable():
|
||||
state = multi_turn_question.run(
|
||||
question_1="What is the capital of the United States?",
|
||||
question_2="List two local attractions.",
|
||||
stream=True
|
||||
)
|
||||
|
||||
for out in state.text_iter(var_name="answer_2"):
|
||||
print(out, end="", flush=True)
|
||||
print()
|
||||
|
||||
|
||||
async def async_stream():
|
||||
state = multi_turn_question.run(
|
||||
question_1="What is the capital of the United States?",
|
||||
question_2="List two local attractions.",
|
||||
stream=True
|
||||
)
|
||||
|
||||
async for out in state.text_async_iter(var_name="answer_2"):
|
||||
print(out, end="", flush=True)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#stream_a_variable()
|
||||
asyncio.run(async_stream())
|
||||
20
examples/quick_start/openai_example_chat.py
Normal file
20
examples/quick_start/openai_example_chat.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI
|
||||
|
||||
|
||||
@function
|
||||
def multi_turn_question(s, question_1, question_2):
|
||||
s += system("You are a helpful assistant.")
|
||||
s += user(question_1)
|
||||
s += assistant(gen("answer_1", max_tokens=256))
|
||||
s += user(question_2)
|
||||
s += assistant(gen("answer_2", max_tokens=256))
|
||||
|
||||
set_default_backend(OpenAI("gpt-3.5-turbo"))
|
||||
|
||||
state = multi_turn_question.run(
|
||||
question_1="What is the capital of the United States?",
|
||||
question_2="List two local attractions.",
|
||||
)
|
||||
|
||||
for m in state.messages():
|
||||
print(m["role"], ":", m["content"])
|
||||
26
examples/quick_start/openai_example_complete.py
Normal file
26
examples/quick_start/openai_example_complete.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sglang import function, gen, set_default_backend, OpenAI
|
||||
|
||||
|
||||
@function
|
||||
def few_shot_qa(s, question):
|
||||
s += (
|
||||
"""The following are questions with answers.
|
||||
Q: What is the capital of France?
|
||||
A: Paris
|
||||
Q: What is the capital of Germany?
|
||||
A: Berlin
|
||||
Q: What is the capital of Italy?
|
||||
A: Rome
|
||||
""")
|
||||
s += "Q: " + question + "\n"
|
||||
s += "A:" + gen("answer", stop="\n", temperature=0)
|
||||
|
||||
|
||||
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
|
||||
|
||||
state = few_shot_qa.run(question="What is the capital of the United States?")
|
||||
answer = state["answer"].strip().lower()
|
||||
|
||||
assert "washington" in answer, f"answer: {state['answer']}"
|
||||
|
||||
print(state.text())
|
||||
21
examples/quick_start/openai_example_stream.py
Normal file
21
examples/quick_start/openai_example_stream.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI
|
||||
|
||||
|
||||
@function
|
||||
def multi_turn_question(s, question_1, question_2):
|
||||
s += system("You are a helpful assistant.")
|
||||
s += user(question_1)
|
||||
s += assistant(gen("answer_1", max_tokens=256))
|
||||
s += user(question_2)
|
||||
s += assistant(gen("answer_2", max_tokens=256))
|
||||
|
||||
set_default_backend(OpenAI("gpt-3.5-turbo"))
|
||||
|
||||
state = multi_turn_question.run(
|
||||
question_1="What is the capital of the United States?",
|
||||
question_2="List two local attractions.",
|
||||
stream=True
|
||||
)
|
||||
|
||||
for out in state.text_iter():
|
||||
print(out, end="", flush=True)
|
||||
26
examples/quick_start/srt_example_chat.py
Normal file
26
examples/quick_start/srt_example_chat.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sglang import function, system, user, assistant, gen, set_default_backend, Runtime
|
||||
|
||||
|
||||
@function
|
||||
def multi_turn_question(s, question_1, question_2):
|
||||
s += system("You are a helpful assistant.")
|
||||
s += user(question_1)
|
||||
s += assistant(gen("answer_1", max_tokens=256))
|
||||
s += user(question_2)
|
||||
s += assistant(gen("answer_2", max_tokens=256))
|
||||
|
||||
|
||||
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
|
||||
#runtime = Runtime(model_path="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
set_default_backend(runtime)
|
||||
|
||||
state = multi_turn_question.run(
|
||||
question_1="What is the capital of the United States?",
|
||||
question_2="List two local attractions.",
|
||||
)
|
||||
|
||||
for m in state.messages():
|
||||
print(m["role"], ":", m["content"])
|
||||
|
||||
|
||||
runtime.shutdown()
|
||||
28
examples/quick_start/srt_example_complete.py
Normal file
28
examples/quick_start/srt_example_complete.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from sglang import function, gen, set_default_backend, Runtime
|
||||
|
||||
|
||||
@function
|
||||
def few_shot_qa(s, question):
|
||||
s += (
|
||||
"""The following are questions with answers.
|
||||
Q: What is the capital of France?
|
||||
A: Paris
|
||||
Q: What is the capital of Germany?
|
||||
A: Berlin
|
||||
Q: What is the capital of Italy?
|
||||
A: Rome
|
||||
""")
|
||||
s += "Q: " + question + "\n"
|
||||
s += "A:" + gen("answer", stop="\n", temperature=0)
|
||||
|
||||
|
||||
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
|
||||
set_default_backend(runtime)
|
||||
|
||||
state = few_shot_qa.run(question="What is the capital of the United States?")
|
||||
|
||||
answer = state["answer"].strip().lower()
|
||||
assert "washington" in answer, f"answer: {state['answer']}"
|
||||
print(state.text())
|
||||
|
||||
runtime.shutdown()
|
||||
21
examples/quick_start/srt_example_regex.py
Normal file
21
examples/quick_start/srt_example_regex.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sglang import function, gen, set_default_backend, Runtime
|
||||
|
||||
|
||||
@function
|
||||
def regex_gen(s):
|
||||
s += "Q: What is the IP address of the Google DNS servers?\n"
|
||||
s += "A: " + gen(
|
||||
"answer",
|
||||
temperature=0,
|
||||
regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
|
||||
)
|
||||
|
||||
|
||||
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
|
||||
set_default_backend(runtime)
|
||||
|
||||
state = regex_gen.run()
|
||||
|
||||
print(state.text())
|
||||
|
||||
runtime.shutdown()
|
||||
26
examples/quick_start/srt_example_stream.py
Normal file
26
examples/quick_start/srt_example_stream.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sglang import function, system, user, assistant, gen, set_default_backend, Runtime
|
||||
|
||||
|
||||
@function
|
||||
def multi_turn_question(s, question_1, question_2):
|
||||
s += system("You are a helpful assistant.")
|
||||
s += user(question_1)
|
||||
s += assistant(gen("answer_1", max_tokens=256))
|
||||
s += user(question_2)
|
||||
s += assistant(gen("answer_2", max_tokens=256))
|
||||
|
||||
runtime = Runtime("meta-llama/Llama-2-7b-chat-hf")
|
||||
set_default_backend(runtime)
|
||||
|
||||
state = multi_turn_question.run(
|
||||
question_1="What is the capital of the United States?",
|
||||
question_2="List two local attractions.",
|
||||
temperature=0,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for out in state.text_iter():
|
||||
print(out, end="", flush=True)
|
||||
print()
|
||||
|
||||
runtime.shutdown()
|
||||
7
playground/launch_tgi.sh
Normal file
7
playground/launch_tgi.sh
Normal file
@@ -0,0 +1,7 @@
|
||||
# Assuming the model is downdloaded at /home/ubuntu/model_weights/Llama-2-7b-chat-hf
|
||||
docker run --name tgi --rm -ti --gpus all --network host \
|
||||
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
|
||||
ghcr.io/huggingface/text-generation-inference:1.1.0 \
|
||||
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
|
||||
--max-input-length 2048 --max-total-tokens 4096 \
|
||||
--port 24000
|
||||
7
playground/load_tokenizer.py
Normal file
7
playground/load_tokenizer.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import transformers
|
||||
import code
|
||||
|
||||
name = "meta-llama/Llama-2-7b-chat-hf"
|
||||
|
||||
t = transformers.AutoTokenizer.from_pretrained(name)
|
||||
code.interact(local=locals())
|
||||
31
python/pyproject.toml
Normal file
31
python/pyproject.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sglang"
|
||||
version = "0.1.0"
|
||||
description = "A structured generation langauge for LLMs."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
license = {file = "LICENSE"}
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
]
|
||||
dependencies = [
|
||||
"requests",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
|
||||
"interegular", "lark"]
|
||||
openai = ["openai>=1.0"]
|
||||
anthropic = ["anthropic"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
|
||||
|
||||
[tool.wheel]
|
||||
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
|
||||
2
python/sglang/__init__.py
Normal file
2
python/sglang/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from sglang.api import *
|
||||
from sglang.global_config import global_config
|
||||
161
python/sglang/api.py
Normal file
161
python/sglang/api.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Public API"""
|
||||
import re
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from sglang.backend.anthropic import Anthropic
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.backend.openai import OpenAI
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.ir import (
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFunction,
|
||||
SglGen,
|
||||
SglImage,
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
)
|
||||
from sglang.srt.server import Runtime
|
||||
|
||||
|
||||
def function(func: Callable):
|
||||
return SglFunction(func)
|
||||
|
||||
|
||||
def set_default_backend(backend: BaseBackend):
|
||||
global_config.default_backend = backend
|
||||
|
||||
|
||||
def gen(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
dtype: Optional[type] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
regex: Optional[str] = None,
|
||||
):
|
||||
if choices:
|
||||
return SglSelect(name, choices, temperature)
|
||||
|
||||
# check regex is valid
|
||||
if regex is not None:
|
||||
try:
|
||||
re.compile(regex)
|
||||
except re.error as e:
|
||||
raise e
|
||||
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
stop,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
dtype,
|
||||
regex,
|
||||
)
|
||||
|
||||
|
||||
def gen_int(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
stop,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
int,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def gen_string(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
stop,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def image(expr: SglExpr):
|
||||
return SglImage(expr)
|
||||
|
||||
|
||||
def select(
|
||||
name: Optional[str] = None,
|
||||
choices: List[str] = None,
|
||||
temperature: float = 0.0,
|
||||
):
|
||||
assert choices is not None
|
||||
return SglSelect(name, choices, temperature)
|
||||
|
||||
|
||||
def _role_common(name: str, expr: Optional[SglExpr] = None):
|
||||
if expr is None:
|
||||
return SglExprList([SglRoleBegin(name), SglRoleEnd(name)])
|
||||
else:
|
||||
return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
|
||||
|
||||
|
||||
def system(expr: Optional[SglExpr] = None):
|
||||
return _role_common("system", expr)
|
||||
|
||||
|
||||
def user(expr: Optional[SglExpr] = None):
|
||||
return _role_common("user", expr)
|
||||
|
||||
|
||||
def assistant(expr: Optional[SglExpr] = None):
|
||||
return _role_common("assistant", expr)
|
||||
|
||||
|
||||
def user_begin():
|
||||
return SglRoleBegin("user")
|
||||
|
||||
|
||||
def user_end():
|
||||
return SglRoleEnd("user")
|
||||
|
||||
|
||||
def assistant_begin():
|
||||
return SglRoleBegin("assistant")
|
||||
|
||||
|
||||
def assistant_end():
|
||||
return SglRoleEnd("assistant")
|
||||
0
python/sglang/backend/__init__.py
Normal file
0
python/sglang/backend/__init__.py
Normal file
57
python/sglang/backend/anthropic.py
Normal file
57
python/sglang/backend/anthropic.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as e:
|
||||
anthropic = e
|
||||
|
||||
|
||||
class Anthropic(BaseBackend):
|
||||
def __init__(self, model_name):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(anthropic, Exception):
|
||||
raise anthropic
|
||||
|
||||
self.model_name = model_name
|
||||
self.chat_template = get_chat_template("claude")
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
prompt = s.text_
|
||||
ret = anthropic.Anthropic().completions.create(
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
)
|
||||
comp = ret.completion
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
prompt = s.text_
|
||||
generator = anthropic.Anthropic().completions.create(
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
)
|
||||
|
||||
for ret in generator:
|
||||
yield ret.completion, {}
|
||||
74
python/sglang/backend/base_backend.py
Normal file
74
python/sglang/backend/base_backend.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
def __init__(self) -> None:
|
||||
self.support_concate_and_append = False
|
||||
self.chat_template = get_chat_template("default")
|
||||
|
||||
def get_model_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
pass
|
||||
|
||||
def uncache_prefix(self, rid: str):
|
||||
pass
|
||||
|
||||
def end_request(self, rid: Union[str, List[str]]):
|
||||
pass
|
||||
|
||||
def begin_program(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
|
||||
pass
|
||||
|
||||
def commit_lazy_operations(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def fork_program(
|
||||
self,
|
||||
src: StreamExecutor,
|
||||
dst: List[StreamExecutor],
|
||||
position_ids_offset: Optional[List[int]] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def fill_image(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
349
python/sglang/backend/huggingface.py
Normal file
349
python/sglang/backend/huggingface.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import functools
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import ProgramState
|
||||
from sglang.utils import get_available_gpu_memory
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from transformersgl.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
class StopReason(Enum):
|
||||
EOS_TOKEN = auto()
|
||||
STOP_STR = auto()
|
||||
LENGTH = auto()
|
||||
|
||||
|
||||
def load_model(
|
||||
model_name: str,
|
||||
device,
|
||||
num_gpus,
|
||||
max_gpu_memory,
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs=None,
|
||||
):
|
||||
model_kwargs = model_kwargs or {}
|
||||
tokenizer_kwargs = tokenizer_kwargs or {}
|
||||
|
||||
if device == "cuda":
|
||||
model_kwargs["torch_dtype"] = torch.float16
|
||||
if num_gpus != 1:
|
||||
model_kwargs["device_map"] = "auto"
|
||||
if max_gpu_memory is None:
|
||||
model_kwargs[
|
||||
"device_map"
|
||||
] = "sequential" # This is important for not the same VRAM sizes
|
||||
available_gpu_memory = [
|
||||
get_available_gpu_memory(i, False) for i in range(num_gpus)
|
||||
]
|
||||
model_kwargs["max_memory"] = {
|
||||
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
|
||||
for i in range(num_gpus)
|
||||
}
|
||||
else:
|
||||
model_kwargs["max_memory"] = {
|
||||
i: max_gpu_memory for i in range(num_gpus)
|
||||
}
|
||||
elif device == "cpu":
|
||||
model_kwargs["torch_dtype"] = torch.float32
|
||||
else:
|
||||
raise ValueError(f"Invalid device: {device}")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, low_cpu_mem_usage=True, **model_kwargs
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
|
||||
|
||||
if num_gpus == 1:
|
||||
model.to(device).eval()
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def prepare_logits_processor(
|
||||
temperature: float, repetition_penalty: float, top_p: float, top_k: int
|
||||
) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
|
||||
if temperature >= 1e-5 and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
if repetition_penalty > 1.0:
|
||||
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
||||
if 1e-8 <= top_p < 1.0:
|
||||
processor_list.append(TopPLogitsWarper(top_p))
|
||||
if top_k > 0:
|
||||
processor_list.append(TopKLogitsWarper(top_k))
|
||||
return processor_list
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_token_healing_mask(tokenizer, prompt_last_token):
|
||||
last_str = tokenizer.convert_ids_to_tokens(prompt_last_token)
|
||||
disallowed = torch.zeros(len(tokenizer), dtype=bool)
|
||||
for s, t_id in tokenizer.get_vocab().items():
|
||||
if not s.startswith(last_str):
|
||||
disallowed[t_id] = 1
|
||||
return disallowed
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_int_token_mask(tokenizer):
|
||||
disallowed = torch.zeros(len(tokenizer), dtype=bool)
|
||||
for s, t_id in tokenizer.get_vocab().items():
|
||||
s = s.replace("▁", "").strip()
|
||||
if not (s.isdigit() or len(s) == 0 or s == ","):
|
||||
disallowed[t_id] = 1
|
||||
disallowed[tokenizer.eos_token_id] = 0
|
||||
return disallowed
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
max_new_tokens,
|
||||
stop: List[str],
|
||||
temperature,
|
||||
top_p,
|
||||
token_healing,
|
||||
logit_mask=None,
|
||||
):
|
||||
logits_processor = prepare_logits_processor(
|
||||
temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0
|
||||
)
|
||||
device = model.device
|
||||
input_ids = tokenizer.encode(prompt)
|
||||
output_ids = list(input_ids)
|
||||
prompt_len = len(prompt)
|
||||
|
||||
# Resolve stop
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
|
||||
# Token healing
|
||||
token_healing = token_healing and len(input_ids) > 0
|
||||
if token_healing:
|
||||
token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1])
|
||||
del output_ids[-1]
|
||||
|
||||
# Generate
|
||||
past_key_values = None
|
||||
stop_reason = None
|
||||
for i in range(max_new_tokens):
|
||||
# Forward
|
||||
if i == 0: # prefill
|
||||
out = model(torch.as_tensor([output_ids], device=device), use_cache=True)
|
||||
else: # decoding
|
||||
out = model(
|
||||
input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
# Logit mask
|
||||
if token_healing and i == 0:
|
||||
logits[0, -1, token_healing_mask] = -1e4
|
||||
if logit_mask is not None:
|
||||
logits[0, -1, logit_mask] = -1e4
|
||||
|
||||
# Sample next token
|
||||
last_token_logits = logits_processor(None, logits[:, -1, :])[0]
|
||||
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
||||
token = int(torch.argmax(last_token_logits))
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits, dim=-1)
|
||||
token = int(torch.multinomial(probs, num_samples=1))
|
||||
output_ids.append(token)
|
||||
|
||||
# Stop condition
|
||||
if token in stop_token_ids:
|
||||
stop_reason = StopReason.EOS_TOKEN
|
||||
break
|
||||
|
||||
output_str = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
for stop_str in stop:
|
||||
pos = output_str[prompt_len:].find(stop_str)
|
||||
if pos != -1:
|
||||
stop_reason = StopReason.STOP_STR
|
||||
output_str = output_str[: prompt_len + pos]
|
||||
break
|
||||
|
||||
if stop_reason:
|
||||
break
|
||||
|
||||
return output_str[prompt_len:]
|
||||
|
||||
|
||||
class HuggingFaceTransformers(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
device="cuda",
|
||||
num_gpus=1,
|
||||
max_gpu_memory=None,
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs=None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
|
||||
self.model, self.tokenizer = load_model(
|
||||
model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs
|
||||
)
|
||||
|
||||
self.chat_template = get_chat_template_by_model_path(model_name)
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
pass
|
||||
|
||||
def uncache_prefix(self, rid: str):
|
||||
pass
|
||||
|
||||
def end_request(self, rid: str):
|
||||
pass
|
||||
|
||||
def begin_program(self, s: ProgramState):
|
||||
pass
|
||||
|
||||
def end_program(self, s: ProgramState):
|
||||
pass
|
||||
|
||||
def fill(self, s: ProgramState, text: str):
|
||||
return False
|
||||
|
||||
def generate_internal(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
if dtype is None:
|
||||
comp = generate_stream(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt,
|
||||
max_new_tokens=max_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
token_healing=True,
|
||||
)
|
||||
elif dtype in [str, "str", "string"]:
|
||||
comp = generate_stream(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt + '"',
|
||||
max_new_tokens=max_tokens,
|
||||
stop=['"'],
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
token_healing=False,
|
||||
)
|
||||
comp = '"' + comp + '"'
|
||||
elif dtype in [int, "int"]:
|
||||
logit_mask = get_int_token_mask(self.tokenizer)
|
||||
comp = generate_stream(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt,
|
||||
max_new_tokens=max_tokens,
|
||||
stop=stop + [" ", ","],
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
token_healing=False,
|
||||
logit_mask=logit_mask,
|
||||
)
|
||||
return comp
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: ProgramState,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
prompt = s.text
|
||||
comp = self.generate_internal(
|
||||
prompt, max_tokens, stop, temperature, top_p, dtype
|
||||
)
|
||||
return comp
|
||||
|
||||
def parallel_generate(
|
||||
self,
|
||||
s: ProgramState,
|
||||
prefixes: List[str],
|
||||
join_func: Callable,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
prompt = s.text
|
||||
parallel_prompts = [prompt + prefix for prefix in prefixes]
|
||||
|
||||
comps = []
|
||||
for i in range(len(parallel_prompts)):
|
||||
comps.append(
|
||||
self.generate_internal(
|
||||
parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype
|
||||
)
|
||||
)
|
||||
|
||||
joined = join_func([p + c for p, c in zip(prefixes, comps)])
|
||||
return joined, comps
|
||||
|
||||
@torch.inference_mode()
|
||||
def select(
|
||||
self, s: ProgramState, choices: List[str], temperature: float, top_p: float
|
||||
):
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
prompt = s.text
|
||||
|
||||
prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1]
|
||||
prompt_choices = [prompt + choice for choice in choices]
|
||||
|
||||
scores = []
|
||||
for i in range(len(choices)):
|
||||
choice_ids = self.tokenizer.encode(
|
||||
prompt_choices[i], return_tensors="pt"
|
||||
).to(self.model.device)
|
||||
logits = self.model(choice_ids).logits
|
||||
|
||||
# score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item()
|
||||
|
||||
logprobs = torch.log(torch.softmax(logits, dim=-1))
|
||||
idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device)
|
||||
idx2 = choice_ids[0, 1:]
|
||||
selected_logprobs = logprobs[0, idx1, idx2]
|
||||
score = selected_logprobs.mean().item()
|
||||
scores.append(score)
|
||||
|
||||
decision = choices[np.argmax(scores)]
|
||||
return decision, scores
|
||||
241
python/sglang/backend/openai.py
Normal file
241
python/sglang/backend/openai.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
|
||||
try:
|
||||
import openai
|
||||
import tiktoken
|
||||
except ImportError as e:
|
||||
openai = tiktoken = e
|
||||
|
||||
|
||||
def create_logit_bias_int(tokenizer):
|
||||
"""Get logit bias for integer numbers."""
|
||||
int_token_ids = []
|
||||
|
||||
tokens = tokenizer._mergeable_ranks
|
||||
for token, token_id in tokens.items():
|
||||
s = tokenizer.decode([token_id])
|
||||
if all([c.isdigit() for c in s]) or s in [" "]:
|
||||
int_token_ids.append(token_id)
|
||||
if len(int_token_ids) >= 300: # OpenAI API limit
|
||||
break
|
||||
special_tokens = tokenizer._special_tokens
|
||||
mask = {t: 100 for t in int_token_ids[:299]}
|
||||
mask[special_tokens["<|endoftext|>"]] = 100
|
||||
return mask
|
||||
|
||||
|
||||
CHAT_MODEL_NAMES = [
|
||||
# GPT-4
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-0314",
|
||||
# GPT-3.5
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
]
|
||||
|
||||
|
||||
class OpenAI(BaseBackend):
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.client = openai.OpenAI(*args, **kwargs)
|
||||
|
||||
if isinstance(openai, Exception):
|
||||
raise e
|
||||
|
||||
self.model_name = model_name
|
||||
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
|
||||
|
||||
if model_name in CHAT_MODEL_NAMES:
|
||||
self.is_chat_model = True
|
||||
else:
|
||||
self.is_chat_model = False
|
||||
|
||||
self.chat_template = get_chat_template("default")
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
assert s.text_.endswith("ASSISTANT:")
|
||||
prompt = s.messages_
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
elif sampling_params.dtype in [str, "str", "string"]:
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.text_ + '"',
|
||||
stop='"',
|
||||
**kwargs,
|
||||
)
|
||||
comp = '"' + comp + '"'
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.text_,
|
||||
logit_bias=self.logit_bias_int,
|
||||
stop=[" "],
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
assert s.text_.endswith("ASSISTANT:")
|
||||
prompt = s.messages_
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
generator = openai_completion_stream(
|
||||
client=self.client,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
return generator
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
n_choices = len(choices)
|
||||
token_ids = [self.tokenizer.encode(x) for x in choices]
|
||||
scores = [0] * n_choices
|
||||
valid = [len(x) > 0 for x in token_ids]
|
||||
prompt_tokens = self.tokenizer.encode(s.text_)
|
||||
|
||||
max_len = max([len(x) for x in token_ids])
|
||||
for step in range(max_len):
|
||||
# Build logit bias
|
||||
logit_bias = {}
|
||||
for i in range(n_choices):
|
||||
if valid[i]:
|
||||
logit_bias[token_ids[i][step]] = 100
|
||||
|
||||
# Call API
|
||||
ret = self.client.completions.create(
|
||||
model=self.model_name,
|
||||
prompt=prompt_tokens,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
ret_str = ret.choices[0].text
|
||||
ret_token = self.tokenizer.encode(ret_str)[0]
|
||||
|
||||
# TODO:
|
||||
# 1. return logits as the scores
|
||||
# 2. compute logits of the full choice
|
||||
# 3. consider chunk-based decoding
|
||||
|
||||
# Update valid
|
||||
hit = False
|
||||
for i in range(n_choices):
|
||||
if valid[i]:
|
||||
if step == len(token_ids[i]) - 1:
|
||||
valid[i] = False
|
||||
|
||||
if ret_token == token_ids[i][step]:
|
||||
scores[i] += 1
|
||||
hit = True
|
||||
else:
|
||||
valid[i] = False
|
||||
assert hit
|
||||
|
||||
if np.sum(valid) <= 1:
|
||||
break
|
||||
|
||||
prompt_tokens.append(ret_token)
|
||||
|
||||
decision = choices[np.argmax(scores)]
|
||||
return decision, scores
|
||||
|
||||
|
||||
def openai_completion(client, is_chat=None, prompt=None, **kwargs):
|
||||
try:
|
||||
if is_chat:
|
||||
if kwargs["stop"] is None:
|
||||
kwargs.pop("stop")
|
||||
ret = client.chat.completions.create(messages=prompt, **kwargs)
|
||||
comp = ret.choices[0].message.content
|
||||
else:
|
||||
ret = client.completions.create(prompt=prompt, **kwargs)
|
||||
if isinstance(prompt, (list, tuple)):
|
||||
comp = [c.text for c in ret.choices]
|
||||
else:
|
||||
comp = ret.choices[0].text
|
||||
except openai.OpenAIError as e:
|
||||
print(f"OpenAI Error: {e}")
|
||||
raise e
|
||||
|
||||
return comp
|
||||
|
||||
|
||||
def openai_completion_stream(client, is_chat=None, prompt=None, **kwargs):
|
||||
try:
|
||||
if is_chat:
|
||||
generator = client.chat.completions.create(
|
||||
messages=prompt, stream=True, **kwargs
|
||||
)
|
||||
for ret in generator:
|
||||
content = ret.choices[0].delta.content
|
||||
yield content or "", {}
|
||||
else:
|
||||
generator = client.completions.create(prompt=prompt, stream=True, **kwargs)
|
||||
for ret in generator:
|
||||
content = ret.choices[0].text
|
||||
yield content or "", {}
|
||||
except openai.OpenAIError as e:
|
||||
print(f"OpenAI Error: {e}")
|
||||
raise e
|
||||
171
python/sglang/backend/runtime_endpoint.py
Normal file
171
python/sglang/backend/runtime_endpoint.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import json
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams, SglArgument
|
||||
from sglang.utils import encode_image_base64, find_printable_text, http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
def __init__(self, base_url):
|
||||
super().__init__()
|
||||
self.support_concate_and_append = True
|
||||
|
||||
self.base_url = base_url
|
||||
|
||||
res = http_request(self.base_url + "/get_model_info")
|
||||
assert res.status_code == 200
|
||||
self.model_info = res.json()
|
||||
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_path"]
|
||||
)
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_path"]
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
def commit_lazy_operations(self, s: StreamExecutor):
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
def fill_image(self, s: StreamExecutor):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(self.base_url + "/generate", json=data)
|
||||
assert res.status_code == 200
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"dtype": "int",
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
self._add_images(s, data)
|
||||
|
||||
res = http_request(self.base_url + "/generate", json=data)
|
||||
obj = res.json()
|
||||
comp = obj["text"]
|
||||
return comp, obj["meta_info"]
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"dtype": "int",
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
data["stream"] = True
|
||||
self._add_images(s, data)
|
||||
|
||||
response = http_request(self.base_url + "/generate", json=data, stream=True)
|
||||
pos = 0
|
||||
|
||||
incomplete_text = ""
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
text = find_printable_text(data["text"][pos:])
|
||||
meta_info = data["meta_info"]
|
||||
pos += len(text)
|
||||
incomplete_text = data["text"][pos:]
|
||||
yield text, meta_info
|
||||
|
||||
if len(incomplete_text) > 0:
|
||||
yield incomplete_text, meta_info
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
assert temperature <= 1e-5
|
||||
|
||||
# Cache common prefix
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(self.base_url + "/generate", json=data)
|
||||
assert res.status_code == 200
|
||||
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
||||
|
||||
# Compute logprob
|
||||
data = {
|
||||
"text": [s.text_ + c for c in choices],
|
||||
"sampling_params": {"max_new_tokens": 0},
|
||||
"return_normalized_logprob": True,
|
||||
"normalized_logprob_start_len": prompt_len,
|
||||
}
|
||||
self._add_images(s, data)
|
||||
res = http_request(self.base_url + "/generate", json=data)
|
||||
assert res.status_code == 200
|
||||
logps = [r["meta_info"]["normalized_logprob"] for r in res.json()]
|
||||
|
||||
decision = choices[np.argmax(logps)]
|
||||
return decision, logps
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
res = http_request(
|
||||
self.base_url + "/concate_and_append_request",
|
||||
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
def _add_images(self, s: StreamExecutor, data):
|
||||
if s.images_:
|
||||
assert len(s.images_) == 1, "Only support one image."
|
||||
data["image_data"] = s.images_[0][1]
|
||||
190
python/sglang/backend/tgi.py
Normal file
190
python/sglang/backend/tgi.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from itertools import repeat
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
from sglang.utils import http_request
|
||||
|
||||
|
||||
class TGI(BaseBackend):
|
||||
def __init__(self, base_url):
|
||||
super().__init__()
|
||||
|
||||
self.base_url = base_url
|
||||
|
||||
res = http_request(self.base_url + "/info")
|
||||
assert res.status_code == 200
|
||||
self.model_info = res.json()
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_id"]
|
||||
)
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_id"]
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
@staticmethod
|
||||
def adapt_params(max_tokens, stop, sampling_params, **override_params):
|
||||
temperature = sampling_params.temperature
|
||||
do_sample = True
|
||||
if temperature == 0:
|
||||
do_sample = False
|
||||
temperature = None
|
||||
|
||||
if stop is None:
|
||||
stop = []
|
||||
elif isinstance(stop, str):
|
||||
stop = [stop]
|
||||
|
||||
top_p = sampling_params.top_p
|
||||
if top_p == 0:
|
||||
top_p = 0.001
|
||||
if top_p == 1:
|
||||
top_p = 0.999
|
||||
|
||||
top_k = sampling_params.top_k
|
||||
if top_k == -1:
|
||||
top_k = None
|
||||
|
||||
params = {
|
||||
"decoder_input_details": False,
|
||||
"details": False,
|
||||
"do_sample": do_sample,
|
||||
"max_new_tokens": max_tokens,
|
||||
"stop": stop,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"return_full_text": False,
|
||||
}
|
||||
params.update(override_params)
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def _extract_int(text):
|
||||
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
|
||||
for word in words:
|
||||
try:
|
||||
int(word)
|
||||
return word
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError
|
||||
|
||||
@staticmethod
|
||||
def _extract_choice(choices, text):
|
||||
# FIXME: Current only support the case where the choices are single words.
|
||||
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
|
||||
for word in words:
|
||||
if word in choices:
|
||||
return word
|
||||
raise ValueError
|
||||
|
||||
@staticmethod
|
||||
def _truncate_to_stop(text, stop):
|
||||
# The stop sequence may not be a single token. In this case TGI will generate
|
||||
# too many tokens so we need to truncate the output.
|
||||
if stop:
|
||||
stop = [stop] if isinstance(stop, str) else stop
|
||||
for stop_seq in stop:
|
||||
pos = text.find(stop_seq)
|
||||
if pos != -1:
|
||||
return text[:pos]
|
||||
return text
|
||||
|
||||
def _make_request(self, params):
|
||||
res = http_request(self.base_url + "/generate", json=params)
|
||||
if res.status_code != 200:
|
||||
raise ValueError(f"Error from TGI backend: {res.text}")
|
||||
return res.json()
|
||||
|
||||
def retry_for_expected(self, prompt, params, extract_fn, retry=5):
|
||||
# TGI does not support logis_bias (yet), so we have to use an inefficient hack.
|
||||
failed = []
|
||||
while retry > 0:
|
||||
res_json = self._make_request(
|
||||
{
|
||||
"inputs": prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
)
|
||||
text = res_json["generated_text"]
|
||||
try:
|
||||
return extract_fn(text)
|
||||
except ValueError:
|
||||
retry -= 1
|
||||
failed.append(text)
|
||||
|
||||
msg = "=" * 20 + "\n"
|
||||
msg += f"Prompt:\n{prompt}\n"
|
||||
msg += "=" * 20 + "\n"
|
||||
for i, text in enumerate(failed):
|
||||
msg += f"====== Try {i+1}:\n{text}\n"
|
||||
|
||||
raise ValueError(
|
||||
f"Model {self.model_info['model_id']} served by TGI backend does not generate"
|
||||
"expected output. Please improve the prompt, increase the temperature, or "
|
||||
f"use different models.\n{msg}"
|
||||
)
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
decision = self.retry_for_expected(
|
||||
s.text_,
|
||||
self.adapt_params(16, [], sampling_params),
|
||||
partial(self._extract_choice, choices),
|
||||
)
|
||||
return decision, [1 if choice == decision else 0 for choice in choices]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
sampling_params: SamplingParams,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
if dtype is None:
|
||||
res_json = self._make_request(
|
||||
{
|
||||
"inputs": s.text_,
|
||||
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
|
||||
}
|
||||
)
|
||||
return self._truncate_to_stop(res_json["generated_text"], stop), {}
|
||||
|
||||
if dtype in [str, "str", "string"]:
|
||||
stop = ['"']
|
||||
res_json = self._make_request(
|
||||
{
|
||||
"inputs": f'{s.text_}"',
|
||||
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
|
||||
}
|
||||
)
|
||||
return (
|
||||
'"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"',
|
||||
{},
|
||||
)
|
||||
|
||||
if dtype in [int, "int"]:
|
||||
return (
|
||||
self.retry_for_expected(
|
||||
s.text_,
|
||||
self.adapt_params(max_tokens, stop, sampling_params),
|
||||
self._extract_int,
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
60
python/sglang/flush_cache.py
Normal file
60
python/sglang/flush_cache.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Flush cache in the backend by sending random requests."""
|
||||
import argparse
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
@sgl.function
|
||||
def flush_radix_cache(s, prompt):
|
||||
s += prompt + sgl.gen("flush", max_tokens=1, stop="END")
|
||||
|
||||
|
||||
def main(args, max_total_tokens, context_length, print_flag):
|
||||
backend = select_sglang_backend(args)
|
||||
flush_length = int(context_length * 0.8)
|
||||
batch_size = int(max_total_tokens / flush_length)
|
||||
prompt_length = flush_length * 2
|
||||
prompts = [
|
||||
" ".join(random.choices(string.ascii_letters, k=int(prompt_length)))
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
arguments = [{"prompt": prompts[i]} for i in range(batch_size)]
|
||||
|
||||
start_time = time.time()
|
||||
flush_radix_cache.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=1
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
if print_flag:
|
||||
print(
|
||||
f"Flush length: {flush_length}\n",
|
||||
f"Prompt length: {prompt_length}\n",
|
||||
f"Total Prompt letters: {batch_size * prompt_length}\n",
|
||||
f"Flush radix cache latency: {end_time - start_time:.3f}",
|
||||
sep="",
|
||||
)
|
||||
|
||||
# to prevent the backend still running
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def run_flush(args, max_total_tokens=20000, context_length=1024, print_flag=False):
|
||||
main(args, max_total_tokens, context_length, print_flag=print_flag)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max-total-tokens", type=int, default=20000)
|
||||
parser.add_argument("--context-length", type=int, default=1024)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
random.seed(0)
|
||||
main(args, args.max_total_tokens, args.context_length, print_flag=True)
|
||||
28
python/sglang/global_config.py
Normal file
28
python/sglang/global_config.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Global configurations"""
|
||||
|
||||
|
||||
class GlobalConfig:
|
||||
def __init__(self):
|
||||
# Verbosity level
|
||||
# 0: do not output anything
|
||||
# 2: output final text after every run
|
||||
self.verbosity = 0
|
||||
|
||||
self.default_backend = None
|
||||
|
||||
# Output configs
|
||||
self.skip_special_tokens_in_output = True
|
||||
|
||||
# Optimization configs
|
||||
self.eager_fill_image = False
|
||||
self.enable_prefix_sharing = True
|
||||
self.enable_parallel_encoding = True
|
||||
self.enable_parallel_decoding = True
|
||||
|
||||
# Choices: ["no_adjust", "adjust_cache"]
|
||||
# no_adjust: Do not adjust the position embedding of KV cache.
|
||||
# adjust_cache: Adjust the position embedding of KV cache.
|
||||
self.concate_and_append_mode = "no_adjust"
|
||||
|
||||
|
||||
global_config = GlobalConfig()
|
||||
0
python/sglang/lang/__init__.py
Normal file
0
python/sglang/lang/__init__.py
Normal file
186
python/sglang/lang/chat_template.py
Normal file
186
python/sglang/lang/chat_template.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
|
||||
class ChatTemplateStyle(Enum):
|
||||
PLAIN = auto()
|
||||
LLAMA2 = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTemplate:
|
||||
name: str
|
||||
default_system_prompt: str
|
||||
role_prefix_and_suffix: Dict[str, Tuple[str]]
|
||||
image_token: str = "<image>"
|
||||
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
||||
|
||||
def get_prefix_and_suffix(self, role, hist_messages):
|
||||
if self.style == ChatTemplateStyle.PLAIN:
|
||||
return self.role_prefix_and_suffix[role]
|
||||
elif self.style == ChatTemplateStyle.LLAMA2:
|
||||
if len(hist_messages) == 0 and role == "system":
|
||||
return (
|
||||
self.role_prefix_and_suffix["user"][0]
|
||||
+ self.role_prefix_and_suffix["system"][0],
|
||||
self.role_prefix_and_suffix["system"][1],
|
||||
)
|
||||
elif (
|
||||
len(hist_messages) == 1
|
||||
and role == "user"
|
||||
and hist_messages[0]["content"] is not None
|
||||
):
|
||||
return ("", self.role_prefix_and_suffix["user"][1])
|
||||
return self.role_prefix_and_suffix[role]
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.style}")
|
||||
|
||||
def get_prompt(self, messages):
|
||||
prompt = ""
|
||||
for i in range(len(messages)):
|
||||
role, content = messages[i]["role"], messages[i]["content"]
|
||||
if role == "system" and content is None:
|
||||
content = self.default_system_prompt
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
|
||||
prompt += prefix + content + suffix
|
||||
return prompt
|
||||
|
||||
|
||||
chat_template_registry: Dict[str, ChatTemplate] = {}
|
||||
matching_function_registry: List[Callable] = []
|
||||
|
||||
|
||||
def register_chat_template(template):
|
||||
chat_template_registry[template.name] = template
|
||||
|
||||
|
||||
def register_chat_template_matching_function(func):
|
||||
matching_function_registry.append(func)
|
||||
|
||||
|
||||
def get_chat_template(name):
|
||||
return chat_template_registry[name]
|
||||
|
||||
|
||||
def get_chat_template_by_model_path(model_path):
|
||||
for matching_func in matching_function_registry:
|
||||
template = matching_func(model_path)
|
||||
if template is not None:
|
||||
return template
|
||||
return get_chat_template("default")
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="default",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("SYSTEM:", "\n"),
|
||||
"user": ("USER:", "\n"),
|
||||
"assistant": ("ASSISTANT:", "\n"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="claude",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", ""),
|
||||
"user": ("\n\nHuman: ", ""),
|
||||
"assistant": ("\n\nAssistant:", ""),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="vicuna_v1.1",
|
||||
default_system_prompt=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", " "),
|
||||
"user": ("USER:", " "),
|
||||
"assistant": ("ASSISTANT:", "</s>"),
|
||||
},
|
||||
image_token=" <image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-2-chat",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
|
||||
"user": ("[INST] ", " [/INST]"),
|
||||
"assistant": ("", " </s><s>"),
|
||||
},
|
||||
style=ChatTemplateStyle.LLAMA2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_vicuna(model_path: str):
|
||||
if "vicuna" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
if "llava" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_llama2_chat(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "llama-2" in model_path and "chat" in model_path:
|
||||
return get_chat_template("llama-2-chat")
|
||||
if (
|
||||
"mistral" in model_path or "mixtral" in model_path
|
||||
) and "instruct" in model_path:
|
||||
return get_chat_template("llama-2-chat")
|
||||
if "codellama" in model_path and "instruct" in model_path:
|
||||
return get_chat_template("llama-2-chat")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_ml(model_path: str):
|
||||
if "tinyllama" in model_path.lower():
|
||||
return get_chat_template("chatml")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
messages = [
|
||||
{"role": "system", "content": None}, # None means default
|
||||
# {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
{"role": "user", "content": "What can you do?"},
|
||||
{"role": "assistant", "content": "I can chat with you."},
|
||||
]
|
||||
|
||||
template = get_chat_template("llama-2-chat")
|
||||
print(template.get_prompt(messages))
|
||||
237
python/sglang/lang/compiler.py
Normal file
237
python/sglang/lang/compiler.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
|
||||
from sglang.lang.ir import (
|
||||
SamplingParams,
|
||||
SglArgument,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglVariable,
|
||||
)
|
||||
|
||||
|
||||
def compile_func(function, backend):
|
||||
tracer = function.trace(backend=backend)
|
||||
compiler = CompiledFunction(tracer, function)
|
||||
return compiler
|
||||
|
||||
|
||||
class CompiledFunction:
|
||||
def __init__(self, tracer, function):
|
||||
self.function = function
|
||||
|
||||
self.last_node = CompGraphNode(tracer.last_node)
|
||||
self.expr_to_node = {}
|
||||
self.build_graph(tracer)
|
||||
self.topological_sort()
|
||||
|
||||
def build_graph(self, tracer):
|
||||
self.nodes = [self.last_node]
|
||||
self.expr_to_node[tracer.last_node] = self.nodes[-1]
|
||||
|
||||
rename_pid = {}
|
||||
|
||||
visited = set([tracer.last_node])
|
||||
head = 0
|
||||
while head < len(self.nodes):
|
||||
cur_node = self.nodes[head]
|
||||
|
||||
# add prev node
|
||||
prev_node = cur_node.expr.prev_node
|
||||
if prev_node is not None:
|
||||
if prev_node not in visited:
|
||||
visited.add(prev_node)
|
||||
self.nodes.append(CompGraphNode(prev_node))
|
||||
self.expr_to_node[prev_node] = self.nodes[-1]
|
||||
cur_node.prev_node = self.expr_to_node[prev_node]
|
||||
self.expr_to_node[prev_node].add_next_node(cur_node)
|
||||
|
||||
# add source node
|
||||
if isinstance(cur_node.expr, SglVariable):
|
||||
if cur_node.expr.name in tracer.variables:
|
||||
source = tracer.variables[cur_node.expr.name].source
|
||||
else:
|
||||
source = cur_node.expr.source
|
||||
if source not in visited:
|
||||
visited.add(source)
|
||||
self.nodes.append(CompGraphNode(source))
|
||||
self.expr_to_node[source] = self.nodes[-1]
|
||||
cur_node.source_node = self.expr_to_node[source]
|
||||
self.expr_to_node[source].add_next_node(cur_node)
|
||||
head += 1
|
||||
|
||||
# rename pid
|
||||
if cur_node.expr.pid not in rename_pid:
|
||||
rename_pid[cur_node.expr.pid] = len(rename_pid)
|
||||
cur_node.expr.pid = rename_pid[cur_node.expr.pid]
|
||||
|
||||
def topological_sort(self):
|
||||
prevd = {}
|
||||
cand = Queue()
|
||||
for x in self.nodes:
|
||||
prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
|
||||
if prevd[x] == 0:
|
||||
cand.put(x)
|
||||
new_list = []
|
||||
while cand.qsize() > 0:
|
||||
head = cand.get()
|
||||
new_list.append(head)
|
||||
for x in head.next_nodes:
|
||||
prevd[x] -= 1
|
||||
if prevd[x] == 0:
|
||||
cand.put(x)
|
||||
self.nodes = new_list
|
||||
|
||||
def print_graph(
|
||||
self,
|
||||
):
|
||||
for node in self.nodes:
|
||||
print(node)
|
||||
|
||||
def run_internal(
|
||||
self,
|
||||
backend,
|
||||
kwargs,
|
||||
default_sampling_para,
|
||||
):
|
||||
stream_executor_ids = set([x.expr.pid for x in self.nodes])
|
||||
stream_executors = {}
|
||||
for x in stream_executor_ids:
|
||||
arguments = kwargs if x == self.last_node.expr.pid else {}
|
||||
stream_executors[x] = StreamExecutor(
|
||||
backend, arguments, default_sampling_para, None, False
|
||||
)
|
||||
for node in self.nodes:
|
||||
se_id = node.expr.pid
|
||||
expr = node.expr
|
||||
if isinstance(expr, SglVariable):
|
||||
# Make a copy for SglVariable
|
||||
expr = SglVariable(expr.name, expr.source)
|
||||
expr.source_stream_executor = stream_executors[
|
||||
node.source_node.expr.pid
|
||||
]
|
||||
elif isinstance(expr, SglArgument):
|
||||
# Substitute SglArgument
|
||||
expr = kwargs[expr.name]
|
||||
stream_executors[se_id].submit(expr)
|
||||
for stream_executor in stream_executors.values():
|
||||
stream_executor.end()
|
||||
return ProgramState(stream_executors[self.last_node.expr.pid])
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
**kwargs,
|
||||
):
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||
kwargs.update(self.function.bind_arguments)
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
|
||||
return self.run_internal(backend, kwargs, default_sampling_para)
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
batch_kwargs,
|
||||
*,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
):
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
return []
|
||||
assert isinstance(batch_kwargs[0], dict)
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
batch_kwargs = [
|
||||
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
|
||||
]
|
||||
|
||||
# Extract prefix by tracing and cache it
|
||||
if len(batch_kwargs) > 1:
|
||||
pin_program(self.function, backend)
|
||||
|
||||
# Run all programs
|
||||
if num_threads == "auto":
|
||||
num_threads = multiprocessing.cpu_count()
|
||||
num_threads = min(num_threads, len(batch_kwargs))
|
||||
|
||||
if num_threads == 1:
|
||||
rets = []
|
||||
for arguments in batch_kwargs:
|
||||
rets.append(
|
||||
self.run_internal(backend, arguments, default_sampling_para)
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(num_threads) as executor:
|
||||
futures = []
|
||||
for arguments in batch_kwargs:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.run_internal, backend, arguments, default_sampling_para
|
||||
)
|
||||
)
|
||||
rets = [f.result() for f in futures]
|
||||
rets[-1].sync()
|
||||
|
||||
return rets
|
||||
|
||||
|
||||
class CompGraphNode:
|
||||
def __init__(
|
||||
self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
|
||||
):
|
||||
self.expr = expr
|
||||
self.next_nodes = next_nodes or []
|
||||
self.prev_node = prev_node
|
||||
self.source_node = source_node
|
||||
|
||||
def add_next_node(self, other):
|
||||
self.next_nodes.append(other)
|
||||
|
||||
def __repr__(self):
|
||||
re = f"stream {self.expr.pid:2d}: "
|
||||
re += f"%{self.expr.node_id} = "
|
||||
if self.prev_node is not None:
|
||||
re += f"%{self.prev_node.expr.node_id} + "
|
||||
re += repr(self.expr)
|
||||
return re
|
||||
697
python/sglang/lang/interpreter.py
Normal file
697
python/sglang/lang/interpreter.py
Normal file
@@ -0,0 +1,697 @@
|
||||
"""The interpreter that executes SGL programs"""
|
||||
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import tqdm
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.ir import (
|
||||
SglArgument,
|
||||
SglCommitLazy,
|
||||
SglConcateAndAppend,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFunction,
|
||||
SglGen,
|
||||
SglImage,
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglVariable,
|
||||
SglVarScopeBegin,
|
||||
SglVarScopeEnd,
|
||||
)
|
||||
from sglang.utils import encode_image_base64
|
||||
|
||||
|
||||
def run_internal(state, program, func_args, func_kwargs, sync):
|
||||
try:
|
||||
state.ret_value = program.func(state, *func_args, **func_kwargs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
state.stream_executor.end()
|
||||
|
||||
if sync:
|
||||
state.stream_executor.sync()
|
||||
|
||||
if global_config.verbosity >= 2:
|
||||
print(state.text())
|
||||
|
||||
|
||||
def run_program(
|
||||
program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False
|
||||
):
|
||||
assert backend is not None, "Please specify a backend"
|
||||
func_kwargs.update(program.bind_arguments)
|
||||
stream_executor = StreamExecutor(
|
||||
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream
|
||||
)
|
||||
state = ProgramState(stream_executor)
|
||||
|
||||
if stream:
|
||||
t = threading.Thread(
|
||||
target=run_internal, args=(state, program, func_args, func_kwargs, sync)
|
||||
)
|
||||
t.start()
|
||||
return state
|
||||
else:
|
||||
run_internal(state, program, func_args, func_kwargs, sync)
|
||||
return state
|
||||
|
||||
|
||||
def run_program_batch(
|
||||
program,
|
||||
backend,
|
||||
batch_arguments,
|
||||
default_sampling_para,
|
||||
num_threads,
|
||||
progress_bar,
|
||||
):
|
||||
# Extract prefix by tracing and cache it
|
||||
if len(batch_arguments) > 1:
|
||||
pin_program(program, backend)
|
||||
|
||||
# Run all programs
|
||||
if num_threads == "auto":
|
||||
num_threads = multiprocessing.cpu_count()
|
||||
num_threads = min(num_threads, len(batch_arguments))
|
||||
|
||||
if num_threads == 1:
|
||||
rets = []
|
||||
for arguments in batch_arguments:
|
||||
rets.append(
|
||||
run_program(
|
||||
program, backend, (), arguments, default_sampling_para, False, False
|
||||
)
|
||||
)
|
||||
else:
|
||||
if progress_bar:
|
||||
pbar = tqdm.tqdm(total=len(batch_arguments))
|
||||
|
||||
with ThreadPoolExecutor(num_threads) as executor:
|
||||
futures = []
|
||||
for arguments in batch_arguments:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
run_program,
|
||||
program,
|
||||
backend,
|
||||
(),
|
||||
arguments,
|
||||
default_sampling_para,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
)
|
||||
if progress_bar:
|
||||
futures[-1].add_done_callback(lambda _: pbar.update())
|
||||
|
||||
rets = [f.result() for f in futures]
|
||||
rets[-1].sync()
|
||||
|
||||
if progress_bar:
|
||||
pbar.close()
|
||||
|
||||
return rets
|
||||
|
||||
|
||||
def pin_program(program, backend):
|
||||
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
|
||||
# TODO: handle multiple backends
|
||||
from sglang.lang.tracer import extract_prefix_by_tracing
|
||||
|
||||
prefix = extract_prefix_by_tracing(program, backend)
|
||||
if prefix and len(prefix) > 64:
|
||||
prefix_rid = backend.cache_prefix(prefix)
|
||||
program.pin_prefix_rid = prefix_rid
|
||||
return prefix_rid
|
||||
return None
|
||||
|
||||
|
||||
def unpin_program(program, backend):
|
||||
pass
|
||||
|
||||
|
||||
class StreamExecutor:
|
||||
"""A stream executor that executes SGL expressions in a background thread."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend,
|
||||
arguments,
|
||||
default_sampling_para,
|
||||
chat_template,
|
||||
stream,
|
||||
use_thread=True,
|
||||
):
|
||||
self.sid = uuid.uuid4().hex
|
||||
self.backend = backend
|
||||
self.arguments: Dict[str, Any] = arguments
|
||||
self.default_sampling_para = default_sampling_para
|
||||
self.stream = stream
|
||||
|
||||
if hasattr(backend, "endpoint"):
|
||||
self.backend = backend.endpoint
|
||||
|
||||
self.variables = {} # Dict[name: str -> value: str]
|
||||
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
||||
self.meta_info = {} # Dict[name: str -> info: str]
|
||||
self.is_finished = False
|
||||
|
||||
# For completion
|
||||
self.text_ = "" # The full text
|
||||
|
||||
# For chat
|
||||
self.messages_ = [] # The messages in the OpenAI API format
|
||||
self.chat_template = chat_template or self.backend.get_chat_template()
|
||||
self.cur_role = None
|
||||
self.cur_role_begin_pos = None
|
||||
|
||||
# For vision
|
||||
self.images_ = []
|
||||
self.cur_images = []
|
||||
|
||||
# For fork/join
|
||||
self.fork_start_text_pos = None
|
||||
|
||||
# Worker thread
|
||||
self.use_thread = use_thread
|
||||
if self.use_thread:
|
||||
self.queue = queue.Queue()
|
||||
self.worker = threading.Thread(target=self._thread_worker_func)
|
||||
self.worker.start()
|
||||
|
||||
# For streaming
|
||||
if stream:
|
||||
self.stream_text_event = threading.Event()
|
||||
self.stream_var_event = {}
|
||||
else:
|
||||
self.stream_text_event = None
|
||||
self.stream_var_event = None
|
||||
|
||||
def submit(self, expr: SglExpr):
|
||||
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
|
||||
self.variable_event[expr.name] = threading.Event()
|
||||
if self.stream:
|
||||
self.stream_var_event[expr.name] = threading.Event()
|
||||
elif isinstance(expr, SglExprList):
|
||||
for e in expr.expr_list:
|
||||
if isinstance(e, (SglGen, SglSelect, SglVarScopeBegin)):
|
||||
self.variable_event[e.name] = threading.Event()
|
||||
if self.stream:
|
||||
self.stream_var_event[e.name] = threading.Event()
|
||||
|
||||
if self.use_thread:
|
||||
self.queue.put(expr)
|
||||
else:
|
||||
self._execute(expr)
|
||||
|
||||
def sync(self):
|
||||
if self.use_thread:
|
||||
self.queue.join()
|
||||
|
||||
def get_var(self, name):
|
||||
if name in self.variable_event:
|
||||
self.variable_event[name].wait()
|
||||
return self.variables[name]
|
||||
|
||||
def get_meta_info(self, name):
|
||||
if name in self.variable_event:
|
||||
self.variable_event[name].wait()
|
||||
ret = self.meta_info.get(name, None)
|
||||
return ret
|
||||
|
||||
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
|
||||
if number > 1:
|
||||
self.submit(SglCommitLazy())
|
||||
self.sync()
|
||||
|
||||
number = int(number)
|
||||
|
||||
exes = [
|
||||
StreamExecutor(
|
||||
self.backend,
|
||||
self.arguments,
|
||||
self.default_sampling_para,
|
||||
self.chat_template,
|
||||
self.stream,
|
||||
)
|
||||
for _ in range(number)
|
||||
]
|
||||
for i in range(number):
|
||||
exes[i].variables = dict(self.variables)
|
||||
exes[i].text_ = str(self.text_)
|
||||
exes[i].messages_ = list(self.messages_)
|
||||
exes[i].cur_role = self.cur_role
|
||||
exes[i].fork_start_text_pos = len(self.text_)
|
||||
|
||||
return exes
|
||||
|
||||
def text(self):
|
||||
self.sync()
|
||||
return self.text_
|
||||
|
||||
def messages(self):
|
||||
self.sync()
|
||||
return self.messages_
|
||||
|
||||
def end(self):
|
||||
if self.use_thread:
|
||||
if self.worker.is_alive():
|
||||
self.queue.put(None)
|
||||
self.backend.end_program(self)
|
||||
|
||||
def _thread_worker_func(self):
|
||||
while True:
|
||||
expr = self.queue.get()
|
||||
if expr is None:
|
||||
self.queue.task_done()
|
||||
break
|
||||
|
||||
self._execute(expr)
|
||||
self.queue.task_done()
|
||||
if self.stream_text_event:
|
||||
self.stream_text_event.set()
|
||||
|
||||
if self.stream_text_event:
|
||||
self.stream_text_event.set()
|
||||
|
||||
self.is_finished = True
|
||||
|
||||
def _execute(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
|
||||
assert isinstance(other, SglExpr), f"{other}"
|
||||
|
||||
if isinstance(other, (SglConstantText, SglArgument)):
|
||||
self._execute_fill(other.value)
|
||||
elif isinstance(other, SglGen):
|
||||
self._execute_gen(other)
|
||||
elif isinstance(other, SglSelect):
|
||||
self._execute_select(other)
|
||||
elif isinstance(other, SglExprList):
|
||||
for x in other.expr_list:
|
||||
self._execute(x)
|
||||
elif isinstance(other, SglRoleBegin):
|
||||
self._execute_role_begin(other)
|
||||
elif isinstance(other, SglRoleEnd):
|
||||
self._execute_role_end(other)
|
||||
elif isinstance(other, SglImage):
|
||||
self._execute_image(other)
|
||||
elif isinstance(other, SglVariable):
|
||||
self._execute_variable(other)
|
||||
elif isinstance(other, SglVarScopeBegin):
|
||||
self._execute_var_scope_begin(other)
|
||||
elif isinstance(other, SglVarScopeEnd):
|
||||
self._execute_var_scope_end(other)
|
||||
elif isinstance(other, SglCommitLazy):
|
||||
self._execute_commit_lazy_operations(other)
|
||||
elif isinstance(other, SglConcateAndAppend):
|
||||
if (
|
||||
global_config.enable_parallel_encoding
|
||||
and self.backend.support_concate_and_append
|
||||
):
|
||||
self._execute_concatenate_and_append_kv_cache(other)
|
||||
else:
|
||||
self._execute_concatenate_and_append_text(other)
|
||||
else:
|
||||
raise ValueError(f"Unknown type: {type(other)}")
|
||||
|
||||
def _execute_fill(self, value: str):
|
||||
value = str(value)
|
||||
self.text_ += value
|
||||
|
||||
def _execute_image(self, expr: SglImage):
|
||||
path = expr.path
|
||||
if isinstance(path, SglArgument):
|
||||
path = path.value
|
||||
|
||||
base64_data = encode_image_base64(path)
|
||||
|
||||
self.images_.append((path, base64_data))
|
||||
self.cur_images.append((path, base64_data))
|
||||
self.text_ += self.chat_template.image_token
|
||||
|
||||
# if global_config.eager_fill_image:
|
||||
# self.backend.fill_image(self)
|
||||
|
||||
def _execute_gen(self, expr: SglGen):
|
||||
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
||||
name = expr.name
|
||||
|
||||
if not self.stream:
|
||||
comp, meta_info = self.backend.generate(
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
self.text_ += comp
|
||||
|
||||
self.variables[name] = comp
|
||||
self.meta_info[name] = meta_info
|
||||
self.variable_event[name].set()
|
||||
else:
|
||||
generator = self.backend.generate_stream(
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
|
||||
self.stream_var_event[name].set()
|
||||
|
||||
self.variables[name] = ""
|
||||
for comp, meta_info in generator:
|
||||
self.text_ += comp
|
||||
self.variables[name] += comp
|
||||
self.stream_var_event[name].set()
|
||||
self.stream_text_event.set()
|
||||
|
||||
self.meta_info[name] = meta_info
|
||||
|
||||
self.variable_event[name].set()
|
||||
self.stream_var_event[name].set()
|
||||
|
||||
def _execute_select(self, expr: SglSelect):
|
||||
decision, scores = self.backend.select(self, expr.choices, expr.temperature)
|
||||
if expr.name is not None:
|
||||
name = expr.name
|
||||
self.variables[name] = decision
|
||||
self.variable_event[name].set()
|
||||
self.text_ += decision
|
||||
|
||||
def _execute_variable(self, expr: SglVariable):
|
||||
src_executor = expr.source_stream_executor
|
||||
value = src_executor.get_var(expr.name)
|
||||
self._execute_fill(value)
|
||||
|
||||
def _execute_role_begin(self, expr: SglRoleBegin):
|
||||
assert self.cur_role is None, "Nested roles are not allowed."
|
||||
|
||||
if len(self.messages_) == 0 and expr.role != "system":
|
||||
# Insert the default system message
|
||||
default_system = self.chat_template.default_system_prompt
|
||||
if default_system:
|
||||
self._execute_role_begin(SglRoleBegin("system"))
|
||||
self._execute_fill(default_system)
|
||||
self._execute_role_end(SglRoleEnd("system"))
|
||||
|
||||
self.cur_role = expr.role
|
||||
|
||||
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
||||
|
||||
self._execute_fill(prefix)
|
||||
self.cur_role_begin_pos = len(self.text_)
|
||||
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
|
||||
|
||||
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
||||
self._execute_fill(suffix)
|
||||
|
||||
if self.cur_images:
|
||||
# OpenAI vision API format
|
||||
last_msg = {
|
||||
"role": expr.role,
|
||||
"content": [{"type": "text", "text": new_text}],
|
||||
}
|
||||
for (image_path, image_base64_data) in self.cur_images:
|
||||
last_msg["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64_data}"
|
||||
},
|
||||
}
|
||||
)
|
||||
self.messages_.append(last_msg)
|
||||
self.cur_images = []
|
||||
else:
|
||||
self.messages_.append({"role": expr.role, "content": new_text})
|
||||
|
||||
self.cur_role = None
|
||||
|
||||
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
|
||||
self.variables[expr.name] = int(len(self.text_))
|
||||
|
||||
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
||||
self.variables[expr.name] = self.text_[self.variables[expr.name] :]
|
||||
self.variable_event[expr.name].set()
|
||||
|
||||
def _execute_commit_lazy_operations(self, expr: SglCommitLazy):
|
||||
self.backend.commit_lazy_operations(self)
|
||||
|
||||
def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend):
|
||||
new_text = ""
|
||||
for s in expr.states:
|
||||
exe = s.stream_executor
|
||||
exe.sync()
|
||||
new_text += exe.text_[exe.fork_start_text_pos :]
|
||||
|
||||
self._execute_fill(new_text)
|
||||
|
||||
def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend):
|
||||
self_len = len(self.text_)
|
||||
|
||||
for i, s in enumerate(expr.states):
|
||||
exe = s.stream_executor
|
||||
exe.submit(SglCommitLazy())
|
||||
|
||||
for i, s in enumerate(expr.states):
|
||||
exe = s.stream_executor
|
||||
exe.sync()
|
||||
assert exe.fork_start_text_pos == self_len
|
||||
self.text_ += exe.text_[exe.fork_start_text_pos :]
|
||||
|
||||
src_rids = [state.stream_executor.sid for state in expr.states]
|
||||
self.backend.concatenate_and_append(src_rids, self.sid)
|
||||
|
||||
def _resolve_sampling_params(self, sampling_params):
|
||||
clone = None
|
||||
for item in [
|
||||
"max_new_tokens",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"dtype",
|
||||
"regex",
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
if clone is None:
|
||||
clone = self.default_sampling_para.clone()
|
||||
setattr(clone, item, value)
|
||||
return clone or self.default_sampling_para
|
||||
|
||||
def __del__(self):
|
||||
self.end()
|
||||
|
||||
|
||||
class ProgramState:
|
||||
"""The state of an SGL program."""
|
||||
|
||||
def __init__(self, stream_executor: StreamExecutor):
|
||||
self.stream_executor = stream_executor
|
||||
|
||||
def _role_common(self, name: str, expr: Optional[SglExpr] = None):
|
||||
if expr is not None:
|
||||
self.stream_executor.submit(
|
||||
SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
|
||||
)
|
||||
else:
|
||||
|
||||
@contextmanager
|
||||
def role_scope():
|
||||
self.stream_executor.submit(SglRoleBegin(name))
|
||||
yield
|
||||
self.stream_executor.submit(SglRoleEnd(name))
|
||||
|
||||
return role_scope()
|
||||
|
||||
def system(self, expr: Optional[SglExpr] = None):
|
||||
return self._role_common("system", expr)
|
||||
|
||||
def user(self, expr: Optional[SglExpr] = None):
|
||||
return self._role_common("user", expr)
|
||||
|
||||
def assistant(self, expr: Optional[SglExpr] = None):
|
||||
return self._role_common("assistant", expr)
|
||||
|
||||
@contextmanager
|
||||
def var_scope(self, name: str):
|
||||
self.stream_executor.submit(SglVarScopeBegin(name))
|
||||
yield
|
||||
self.stream_executor.submit(SglVarScopeEnd(name))
|
||||
|
||||
def fork(self, number: int = 1, position_ids_offset: Optional[List[int]] = None):
|
||||
stream_executors = self.stream_executor.fork(number, position_ids_offset)
|
||||
states = [ProgramState(x) for x in stream_executors]
|
||||
state_group = ProgramStateGroup(states, self)
|
||||
return state_group
|
||||
|
||||
@contextmanager
|
||||
def copy(self, position_ids_offset: Optional[List[int]] = None):
|
||||
state_group = self.fork(1, position_ids_offset)
|
||||
try:
|
||||
yield state_group[0]
|
||||
finally:
|
||||
state_group.join()
|
||||
|
||||
def text(self):
|
||||
return self.stream_executor.text()
|
||||
|
||||
def messages(self):
|
||||
return self.stream_executor.messages()
|
||||
|
||||
def sync(self):
|
||||
return self.stream_executor.sync()
|
||||
|
||||
def text_iter(self, var_name=None):
|
||||
if self.stream_executor.stream:
|
||||
prev = 0
|
||||
if var_name is None:
|
||||
event = self.stream_executor.stream_text_event
|
||||
while True:
|
||||
event.wait()
|
||||
event.clear()
|
||||
out = str(self.stream_executor.text_[prev:])
|
||||
prev += len(out)
|
||||
if out:
|
||||
yield out
|
||||
if self.stream_executor.is_finished:
|
||||
break
|
||||
else:
|
||||
event = self.stream_executor.stream_var_event[var_name]
|
||||
while True:
|
||||
event.wait()
|
||||
event.clear()
|
||||
out = str(self.stream_executor.variables[var_name][prev:])
|
||||
prev += len(out)
|
||||
if out:
|
||||
yield out
|
||||
if self.stream_executor.variable_event[var_name].is_set():
|
||||
break
|
||||
else:
|
||||
if var_name is None:
|
||||
yield self.text()
|
||||
else:
|
||||
yield self.get_var(name)
|
||||
|
||||
async def text_async_iter(self, var_name=None):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if self.stream_executor.stream:
|
||||
prev = 0
|
||||
if var_name is None:
|
||||
event = self.stream_executor.stream_text_event
|
||||
while True:
|
||||
await loop.run_in_executor(None, event.wait)
|
||||
event.clear()
|
||||
out = str(self.stream_executor.text_[prev:])
|
||||
prev += len(out)
|
||||
if out:
|
||||
yield out
|
||||
if self.stream_executor.is_finished:
|
||||
break
|
||||
else:
|
||||
event = self.stream_executor.stream_var_event[var_name]
|
||||
while True:
|
||||
await loop.run_in_executor(None, event.wait)
|
||||
event.clear()
|
||||
out = str(self.stream_executor.variables[var_name][prev:])
|
||||
prev += len(out)
|
||||
if out:
|
||||
yield out
|
||||
if self.stream_executor.variable_event[var_name].is_set():
|
||||
break
|
||||
else:
|
||||
if var_name is None:
|
||||
yield self.text()
|
||||
else:
|
||||
yield self.get_var(name)
|
||||
|
||||
def get_var(self, name):
|
||||
return self.stream_executor.get_var(name)
|
||||
|
||||
def get_meta_info(self, name):
|
||||
return self.stream_executor.get_meta_info(name)
|
||||
|
||||
def __iadd__(self, other):
|
||||
self.stream_executor.submit(other)
|
||||
return self
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.get_var(name)
|
||||
|
||||
def __del__(self):
|
||||
self.stream_executor.end()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msgs = self.messages()
|
||||
ret = ""
|
||||
for msg in msgs:
|
||||
ret += msg["role"] + ":\n" + msg["content"] + "\n"
|
||||
return ret
|
||||
|
||||
|
||||
class ProgramStateGroup:
|
||||
def __init__(
|
||||
self, states: List[ProgramState], src_state: Optional[ProgramState] = None
|
||||
):
|
||||
self.states = states
|
||||
self.src_state = src_state
|
||||
|
||||
def join(self, mode: str = "gather_variable"):
|
||||
if mode == "gather_variable":
|
||||
# Copy variables back
|
||||
src_vars = self.src_state.stream_executor.variables
|
||||
src_var_set = set(src_vars.keys())
|
||||
for child_state in self.states:
|
||||
child_state.stream_executor.sync()
|
||||
child_vars = child_state.stream_executor.variables
|
||||
new_vars = set(child_vars.keys()) - src_var_set
|
||||
|
||||
for k in new_vars:
|
||||
if k in src_vars:
|
||||
src_vars[k].append(child_vars[k])
|
||||
else:
|
||||
src_vars[k] = [child_vars[k]]
|
||||
elif mode == "concate_and_append":
|
||||
# Concatenate and append KV cache
|
||||
self.src_state += SglConcateAndAppend(self.states)
|
||||
# Need a sync here. Otherwise, `states` can be deleted.
|
||||
self.src_state.stream_executor.sync()
|
||||
else:
|
||||
raise ValueError(f"Invalid join mode: {mode}")
|
||||
|
||||
for s in self.states:
|
||||
s.stream_executor.end()
|
||||
|
||||
def __getitem__(self, i: int):
|
||||
return self.states[i]
|
||||
|
||||
def __setitem__(self, i: int, value):
|
||||
assert self.states[i] == value
|
||||
|
||||
def __iadd__(self, other):
|
||||
if isinstance(other, Callable):
|
||||
# lambda function
|
||||
for i in range(len(self.states)):
|
||||
self.states[i] += other(i)
|
||||
elif isinstance(other, SglExpr):
|
||||
for i in range(len(self.states)):
|
||||
self.states[i] += other
|
||||
elif isinstance(other, (list, tuple)):
|
||||
for i in range(len(self.states)):
|
||||
self.states[i] += other[i]
|
||||
else:
|
||||
raise ValueError(f"Invalid value: {other}")
|
||||
|
||||
return self
|
||||
442
python/sglang/lang/ir.py
Normal file
442
python/sglang/lang/ir.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""The intermediate representation."""
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingParams:
|
||||
max_new_tokens: int = 16
|
||||
stop: Union[str, List[str]] = ()
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1 # -1 means disable
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
|
||||
# for constrained generation, not included in to_xxx_kwargs
|
||||
dtype: Optional[str] = None
|
||||
regex: Optional[str] = None
|
||||
|
||||
def clone(self):
|
||||
return SamplingParams(
|
||||
self.max_new_tokens,
|
||||
self.stop,
|
||||
self.temperature,
|
||||
self.top_p,
|
||||
self.top_k,
|
||||
self.frequency_penalty,
|
||||
self.presence_penalty,
|
||||
)
|
||||
|
||||
def to_openai_kwargs(self):
|
||||
# OpenAI does not support top_k, so we drop it here
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"stop": self.stop or None,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
}
|
||||
|
||||
def to_anthropic_kwargs(self):
|
||||
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
|
||||
return {
|
||||
"max_tokens_to_sample": self.max_new_tokens,
|
||||
"stop_sequences": self.stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
def to_srt_kwargs(self):
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"stop": self.stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"regex": self.regex,
|
||||
}
|
||||
|
||||
|
||||
class SglFunction:
|
||||
def __init__(self, func, bind_arguments=None):
|
||||
self.func = func
|
||||
self.bind_arguments = bind_arguments or {}
|
||||
self.pin_prefix_rid = None
|
||||
|
||||
# Parse arguments
|
||||
argspec = inspect.getfullargspec(func)
|
||||
assert argspec.args[0] == "s", 'The first argument must be "s"'
|
||||
self.arg_names = argspec.args[1:]
|
||||
|
||||
def bind(self, **kwargs):
|
||||
assert all(key in self.arg_names for key in kwargs)
|
||||
|
||||
new_bind_dict = {**self.bind_arguments, **kwargs}
|
||||
return SglFunction(self.func, bind_arguments=new_bind_dict)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
stream: bool = False,
|
||||
backend=None,
|
||||
**kwargs,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
batch_kwargs,
|
||||
*,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
progress_bar: bool = False,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program_batch
|
||||
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
return []
|
||||
assert isinstance(batch_kwargs[0], dict)
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
batch_kwargs = [
|
||||
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
|
||||
]
|
||||
return run_program_batch(
|
||||
self,
|
||||
backend,
|
||||
batch_kwargs,
|
||||
default_sampling_para,
|
||||
num_threads,
|
||||
progress_bar,
|
||||
)
|
||||
|
||||
def trace(self, *, backend=None, **kwargs):
|
||||
from sglang.lang.tracer import trace_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return trace_program(self, kwargs, backend)
|
||||
|
||||
def pin(self, backend=None):
|
||||
from sglang.lang.interpreter import pin_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return pin_program(self, backend)
|
||||
|
||||
def unpin(self, backend=None):
|
||||
from sglang.lang.interpreter import unpin_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return unpin_program(self, backend)
|
||||
|
||||
def compile(self, *, backend=None):
|
||||
from sglang.lang.compiler import compile_func
|
||||
|
||||
return compile_func(self, backend)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
from sglang.lang.tracer import TracingScope
|
||||
|
||||
tracing_scope = TracingScope.get_current_scope()
|
||||
if tracing_scope is None:
|
||||
return self.run(*args, **kwargs)
|
||||
else:
|
||||
kwargs["backend"] = tracing_scope.tracer_state.backend
|
||||
return self.trace(*args, **kwargs)
|
||||
|
||||
|
||||
class SglExpr:
|
||||
node_ct = 0
|
||||
|
||||
def __init__(self):
|
||||
self.node_id = SglExpr.node_ct
|
||||
self.prev_node = None
|
||||
self.pid = None
|
||||
SglExpr.node_ct += 1
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
assert isinstance(other, SglExpr)
|
||||
|
||||
return self.concatenate_ir(self, other)
|
||||
|
||||
def __radd__(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
assert isinstance(other, SglExpr), f"{other}"
|
||||
|
||||
return self.concatenate_ir(other, self)
|
||||
|
||||
def concatenate_ir(self, a, b):
|
||||
if isinstance(a, SglExprList):
|
||||
if isinstance(b, SglExprList):
|
||||
return SglExprList(a.expr_list + b.expr_list)
|
||||
else:
|
||||
return SglExprList(a.expr_list + [b])
|
||||
elif isinstance(b, SglExprList):
|
||||
return SglExprList([a] + b.expr_list)
|
||||
|
||||
return SglExprList([a, b])
|
||||
|
||||
def print_graph_dfs(self):
|
||||
ret = [""]
|
||||
visited = set()
|
||||
|
||||
def dfs_print(x):
|
||||
if x is None or x in visited:
|
||||
return
|
||||
visited.add(x)
|
||||
|
||||
# Print dependency
|
||||
if x.prev_node is not None:
|
||||
dfs_print(x.prev_node)
|
||||
|
||||
if isinstance(x, SglExprList):
|
||||
for y in x.expr_list:
|
||||
dfs_print(y)
|
||||
# elif isinstance(x, SglRole):
|
||||
# dfs_print(x.expr)
|
||||
elif isinstance(x, SglVariable):
|
||||
dfs_print(x.source)
|
||||
|
||||
# Print the node itself
|
||||
if isinstance(x, (SglFork, SglGetForkItem)):
|
||||
ret[0] += f"%{x.node_id} = {x}\n"
|
||||
else:
|
||||
if x.prev_node is not None:
|
||||
ret[0] += (
|
||||
f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
|
||||
)
|
||||
else:
|
||||
ret[0] += f"%{x.node_id} = " + str(x) + "\n"
|
||||
|
||||
dfs_print(self)
|
||||
return ret[0]
|
||||
|
||||
|
||||
class SglExprList(SglExpr):
|
||||
def __init__(self, expr_list: List[SglExpr]):
|
||||
super().__init__()
|
||||
self.expr_list = expr_list
|
||||
|
||||
def __repr__(self):
|
||||
return f"ExprList({self.expr_list})"
|
||||
|
||||
|
||||
class SglArgument(SglExpr):
|
||||
def __init__(self, name: str, value: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"Argument(name={self.name}, value={repr(self.value)})"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.value)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.value[i]
|
||||
|
||||
def __int__(self):
|
||||
return self.value
|
||||
|
||||
def __bool__(self):
|
||||
return self.value
|
||||
|
||||
def __format__(self, *args):
|
||||
raise TypeError(
|
||||
"Cannot put argument inside a f-string. "
|
||||
"This is not compatible with the tracer. "
|
||||
)
|
||||
|
||||
|
||||
class SglImage(SglExpr):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SglImage({self.path})"
|
||||
|
||||
|
||||
class SglGen(SglExpr):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
max_new_tokens,
|
||||
stop,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
dtype,
|
||||
regex,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.sampling_params = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
dtype=dtype,
|
||||
regex=regex,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Gen('{self.name}')"
|
||||
|
||||
|
||||
class SglConstantText(SglExpr):
|
||||
def __init__(self, value):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"Constant({repr(self.value)})"
|
||||
|
||||
|
||||
class SglRoleBegin(SglExpr):
|
||||
def __init__(self, role):
|
||||
super().__init__()
|
||||
self.role = role
|
||||
|
||||
def __repr__(self):
|
||||
return f"RoleBegin({self.role})"
|
||||
|
||||
|
||||
class SglRoleEnd(SglExpr):
|
||||
def __init__(self, role):
|
||||
super().__init__()
|
||||
self.role = role
|
||||
|
||||
def __repr__(self):
|
||||
return f"RoleEnd({self.role})"
|
||||
|
||||
|
||||
class SglSelect(SglExpr):
|
||||
def __init__(self, name, choices, temperature):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.choices = choices
|
||||
self.temperature = temperature
|
||||
|
||||
def __repr__(self):
|
||||
return f"Select({self.name}, choices={self.choices})"
|
||||
|
||||
|
||||
class SglFork(SglExpr):
|
||||
def __init__(self, number, position_ids_offset=None):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
self.position_ids_offset = position_ids_offset
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Fork(%{self.prev_node.node_id}, number={self.number}, "
|
||||
f"position_ids_offset={self.position_ids_offset})"
|
||||
)
|
||||
|
||||
|
||||
class SglGetForkItem(SglExpr):
|
||||
def __init__(self, index):
|
||||
super().__init__()
|
||||
self.index = index
|
||||
|
||||
def __repr__(self):
|
||||
return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"
|
||||
|
||||
|
||||
class SglVariable(SglExpr):
|
||||
def __init__(self, name, source):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.source = source
|
||||
|
||||
def __repr__(self):
|
||||
return f"Variable('{self.name}', source=%{self.source.node_id})"
|
||||
|
||||
|
||||
class SglVarScopeBegin(SglExpr):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return f"VarScopeBegin('{self.name}')"
|
||||
|
||||
|
||||
class SglVarScopeEnd(SglExpr):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return f"VarScopeEnd('{self.name}')"
|
||||
|
||||
|
||||
class SglConcateAndAppend(SglExpr):
|
||||
def __init__(self, states):
|
||||
super().__init__()
|
||||
self.states = states
|
||||
|
||||
def __repr__(self):
|
||||
return f"ConcatenateAndAppend('{self.states}')"
|
||||
|
||||
|
||||
class SglCommitLazy(SglExpr):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self):
|
||||
return f"CommitLazy()"
|
||||
279
python/sglang/lang/tracer.py
Normal file
279
python/sglang/lang/tracer.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Tracing a program."""
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
||||
from sglang.lang.ir import (
|
||||
SglArgument,
|
||||
SglCommitLazy,
|
||||
SglConcateAndAppend,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFork,
|
||||
SglFunction,
|
||||
SglGen,
|
||||
SglGetForkItem,
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglVariable,
|
||||
SglVarScopeBegin,
|
||||
SglVarScopeEnd,
|
||||
)
|
||||
|
||||
|
||||
class StopTracing(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def extract_prefix_by_tracing(program, backend):
|
||||
# Create dummy arguments
|
||||
dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
|
||||
arguments = dummy_arguments
|
||||
arguments.update(program.bind_arguments)
|
||||
|
||||
# Trace
|
||||
tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
|
||||
try:
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
except StopTracing:
|
||||
pass
|
||||
|
||||
# Run and cache prefix
|
||||
prefix = ""
|
||||
for expr in tracer.flatten_nodes():
|
||||
if isinstance(expr, SglConstantText):
|
||||
prefix += expr.value
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def trace_program(program, arguments, backend):
|
||||
# Create dummy backend
|
||||
if backend is None:
|
||||
backend = BaseBackend()
|
||||
|
||||
# Create dummy arguments
|
||||
dummy_arguments = {
|
||||
name: SglArgument(name, None)
|
||||
for name in program.arg_names
|
||||
if name not in arguments
|
||||
}
|
||||
arguments.update(dummy_arguments)
|
||||
arguments.update(program.bind_arguments)
|
||||
|
||||
# Trace
|
||||
tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
return tracer
|
||||
|
||||
|
||||
class TracerProgramState(ProgramState):
|
||||
def __init__(self, backend, arguments, only_trace_prefix):
|
||||
self.pid = uuid.uuid4().hex
|
||||
self.backend = backend
|
||||
self.arguments: Dict[str, Any] = arguments
|
||||
self.only_trace_prefix = only_trace_prefix
|
||||
|
||||
if hasattr(backend, "endpoint"):
|
||||
self.backend = backend.endpoint
|
||||
|
||||
self.nodes = []
|
||||
self.last_node = None
|
||||
self.variables = {}
|
||||
self.ret_value = None
|
||||
|
||||
# For completion
|
||||
|
||||
# For chat
|
||||
self.messages_ = []
|
||||
self.cur_role = None
|
||||
self.chat_template = self.backend.get_chat_template()
|
||||
|
||||
# For multi states
|
||||
self.child_states = []
|
||||
|
||||
cur_scope = TracingScope.get_current_scope()
|
||||
if cur_scope is not None:
|
||||
cur_scope.add_child_state(self)
|
||||
|
||||
##################################
|
||||
########### Public API ###########
|
||||
##################################
|
||||
|
||||
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
|
||||
fork_node = SglFork(number)
|
||||
fork_node.prev_node = self.last_node
|
||||
|
||||
states = [
|
||||
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
|
||||
for _ in range(number)
|
||||
]
|
||||
|
||||
for i in range(number):
|
||||
node = SglGetForkItem(i)
|
||||
node.prev_node = fork_node
|
||||
states[i].last_node = node
|
||||
states[i].variables = dict(self.variables)
|
||||
states[i].messages_ = list(self.messages_)
|
||||
states[i].cur_role = self.cur_role
|
||||
states[i].chat_template = self.chat_template
|
||||
|
||||
state_group = ProgramStateGroup(states, self)
|
||||
|
||||
return state_group
|
||||
|
||||
##################################
|
||||
########## Internal API ##########
|
||||
##################################
|
||||
|
||||
def _append_node(self, other: SglExpr):
|
||||
self.nodes.append(other)
|
||||
other.prev_node = self.last_node
|
||||
self.last_node = other
|
||||
|
||||
def _execute(self, other: SglExpr):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
|
||||
other.pid = self.pid
|
||||
|
||||
if isinstance(other, SglConstantText):
|
||||
self._execute_fill(other)
|
||||
elif isinstance(other, SglGen):
|
||||
self._execute_gen(other)
|
||||
elif isinstance(other, SglSelect):
|
||||
self._execute_select(other)
|
||||
elif isinstance(other, SglExprList):
|
||||
for x in other.expr_list:
|
||||
self._execute(x)
|
||||
elif isinstance(other, SglRoleBegin):
|
||||
self._execute_role_begin(other)
|
||||
elif isinstance(other, SglRoleEnd):
|
||||
self._execute_role_end(other)
|
||||
elif isinstance(other, SglVarScopeBegin):
|
||||
self._execute_var_scope_begin(other)
|
||||
elif isinstance(other, SglVarScopeEnd):
|
||||
self._execute_var_scope_end(other)
|
||||
else:
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
else:
|
||||
self._append_node(other)
|
||||
|
||||
return self
|
||||
|
||||
def __iadd__(self, other):
|
||||
self._execute(other)
|
||||
return self
|
||||
|
||||
def _execute_fill(self, expr: SglConstantText):
|
||||
if isinstance(expr, str):
|
||||
expr = SglConstantText(expr)
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_gen(self, expr: SglGen):
|
||||
name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
|
||||
new_node = SglVariable(name, source=expr)
|
||||
self.variables[name] = new_node
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_select(self, expr: SglSelect):
|
||||
name = (
|
||||
expr.name if expr.name is not None else "select_" + str(len(self.variables))
|
||||
)
|
||||
new_node = SglVariable(name, source=expr)
|
||||
self.variables[name] = new_node
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_role_begin(self, expr: SglRoleBegin):
|
||||
assert self.cur_role is None, "Nested roles are not allowed."
|
||||
|
||||
if len(self.messages_) == 0 and expr.role != "system":
|
||||
# Insert default system message
|
||||
default_system = self.chat_template.default_system_prompt
|
||||
if default_system:
|
||||
self._execute_role_begin(SglRoleBegin("system"))
|
||||
self._execute_fill(default_system)
|
||||
self._execute_role_end(SglRoleEnd("system"))
|
||||
|
||||
self.cur_role = expr.role
|
||||
|
||||
prefix, suffix = self.chat_template.get_prefix_and_suffix(
|
||||
expr.role, self.messages_
|
||||
)
|
||||
|
||||
self._execute_fill(prefix)
|
||||
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
prefix, suffix = self.chat_template.get_prefix_and_suffix(
|
||||
expr.role, self.messages_
|
||||
)
|
||||
|
||||
self._execute_fill(suffix)
|
||||
|
||||
self.messages_.append({"role": expr.role, "content": ""})
|
||||
|
||||
self.cur_role = None
|
||||
|
||||
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
||||
new_node = SglVariable(name, source=self.last_node)
|
||||
self.variables[name] = new_node
|
||||
|
||||
def get_var(self, name):
|
||||
ret = self.arguments.get(name, None)
|
||||
if ret is not None:
|
||||
return ret
|
||||
|
||||
v = self.variables[name]
|
||||
return SglVariable(v.name, v.source)
|
||||
|
||||
def flatten_nodes(self):
|
||||
def traverse(cur):
|
||||
if isinstance(cur, SglExprList):
|
||||
for child in cur.expr_list:
|
||||
traverse(child)
|
||||
else:
|
||||
ret.append(cur)
|
||||
|
||||
ret = []
|
||||
for x in self.nodes:
|
||||
traverse(x)
|
||||
return ret
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
|
||||
class TracingScope:
|
||||
cur_scope = None
|
||||
|
||||
def __init__(self, tracer_state: TracerProgramState):
|
||||
self.tracer_state = tracer_state
|
||||
self.last_scope = TracingScope.cur_scope
|
||||
|
||||
def __enter__(self):
|
||||
TracingScope.cur_scope = self
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
TracingScope.cur_scope = self.last_scope
|
||||
|
||||
@staticmethod
|
||||
def get_current_scope():
|
||||
return TracingScope.cur_scope
|
||||
|
||||
def add_child_state(self, state: TracerProgramState):
|
||||
cur_scope = self
|
||||
while cur_scope != None:
|
||||
cur_scope.tracer_state.child_states.append(state)
|
||||
cur_scope = cur_scope.last_scope
|
||||
11
python/sglang/launch_server.py
Normal file
11
python/sglang/launch_server.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import argparse
|
||||
|
||||
from sglang.srt.server import ServerArgs, launch_server
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
launch_server(server_args, None)
|
||||
385
python/sglang/srt/constrained/fsm.py
Normal file
385
python/sglang/srt/constrained/fsm.py
Normal file
@@ -0,0 +1,385 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py
|
||||
from typing import List, NewType, Protocol
|
||||
|
||||
import interegular
|
||||
from lark import Lark
|
||||
|
||||
# from outlines.fsm.parsing import PartialLark
|
||||
from sglang.srt.constrained.regex import (
|
||||
create_fsm_index_tokenizer,
|
||||
make_deterministic_fsm,
|
||||
)
|
||||
from sglang.srt.constrained.tokenizer import Tokenizer
|
||||
|
||||
FSMState = NewType("FSMState", int)
|
||||
|
||||
|
||||
class FSM(Protocol):
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
...
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
...
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
...
|
||||
|
||||
def reset(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
class StopAtTokenFSM(FSM):
|
||||
"""FSM to generate text until a specified token id is generated or
|
||||
a specified number of tokens has been generated.
|
||||
|
||||
Text is usually produced until the EOS token is generated by the
|
||||
model.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: "Tokenizer",
|
||||
stop_token_id: int,
|
||||
):
|
||||
self.stop_token_id = stop_token_id
|
||||
self.num_tokens_generated = 0
|
||||
self.vocabulary = tokenizer.vocabulary.values()
|
||||
self.final_states = {1}
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
When in the initial state we allow every token to be generated.
|
||||
In the final state the only allowed token is `stop_token_id`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
if state == 0:
|
||||
return list(self.vocabulary)
|
||||
else:
|
||||
return [self.stop_token_id]
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
The FSM stays in the initial state `0` unless the specified stop token
|
||||
has been generated or the maximum number of tokens has been reached. In
|
||||
which case the FSM moves to the final state `1`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
|
||||
if token_id == self.stop_token_id:
|
||||
return FSMState(1)
|
||||
|
||||
return FSMState(0)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
"""Determine whether the current state of the FSM is a final state."""
|
||||
return state in self.final_states
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state. Here this only resets the token counter."""
|
||||
self.num_tokens_generated = 0
|
||||
|
||||
|
||||
class RegexFSM(FSM):
|
||||
"""FSM to generate text that is in the language of a regular expression."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regex_string: str,
|
||||
tokenizer: "Tokenizer",
|
||||
):
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||
(
|
||||
self.states_to_token_maps,
|
||||
self.empty_token_ids,
|
||||
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
|
||||
|
||||
# We make sure that it is possible to generate strings in the language
|
||||
# of the regular expression with the tokens present in the model's
|
||||
# vocabulary.
|
||||
if not any(
|
||||
regex_fsm.finals.intersection(v.values())
|
||||
for v in self.states_to_token_maps.values()
|
||||
):
|
||||
raise ValueError(
|
||||
"The vocabulary does not allow us to build a sequence that matches the input regex"
|
||||
)
|
||||
|
||||
self.final_states = regex_fsm.finals | {
|
||||
-1
|
||||
} # Include the EOS token in final states
|
||||
self.num_tokens_generated = 0
|
||||
self.vocabulary = tokenizer.vocabulary.values()
|
||||
self.end_token_id = tokenizer.eos_token_id
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
The initialization of the FSM builds an index which maps FSM states to a
|
||||
map from authorized tokens to the state in which the FSM needs to move
|
||||
if said token is generated. Therefore the authorized tokens at the
|
||||
current state are the keys of the map returned by the value of the index
|
||||
for current state.
|
||||
|
||||
If the current state is not contained in the end this means that we are
|
||||
in a final state of the FSM. We only authorize EOS tokens in the final
|
||||
state.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
next_tokens_to_end_states = self.states_to_token_maps.get(state)
|
||||
|
||||
if next_tokens_to_end_states is None:
|
||||
return [self.end_token_id]
|
||||
else:
|
||||
return list(next_tokens_to_end_states.keys())
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
We use the index to determine to which state the FSM should transition
|
||||
given the token that was just generated.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
|
||||
if token_id == self.end_token_id:
|
||||
return FSMState(-1)
|
||||
|
||||
last_token_to_end_state = self.states_to_token_maps[state]
|
||||
next_state = last_token_to_end_state.get(token_id)
|
||||
if next_state is None:
|
||||
next_state = -1
|
||||
|
||||
return FSMState(next_state)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
"""Determine whether the current state of the FSM is a final state."""
|
||||
return state in self.final_states
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state. Here this only resets the token counter."""
|
||||
self.num_tokens_generated = 0
|
||||
|
||||
|
||||
class CFGFSM(FSM):
|
||||
"""FSM to generate text that is in the language of a context-free grammar."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_string: str,
|
||||
tokenizer: "Tokenizer",
|
||||
):
|
||||
# self.parser = PartialLark(cfg_string, parser="lalr")
|
||||
self.parser = Lark(
|
||||
cfg_string,
|
||||
parser="lalr",
|
||||
lexer="contextual",
|
||||
propagate_positions=False,
|
||||
maybe_placeholders=False,
|
||||
regex=True,
|
||||
)
|
||||
self.terminal_regexps = dict()
|
||||
for terminal in self.parser.terminals:
|
||||
if terminal.pattern is not None:
|
||||
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
|
||||
self.terminal_regexps["$END"] = tokenizer.eos_token
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.num_tokens_generated = 0
|
||||
self.generations: List[str] = []
|
||||
self.regex_fsms: List[RegexFSM] = []
|
||||
self.reset_state: List[bool] = []
|
||||
self.allow_eos: List[bool] = []
|
||||
self.done: List[bool] = []
|
||||
|
||||
def _set_next_regex_fsm(self, idx: int = 0) -> None:
|
||||
"""Use the CFG incremental parser to set the next regex FSM.
|
||||
|
||||
Check what the CFG incremental parser proposes next.
|
||||
If the only proposal is the EOS token,
|
||||
we set the state to done and return.
|
||||
If there are other proposals,
|
||||
we set a new regex FSM and return.
|
||||
|
||||
"""
|
||||
interactive = self.parser.parse_interactive(self.generations[idx])
|
||||
interactive.exhaust_lexer()
|
||||
options = {self.terminal_regexps[x] for x in interactive.accepts()}
|
||||
|
||||
if self.terminal_regexps["$END"] in options:
|
||||
options.remove(self.terminal_regexps["$END"])
|
||||
if len(options) == 0:
|
||||
self.done[idx] = True
|
||||
return
|
||||
self.allow_eos[idx] = True
|
||||
options.add("")
|
||||
assert len(options) > 1
|
||||
|
||||
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
|
||||
args = (
|
||||
regex_string,
|
||||
self.tokenizer,
|
||||
)
|
||||
if len(self.regex_fsms) <= idx:
|
||||
self.regex_fsms.append(RegexFSM(*args))
|
||||
else:
|
||||
self.regex_fsms[idx] = RegexFSM(*args)
|
||||
self.reset_state[idx] = True
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
Upon initialization, the CFG incremental parser is used to determine the first regex.
|
||||
|
||||
This regex is used for proposals until either:
|
||||
- the regex is exhausted, and its only remaining option is the EOS token,
|
||||
in which case we always transition to the next regex
|
||||
- the regex can be exhausted, but the EOS token is not the only remaining option,
|
||||
in which case we transition to the next regex with probability P (TODO)
|
||||
or remove the possibility of generating the EOS token and continue with the current regex
|
||||
|
||||
The CFG incremental parser is allowed to propose the EOS token from any final state,
|
||||
and once it is generated, the FSM will continue to always generate the EOS token.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
if len(self.generations) <= idx:
|
||||
self.generations.append("")
|
||||
self.reset_state.append(False)
|
||||
self.allow_eos.append(False)
|
||||
self.done.append(False)
|
||||
|
||||
if len(self.regex_fsms) > idx:
|
||||
proposal = self.regex_fsms[idx].allowed_token_ids(state)
|
||||
if self.tokenizer.eos_token_id not in proposal:
|
||||
return proposal
|
||||
if set(proposal) != {self.tokenizer.eos_token_id}:
|
||||
if False: # TODO: THIS NEEDS TO BE SAMPLED
|
||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||
return proposal
|
||||
|
||||
self._set_next_regex_fsm(idx)
|
||||
|
||||
if self.done[idx]:
|
||||
return [self.tokenizer.eos_token_id]
|
||||
|
||||
if self.reset_state[idx]:
|
||||
state = FSMState(0)
|
||||
|
||||
proposal = self.regex_fsms[idx].allowed_token_ids(state)
|
||||
if self.allow_eos[idx]:
|
||||
self.allow_eos[idx] = False
|
||||
else:
|
||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||
assert len(proposal) > 0
|
||||
return proposal
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
Transitions the underlying regex FSM to its next state.
|
||||
If at max tokens or EOS token, transition permanently to the final state.
|
||||
Update stored partial generations for subsequent incremental parsing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
if token_id == self.tokenizer.eos_token_id:
|
||||
self.done[idx] = True
|
||||
return FSMState(-1)
|
||||
if self.reset_state[idx]:
|
||||
self.reset_state[idx] = False
|
||||
state = FSMState(0)
|
||||
|
||||
self.generations[idx] += self.tokenizer.decode([token_id])[0]
|
||||
|
||||
return self.regex_fsms[idx].next_state(state, token_id, idx)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
"""Return whether the current state of the FSM is a final state."""
|
||||
return self.done[idx]
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state, so it can be called on a fresh batch on inputs."""
|
||||
self.num_tokens_generated = 0
|
||||
self.generations = []
|
||||
self.regex_fsms = []
|
||||
self.reset_state = []
|
||||
self.done = []
|
||||
41
python/sglang/srt/constrained/fsm_cache.py
Normal file
41
python/sglang/srt/constrained/fsm_cache.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import threading
|
||||
|
||||
from sglang.srt.constrained.fsm import RegexFSM
|
||||
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
||||
|
||||
|
||||
def get_fsm(regex, tokenizer, fsm_cache_entry):
|
||||
outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
fsm = RegexFSM(regex, outlines_tokenizer)
|
||||
fsm_cache_entry.fsm = fsm
|
||||
fsm_cache_entry.event.set()
|
||||
|
||||
|
||||
class FSMCacheEntry:
|
||||
def __init__(self):
|
||||
self.fsm = None
|
||||
self.event = threading.Event()
|
||||
|
||||
|
||||
class FSMCache:
|
||||
def __init__(self, tokenizer):
|
||||
self.cache = {}
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def init_fsm_in_background(self, regex):
|
||||
if regex not in self.cache:
|
||||
self.cache[regex] = FSMCacheEntry()
|
||||
threading.Thread(
|
||||
target=get_fsm,
|
||||
args=(
|
||||
regex,
|
||||
self.tokenizer,
|
||||
self.cache[regex],
|
||||
),
|
||||
).start()
|
||||
|
||||
def get_fsm(self, regex):
|
||||
self.init_fsm_in_background(regex)
|
||||
entry = self.cache[regex]
|
||||
entry.event.wait()
|
||||
return entry.fsm
|
||||
586
python/sglang/srt/constrained/regex.py
Normal file
586
python/sglang/srt/constrained/regex.py
Normal file
@@ -0,0 +1,586 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/regex.py
|
||||
from collections import namedtuple
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Generator, List, Sequence, Set, Tuple
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from interegular.fsm import FSM, Alphabet, OblivionError, anything_else
|
||||
from numba.typed.typedobjectutils import _nonoptional
|
||||
from sglang.srt.constrained.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class BetterAlphabet(Alphabet):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert anything_else in self._symbol_mapping
|
||||
self.anything_value = self._symbol_mapping[anything_else]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._symbol_mapping.get(item, self.anything_value)
|
||||
|
||||
def copy(self):
|
||||
return BetterAlphabet(self._symbol_mapping.copy())
|
||||
|
||||
|
||||
class BetterFSM(FSM):
|
||||
flat_transition_map: Dict[Tuple[int, int], int]
|
||||
trans_key_to_states: Dict[int, List[int]]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if not isinstance(self.alphabet, BetterAlphabet):
|
||||
self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping)
|
||||
|
||||
flat_transition_map = {}
|
||||
trans_key_to_states = {}
|
||||
for from_state, trans_map in self.map.items():
|
||||
for trans_key, to_state in trans_map.items():
|
||||
flat_transition_map[(from_state, trans_key)] = to_state
|
||||
trans_key_to_states.setdefault(trans_key, set()).add(from_state)
|
||||
|
||||
self.__dict__["trans_key_to_states"] = trans_key_to_states
|
||||
self.__dict__["flat_transition_map"] = flat_transition_map
|
||||
self.__dict__["_fsm_info"] = None
|
||||
|
||||
def copy(self):
|
||||
return BetterFSM(
|
||||
alphabet=self.alphabet.copy(),
|
||||
states=self.states.copy(),
|
||||
initial=self.initial,
|
||||
finals=self.finals.copy(),
|
||||
map=self.map.copy(),
|
||||
__no_validation__=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def fsm_info(self):
|
||||
if self._fsm_info is None:
|
||||
flat_transition_map_items = np.fromiter(
|
||||
((a[0], a[1], b) for a, b in self.flat_transition_map.items()),
|
||||
dtype=np.dtype("i8, i8, i8"),
|
||||
)
|
||||
trans_key_to_states_items = np.fromiter(
|
||||
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
|
||||
dtype=np.dtype("i8, i8"),
|
||||
)
|
||||
alphabet_symbol_mapping_items = np.fromiter(
|
||||
(
|
||||
it
|
||||
for it in self.alphabet._symbol_mapping.items()
|
||||
if it[0] != anything_else
|
||||
),
|
||||
dtype=np.dtype("U1, i8"),
|
||||
)
|
||||
nb_finals = np.fromiter(self.finals, dtype=np.dtype("i8"))
|
||||
self.__dict__["_fsm_info"] = create_fsm_info(
|
||||
self.initial,
|
||||
nb_finals,
|
||||
flat_transition_map_items,
|
||||
trans_key_to_states_items,
|
||||
self.alphabet.anything_value,
|
||||
alphabet_symbol_mapping_items,
|
||||
)
|
||||
|
||||
return self._fsm_info
|
||||
|
||||
|
||||
nb_int_list_type = numba.types.ListType(numba.int64)
|
||||
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
|
||||
nb_unichar_1_type = numba.types.UnicodeCharSeq(1)
|
||||
|
||||
|
||||
@numba.njit(cache=True)
|
||||
def create_fsm_info(
|
||||
py_initial,
|
||||
py_finals,
|
||||
flat_transition_map_items,
|
||||
trans_key_to_states_items,
|
||||
py_anything_value,
|
||||
alphabet_symbol_mapping_items,
|
||||
):
|
||||
trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type)
|
||||
for trans_key_and_state in trans_key_to_states_items:
|
||||
trans_key_to_states.setdefault(
|
||||
trans_key_and_state[0], numba.typed.List.empty_list(numba.int64)
|
||||
).append(trans_key_and_state[1])
|
||||
|
||||
flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64)
|
||||
for trans_key_and_state in flat_transition_map_items:
|
||||
flat_transition_map[
|
||||
(trans_key_and_state[0], trans_key_and_state[1])
|
||||
] = trans_key_and_state[2]
|
||||
|
||||
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_1_type, numba.int64)
|
||||
for symbol_and_trans_key in alphabet_symbol_mapping_items:
|
||||
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]
|
||||
|
||||
initial = numba.int64(py_initial)
|
||||
|
||||
finals = set()
|
||||
for final in py_finals:
|
||||
finals.add(final)
|
||||
|
||||
anything_value = numba.int64(py_anything_value)
|
||||
|
||||
return FSMInfo(
|
||||
initial,
|
||||
finals,
|
||||
flat_transition_map,
|
||||
trans_key_to_states,
|
||||
anything_value,
|
||||
alphabet_symbol_map,
|
||||
)
|
||||
|
||||
|
||||
FSMInfo = namedtuple(
|
||||
"FSMInfo",
|
||||
[
|
||||
"initial",
|
||||
"finals",
|
||||
"transitions",
|
||||
"trans_key_to_states",
|
||||
"alphabet_anything_value",
|
||||
"alphabet_symbol_mapping",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
|
||||
"""Construct an equivalent FSM with deterministic state labels."""
|
||||
old_to_new_trans_keys = {
|
||||
trans_key: i
|
||||
for i, (trans_key, _) in enumerate(
|
||||
sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1]))
|
||||
)
|
||||
}
|
||||
|
||||
new_symbol_mapping = {
|
||||
symbol: old_to_new_trans_keys[trans_key]
|
||||
for symbol, trans_key in fsm.alphabet._symbol_mapping.items()
|
||||
}
|
||||
|
||||
new_alphabet = BetterAlphabet(new_symbol_mapping)
|
||||
|
||||
new_map = {
|
||||
from_state: {
|
||||
old_to_new_trans_keys[trans_key]: to_state
|
||||
for trans_key, to_state in trans_map.items()
|
||||
}
|
||||
for from_state, trans_map in fsm.map.items()
|
||||
}
|
||||
|
||||
old_to_new_states = {}
|
||||
old_to_new_states[fsm.initial] = 0
|
||||
|
||||
i = 0
|
||||
seen = {fsm.initial}
|
||||
old_state_queue = [fsm.initial]
|
||||
while old_state_queue:
|
||||
old_state = old_state_queue.pop(-1)
|
||||
transitions = new_map[old_state]
|
||||
sorted_transitions = sorted(transitions.items(), key=lambda v: v[0])
|
||||
for _, old_state in sorted_transitions:
|
||||
if old_state not in seen:
|
||||
old_state_queue.append(old_state)
|
||||
seen.add(old_state)
|
||||
if old_state not in old_to_new_states:
|
||||
i += 1
|
||||
old_to_new_states[old_state] = i
|
||||
|
||||
new_map = dict(
|
||||
sorted(
|
||||
(
|
||||
(
|
||||
old_to_new_states[from_state],
|
||||
dict(
|
||||
sorted(
|
||||
(
|
||||
(trans_key, old_to_new_states[to_state])
|
||||
for trans_key, to_state in trans_map.items()
|
||||
),
|
||||
key=lambda v: v[0],
|
||||
)
|
||||
),
|
||||
)
|
||||
for from_state, trans_map in new_map.items()
|
||||
),
|
||||
key=lambda v: v[0],
|
||||
)
|
||||
)
|
||||
|
||||
new_initial = 0
|
||||
new_finals = frozenset(
|
||||
sorted(old_to_new_states[old_state] for old_state in fsm.finals)
|
||||
)
|
||||
new_states = frozenset(sorted(new_map.keys()))
|
||||
|
||||
new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map)
|
||||
|
||||
return new_fsm, old_to_new_states
|
||||
|
||||
|
||||
@numba.njit(nogil=True, cache=True)
|
||||
def _walk_fsm(
|
||||
fsm_transitions: Dict[Tuple[int, int], int],
|
||||
alphabet_symbol_mapping: Dict[str, int],
|
||||
alphabet_anything_value: int,
|
||||
fsm_initial: int,
|
||||
fsm_finals: Set[int],
|
||||
input_string: str,
|
||||
start_state: int,
|
||||
full_match: bool = True,
|
||||
) -> List[int]:
|
||||
state = start_state
|
||||
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
|
||||
last_final_idx: int = numba.uint64(0)
|
||||
|
||||
for i, symbol in enumerate(input_string):
|
||||
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
|
||||
|
||||
new_state = fsm_transitions.get((state, trans_key))
|
||||
|
||||
if new_state is None:
|
||||
if not full_match and last_final_idx > 0:
|
||||
return accepted_states[:last_final_idx]
|
||||
|
||||
return numba.typed.List.empty_list(numba.int64)
|
||||
|
||||
state = new_state
|
||||
|
||||
if state in fsm_finals:
|
||||
last_final_idx = numba.uint64(i + 1)
|
||||
|
||||
accepted_states.append(_nonoptional(state))
|
||||
|
||||
if full_match and last_final_idx - 1 != i:
|
||||
return numba.typed.List.empty_list(numba.int64)
|
||||
|
||||
return accepted_states
|
||||
|
||||
|
||||
def walk_fsm(
|
||||
fsm: BetterFSM,
|
||||
input_string: str,
|
||||
start_state: int,
|
||||
full_match: bool = True,
|
||||
) -> List[int]:
|
||||
fsm_finals = fsm.finals
|
||||
|
||||
state = start_state
|
||||
accepted_states: List[int] = []
|
||||
last_final_idx: int = 0
|
||||
|
||||
alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
|
||||
alphabet_anything_value = fsm.alphabet.anything_value
|
||||
fsm_transitions = fsm.flat_transition_map
|
||||
|
||||
for i, symbol in enumerate(input_string):
|
||||
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
|
||||
|
||||
new_state = fsm_transitions.get((state, trans_key))
|
||||
|
||||
if new_state is None:
|
||||
if not full_match and last_final_idx > 0:
|
||||
return accepted_states[:last_final_idx]
|
||||
|
||||
return []
|
||||
|
||||
state = new_state
|
||||
|
||||
if state in fsm_finals:
|
||||
last_final_idx = i + 1
|
||||
|
||||
accepted_states.append(state)
|
||||
|
||||
if full_match and last_final_idx - 1 != i:
|
||||
return []
|
||||
|
||||
return accepted_states
|
||||
|
||||
|
||||
def fsm_union(
|
||||
fsms: Sequence[FSM],
|
||||
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
|
||||
"""Construct an FSM representing the union of the FSMs in `fsms`.
|
||||
|
||||
This is an updated version of `interegular.fsm.FSM.union` made to return an
|
||||
extra map of component FSMs to the sets of state transitions that
|
||||
correspond to them in the new FSM.
|
||||
|
||||
"""
|
||||
|
||||
alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])
|
||||
|
||||
indexed_fsms = tuple(enumerate(fsms))
|
||||
|
||||
initial = {i: fsm.initial for (i, fsm) in indexed_fsms}
|
||||
|
||||
# Dedicated function accepting a "superset" and returning the next
|
||||
# "superset" obtained by following this transition in the new FSM
|
||||
def follow(current_state, new_transition: int):
|
||||
next = {}
|
||||
for i, f in indexed_fsms:
|
||||
old_transition = new_to_old[i][new_transition]
|
||||
if (
|
||||
i in current_state
|
||||
and current_state[i] in f.map
|
||||
and old_transition in f.map[current_state[i]]
|
||||
):
|
||||
next[i] = f.map[current_state[i]][old_transition]
|
||||
if not next:
|
||||
raise OblivionError
|
||||
return next
|
||||
|
||||
states = [initial]
|
||||
finals: Set[int] = set()
|
||||
map: Dict[int, Dict[int, int]] = {}
|
||||
|
||||
# Map component FSMs to their new state-to-state transitions, finals, and a
|
||||
# map translating component FSM states to aggregate FSM states
|
||||
fsms_to_trans_finals: Dict[
|
||||
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
|
||||
] = {}
|
||||
|
||||
i = 0
|
||||
while i < len(states):
|
||||
state = states[i]
|
||||
|
||||
# Add to the finals of the aggregate FSM whenever we hit a final in a
|
||||
# component FSM
|
||||
if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
|
||||
finals.add(i)
|
||||
|
||||
# Compute the map for this state
|
||||
map[i] = {}
|
||||
for transition in alphabet.by_transition:
|
||||
try:
|
||||
next = follow(state, transition)
|
||||
except OblivionError:
|
||||
# Reached an oblivion state; don't list it
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# TODO: Seems like this could--and should--be avoided
|
||||
j = states.index(next)
|
||||
except ValueError:
|
||||
j = len(states)
|
||||
states.append(next)
|
||||
|
||||
map[i][transition] = j
|
||||
|
||||
for fsm_id, fsm_state in next.items():
|
||||
(
|
||||
fsm_transitions,
|
||||
fsm_finals,
|
||||
fsm_old_to_new,
|
||||
) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
|
||||
old_from = state[fsm_id]
|
||||
old_to = fsm_state
|
||||
fsm_old_to_new.setdefault(old_from, set()).add(i)
|
||||
fsm_old_to_new.setdefault(old_to, set()).add(j)
|
||||
fsm_transitions.add((i, j))
|
||||
if fsm_state in fsms[fsm_id].finals:
|
||||
fsm_finals.add(j)
|
||||
|
||||
i += 1
|
||||
|
||||
fsm = FSM(
|
||||
alphabet=alphabet,
|
||||
states=range(len(states)),
|
||||
initial=0,
|
||||
finals=finals,
|
||||
map=map,
|
||||
__no_validation__=True,
|
||||
)
|
||||
|
||||
fsm, old_to_new_states = make_deterministic_fsm(fsm)
|
||||
_fsms_to_trans_finals = {
|
||||
fsm_id: (
|
||||
{(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
|
||||
{old_to_new_states[s] for s in finals},
|
||||
{
|
||||
old_state: {old_to_new_states[new_state] for new_state in new_states}
|
||||
for old_state, new_states in old_to_new.items()
|
||||
},
|
||||
)
|
||||
for fsm_id, (transitions, finals, old_to_new) in sorted(
|
||||
fsms_to_trans_finals.items(), key=lambda x: x[0]
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
fsm,
|
||||
_fsms_to_trans_finals,
|
||||
)
|
||||
|
||||
|
||||
def get_sub_fsms_from_seq(
|
||||
state_seq: Sequence[int],
|
||||
fsms_to_trans_finals: Dict[
|
||||
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
|
||||
],
|
||||
) -> Generator[Tuple[int, bool, bool], None, None]:
|
||||
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state_seq
|
||||
A state sequence.
|
||||
fsms_to_trans_finals
|
||||
A map from FSM indices to tuples containing sets of their state transitions
|
||||
and sets of the final/accept states.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generator returning tuples containing each sub-FSM index (in the order
|
||||
they were union-ed to construct `fsm`) and booleans indicating whether or
|
||||
not there is another valid transition from the last state in the sequence
|
||||
for the associated sub-FSM (i.e. if the FSM can continue
|
||||
accepting/matching) and whether or not the sequence ends in a final state
|
||||
of the sub-FSM.
|
||||
"""
|
||||
state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
|
||||
last_fsm_state = state_seq[-1]
|
||||
yield from (
|
||||
(
|
||||
# The sub-FMS index
|
||||
fsm_idx,
|
||||
# Is there another possible transition in this sub-FSM?
|
||||
any(last_fsm_state == from_s for (from_s, to_s) in transitions),
|
||||
# Is this sub-FSM in a final state?
|
||||
state_seq[-1] in finals,
|
||||
)
|
||||
for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
|
||||
if state_seq_transitions.issubset(transitions)
|
||||
)
|
||||
|
||||
|
||||
@numba.njit(cache=True, nogil=True)
|
||||
def state_scan_tokens(
|
||||
fsm_transitions: Dict[Tuple[int, int], int],
|
||||
alphabet_symbol_mapping: Dict[str, int],
|
||||
alphabet_anything_value: int,
|
||||
fsm_initial: int,
|
||||
fsm_finals: Set[int],
|
||||
vocabulary: Dict[str, List[int]],
|
||||
start_state: int,
|
||||
) -> Set[Tuple[int, int]]:
|
||||
res = set()
|
||||
|
||||
for token, token_ids in vocabulary.items():
|
||||
state_seq = _walk_fsm(
|
||||
fsm_transitions,
|
||||
alphabet_symbol_mapping,
|
||||
alphabet_anything_value,
|
||||
fsm_initial,
|
||||
fsm_finals,
|
||||
token,
|
||||
start_state,
|
||||
False,
|
||||
)
|
||||
|
||||
if state_seq is not None and len(state_seq) < len(token):
|
||||
continue
|
||||
|
||||
for token_id in token_ids:
|
||||
res.add((token_id, state_seq[-1]))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def create_fsm_index_end_to_end(
|
||||
fsm_info: FSMInfo,
|
||||
vocabulary: Dict[str, List[int]],
|
||||
) -> Dict[int, Set[Tuple[int, int]]]:
|
||||
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""
|
||||
|
||||
# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
|
||||
# code, too.
|
||||
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {}
|
||||
seen: Set[int] = set()
|
||||
next_states = {fsm_info.initial}
|
||||
|
||||
while next_states:
|
||||
start_state = next_states.pop()
|
||||
|
||||
token_ids_end_states = state_scan_tokens(
|
||||
fsm_info.transitions,
|
||||
fsm_info.alphabet_symbol_mapping,
|
||||
fsm_info.alphabet_anything_value,
|
||||
fsm_info.initial,
|
||||
fsm_info.finals,
|
||||
vocabulary,
|
||||
start_state,
|
||||
)
|
||||
|
||||
for token_id_and_end_state in token_ids_end_states:
|
||||
states_to_token_subsets.setdefault(start_state, set()).add(
|
||||
token_id_and_end_state
|
||||
)
|
||||
end_state = token_id_and_end_state[1]
|
||||
if end_state not in seen:
|
||||
next_states.add(end_state)
|
||||
|
||||
seen.add(start_state)
|
||||
|
||||
return states_to_token_subsets
|
||||
|
||||
|
||||
# TODO: Cannot cache typed collections to disk, yet. See
|
||||
# https://github.com/numba/numba/issues/4698
|
||||
@lru_cache
|
||||
def reduced_vocabulary(tokenizer: "Tokenizer"):
|
||||
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
|
||||
vocabulary = numba.typed.Dict.empty(
|
||||
numba.types.string, numba.types.ListType(numba.int64)
|
||||
)
|
||||
empty_token_ids = set()
|
||||
for token, token_idx in tokenizer.vocabulary.items():
|
||||
if token in tokenizer.special_tokens:
|
||||
continue
|
||||
|
||||
token_str = tokenizer.convert_token_to_string(token)
|
||||
|
||||
if token_str:
|
||||
vocabulary.setdefault(
|
||||
token_str,
|
||||
numba.typed.List.empty_list(numba.int64),
|
||||
).append(numba.int64(token_idx))
|
||||
else:
|
||||
empty_token_ids.add(numba.int64(token_idx))
|
||||
|
||||
return vocabulary, empty_token_ids
|
||||
|
||||
|
||||
def create_fsm_index_tokenizer(
|
||||
fsm: BetterFSM,
|
||||
tokenizer: "Tokenizer",
|
||||
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
|
||||
"""Construct an FMS index from a tokenizer.
|
||||
|
||||
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
|
||||
|
||||
.. warning::
|
||||
|
||||
`fsm` needs to be deterministically ordered so that future caching makes sense.
|
||||
|
||||
"""
|
||||
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)
|
||||
|
||||
states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)
|
||||
|
||||
# Allow transitions to EOS from all terminals FSM states that are
|
||||
# reachable
|
||||
# TODO: Do we really need this anymore?
|
||||
for state in fsm.fsm_info.finals:
|
||||
subset = states_to_token_subsets.get(state)
|
||||
if subset is not None:
|
||||
subset.add((tokenizer.eos_token_id, state))
|
||||
|
||||
# Convert to token-to-end-state maps
|
||||
states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()}
|
||||
|
||||
return states_to_token_subsets, empty_token_ids
|
||||
266
python/sglang/srt/constrained/tokenizer.py
Normal file
266
python/sglang/srt/constrained/tokenizer.py
Normal file
@@ -0,0 +1,266 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
|
||||
from abc import abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Hashable,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
class Tokenizer(Protocol, Hashable):
|
||||
eos_token: str
|
||||
eos_token_id: int
|
||||
pad_token_id: int
|
||||
vocabulary: Dict[str, int]
|
||||
special_tokens: Set[int]
|
||||
|
||||
@abstractmethod
|
||||
def encode(
|
||||
self, prompt: Union[str, List[str]]
|
||||
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
|
||||
"""Translate the input prompts into NumPy arrays of token ids and attention mask."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
|
||||
"""Translate an array of token ids to a string or list of strings."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def convert_token_to_string(self, token: str) -> str:
|
||||
"""Convert a token to its equivalent string.
|
||||
|
||||
This is for instance useful for BPE tokenizers where whitespaces are
|
||||
represented by the special characted `Ġ`. This prevents matching a raw
|
||||
token that includes `Ġ` with a string.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
__all__ = ["transformers"]
|
||||
|
||||
|
||||
KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
|
||||
|
||||
|
||||
def get_llama_tokenizer_types():
|
||||
"""Get all the Llama tokenizer types/classes that need work-arounds.
|
||||
|
||||
When they can't be imported, a dummy class is created.
|
||||
|
||||
"""
|
||||
try:
|
||||
from transformers.models.llama import LlamaTokenizer
|
||||
except ImportError:
|
||||
|
||||
class LlamaTokenizer: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
except ImportError:
|
||||
|
||||
class LlamaTokenizerFast: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.code_llama import CodeLlamaTokenizer
|
||||
except ImportError:
|
||||
|
||||
class CodeLlamaTokenizer: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.code_llama import CodeLlamaTokenizerFast
|
||||
except ImportError:
|
||||
|
||||
class CodeLlamaTokenizerFast: # type: ignore
|
||||
pass
|
||||
|
||||
return (
|
||||
LlamaTokenizer,
|
||||
LlamaTokenizerFast,
|
||||
CodeLlamaTokenizer,
|
||||
CodeLlamaTokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
class Transformer:
|
||||
"""Represents a `transformers` model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
):
|
||||
self.device = model.device
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@torch.inference_mode
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values: Optional[Tuple] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
|
||||
"""Compute a forward pass through the transformer model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_ids
|
||||
The input token ids. Must be one or two dimensional.
|
||||
attention_mask
|
||||
The attention mask. Must be one or two dimensional.
|
||||
past_key_values
|
||||
A tuple of tuples containing the cached key and value tensors for each
|
||||
attention head.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The computed logits and the new cached key and value tensors.
|
||||
|
||||
"""
|
||||
assert 0 < input_ids.ndim < 3
|
||||
|
||||
if past_key_values:
|
||||
input_ids = input_ids[..., -1].unsqueeze(-1)
|
||||
|
||||
output = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
return output.logits, output.past_key_values
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values: Optional[Tuple] = None,
|
||||
) -> torch.FloatTensor:
|
||||
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
|
||||
next_token_logits = logits[..., -1, :]
|
||||
|
||||
return next_token_logits, kv_cache
|
||||
|
||||
|
||||
class TransformerTokenizer(Tokenizer):
|
||||
"""Represents a tokenizer for models in the `transformers` library."""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
# TODO: Do something to make this hashable?
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.eos_token = self.tokenizer.eos_token
|
||||
|
||||
if not self.tokenizer.pad_token_id:
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||
self.pad_token_id = self.eos_token_id
|
||||
else:
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.pad_token = self.tokenizer.pad_token
|
||||
|
||||
self.special_tokens = set(self.tokenizer.all_special_tokens)
|
||||
|
||||
self.vocabulary = self.tokenizer.get_vocab()
|
||||
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())
|
||||
|
||||
def encode(
|
||||
self, prompt: Union[str, List[str]], **kwargs
|
||||
) -> Tuple[torch.LongTensor, torch.LongTensor]:
|
||||
kwargs["padding"] = True
|
||||
kwargs["return_tensors"] = "pt"
|
||||
output = self.tokenizer(prompt, **kwargs)
|
||||
return output["input_ids"], output["attention_mask"]
|
||||
|
||||
def decode(self, token_ids: torch.LongTensor) -> List[str]:
|
||||
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return text
|
||||
|
||||
def convert_token_to_string(self, token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = self.tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
if self.is_llama:
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, type(self)):
|
||||
return False
|
||||
# TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ?
|
||||
# return other.model_name == self.model_name and other.kwargs == self.kwargs
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
from datasets.fingerprint import Hasher
|
||||
|
||||
return hash(Hasher.hash(self.tokenizer))
|
||||
|
||||
|
||||
def transformers(
|
||||
model_name: str,
|
||||
device: Optional[str] = None,
|
||||
model_kwargs: dict = {},
|
||||
tokenizer_kwargs: dict = {},
|
||||
):
|
||||
"""Instantiate a model from the `transformers` library and its tokenizer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name
|
||||
The name of the model as listed on Hugging Face's model page.
|
||||
device
|
||||
The device(s) on which the model should be loaded. This overrides
|
||||
the `device_map` entry in `model_kwargs` when provided.
|
||||
model_kwargs
|
||||
A dictionary that contains the keyword arguments to pass to the
|
||||
`from_pretrained` method when loading the model.
|
||||
tokenizer_kwargs
|
||||
A dictionary that contains the keyword arguments to pass to the
|
||||
`from_pretrained` method when loading the tokenizer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A `TransformersModel` model instance.
|
||||
|
||||
"""
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The `transformers` library needs to be installed in order to use `transformers` models."
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
model_kwargs["device_map"] = device
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
||||
tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs)
|
||||
|
||||
return Transformer(model, tokenizer)
|
||||
164
python/sglang/srt/hf_transformers_utils.py
Normal file
164
python/sglang/srt/hf_transformers_utils.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Utilities for Huggingface Transformers."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from sglang.srt.utils import is_multimodal_model
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
def download_from_hf(model_path: str):
|
||||
if os.path.exists(model_path):
|
||||
return model_path
|
||||
|
||||
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
|
||||
|
||||
|
||||
def get_config_json(model_path: str):
|
||||
with open(os.path.join(model_path, "config.json")) as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
|
||||
|
||||
def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None):
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
# Models don't use the same configuration key for determining the maximum
|
||||
# context length. Store them here so we can sanely check them.
|
||||
# NOTE: The ordering here is important. Some models have two of these and we
|
||||
# have a preference for which value gets used.
|
||||
CONTEXT_LENGTH_KEYS = [
|
||||
"max_sequence_length",
|
||||
"seq_length",
|
||||
"max_position_embeddings",
|
||||
"max_seq_len",
|
||||
"model_max_length",
|
||||
]
|
||||
|
||||
|
||||
def get_context_length(config):
|
||||
"""Get the context length of a model from a huggingface model config."""
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling:
|
||||
rope_scaling_factor = config.rope_scaling["factor"]
|
||||
else:
|
||||
rope_scaling_factor = 1
|
||||
|
||||
for key in CONTEXT_LENGTH_KEYS:
|
||||
val = getattr(config, key, None)
|
||||
if val is not None:
|
||||
return int(rope_scaling_factor * val)
|
||||
return 2048
|
||||
|
||||
|
||||
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
|
||||
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
tokenizer_name: str,
|
||||
*args,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||
if is_multimodal_model(tokenizer_name):
|
||||
processor = get_processor(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
**kwargs,
|
||||
)
|
||||
tokenizer = processor.tokenizer
|
||||
return tokenizer
|
||||
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
kwargs["use_fast"] = False
|
||||
|
||||
if (
|
||||
"llama" in tokenizer_name.lower()
|
||||
and kwargs.get("use_fast", True)
|
||||
and tokenizer_name != _FAST_LLAMA_TOKENIZER
|
||||
):
|
||||
pass
|
||||
# warnings.warn(
|
||||
# "For some LLaMA V1 models, initializing the fast tokenizer may "
|
||||
# "take a long time. To reduce the initialization time, consider "
|
||||
# f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
||||
# "tokenizer."
|
||||
# )
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
**kwargs,
|
||||
)
|
||||
except TypeError as e:
|
||||
# The LLaMA tokenizer causes a protobuf error in some environments.
|
||||
err_msg = (
|
||||
"Failed to load the tokenizer. If you are using a LLaMA V1 model "
|
||||
f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
|
||||
"original tokenizer."
|
||||
)
|
||||
raise RuntimeError(err_msg) from e
|
||||
except ValueError as e:
|
||||
# If the error pertains to the tokenizer class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
if not trust_remote_code and (
|
||||
"does not exist or is not currently imported." in str(e)
|
||||
or "requires you to execute the tokenizer file" in str(e)
|
||||
):
|
||||
err_msg = (
|
||||
"Failed to load the tokenizer. If the tokenizer is a custom "
|
||||
"tokenizer not yet available in the HuggingFace transformers "
|
||||
"library, consider setting `trust_remote_code=True` in LLM "
|
||||
"or using the `--trust-remote-code` flag in the CLI."
|
||||
)
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
warnings.warn(
|
||||
"Using a slow tokenizer. This might cause a significant "
|
||||
"slowdown. Consider using a fast tokenizer instead."
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_processor(
|
||||
tokenizer_name: str,
|
||||
*args,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
**kwargs,
|
||||
)
|
||||
return processor
|
||||
181
python/sglang/srt/layers/context_flashattention_nopad.py
Normal file
181
python/sglang/srt/layers/context_flashattention_nopad.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Out,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
|
||||
|
||||
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
|
||||
|
||||
cached_kernel = None
|
||||
|
||||
|
||||
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
BLOCK = 128
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
global cached_kernel
|
||||
if cached_kernel:
|
||||
cached_kernel(
|
||||
grid,
|
||||
num_warps,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user