sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
308
scripts/playground/bench_speculative.py
Normal file
308
scripts/playground/bench_speculative.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Usage:
|
||||
# single GPU
|
||||
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B
|
||||
|
||||
# multiple GPU
|
||||
python3 bench_speculative.py --model-path deepseek-ai/DeepSeek-V3 --speculative-draft-model-path lmsys/DeepSeek-V3-NextN --tp-size 8 --trust-remote-code --batch-size 1 4 8 16 32 --steps 0 1 2 --topk 0 1 2 4 --num_draft_tokens 0 2 4 8
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from sglang.bench_serving import (
|
||||
DatasetRow,
|
||||
benchmark,
|
||||
sample_mmmu_requests,
|
||||
set_global_args,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
kill_process_tree,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
def node0_print(msg):
|
||||
if server_args.node_rank == 0:
|
||||
print(msg)
|
||||
|
||||
|
||||
prompts = [
|
||||
"Human: Give me a fully functional FastAPI server. Show the full, long python code without stop.\n\nAssistant:",
|
||||
"Human: Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation.\n\nAssistant:",
|
||||
"Human: Write a travel blog post to Hawaii.\n\nAssistant:",
|
||||
"Human: I want you to act as an English translator, spelling corrector and improver. I will speak to you in any language and you will detect the language, translate it and answer in the corrected and improved version of my text, in English. I want you to replace my simplified A0-level words and sentences with more beautiful and elegant, upper level English words and sentences. Keep the meaning same, but make them more literary. My first sentence is 'istanbulu cok seviyom burada olmak cok guzel'. Answer in more than 5000 words.\n\nAssistant:",
|
||||
"Human: I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it’s children then you can talk about animals; If it’s adults then history-based tales might engage them better etc. Answer in more than 5000 words. My first request is 'I need an interesting story on perseverance.'\n\nAssistant:",
|
||||
"Human: Solve x^2 = -1. Think step-by-step. Give me a long detailed explanation. \n\nAssistant:",
|
||||
"Human: Tell me about the president of the USA in wikipedia style.\n\nAssistant:",
|
||||
"Human: Hello? Who are you? Write code, math, and poem to explanin yourself.\n\nAssistant:",
|
||||
]
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def encode(self, text: str, add_special_tokens: bool = False):
|
||||
return []
|
||||
|
||||
|
||||
def send_one_batch(base_url, num_prompts, batch_size, tokenizer, is_multimodal):
|
||||
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
|
||||
if is_multimodal:
|
||||
input_requests = sample_mmmu_requests(
|
||||
num_prompts,
|
||||
tokenizer,
|
||||
512,
|
||||
apply_chat_template=False,
|
||||
)
|
||||
backend = "sglang-oai-chat"
|
||||
api_url = f"{base_url}/v1/chat/completions"
|
||||
else:
|
||||
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
|
||||
:num_prompts
|
||||
]
|
||||
input_requests: List[DatasetRow] = [
|
||||
DatasetRow(p, 0, 512) for p in padded_prompts
|
||||
]
|
||||
backend = "sglang"
|
||||
api_url = f"{base_url}/generate"
|
||||
|
||||
# We need to set some dummy values in order to call `benchmark` below.
|
||||
args = SimpleNamespace(
|
||||
disable_ignore_eos=False,
|
||||
disable_stream=False,
|
||||
return_logprob=False,
|
||||
backend=backend,
|
||||
dataset_name="custom",
|
||||
num_prompts=None,
|
||||
sharegpt_output_len=None,
|
||||
random_input_len=None,
|
||||
random_output_len=None,
|
||||
random_range_ratio=None,
|
||||
output_file=None,
|
||||
warmup_requests=1,
|
||||
output_details=False,
|
||||
)
|
||||
set_global_args(args)
|
||||
|
||||
# Run benchmark
|
||||
results = asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id="default",
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=float("inf"),
|
||||
max_concurrency=batch_size,
|
||||
disable_tqdm=False,
|
||||
lora_names=None,
|
||||
extra_request_body={},
|
||||
profile=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert results["completed"] == len(input_requests)
|
||||
acc_length = results["accept_length"] or 1.0
|
||||
avg_output_token = results["total_output_tokens"] / results["completed"]
|
||||
|
||||
server_info = requests.get(base_url + "/get_server_info").json()
|
||||
# We use 20% percentile instead of median on purpose
|
||||
step_time = np.percentile(
|
||||
server_info["internal_states"][0]["step_time_dict"][str(batch_size)], 20
|
||||
)
|
||||
speed = 1 / step_time * acc_length
|
||||
|
||||
return (
|
||||
round(acc_length, 3),
|
||||
round(step_time, 5),
|
||||
round(speed, 3),
|
||||
avg_output_token,
|
||||
)
|
||||
|
||||
|
||||
def main(args, server_args):
|
||||
base_url = "http://127.0.0.1:20000"
|
||||
|
||||
configs = []
|
||||
for batch_size in args.batch_size:
|
||||
for steps in args.steps:
|
||||
for topk in args.topk:
|
||||
for num_draft_tokens in args.num_draft_tokens:
|
||||
if steps * topk + 1 < num_draft_tokens:
|
||||
continue
|
||||
|
||||
if (steps == 0 or topk == 0 or num_draft_tokens == 0) and (
|
||||
steps + topk + num_draft_tokens != 0
|
||||
):
|
||||
# steps == 0 and topk == 0 and num_draft_tokens == 0 is a special case for non-speculative decoding.
|
||||
continue
|
||||
|
||||
configs.append((batch_size, steps, topk, num_draft_tokens))
|
||||
|
||||
for i in range(args.start, args.end or len(configs)):
|
||||
batch_size, steps, topk, num_draft_tokens = configs[i]
|
||||
|
||||
node0_print(
|
||||
f"Start {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}"
|
||||
)
|
||||
|
||||
# Create an LLM.
|
||||
if steps == 0:
|
||||
other_args = []
|
||||
else:
|
||||
other_args = [
|
||||
"--speculative-num-steps",
|
||||
steps,
|
||||
"--speculative-eagle-topk",
|
||||
topk,
|
||||
"--speculative-num-draft-tokens",
|
||||
num_draft_tokens,
|
||||
]
|
||||
if server_args.speculative_draft_model_path is not None:
|
||||
other_args.extend(
|
||||
[
|
||||
"--speculative-draft-model-path",
|
||||
server_args.speculative_draft_model_path,
|
||||
"--speculative-algorithm",
|
||||
server_args.speculative_algorithm,
|
||||
]
|
||||
)
|
||||
|
||||
other_args.extend(
|
||||
[
|
||||
"--cuda-graph-max-bs",
|
||||
batch_size,
|
||||
"--mem-fraction-static",
|
||||
server_args.mem_fraction_static,
|
||||
"--tp-size",
|
||||
server_args.tp_size,
|
||||
"--max-running-requests",
|
||||
batch_size,
|
||||
]
|
||||
)
|
||||
|
||||
if server_args.trust_remote_code:
|
||||
other_args.extend(
|
||||
[
|
||||
"--trust-remote-code",
|
||||
]
|
||||
)
|
||||
|
||||
if server_args.attention_backend:
|
||||
other_args.extend(
|
||||
[
|
||||
"--attention-backend",
|
||||
server_args.attention_backend,
|
||||
]
|
||||
)
|
||||
|
||||
if server_args.quantization:
|
||||
other_args.extend(
|
||||
[
|
||||
"--quantization",
|
||||
server_args.quantization,
|
||||
]
|
||||
)
|
||||
|
||||
process = popen_launch_server(
|
||||
args.model_path,
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
env={
|
||||
"SGLANG_RECORD_STEP_TIME": "1",
|
||||
**os.environ,
|
||||
},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
|
||||
try:
|
||||
# Warmup
|
||||
send_one_batch(
|
||||
base_url, batch_size, batch_size, tokenizer, args.is_multimodal
|
||||
)
|
||||
|
||||
# Benchmark
|
||||
acc_length, step_time, speed, completion_tokens = send_one_batch(
|
||||
base_url,
|
||||
max(args.num_prompts, batch_size),
|
||||
batch_size,
|
||||
tokenizer,
|
||||
args.is_multimodal,
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
node0_print(
|
||||
f"Finish {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}, {speed=:.2f} token/s, step_time={step_time * 1000:.2f} ms"
|
||||
)
|
||||
|
||||
record = {
|
||||
"batch_size": batch_size,
|
||||
"steps": steps,
|
||||
"topk": topk,
|
||||
"num_draft_tokens": num_draft_tokens,
|
||||
"acc_length": acc_length,
|
||||
"step_time": step_time,
|
||||
"speed": speed,
|
||||
"completion_tokens": completion_tokens,
|
||||
}
|
||||
|
||||
with open(args.output, "a") as fout:
|
||||
fout.write(json.dumps(record) + "\n")
|
||||
|
||||
# Wait for the server to shutdown
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2, 4, 8, 16),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(0, 1, 3, 5, 7), # use (0, 1, 2, 3, 4) for large batch size
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(0, 1, 2, 4, 8),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_draft_tokens",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(0, 2, 4, 8, 16, 32), # use (0, 2, 4, 8) for large batch size
|
||||
)
|
||||
parser.add_argument("--num-prompts", type=int, default=16)
|
||||
parser.add_argument("--start", type=int, default=0)
|
||||
parser.add_argument("--end", type=int)
|
||||
parser.add_argument("--output", type=str, default="output.jsonl")
|
||||
parser.add_argument("--is-multimodal", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
server_args: ServerArgs = ServerArgs.from_cli_args(args)
|
||||
|
||||
main(args, server_args)
|
||||
22
scripts/playground/disaggregation/cli-logprob.py
Normal file
22
scripts/playground/disaggregation/cli-logprob.py
Normal file
@@ -0,0 +1,22 @@
|
||||
prompt = "The capital of france is "
|
||||
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
response = requests.post(
|
||||
"http://0.0.0.0:8000/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"sampling_params": {"temperature": 0},
|
||||
"return_logprob": True,
|
||||
"return_input_logprob": True,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
|
||||
j = response.json()
|
||||
input_logprobs = j["meta_info"]["input_token_logprobs"]
|
||||
output_logprobs = j["meta_info"]["output_token_logprobs"]
|
||||
|
||||
print(len(input_logprobs), len(output_logprobs))
|
||||
34
scripts/playground/disaggregation/cli-so.py
Normal file
34
scripts/playground/disaggregation/cli-so.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
port = 8000
|
||||
|
||||
json_schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||
"population": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "population"],
|
||||
}
|
||||
)
|
||||
|
||||
# JSON
|
||||
response = requests.post(
|
||||
f"http://localhost:{port}/generate",
|
||||
json={
|
||||
"text": "Here is the information of the capital of France in the JSON format.\n",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 64,
|
||||
"json_schema": json_schema,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
print(response.json())
|
||||
|
||||
|
||||
# python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --trust-remote-code --disaggregation-mode prefill --tp 2 --disaggregation-ib-device mlx5_roce0,mlx5_roce1 --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --host 127.0.0.1 --port 8100
|
||||
29
scripts/playground/disaggregation/cli.py
Normal file
29
scripts/playground/disaggregation/cli.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
prompt = """
|
||||
According to CNBC's Faber, the investors present on the call interpreted this statement as an indication of an upcoming funding round. While speculative, Faber believes the funding round could be as large as $25 billion, and bestow a valuation of between $150 billion and $200 billion on xAI.
|
||||
|
||||
For the benefit of those who might not be aware, xAI recently acquired the social media platform X in an all-stock deal that valued the former at $80 billion and the latter at $33 billion, inclusive of $12 billion in liabilities. This meant that the deal bestowed a gross valuation of $45 billion on X before factoring in its debt load of $12 billion.
|
||||
|
||||
Bear in mind that Elon Musk took X (then called Twitter) private back in 2022 in a $44 billion deal. Since then, Musk has managed to stem X's cash bleed, with the company reportedly generating $1.2 billion in adjusted EBITDA in 2024.
|
||||
|
||||
According to the investors present on the call, xAI is currently generating around $1 billion in annual revenue. This contrasts sharply with the erstwhile muted expectations of many investors, who did not expect the startup to generate any material revenue this year.
|
||||
|
||||
Elsewhere, Faber also alludes to the fact that xAI is already working on its next big training supercluster, officially dubbed the Colossus 2, which is expected to eventually house as many as 1 million NVIDIA GPUs at a cost of between $35 billion and $40 billion.
|
||||
|
||||
|
||||
Even though xAI's Grok LLM is already largely comparable with OpenAI's cutting-edge models, the Colossus 2 would significantly up the ante, and could feasibly challenge OpenAI's apex position in the AI sphere.
|
||||
|
||||
Give your honest take on the above text:
|
||||
"""
|
||||
|
||||
response = requests.post(
|
||||
"http://0.0.0.0:8000/generate",
|
||||
json={"text": prompt, "sampling_params": {"temperature": 0}},
|
||||
)
|
||||
|
||||
|
||||
response_json = response.json()
|
||||
print(response_json["text"])
|
||||
241
scripts/playground/frontend_reasoning.ipynb
Normal file
241
scripts/playground/frontend_reasoning.ipynb
Normal file
@@ -0,0 +1,241 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Launch A Server\n",
|
||||
"\n",
|
||||
"Launch the server with a reasoning model (Qwen 3.5-4B) and reasoning parser."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sglang import separate_reasoning, assistant_begin, assistant_end\n",
|
||||
"from sglang import assistant, function, gen, system, user\n",
|
||||
"from sglang import image\n",
|
||||
"from sglang import RuntimeEndpoint, set_default_backend\n",
|
||||
"from sglang.srt.utils import load_image\n",
|
||||
"from sglang.test.test_utils import is_in_ci\n",
|
||||
"from sglang.utils import print_highlight, terminate_process, wait_for_server\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"if is_in_ci():\n",
|
||||
" from patch import launch_server_cmd\n",
|
||||
"else:\n",
|
||||
" from sglang.utils import launch_server_cmd\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"server_process, port = launch_server_cmd(\n",
|
||||
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen3-4B --reasoning-parser qwen3 --host 0.0.0.0\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"wait_for_server(f\"http://localhost:{port}\")\n",
|
||||
"print(f\"Server started on http://localhost:{port}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Set the default backend. Note: you can set chat_template_name in RontimeEndpoint. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"set_default_backend(\n",
|
||||
" RuntimeEndpoint(f\"http://localhost:{port}\", chat_template_name=\"qwen\")\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's start with a basic question-answering task. And see how the reasoning content is generated."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@function\n",
|
||||
"def basic_qa(s, question):\n",
|
||||
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
|
||||
" s += user(question)\n",
|
||||
" s += assistant_begin()\n",
|
||||
" s += gen(\"answer\", max_tokens=512)\n",
|
||||
" s += assistant_end()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"state = basic_qa(\"List 3 countries and their capitals.\")\n",
|
||||
"print_highlight(state[\"answer\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"With `separate_reasoning`, you can move the reasoning content to `{param_name}_reasoning_content` in the state."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@function\n",
|
||||
"def basic_qa_separate_reasoning(s, question):\n",
|
||||
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
|
||||
" s += user(question)\n",
|
||||
" s += assistant_begin()\n",
|
||||
" s += separate_reasoning(gen(\"answer\", max_tokens=512), model_type=\"qwen3\")\n",
|
||||
" s += assistant_end()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"reasoning_state = basic_qa_separate_reasoning(\"List 3 countries and their capitals.\")\n",
|
||||
"print_highlight(reasoning_state.stream_executor.variable_event.keys())\n",
|
||||
"print_highlight(\n",
|
||||
" f\"\\nSeparated Reasoning Content:\\n{reasoning_state['answer_reasoning_content']}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print_highlight(f\"\\n\\nContent:\\n{reasoning_state['answer']}\")\n",
|
||||
"print_highlight(f\"\\n\\nMessages:\\n{reasoning_state.messages()[-1]}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`separate_reasoning` can also be used in multi-turn conversations."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@function\n",
|
||||
"def multi_turn_qa(s):\n",
|
||||
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
|
||||
" s += user(\"Please give me a list of 3 countries and their capitals.\")\n",
|
||||
" s += assistant(\n",
|
||||
" separate_reasoning(gen(\"first_answer\", max_tokens=512), model_type=\"qwen3\")\n",
|
||||
" )\n",
|
||||
" s += user(\"Please give me another list of 3 countries and their capitals.\")\n",
|
||||
" s += assistant(\n",
|
||||
" separate_reasoning(gen(\"second_answer\", max_tokens=512), model_type=\"qwen3\")\n",
|
||||
" )\n",
|
||||
" return s\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"reasoning_state = multi_turn_qa()\n",
|
||||
"print_highlight(f\"\\n\\nfirst_answer:\\n{reasoning_state['first_answer']}\")\n",
|
||||
"print_highlight(\n",
|
||||
" f\"\\n\\nfirst_answer_reasoning_content:\\n{reasoning_state['first_answer_reasoning_content']}\"\n",
|
||||
")\n",
|
||||
"print_highlight(f\"\\n\\nsecond_answer:\\n{reasoning_state['second_answer']}\")\n",
|
||||
"print_highlight(\n",
|
||||
" f\"\\n\\nsecond_answer_reasoning_content:\\n{reasoning_state['second_answer_reasoning_content']}\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using No thinking as Qwen 3's advanced feature \n",
|
||||
"\n",
|
||||
"sglang separate_reasoning is particularly useful when combined with Qwen 3's advanced feature.\n",
|
||||
"\n",
|
||||
"[Qwen 3's advanced usages](https://qwenlm.github.io/blog/qwen3/#advanced-usages)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"reasoning_state = basic_qa_separate_reasoning(\n",
|
||||
" \"List 3 countries and their capitals. /no_think\"\n",
|
||||
")\n",
|
||||
"print_highlight(f\"Reasoning Content:\\n{reasoning_state['answer_reasoning_content']}\")\n",
|
||||
"print_highlight(f\"Content:\\n{reasoning_state['answer']}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`separate_reasoning` can also be used in regular expression generation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@function\n",
|
||||
"def regular_expression_gen(s):\n",
|
||||
" s += user(\n",
|
||||
" \"What is the IP address of the Google DNS servers? just provide the answer\"\n",
|
||||
" )\n",
|
||||
" s += assistant(\n",
|
||||
" separate_reasoning(\n",
|
||||
" gen(\n",
|
||||
" \"answer\",\n",
|
||||
" temperature=0,\n",
|
||||
" regex=r\"((25[0-5]|2[0-4]\\d|[01]?\\d\\d?).){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\",\n",
|
||||
" max_tokens=512,\n",
|
||||
" ),\n",
|
||||
" model_type=\"qwen3\",\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"reasoning_state = regular_expression_gen()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_highlight(f\"Answer:\\n{reasoning_state['answer']}\")\n",
|
||||
"print_highlight(\n",
|
||||
" f\"\\n\\nReasoning Content:\\n{reasoning_state['answer_reasoning_content']}\"\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
7
scripts/playground/launch_tgi.sh
Normal file
7
scripts/playground/launch_tgi.sh
Normal file
@@ -0,0 +1,7 @@
|
||||
# Assuming the model is downdloaded at /home/ubuntu/model_weights/Llama-2-7b-chat-hf
|
||||
docker run --name tgi --rm -ti --gpus all --network host \
|
||||
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
|
||||
ghcr.io/huggingface/text-generation-inference:1.1.0 \
|
||||
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
|
||||
--max-input-length 2048 --max-total-tokens 4096 \
|
||||
--port 24000
|
||||
14
scripts/playground/load_tokenizer.py
Normal file
14
scripts/playground/load_tokenizer.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import argparse
|
||||
import code
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
t = get_tokenizer(args.name)
|
||||
code.interact(local=locals())
|
||||
36
scripts/playground/long_context_example.py
Normal file
36
scripts/playground/long_context_example.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from urllib.request import urlopen
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
test_cases = {
|
||||
"64k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt",
|
||||
"200k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt",
|
||||
"600k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt",
|
||||
"1m": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt",
|
||||
}
|
||||
|
||||
client = OpenAI(api_key="EMPTY", base_url="http://127.0.0.1:30000/v1")
|
||||
|
||||
for name, url in test_cases.items():
|
||||
print(f"\n==== Running test case: {name} ====")
|
||||
try:
|
||||
with urlopen(url, timeout=10) as response:
|
||||
prompt = response.read().decode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"Failed to load prompt for {name}: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=True,
|
||||
max_tokens=128,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content is not None:
|
||||
print(chunk.choices[0].delta.content, end="", flush=True)
|
||||
except Exception as e:
|
||||
print(f"\nError during completion for {name}: {e}")
|
||||
77
scripts/playground/lora/analyzer.py
Normal file
77
scripts/playground/lora/analyzer.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.append("../../")
|
||||
from fix_corrupted_json import clean_json_file
|
||||
|
||||
dirpath = "/Users/ying"
|
||||
output_file_prefix = "analyzed_log"
|
||||
|
||||
time = {}
|
||||
tot_time = {}
|
||||
size = {}
|
||||
|
||||
os.system(f"rm {output_file_prefix}*")
|
||||
|
||||
for dirname in glob.glob(os.path.join(dirpath, "trace*")):
|
||||
print(dirname)
|
||||
trace_name = dirname.split("/")[-1]
|
||||
time[trace_name] = {}
|
||||
size[trace_name] = {}
|
||||
total_time = 0
|
||||
for filename in tqdm(glob.glob(os.path.join(dirname, "*.json"))):
|
||||
step_name = filename.split("/")[-1].split(".")[0]
|
||||
step_name = "_".join(step_name.split("_")[1:])
|
||||
if "prefill" not in filename and "decode" not in filename:
|
||||
continue
|
||||
|
||||
match = re.search(r"(prefill|decode)_step_(\d+)\.json", filename)
|
||||
if match:
|
||||
phase = match.group(1)
|
||||
step = match.group(2)
|
||||
else:
|
||||
raise Exception(f"Cannot parse {filename}")
|
||||
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
trace = json.load(f)
|
||||
except:
|
||||
clean_json_file(filename, filename)
|
||||
with open(filename, "r") as f:
|
||||
trace = json.load(f)
|
||||
|
||||
for event in trace["traceEvents"]:
|
||||
name = event["name"]
|
||||
if name in ["profile_prefill_step", "profile_decode_step"]:
|
||||
dur = event["dur"] / 1e3
|
||||
time[trace_name][step_name] = dur
|
||||
break
|
||||
total_time += dur
|
||||
|
||||
step = int(step_name.split("_")[-1])
|
||||
with open(os.path.join(dirname, f"size_{step}.json"), "r") as f:
|
||||
size_info = json.load(f)
|
||||
size[trace_name][step_name] = size_info["size"]
|
||||
|
||||
tot_time[trace_name] = total_time
|
||||
time[trace_name] = dict(
|
||||
sorted(time[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
|
||||
)
|
||||
size[trace_name] = dict(
|
||||
sorted(size[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
|
||||
)
|
||||
|
||||
with open(f"{output_file_prefix}_{trace_name}", "a") as f:
|
||||
for k, v in time[trace_name].items():
|
||||
size_v = size[trace_name][k]
|
||||
print(f"{k:>15}{v:10.2f}\t{size_v}")
|
||||
f.write(f"{k:>15}{v:10.2f}\t{size_v}\n")
|
||||
|
||||
with open(f"{output_file_prefix}_total_time", "w") as f:
|
||||
print(tot_time)
|
||||
json.dump(tot_time, f)
|
||||
62
scripts/playground/lora/lora_hf_play.py
Normal file
62
scripts/playground/lora/lora_hf_play.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
# ADAPTER = "winddude/wizardLM-LlaMA-LoRA-7B"
|
||||
ADAPTER = "/home/ying/test_lora"
|
||||
HF_TOKEN = "..."
|
||||
|
||||
|
||||
prompt = """
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
"""
|
||||
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(MODEL)
|
||||
|
||||
base_model = LlamaForCausalLM.from_pretrained(
|
||||
MODEL,
|
||||
device_map="auto",
|
||||
# load_in_8bit=True,
|
||||
torch_dtype=torch.float16,
|
||||
# use_auth_token=HF_TOKEN,
|
||||
).cuda()
|
||||
|
||||
|
||||
# base model generate
|
||||
with torch.no_grad():
|
||||
output_tensors = base_model.generate(
|
||||
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
|
||||
max_new_tokens=32,
|
||||
do_sample=False,
|
||||
)[0]
|
||||
|
||||
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
|
||||
print("======= base output ========")
|
||||
print(output)
|
||||
|
||||
|
||||
# peft model generate
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
ADAPTER,
|
||||
torch_dtype=torch.float16,
|
||||
is_trainable=False,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = model.generate(
|
||||
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
|
||||
max_new_tokens=32,
|
||||
do_sample=False,
|
||||
)[0]
|
||||
|
||||
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
|
||||
print("======= peft output ========")
|
||||
print(output)
|
||||
30
scripts/playground/lora/lora_vllm_play.py
Normal file
30
scripts/playground/lora/lora_vllm_play.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
ADAPTER = "/home/ying/test_lora"
|
||||
prompt = """
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
"""
|
||||
|
||||
|
||||
llm = LLM(model=MODEL, enable_lora=True)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
)
|
||||
|
||||
prompts = [prompt]
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=LoRARequest("test_lora", 1, ADAPTER)
|
||||
)
|
||||
|
||||
print(outputs[0].prompt)
|
||||
print(outputs[0].outputs[0].text)
|
||||
197
scripts/playground/reference_hf.py
Normal file
197
scripts/playground/reference_hf.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Usage: python3 scripts/playground/reference_hf.py --model-path MODEL_PATH --model-type {text,vlm} [--max-new-tokens NUM] [--dtype DTYPE]
|
||||
--model-path MODEL_PATH: Path to model (default: TinyLlama/TinyLlama-1.1B-Chat-v0.4)
|
||||
--model-type {text,vlm}: Model type, text or vlm (default: text)
|
||||
--max-new-tokens NUM: Max new tokens to generate (default: 16)
|
||||
--dtype DTYPE: Data type for computation (default: float16)
|
||||
Note: '--model' is deprecated; use '--model-path'. Runs normal_text() for text, vlm_text_with_image() for vlm.
|
||||
|
||||
Reference output:
|
||||
========== Prompt 0 ==========
|
||||
prefill logits (final) tensor([-8.3125, -7.1172, 3.3398, ..., -4.9531, -4.1328, -3.4141],
|
||||
device='cuda:0')
|
||||
<s> The capital of France is Paris.
|
||||
The capital of the United States is Washington, D.C.
|
||||
|
||||
========== Prompt 1 ==========
|
||||
prefill logits (final) tensor([-8.9062, -9.0156, 4.1484, ..., -4.9922, -4.4961, -4.0742],
|
||||
device='cuda:0')
|
||||
<s> The capital of the United Kindom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
The capital of
|
||||
|
||||
========== Prompt 2 ==========
|
||||
prefill logits (final) tensor([-9.6328, -9.0547, 4.0234, ..., -5.3047, -4.7148, -4.4609],
|
||||
device='cuda:0')
|
||||
<s> Today is a sunny day and I like to go for a walk in the park.
|
||||
I'm going to the
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
)
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def vlm_text_with_image(args):
|
||||
# Load the processor and model for ImageTextToText tasks
|
||||
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=args.dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# List of image URLs to process
|
||||
image_urls = [
|
||||
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
|
||||
]
|
||||
|
||||
# Conversation template for the processor
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
max_new_tokens = args.max_new_tokens
|
||||
|
||||
for i, url in enumerate(image_urls):
|
||||
# Load the image from the URL
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
# Apply the chat template to the text prompt
|
||||
# Notice that not all processors support chat templates.
|
||||
# LLaVA and QWen are two processors that support chat templates.
|
||||
if not hasattr(processor, "apply_chat_template"):
|
||||
raise ValueError("The processor does not support chat templates.")
|
||||
text_prompt = processor.apply_chat_template(
|
||||
conversation, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Prepare inputs for the model
|
||||
inputs = processor(text=[text_prompt], images=[image], return_tensors="pt").to(
|
||||
"cuda:0"
|
||||
)
|
||||
|
||||
# Generate output from the model
|
||||
output_ids = model.generate(
|
||||
**inputs, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_str = processor.decode(output_ids[0])
|
||||
|
||||
# Get the logits from the model's forward pass
|
||||
outputs = model.forward(**inputs)
|
||||
logits = outputs.logits[0, -1, :]
|
||||
|
||||
print(f"\n========== Image {i} ==========")
|
||||
print("prefill logits (final)", logits)
|
||||
# TODO(gaocegege): The output contains numerous <|image_pad|> tokens,
|
||||
# making it cluttered and difficult to read.
|
||||
# These tokens should be removed or cleaned up for better readability.
|
||||
print(output_str)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def normal_text(args):
|
||||
t = get_tokenizer(args.model_path, trust_remote_code=True)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=args.dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
max_new_tokens = args.max_new_tokens
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = t.encode(p, return_tensors="pt").to("cuda:0")
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda:0")
|
||||
|
||||
output_ids = m.generate(
|
||||
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_str = t.decode(output_ids[0])
|
||||
|
||||
prefill_logits = m.forward(input_ids).logits[0][-1]
|
||||
|
||||
print(f"\n========== Prompt {i} ==========")
|
||||
print("prefill logits (final)", prefill_logits)
|
||||
print(output_str)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def synthetic_tokens(args):
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
)
|
||||
m.cuda()
|
||||
print(m)
|
||||
|
||||
input_len = 256
|
||||
output_len = 8
|
||||
prompts = [list(range(5, 5 + input_len))]
|
||||
|
||||
for p in prompts:
|
||||
input_ids = p
|
||||
for i in range(output_len + 1):
|
||||
prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
|
||||
0
|
||||
][-1]
|
||||
|
||||
if i == 0:
|
||||
print("prefill logits", prefill_logits)
|
||||
else:
|
||||
print("decode", i - 1, prefill_logits)
|
||||
|
||||
input_ids.append(torch.argmax(prefill_logits).item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
|
||||
)
|
||||
parser.add_argument("--max-new-tokens", type=int, default=16)
|
||||
|
||||
parser.add_argument("--dtype", type=str, default="float16")
|
||||
|
||||
parser.add_argument("--model-type", type=str, default="text")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "vlm":
|
||||
vlm_text_with_image(args)
|
||||
else:
|
||||
normal_text(args)
|
||||
151
scripts/playground/replay_request_dump.py
Normal file
151
scripts/playground/replay_request_dump.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
Usage:
|
||||
# replay from a folder
|
||||
python3 replay_request_dump.py --file-number 100 --parallel 512 --input-folder /data/lianmin/sglang_request_dump/grok-mini-0220-engine-5756f8f94-28bm6/
|
||||
|
||||
# replay from a single file
|
||||
python3 replay_request_dump.py --parallel 512 --input-file /data/sglang_crash_dump/memx-cti-34-sr1.xpop.twttr.net/crash_dump_2025-06-04_20-13-18.pkl
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.bench_serving import set_ulimit
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
def read_records(files):
|
||||
records = []
|
||||
for f in files:
|
||||
tmp = pickle.load(open(f, "rb"))
|
||||
if isinstance(tmp, dict) and "requests" in tmp:
|
||||
records.extend(tmp["requests"])
|
||||
else:
|
||||
records.extend(tmp)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
def run_one_request_internal(record):
|
||||
(req, output, replay_init_time, start_time, end_time, idx) = record
|
||||
time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
|
||||
|
||||
if "completion_tokens" in output.get("meta_info", {}):
|
||||
recorded_completion_tokens = output["meta_info"]["completion_tokens"]
|
||||
else:
|
||||
recorded_completion_tokens = ""
|
||||
|
||||
json_data = asdict(req)
|
||||
stream = json_data["stream"]
|
||||
|
||||
if args.ignore_eos:
|
||||
json_data["sampling_params"]["ignore_eos"] = True
|
||||
if recorded_completion_tokens:
|
||||
json_data["sampling_params"]["max_new_tokens"] = recorded_completion_tokens
|
||||
|
||||
response = requests.post(
|
||||
f"http://{args.host}:{args.port}/generate",
|
||||
json=json_data,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if stream:
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
ret = json.loads(chunk[5:].strip("\n"))
|
||||
else:
|
||||
ret = response.json()
|
||||
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||
print(
|
||||
f"{idx=}, {start_time=:.2f}, {prompt_tokens=}, "
|
||||
f"{completion_tokens=}, {recorded_completion_tokens=}"
|
||||
)
|
||||
|
||||
|
||||
def run_one_request(record):
|
||||
# global success_ct, error_ct
|
||||
|
||||
try:
|
||||
run_one_request_internal(record)
|
||||
# success_ct += 1
|
||||
except Exception:
|
||||
# error_ct += 1
|
||||
traceback = get_exception_traceback()
|
||||
print(f"Hit an exception: {traceback}")
|
||||
|
||||
|
||||
def main(records):
|
||||
if len(records) == 0:
|
||||
return
|
||||
|
||||
base_time = records[0][-2]
|
||||
base_time_str = datetime.fromtimestamp(base_time).strftime("%y-%m-%d %H:%M:%S")
|
||||
print(f"{base_time_str=}")
|
||||
replay_init_time = time.time()
|
||||
|
||||
for i in range(len(records)):
|
||||
req, output, start_time, end_time = records[i]
|
||||
start_time -= base_time
|
||||
records[i] = (req, output, replay_init_time, start_time, end_time, i)
|
||||
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(run_one_request, records)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
parser.add_argument(
|
||||
"--input-folder", type=str, default=None, help="Folder containing pickle files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-file", type=str, default=None, help="Single pickle file to process"
|
||||
)
|
||||
parser.add_argument("--file-number", type=int, default=1)
|
||||
parser.add_argument("--req-number", type=int, default=1000000)
|
||||
parser.add_argument("--req-start", type=int, default=0)
|
||||
parser.add_argument("--parallel", type=int, default=512)
|
||||
parser.add_argument("--idx", type=int, default=None)
|
||||
parser.add_argument("--ignore-eos", action="store_true")
|
||||
parser.add_argument("--speed", type=float, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
set_ulimit()
|
||||
|
||||
files = []
|
||||
if args.input_file:
|
||||
files = [args.input_file]
|
||||
if args.file_number > 1:
|
||||
print("Warning: --file-number is ignored when --input-file is provided.")
|
||||
elif args.input_folder:
|
||||
files = glob.glob(f"{args.input_folder}/*.pkl")
|
||||
files = files[: args.file_number]
|
||||
else:
|
||||
print("Error: Either --input-folder or --input-file must be provided.")
|
||||
exit(1)
|
||||
print(f"{files=}")
|
||||
|
||||
records = read_records(files)
|
||||
# Sort by the receive time, before filtering
|
||||
records.sort(key=lambda x: x[-2])
|
||||
records = records[args.req_start :]
|
||||
if args.idx:
|
||||
records = [records[args.idx]]
|
||||
print(f"testing {args.idx=}")
|
||||
print(f"{records[0]}")
|
||||
print(f"{len(records)=}")
|
||||
main(records)
|
||||
207
scripts/playground/router/test_tree.py
Normal file
207
scripts/playground/router/test_tree.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from tree import MultiTenantRadixTree
|
||||
|
||||
|
||||
class TestMultiTenantRadixTree(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tree = MultiTenantRadixTree()
|
||||
|
||||
def test_insert_exact_match(self):
|
||||
"""Test 1: Basic insert and exact match operations"""
|
||||
# Insert a single string for one tenant
|
||||
self.tree.insert("hello", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("hello")
|
||||
self.assertEqual(matched, "hello")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Insert same string for different tenant
|
||||
self.tree.insert("hello", "tenant2")
|
||||
matched, tenant = self.tree.prefix_match("hello")
|
||||
self.assertIn(tenant, ["tenant1", "tenant2"])
|
||||
|
||||
# Insert different string for same tenant
|
||||
self.tree.insert("world", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("world")
|
||||
self.assertEqual(matched, "world")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
def test_insert_partial_match(self):
|
||||
"""Test 2: Insert with partial matching scenarios"""
|
||||
# Test partial matches with common prefixes
|
||||
self.tree.insert("hello", "tenant1")
|
||||
print(self.tree.pretty_print())
|
||||
self.tree.insert("help", "tenant2")
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
# Match exact strings
|
||||
matched, tenant = self.tree.prefix_match("hello")
|
||||
self.assertEqual(matched, "hello")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
matched, tenant = self.tree.prefix_match("help")
|
||||
self.assertEqual(matched, "help")
|
||||
self.assertEqual(tenant, "tenant2")
|
||||
|
||||
# Match partial string
|
||||
matched, tenant = self.tree.prefix_match("hel")
|
||||
self.assertEqual(matched, "hel")
|
||||
self.assertIn(tenant, ["tenant1", "tenant2"])
|
||||
|
||||
# Match longer string
|
||||
matched, tenant = self.tree.prefix_match("hello_world")
|
||||
self.assertEqual(matched, "hello")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
def test_insert_edge_cases(self):
|
||||
"""Test 3: Edge cases for insert and match operations"""
|
||||
# Empty string
|
||||
self.tree.insert("", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("")
|
||||
self.assertEqual(matched, "")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Single character
|
||||
self.tree.insert("a", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("a")
|
||||
self.assertEqual(matched, "a")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Very long string
|
||||
long_str = "a" * 1000
|
||||
self.tree.insert(long_str, "tenant1")
|
||||
matched, tenant = self.tree.prefix_match(long_str)
|
||||
self.assertEqual(matched, long_str)
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Unicode characters
|
||||
self.tree.insert("你好", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("你好")
|
||||
self.assertEqual(matched, "你好")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
def test_simple_eviction(self):
|
||||
"""Test 4: Simple eviction scenarios
|
||||
Tenant1: limit 10 chars
|
||||
Tenant2: limit 5 chars
|
||||
|
||||
Should demonstrate:
|
||||
1. Basic eviction when size limit exceeded
|
||||
2. Proper eviction based on last access time
|
||||
3. Verification that shared nodes remain intact for other tenants
|
||||
"""
|
||||
# Set up size limits
|
||||
max_size = {"tenant1": 10, "tenant2": 5}
|
||||
|
||||
# Insert strings for both tenants
|
||||
self.tree.insert("hello", "tenant1") # size 5
|
||||
self.tree.insert("hello", "tenant2") # size 5
|
||||
self.tree.insert("world", "tenant2") # size 5, total for tenant2 = 10
|
||||
|
||||
# Verify initial sizes
|
||||
sizes_before = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_before["tenant1"], 5) # "hello" = 5
|
||||
self.assertEqual(sizes_before["tenant2"], 10) # "hello" + "world" = 10
|
||||
|
||||
# Evict - should remove "hello" from tenant2 as it's the oldest
|
||||
self.tree.evict_tenant_data(max_size)
|
||||
|
||||
# Verify sizes after eviction
|
||||
sizes_after = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_after["tenant1"], 5) # Should be unchanged
|
||||
self.assertEqual(sizes_after["tenant2"], 5) # Only "world" remains
|
||||
|
||||
# Verify "world" remains for tenant2 (was accessed more recently)
|
||||
matched, tenant = self.tree.prefix_match("world")
|
||||
self.assertEqual(matched, "world")
|
||||
self.assertEqual(tenant, "tenant2")
|
||||
|
||||
def test_medium_eviction(self):
|
||||
"""Test 5: Medium complexity eviction scenarios with shared prefixes
|
||||
Tenant1: limit 10 chars
|
||||
Tenant2: limit 7 chars (forces one string to be evicted)
|
||||
|
||||
Tree structure after inserts:
|
||||
└── 'h' [t1, t2]
|
||||
├── 'i' [t1, t2] # Oldest for t2
|
||||
└── 'e' [t1, t2]
|
||||
├── 'llo' [t1, t2]
|
||||
└── 'y' [t2] # Newest for t2
|
||||
|
||||
Size calculations:
|
||||
tenant1: "h"(1) + "i"(1) + "e"(1) + "llo"(3) = 6 chars
|
||||
tenant2: "h"(1) + "i"(1) + "e"(1) + "llo"(3) + "y"(1) = 7 chars
|
||||
|
||||
After eviction (tenant2 exceeds limit by 1 char):
|
||||
"hi" should be removed from tenant2 as it's the oldest access
|
||||
"""
|
||||
max_size = {
|
||||
"tenant1": 10,
|
||||
"tenant2": 6,
|
||||
} # tenant2 will need to evict one string
|
||||
|
||||
# Create a tree with overlapping prefixes
|
||||
self.tree.insert("hi", "tenant1")
|
||||
self.tree.insert("hi", "tenant2") # OLDEST for t2
|
||||
|
||||
self.tree.insert("hello", "tenant1")
|
||||
self.tree.insert("hello", "tenant2")
|
||||
|
||||
self.tree.insert("hey", "tenant2") # NEWEST for t2
|
||||
|
||||
# Verify initial sizes
|
||||
sizes_before = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_before["tenant1"], 6) # h(1) + i(1) + e(1) + llo(3) = 6
|
||||
self.assertEqual(
|
||||
sizes_before["tenant2"], 7
|
||||
) # h(1) + i(1) + e(1) + llo(3) + y(1) = 7
|
||||
|
||||
print("\nTree before eviction:")
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
# Evict - should remove "hi" from tenant2 as it's the oldest
|
||||
self.tree.evict_tenant_data(max_size)
|
||||
|
||||
print("\nTree after eviction:")
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
# Verify sizes after eviction
|
||||
sizes_after = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_after["tenant1"], 6) # Should be unchanged
|
||||
self.assertEqual(sizes_after["tenant2"], 6) # h(1) + e(1) + llo(3) + y(1) = 6
|
||||
|
||||
def test_advanced_eviction(self):
|
||||
...
|
||||
# Create 4 tenants
|
||||
# Each tenants keeps adding strings with shared prefixes to thousands usage
|
||||
# Set a strict limit for each tenant to only 100
|
||||
# At the end, check whether all of the tenant is under 100 after eviction
|
||||
|
||||
max_size = {"tenant1": 100, "tenant2": 100, "tenant3": 100, "tenant4": 100}
|
||||
|
||||
prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"]
|
||||
for i in range(100):
|
||||
for j, prefix in enumerate(prefixes):
|
||||
random_suffix = "".join(random.choices(string.ascii_letters, k=10))
|
||||
self.tree.insert(prefix + random_suffix, f"tenant{j+1}")
|
||||
|
||||
sizes_before = self.tree.get_used_size_per_tenant()
|
||||
print(sizes_before)
|
||||
|
||||
self.tree.evict_tenant_data(max_size)
|
||||
|
||||
sizes_after = self.tree.get_used_size_per_tenant()
|
||||
print(sizes_after)
|
||||
# ensure size_after is below max_size
|
||||
for tenant, size in sizes_after.items():
|
||||
self.assertLessEqual(size, max_size[tenant])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
292
scripts/playground/router/tree.py
Normal file
292
scripts/playground/router/tree.py
Normal file
@@ -0,0 +1,292 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self):
|
||||
self.children: Dict[str, Node] = dict()
|
||||
# We choose to use text because most of the use cases are text-to-text,
|
||||
# so we can save the tokenizing overhead.
|
||||
self.text: str = ""
|
||||
# Maps tenant_id to their last access timestamp
|
||||
self.tenant_last_access_time: Dict[str, float] = dict()
|
||||
self.parent = None
|
||||
|
||||
|
||||
def shared_prefix_length(s1, s2):
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(min_length):
|
||||
if s1[i] != s2[i]:
|
||||
return i
|
||||
return min_length
|
||||
|
||||
|
||||
class MultiTenantRadixTree:
|
||||
"""
|
||||
Python Reference of Rust implementation of MultiTenantRadixTree
|
||||
|
||||
MultiTenantRadixTree is the overlap of multiple radix trees by different tenant
|
||||
Each node in the tree can be owned by multiple tenants, allowing for efficient storage of common prefixes
|
||||
while maintaining tenant isolation.
|
||||
|
||||
Key concepts:
|
||||
- Tenant: An entity that owns a subset of the stored strings
|
||||
- Each node tracks which tenants have access to it via tenant_last_access_time
|
||||
- The tree structure is shared, but queries can be filtered by tenant_id
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.root = Node()
|
||||
|
||||
def insert(self, s: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Insert string 's' and associate it with the given tenant_id.
|
||||
|
||||
Args:
|
||||
s: The string to insert
|
||||
tenant_id: The identifier of the tenant who owns this string
|
||||
"""
|
||||
curr = self.root
|
||||
curr_idx = 0
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
|
||||
while curr_idx < len(s):
|
||||
matched_node = None
|
||||
if s[curr_idx] in curr.children:
|
||||
matched_node = curr.children[s[curr_idx]]
|
||||
|
||||
if matched_node is None:
|
||||
# No match => create a new node
|
||||
new_node = Node()
|
||||
new_node.text = s[curr_idx:]
|
||||
new_node.parent = curr
|
||||
|
||||
curr.children[s[curr_idx]] = new_node
|
||||
curr_idx = len(s)
|
||||
curr = new_node
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
else:
|
||||
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
|
||||
|
||||
# 1. If the matched text is shorter than the node text => split the node
|
||||
if shared_len < len(matched_node.text):
|
||||
# Split structure: [matched_node] => [new_node] -> [contracted_matched_node]
|
||||
|
||||
matched_text = matched_node.text[:shared_len]
|
||||
unmatched_text = matched_node.text[shared_len:]
|
||||
|
||||
new_node = Node()
|
||||
new_node.text = matched_text
|
||||
new_node.children = {unmatched_text[0]: matched_node}
|
||||
new_node.parent = curr
|
||||
new_node.parent.children[matched_text[0]] = new_node
|
||||
new_node.tenant_last_access_time = (
|
||||
matched_node.tenant_last_access_time.copy()
|
||||
)
|
||||
|
||||
# Contract matched node
|
||||
matched_node.text = unmatched_text
|
||||
matched_node.parent = new_node
|
||||
|
||||
curr_idx += shared_len
|
||||
curr = new_node
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
# 2. If the matched text is longer or equal to the node text => walk down the node
|
||||
else:
|
||||
curr_idx += shared_len
|
||||
curr = matched_node
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
|
||||
def prefix_match(self, s: str) -> tuple[str, int]:
|
||||
"""
|
||||
Match string 's' with multiple tenants' trees in one operation.
|
||||
|
||||
Args:
|
||||
s: The string to match
|
||||
|
||||
Returns:
|
||||
Tuple(str, int): The longest prefix of 's' that matches the tree and the first tenant_id that own the matched prefix
|
||||
"""
|
||||
curr = self.root
|
||||
curr_idx = 0
|
||||
|
||||
ret_text = ""
|
||||
ret_tenant = None
|
||||
|
||||
while curr_idx < len(s):
|
||||
matched_node = None
|
||||
if s[curr_idx] in curr.children:
|
||||
matched_node = curr.children[s[curr_idx]]
|
||||
|
||||
if matched_node is None:
|
||||
break
|
||||
|
||||
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
|
||||
if shared_len == len(matched_node.text):
|
||||
curr_idx += shared_len
|
||||
curr = matched_node
|
||||
else:
|
||||
curr_idx += shared_len
|
||||
curr = matched_node
|
||||
break
|
||||
|
||||
selected_tenant = list(curr.tenant_last_access_time.keys())[0]
|
||||
|
||||
# traverse back to the root to update last access time for the selected tenant
|
||||
while curr != self.root:
|
||||
curr.tenant_last_access_time[selected_tenant] = time.time()
|
||||
curr = curr.parent
|
||||
|
||||
return s[:curr_idx], selected_tenant
|
||||
|
||||
def evict_tenant_data(self, max_size_per_tenant: Dict[str, int]) -> None:
|
||||
"""
|
||||
Evict data for tenants that have exceeded their storage limits.
|
||||
|
||||
Args:
|
||||
max_size_per_tenant: Dictionary mapping tenant_id to their maximum allowed storage size
|
||||
"""
|
||||
|
||||
def leaf_of(node):
|
||||
"""
|
||||
If the node is a leaf for a tenant, add tenant_id to the return list
|
||||
This will return list of tenant ids
|
||||
If not a leaf for all tenants, return []
|
||||
"""
|
||||
candidates = dict([(k, True) for k in node.tenant_last_access_time.keys()])
|
||||
|
||||
for n in node.children.values():
|
||||
for c in n.tenant_last_access_time.keys():
|
||||
candidates[c] = False
|
||||
|
||||
return [k for k, v in candidates.items() if v]
|
||||
|
||||
# maintain a heap with (time, tenant, node) as the value
|
||||
import heapq
|
||||
|
||||
# 1. traverse the tree to
|
||||
# a. add all the leaves into a heap (a node with N tenants will be added N times into the heap)
|
||||
# b. calculate the used size for each tenant
|
||||
# do a dfs with stack
|
||||
stack = [self.root]
|
||||
pq = []
|
||||
used_size_per_tenant = defaultdict(int)
|
||||
|
||||
while stack:
|
||||
curr = stack.pop()
|
||||
for t in curr.tenant_last_access_time.keys():
|
||||
used_size_per_tenant[t] += len(curr.text)
|
||||
|
||||
for c in curr.children.values():
|
||||
stack.append(c)
|
||||
|
||||
# if the node is a leaf for a tenant, add the tenant to the heap
|
||||
tenants = leaf_of(curr)
|
||||
for t in tenants:
|
||||
heapq.heappush(pq, (curr.tenant_last_access_time[t], t, curr))
|
||||
|
||||
# 2. pop the heap
|
||||
# a. if the tenant's used size is less than the limit, continue
|
||||
# b. if the tenant's used size is greater than the limit, remove the leaf and update the used size, and add its parent to the heap
|
||||
while len(pq) > 0:
|
||||
time, tenant, node = heapq.heappop(pq)
|
||||
if used_size_per_tenant[tenant] <= max_size_per_tenant[tenant]:
|
||||
continue
|
||||
|
||||
# remove the leaf
|
||||
used_size_per_tenant[tenant] -= len(node.text)
|
||||
del node.tenant_last_access_time[tenant]
|
||||
# if no children and no tenants, remove the node
|
||||
if len(node.children) == 0 and len(node.tenant_last_access_time) == 0:
|
||||
del node.parent.children[node.text[0]]
|
||||
|
||||
# add its parent to the heap
|
||||
if tenant in leaf_of(node.parent):
|
||||
heapq.heappush(
|
||||
pq,
|
||||
(node.parent.tenant_last_access_time[tenant], tenant, node.parent),
|
||||
)
|
||||
|
||||
def get_used_size_per_tenant(self) -> Dict[str, int]:
|
||||
"""
|
||||
Calculate the used storage size for each tenant.
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping tenant_id to their used storage size
|
||||
"""
|
||||
used_size_per_tenant = defaultdict(int)
|
||||
|
||||
stack = [self.root]
|
||||
while stack:
|
||||
curr = stack.pop()
|
||||
for t in curr.tenant_last_access_time.keys():
|
||||
used_size_per_tenant[t] += len(curr.text)
|
||||
|
||||
for c in curr.children.values():
|
||||
stack.append(c)
|
||||
|
||||
return used_size_per_tenant
|
||||
|
||||
def remove_tenant(self, tenant_id: str) -> None:
|
||||
"""
|
||||
Remove all data associated with a specific tenant from the tree.
|
||||
This operation maintains the integrity of the shared tree structure while
|
||||
removing only the specified tenant's access information.
|
||||
|
||||
Args:
|
||||
tenant_id: The identifier of the tenant whose data should be removed
|
||||
"""
|
||||
# TODO: Implementation needed
|
||||
pass
|
||||
|
||||
def pretty_print(self) -> str:
|
||||
"""
|
||||
Returns a string representation of the tree showing the structure, tenant ownership,
|
||||
and leaf status for each node.
|
||||
|
||||
Returns:
|
||||
str: A formatted string showing the tree hierarchy with tenant information
|
||||
"""
|
||||
|
||||
def _node_to_str(node: Node, prefix: str = "", is_last: bool = True) -> str:
|
||||
# Current node representation
|
||||
node_str = prefix
|
||||
node_str += "└── " if is_last else "├── "
|
||||
|
||||
# Add node text
|
||||
node_str += f"'{node.text}' ["
|
||||
|
||||
# Add tenant information including both timestamp and leaf status
|
||||
tenant_info = []
|
||||
for tid, ts in node.tenant_last_access_time.items():
|
||||
time_str = (
|
||||
time.strftime("%H:%M:%S.", time.localtime(ts))
|
||||
+ f"{(ts % 1):0.3f}"[2:]
|
||||
)
|
||||
tenant_info.append(f"{tid} | {time_str}")
|
||||
|
||||
node_str += ", ".join(tenant_info)
|
||||
node_str += "]\n"
|
||||
|
||||
# Handle children
|
||||
children = list(node.children.items())
|
||||
for i, (char, child) in enumerate(children):
|
||||
is_last_child = i == len(children) - 1
|
||||
# Adjust prefix for children based on whether this is the last child
|
||||
new_prefix = prefix + (" " if is_last else "│ ")
|
||||
node_str += _node_to_str(child, new_prefix, is_last_child)
|
||||
|
||||
return node_str
|
||||
|
||||
if not self.root.children:
|
||||
return "Empty tree"
|
||||
|
||||
# Start with root's children since root itself is just an empty node
|
||||
result = ""
|
||||
children = list(self.root.children.items())
|
||||
for i, (char, child) in enumerate(children):
|
||||
is_last = i == len(children) - 1
|
||||
result += _node_to_str(child, "", is_last)
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user