Json Decode && Mutl-Turns (#4)
This commit is contained in:
2
3rdparty/flashinfer
vendored
2
3rdparty/flashinfer
vendored
Submodule 3rdparty/flashinfer updated: 00cf5f46fd...88b9496e1a
61
benchmark/json_regex_decode/README.md
Normal file
61
benchmark/json_regex_decode/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
## Run benchmark
|
||||
|
||||
### Build dataset
|
||||
```
|
||||
pip install wikipedia
|
||||
python3 build_dataset.py
|
||||
```
|
||||
|
||||
### Dependencies
|
||||
|
||||
```
|
||||
llama_cpp_python 0.2.19
|
||||
guidance 0.1.10
|
||||
vllm 0.2.5
|
||||
outlines 0.0.22
|
||||
```
|
||||
|
||||
### Benchmark sglang
|
||||
|
||||
Run llama-7b
|
||||
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
Run mixtral-8x7b
|
||||
(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`)
|
||||
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
|
||||
```
|
||||
|
||||
Benchmark
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 10
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
|
||||
Run llama-7b
|
||||
|
||||
```
|
||||
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
Benchmark
|
||||
|
||||
```
|
||||
python3 bench_other.py --backend vllm --num-questions 10
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
|
||||
Run llama-7b and benchmark
|
||||
|
||||
```
|
||||
python3 bench_other.py --backend guidance --num-questions 10 --parallel 1
|
||||
```
|
||||
125
benchmark/json_regex_decode/bench_other.py
Normal file
125
benchmark/json_regex_decode/bench_other.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
add_common_other_args_and_parse,
|
||||
call_generate_outlines,
|
||||
)
|
||||
from sglang.utils import dump_state_text, read_jsonl
|
||||
from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT
|
||||
from tqdm import tqdm
|
||||
|
||||
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
|
||||
|
||||
|
||||
# fmt: off
|
||||
def json_decode(document, generate):
|
||||
s = "Please extract the information of a city from the following wikipedia page.\n"
|
||||
s += "Page begin.\n" + document + "Page end.\n"
|
||||
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||
s += "{\n"
|
||||
s += ' "name": '
|
||||
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||
s += ' "country": '
|
||||
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||
s += ' "latitude": '
|
||||
s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||
s += ' "population": '
|
||||
s += generate(s, max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||
s += ' "top 3 landmarks": '
|
||||
s += generate(s, max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||
s += "}\n"
|
||||
|
||||
return s
|
||||
# fmt: on
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
arguments = []
|
||||
for i in range(len(lines[: args.num_questions])):
|
||||
arguments.append(
|
||||
{
|
||||
"document": lines[i]["document"],
|
||||
}
|
||||
)
|
||||
states = [None] * len(arguments)
|
||||
|
||||
# Select backend
|
||||
if args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
generate = partial(call_generate_outlines, url=url, temperature=0)
|
||||
elif args.backend == "guidance":
|
||||
from guidance import gen, models
|
||||
|
||||
model = models.LlamaCpp(
|
||||
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
|
||||
n_gpu_layers=-1,
|
||||
n_ctx=4096,
|
||||
)
|
||||
|
||||
def generate(prompt, max_tokens, stop=None, regex=None):
|
||||
out = (
|
||||
model
|
||||
+ prompt
|
||||
+ gen(
|
||||
name="answer",
|
||||
max_tokens=max_tokens,
|
||||
temperature=0,
|
||||
stop=stop,
|
||||
regex=regex,
|
||||
)
|
||||
)
|
||||
return out["answer"]
|
||||
|
||||
# warmup
|
||||
for _ in range(3):
|
||||
generate("Hello!" * 10, max_tokens=64, stop=None)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
# Run requests
|
||||
def get_one_answer(i):
|
||||
states[i] = json_decode(generate=generate, **arguments[i])
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(arguments))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
rets = executor.map(get_one_answer, list(range(len(arguments))))
|
||||
for _ in rets:
|
||||
pass
|
||||
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "json_regex_decode",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=20)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
100
benchmark/json_regex_decode/bench_sglang.py
Normal file
100
benchmark/json_regex_decode/bench_sglang.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import dump_state_text, read_jsonl
|
||||
|
||||
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
def json_warm_up(s):
|
||||
s += "The information about Hogwarts is in the following JSON format.\n"
|
||||
with s.var_scope("json_output"):
|
||||
s += "{\n"
|
||||
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||
s += "}\n"
|
||||
print(f'The warmp up json result is:\n{s["json_output"]}')
|
||||
# fmt: on
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
def json_decode(s, document):
|
||||
s += "Please extract the information of a city from the following wikipedia page.\n"
|
||||
s += "Page begin.\n" + document + "Page end.\n"
|
||||
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||
with s.var_scope("json_output"):
|
||||
s += "{\n"
|
||||
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||
s += "}\n"
|
||||
# fmt: on
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
arguments = []
|
||||
for i in range(len(lines[: args.num_questions])):
|
||||
arguments.append(
|
||||
{
|
||||
"document": lines[i]["document"],
|
||||
}
|
||||
)
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
# Warm up
|
||||
json_warm_up.run().sync()
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
states = json_decode.run_batch(arguments, temperature=0, num_threads=args.parallel)
|
||||
for state in states:
|
||||
state.sync()
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(f"tmp_{args.backend}_json_results.txt", "w") as fout:
|
||||
for state in states:
|
||||
fout.write(state["json_output"] + "\n")
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "json_regex_decode",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=20)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
58
benchmark/json_regex_decode/build_dataset.py
Normal file
58
benchmark/json_regex_decode/build_dataset.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
|
||||
import transformers
|
||||
import wikipedia
|
||||
|
||||
model_path = "meta-llama/Llama-2-7b-chat-hf"
|
||||
t = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
city_names = [
|
||||
"los angles",
|
||||
"london",
|
||||
"tokyo",
|
||||
"beijing",
|
||||
"singapore",
|
||||
"paris",
|
||||
"dubai",
|
||||
"sydney",
|
||||
"moscow",
|
||||
"rome",
|
||||
"toronto",
|
||||
"rio de janeiro",
|
||||
"istanbul",
|
||||
"berlin",
|
||||
"auckland",
|
||||
"buenos aires",
|
||||
"mexico city",
|
||||
"mumbai",
|
||||
"seoul",
|
||||
"bangkok",
|
||||
"cairo",
|
||||
"athens",
|
||||
"jerusalem",
|
||||
]
|
||||
|
||||
|
||||
def get_content(city_name):
|
||||
content = str(wikipedia.page(city_name).content)
|
||||
content = content.replace("\n\n", "\n")
|
||||
|
||||
tokens = t.encode(content)
|
||||
|
||||
expected_tokens = 3000
|
||||
truncate_len = int((expected_tokens / len(tokens)) * len(content))
|
||||
truncate_content = content[:truncate_len]
|
||||
truncate_tokens = t.encode(truncate_content)
|
||||
|
||||
# Count token
|
||||
print(
|
||||
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
|
||||
)
|
||||
|
||||
return truncate_content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("questions.jsonl", "w") as fout:
|
||||
for city_name in city_names:
|
||||
truncate_content = get_content(city_name)
|
||||
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
||||
66
benchmark/multi_turns/README.md
Normal file
66
benchmark/multi_turns/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
### Benchmark sglang
|
||||
|
||||
Run llama-7b
|
||||
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
Run mixtral-8x7b
|
||||
(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`)
|
||||
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
|
||||
```
|
||||
|
||||
Benchmark(short output)
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf
|
||||
```
|
||||
|
||||
Benchmark(long output)
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf --long
|
||||
```
|
||||
|
||||
### Benchmark vLLM
|
||||
|
||||
Run llama-7b
|
||||
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
Run mixtral-8x7b
|
||||
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model mistralai/Mixtral-8x7B-Instruct-v0.1 --disable-log-requests --port 21000 --tensor-parallel-size 8
|
||||
```
|
||||
|
||||
Benchmark(short output)
|
||||
|
||||
```
|
||||
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm
|
||||
```
|
||||
|
||||
Benchmark(long output)
|
||||
|
||||
```
|
||||
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm --long
|
||||
```
|
||||
|
||||
### Benchmark guidance
|
||||
|
||||
Benchmark llama-7b(short output)
|
||||
|
||||
```
|
||||
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1
|
||||
```
|
||||
|
||||
Benchmark llama-7b(long output)
|
||||
|
||||
```
|
||||
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --long
|
||||
```
|
||||
133
benchmark/multi_turns/bench_other.py
Normal file
133
benchmark/multi_turns/bench_other.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import json
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse
|
||||
from sglang.utils import dump_state_text
|
||||
from tqdm import tqdm
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from data_gen import gen_arguments
|
||||
|
||||
|
||||
def get_generate(args):
|
||||
# Select backend
|
||||
if args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
|
||||
def generate(prompt, max_tokens, stop=None, temperature=0, url=url, n=1):
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"ignore_eos": True,
|
||||
"stop": stop,
|
||||
"stream": False,
|
||||
"n": n,
|
||||
}
|
||||
res = requests.post(url, json=data)
|
||||
assert res.status_code == 200
|
||||
return res.json()["text"][0][len(prompt) :]
|
||||
|
||||
elif args.backend == "guidance":
|
||||
from guidance import gen, models
|
||||
|
||||
model = models.LlamaCpp(
|
||||
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
|
||||
n_gpu_layers=-1,
|
||||
n_ctx=4096,
|
||||
)
|
||||
|
||||
def generate(prompt, max_tokens, stop=None):
|
||||
out = (
|
||||
model
|
||||
+ prompt
|
||||
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
|
||||
)
|
||||
return out["answer"]
|
||||
|
||||
# warmup
|
||||
for _ in range(3):
|
||||
generate("Hello!" * 10, max_tokens=64, stop=None)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
return generate
|
||||
|
||||
|
||||
def multi_turns(generate, qas):
|
||||
s = ""
|
||||
for qa in qas:
|
||||
s += qa["prompt"]
|
||||
s += generate(s, max_tokens=qa["new_tokens"])
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def main(args):
|
||||
print(args)
|
||||
|
||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
|
||||
multi_qas = gen_arguments(args, tokenizer)
|
||||
|
||||
states = [None] * args.num_qa
|
||||
|
||||
generate = get_generate(args)
|
||||
|
||||
def get_one_answer(i):
|
||||
states[i] = multi_turns(generate=generate, **multi_qas[i])
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(multi_qas))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
rets = executor.map(get_one_answer, list(range(len(multi_qas))))
|
||||
for _ in rets:
|
||||
pass
|
||||
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "multi_turns",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_qa,
|
||||
"num_turns": args.turns,
|
||||
"other": {
|
||||
"parallel": args.parallel,
|
||||
"output_mode": "long" if args.long else "short",
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--turns", type=int, default=4)
|
||||
parser.add_argument("--num-qa", type=int, default=20)
|
||||
parser.add_argument("--min-len-q", type=int, default=256)
|
||||
parser.add_argument("--max-len-q", type=int, default=512)
|
||||
parser.add_argument("--min-len-a", type=int, default=4)
|
||||
parser.add_argument("--max-len-a", type=int, default=8)
|
||||
parser.add_argument("--tokenizer", type=str, required=True)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--long", action="store_true")
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
|
||||
if args.long:
|
||||
args.min_len_a = 256
|
||||
args.max_len_a = 512
|
||||
args.num_qa = 20
|
||||
main(args)
|
||||
77
benchmark/multi_turns/bench_sglang.py
Normal file
77
benchmark/multi_turns/bench_sglang.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import json
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import dump_state_text
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from data_gen import gen_arguments
|
||||
|
||||
|
||||
@sgl.function
|
||||
def multi_turns(s, qas):
|
||||
for qa in qas:
|
||||
s += qa["prompt"]
|
||||
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
|
||||
|
||||
|
||||
def main(args):
|
||||
print(args)
|
||||
|
||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
|
||||
multi_qas = gen_arguments(args, tokenizer)
|
||||
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
tic = time.time()
|
||||
states = multi_turns.run_batch(
|
||||
multi_qas, temperature=0, backend=backend, num_threads=args.parallel
|
||||
)
|
||||
for state in states:
|
||||
state.sync()
|
||||
latency = time.time() - tic
|
||||
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "multi_turns",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_qa,
|
||||
"num_turns": args.turns,
|
||||
"other": {
|
||||
"parallel": args.parallel,
|
||||
"output_mode": "long" if args.long else "short",
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--turns", type=int, default=4)
|
||||
parser.add_argument("--num-qa", type=int, default=20)
|
||||
parser.add_argument("--min-len-q", type=int, default=256)
|
||||
parser.add_argument("--max-len-q", type=int, default=512)
|
||||
parser.add_argument("--min-len-a", type=int, default=4)
|
||||
parser.add_argument("--max-len-a", type=int, default=8)
|
||||
parser.add_argument("--tokenizer", type=str, required=True)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--long", action="store_true")
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
|
||||
if args.long:
|
||||
args.min_len_a = 256
|
||||
args.max_len_a = 512
|
||||
args.num_qa = 20
|
||||
main(args)
|
||||
29
benchmark/multi_turns/data_gen.py
Normal file
29
benchmark/multi_turns/data_gen.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
||||
def gen_prompt(tokenizer, token_num):
|
||||
cha_set = string.ascii_letters + string.digits
|
||||
ret = "".join(random.choices(cha_set, k=token_num))
|
||||
while len(tokenizer(ret).input_ids) < token_num:
|
||||
ret += random.choice(cha_set)
|
||||
return ret
|
||||
|
||||
|
||||
def gen_arguments(args, tokenizer):
|
||||
multi_qas = [{"qas": []} for _ in range(args.num_qa)]
|
||||
for i in range(args.num_qa):
|
||||
qas = multi_qas[i]["qas"]
|
||||
for _ in range(args.turns):
|
||||
prompt_len = random.randint(args.min_len_q, args.max_len_q)
|
||||
new_tokens = random.randint(args.min_len_a, args.max_len_a)
|
||||
qas.append(
|
||||
{
|
||||
"prompt": gen_prompt(tokenizer, prompt_len),
|
||||
"new_tokens": new_tokens,
|
||||
}
|
||||
)
|
||||
|
||||
return multi_qas
|
||||
@@ -37,6 +37,7 @@ def gen(
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
dtype: Optional[type] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
regex: Optional[str] = None,
|
||||
@@ -60,6 +61,7 @@ def gen(
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
dtype,
|
||||
regex,
|
||||
)
|
||||
@@ -74,6 +76,7 @@ def gen_int(
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
@@ -84,6 +87,7 @@ def gen_int(
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
int,
|
||||
None,
|
||||
)
|
||||
@@ -98,6 +102,7 @@ def gen_string(
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
@@ -108,6 +113,7 @@ def gen_string(
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
@@ -28,7 +28,7 @@ class Anthropic(BaseBackend):
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
prompt = s.text_
|
||||
ret = anthropic.Anthropic().completions.create(
|
||||
@@ -43,7 +43,7 @@ class Anthropic(BaseBackend):
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
prompt = s.text_
|
||||
generator = anthropic.Anthropic().completions.create(
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Union
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
@@ -48,14 +48,14 @@ class BaseBackend:
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import openai
|
||||
@@ -73,7 +73,7 @@ class OpenAI(BaseBackend):
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
@@ -122,7 +122,7 @@ class OpenAI(BaseBackend):
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
|
||||
@@ -7,7 +7,7 @@ from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams, SglArgument
|
||||
from sglang.lang.ir import SglSamplingParams, SglArgument
|
||||
from sglang.utils import encode_image_base64, find_printable_text, http_request
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
@@ -87,7 +87,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import List, Optional, Union
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
from sglang.utils import http_request
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ class TGI(BaseBackend):
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
decision = self.retry_for_expected(
|
||||
s.text_,
|
||||
@@ -152,7 +152,7 @@ class TGI(BaseBackend):
|
||||
s: StreamExecutor,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
if dtype is None:
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import List, Union
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
|
||||
from sglang.lang.ir import (
|
||||
SamplingParams,
|
||||
SglSamplingParams,
|
||||
SglArgument,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
@@ -140,7 +140,7 @@ class CompiledFunction:
|
||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||
kwargs.update(self.function.bind_arguments)
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
@@ -173,7 +173,7 @@ class CompiledFunction:
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -292,7 +292,7 @@ class StreamExecutor:
|
||||
|
||||
assert isinstance(other, SglExpr), f"{other}"
|
||||
|
||||
if isinstance(other, (SglConstantText, SglArgument)):
|
||||
if isinstance(other, SglConstantText):
|
||||
self._execute_fill(other.value)
|
||||
elif isinstance(other, SglGen):
|
||||
self._execute_gen(other)
|
||||
@@ -332,8 +332,6 @@ class StreamExecutor:
|
||||
|
||||
def _execute_image(self, expr: SglImage):
|
||||
path = expr.path
|
||||
if isinstance(path, SglArgument):
|
||||
path = path.value
|
||||
|
||||
base64_data = encode_image_base64(path)
|
||||
|
||||
@@ -419,7 +417,7 @@ class StreamExecutor:
|
||||
"role": expr.role,
|
||||
"content": [{"type": "text", "text": new_text}],
|
||||
}
|
||||
for (image_path, image_base64_data) in self.cur_images:
|
||||
for image_path, image_base64_data in self.cur_images:
|
||||
last_msg["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
@@ -480,6 +478,7 @@ class StreamExecutor:
|
||||
"top_k",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"ignore_eos",
|
||||
"dtype",
|
||||
"regex",
|
||||
]:
|
||||
|
||||
@@ -13,7 +13,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingParams:
|
||||
class SglSamplingParams:
|
||||
max_new_tokens: int = 16
|
||||
stop: Union[str, List[str]] = ()
|
||||
temperature: float = 1.0
|
||||
@@ -21,13 +21,14 @@ class SamplingParams:
|
||||
top_k: int = -1 # -1 means disable
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
ignore_eos: bool = False
|
||||
|
||||
# for constrained generation, not included in to_xxx_kwargs
|
||||
dtype: Optional[str] = None
|
||||
regex: Optional[str] = None
|
||||
|
||||
def clone(self):
|
||||
return SamplingParams(
|
||||
return SglSamplingParams(
|
||||
self.max_new_tokens,
|
||||
self.stop,
|
||||
self.temperature,
|
||||
@@ -67,6 +68,7 @@ class SamplingParams:
|
||||
"top_k": self.top_k,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"regex": self.regex,
|
||||
}
|
||||
|
||||
@@ -98,13 +100,14 @@ class SglFunction:
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
stream: bool = False,
|
||||
backend=None,
|
||||
**kwargs,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
@@ -112,9 +115,9 @@ class SglFunction:
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
||||
|
||||
def run_batch(
|
||||
@@ -128,6 +131,7 @@ class SglFunction:
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
progress_bar: bool = False,
|
||||
@@ -139,7 +143,7 @@ class SglFunction:
|
||||
return []
|
||||
assert isinstance(batch_kwargs[0], dict)
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
@@ -147,11 +151,9 @@ class SglFunction:
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
batch_kwargs = [
|
||||
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
|
||||
]
|
||||
return run_program_batch(
|
||||
self,
|
||||
backend,
|
||||
@@ -321,12 +323,13 @@ class SglGen(SglExpr):
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
dtype,
|
||||
regex,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.sampling_params = SamplingParams(
|
||||
self.sampling_params = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
@@ -334,6 +337,7 @@ class SglGen(SglExpr):
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
dtype=dtype,
|
||||
regex=regex,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,8 @@ def extract_prefix_by_tracing(program, backend):
|
||||
try:
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
except StopTracing:
|
||||
except (StopTracing, TypeError):
|
||||
# Some exceptions may not be catched
|
||||
pass
|
||||
|
||||
# Run and cache prefix
|
||||
|
||||
12
python/sglang/srt/backend_config.py
Normal file
12
python/sglang/srt/backend_config.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Backend configurations, may vary with different serving platforms.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
extend_dependency_time: float = 0.03
|
||||
|
||||
|
||||
GLOBAL_BACKEND_CONFIG = BackendConfig()
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
@@ -8,6 +7,7 @@ import zmq.asyncio
|
||||
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_exception_traceback
|
||||
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -28,6 +28,9 @@ class RouterManager:
|
||||
self.model_client = model_client
|
||||
self.recv_reqs = []
|
||||
|
||||
# Init Some Configs
|
||||
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
|
||||
|
||||
async def loop_for_forward(self):
|
||||
while True:
|
||||
next_step_input = list(self.recv_reqs)
|
||||
@@ -37,7 +40,12 @@ class RouterManager:
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
# await for a while to accept input requests
|
||||
# async sleep for recving the subsequent request, and avoiding cache miss
|
||||
if len(out_pyobjs) != 0:
|
||||
has_finished = any([obj.finished for obj in out_pyobjs])
|
||||
if has_finished:
|
||||
await asyncio.sleep(self.extend_dependency_time)
|
||||
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
async def loop_for_recv_requests(self):
|
||||
|
||||
@@ -19,7 +19,6 @@ from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
from sglang.srt.managers.router.scheduler import Scheduler
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_exception_traceback,
|
||||
@@ -158,6 +157,18 @@ class ModelRpcServer(rpyc.Service):
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
break
|
||||
else:
|
||||
# check the available size
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
if available_size != self.max_total_num_token:
|
||||
logger.warning(
|
||||
"Warning: "
|
||||
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
|
||||
"KV cache pool leak detected!"
|
||||
)
|
||||
|
||||
if self.running_batch is not None and self.tp_rank == 0:
|
||||
if self.decode_forward_ct >= 20:
|
||||
@@ -408,7 +419,9 @@ class ModelRpcServer(rpyc.Service):
|
||||
token_ids = tuple(req.input_ids + req.output_ids)
|
||||
seq_len = len(token_ids) - 1
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
|
||||
prefix_len = self.tree_cache.insert(token_ids, indices.clone())
|
||||
prefix_len = self.tree_cache.insert(
|
||||
token_ids[:seq_len], indices.clone()
|
||||
)
|
||||
|
||||
self.token_to_kv_pool.free(indices[:prefix_len])
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
|
||||
@@ -18,7 +18,7 @@ class Scheduler:
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
def new_token_estimation_ratio(self):
|
||||
return 0.4 if self.schedule_heuristic != "fcfs" else 0.5
|
||||
return 0.5 if self.schedule_heuristic != "fcfs" else 0.6
|
||||
|
||||
def get_priority_queue(self, forward_queue):
|
||||
if self.schedule_heuristic == "lpm":
|
||||
|
||||
@@ -7,13 +7,13 @@ _SAMPLING_EPS = 1e-6
|
||||
class SamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_new_tokens: int = 16,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
dtype: Optional[str] = None,
|
||||
|
||||
@@ -24,6 +24,8 @@ class ServerArgs:
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
if self.tp_size > 1:
|
||||
self.mem_fraction_static = 0.8
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
|
||||
@@ -38,6 +38,26 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
|
||||
return pred
|
||||
|
||||
|
||||
def call_generate_outlines(
|
||||
prompt, temperature, max_tokens, url, stop=[], regex=None, n=1
|
||||
):
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stop": stop,
|
||||
"regex": regex,
|
||||
"n": n,
|
||||
}
|
||||
res = requests.post(url, json=data)
|
||||
assert res.status_code == 200
|
||||
if n == 1:
|
||||
pred = res.json()["text"][0][len(prompt) :]
|
||||
else:
|
||||
pred = [x[len(prompt) :] for x in res.json()["text"]]
|
||||
return pred
|
||||
|
||||
|
||||
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
|
||||
data = {
|
||||
"text": prompt,
|
||||
|
||||
@@ -67,7 +67,7 @@ def dump_state_text(filename, states, mode="w"):
|
||||
if isinstance(s, str):
|
||||
pass
|
||||
elif isinstance(s, ProgramState):
|
||||
s = s.text().strip()
|
||||
s = s.text()
|
||||
else:
|
||||
s = str(s)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user