sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

View File

@@ -0,0 +1,308 @@
"""
Usage:
# single GPU
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B
# multiple GPU
python3 bench_speculative.py --model-path deepseek-ai/DeepSeek-V3 --speculative-draft-model-path lmsys/DeepSeek-V3-NextN --tp-size 8 --trust-remote-code --batch-size 1 4 8 16 32 --steps 0 1 2 --topk 0 1 2 4 --num_draft_tokens 0 2 4 8
"""
import argparse
import asyncio
import json
import os
import time
from types import SimpleNamespace
import numpy as np
import requests
from transformers import AutoTokenizer
from sglang.bench_serving import (
DatasetRow,
benchmark,
sample_mmmu_requests,
set_global_args,
)
from sglang.srt.server_args import ServerArgs
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
kill_process_tree,
popen_launch_server,
)
def node0_print(msg):
if server_args.node_rank == 0:
print(msg)
prompts = [
"Human: Give me a fully functional FastAPI server. Show the full, long python code without stop.\n\nAssistant:",
"Human: Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation.\n\nAssistant:",
"Human: Write a travel blog post to Hawaii.\n\nAssistant:",
"Human: I want you to act as an English translator, spelling corrector and improver. I will speak to you in any language and you will detect the language, translate it and answer in the corrected and improved version of my text, in English. I want you to replace my simplified A0-level words and sentences with more beautiful and elegant, upper level English words and sentences. Keep the meaning same, but make them more literary. My first sentence is 'istanbulu cok seviyom burada olmak cok guzel'. Answer in more than 5000 words.\n\nAssistant:",
"Human: I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if its children then you can talk about animals; If its adults then history-based tales might engage them better etc. Answer in more than 5000 words. My first request is 'I need an interesting story on perseverance.'\n\nAssistant:",
"Human: Solve x^2 = -1. Think step-by-step. Give me a long detailed explanation. \n\nAssistant:",
"Human: Tell me about the president of the USA in wikipedia style.\n\nAssistant:",
"Human: Hello? Who are you? Write code, math, and poem to explanin yourself.\n\nAssistant:",
]
class FakeTokenizer:
def encode(self, text: str, add_special_tokens: bool = False):
return []
def send_one_batch(base_url, num_prompts, batch_size, tokenizer, is_multimodal):
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
if is_multimodal:
input_requests = sample_mmmu_requests(
num_prompts,
tokenizer,
512,
apply_chat_template=False,
)
backend = "sglang-oai-chat"
api_url = f"{base_url}/v1/chat/completions"
else:
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
:num_prompts
]
input_requests: List[DatasetRow] = [
DatasetRow(p, 0, 512) for p in padded_prompts
]
backend = "sglang"
api_url = f"{base_url}/generate"
# We need to set some dummy values in order to call `benchmark` below.
args = SimpleNamespace(
disable_ignore_eos=False,
disable_stream=False,
return_logprob=False,
backend=backend,
dataset_name="custom",
num_prompts=None,
sharegpt_output_len=None,
random_input_len=None,
random_output_len=None,
random_range_ratio=None,
output_file=None,
warmup_requests=1,
output_details=False,
)
set_global_args(args)
# Run benchmark
results = asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id="default",
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=float("inf"),
max_concurrency=batch_size,
disable_tqdm=False,
lora_names=None,
extra_request_body={},
profile=None,
)
)
assert results["completed"] == len(input_requests)
acc_length = results["accept_length"] or 1.0
avg_output_token = results["total_output_tokens"] / results["completed"]
server_info = requests.get(base_url + "/get_server_info").json()
# We use 20% percentile instead of median on purpose
step_time = np.percentile(
server_info["internal_states"][0]["step_time_dict"][str(batch_size)], 20
)
speed = 1 / step_time * acc_length
return (
round(acc_length, 3),
round(step_time, 5),
round(speed, 3),
avg_output_token,
)
def main(args, server_args):
base_url = "http://127.0.0.1:20000"
configs = []
for batch_size in args.batch_size:
for steps in args.steps:
for topk in args.topk:
for num_draft_tokens in args.num_draft_tokens:
if steps * topk + 1 < num_draft_tokens:
continue
if (steps == 0 or topk == 0 or num_draft_tokens == 0) and (
steps + topk + num_draft_tokens != 0
):
# steps == 0 and topk == 0 and num_draft_tokens == 0 is a special case for non-speculative decoding.
continue
configs.append((batch_size, steps, topk, num_draft_tokens))
for i in range(args.start, args.end or len(configs)):
batch_size, steps, topk, num_draft_tokens = configs[i]
node0_print(
f"Start {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}"
)
# Create an LLM.
if steps == 0:
other_args = []
else:
other_args = [
"--speculative-num-steps",
steps,
"--speculative-eagle-topk",
topk,
"--speculative-num-draft-tokens",
num_draft_tokens,
]
if server_args.speculative_draft_model_path is not None:
other_args.extend(
[
"--speculative-draft-model-path",
server_args.speculative_draft_model_path,
"--speculative-algorithm",
server_args.speculative_algorithm,
]
)
other_args.extend(
[
"--cuda-graph-max-bs",
batch_size,
"--mem-fraction-static",
server_args.mem_fraction_static,
"--tp-size",
server_args.tp_size,
"--max-running-requests",
batch_size,
]
)
if server_args.trust_remote_code:
other_args.extend(
[
"--trust-remote-code",
]
)
if server_args.attention_backend:
other_args.extend(
[
"--attention-backend",
server_args.attention_backend,
]
)
if server_args.quantization:
other_args.extend(
[
"--quantization",
server_args.quantization,
]
)
process = popen_launch_server(
args.model_path,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
env={
"SGLANG_RECORD_STEP_TIME": "1",
**os.environ,
},
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=server_args.trust_remote_code
)
try:
# Warmup
send_one_batch(
base_url, batch_size, batch_size, tokenizer, args.is_multimodal
)
# Benchmark
acc_length, step_time, speed, completion_tokens = send_one_batch(
base_url,
max(args.num_prompts, batch_size),
batch_size,
tokenizer,
args.is_multimodal,
)
finally:
kill_process_tree(process.pid)
node0_print(
f"Finish {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}, {speed=:.2f} token/s, step_time={step_time * 1000:.2f} ms"
)
record = {
"batch_size": batch_size,
"steps": steps,
"topk": topk,
"num_draft_tokens": num_draft_tokens,
"acc_length": acc_length,
"step_time": step_time,
"speed": speed,
"completion_tokens": completion_tokens,
}
with open(args.output, "a") as fout:
fout.write(json.dumps(record) + "\n")
# Wait for the server to shutdown
time.sleep(5)
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
parser.add_argument(
"--batch-size",
type=int,
nargs="+",
default=(1, 2, 4, 8, 16),
)
parser.add_argument(
"--steps",
type=int,
nargs="+",
default=(0, 1, 3, 5, 7), # use (0, 1, 2, 3, 4) for large batch size
)
parser.add_argument(
"--topk",
type=int,
nargs="+",
default=(0, 1, 2, 4, 8),
)
parser.add_argument(
"--num_draft_tokens",
type=int,
nargs="+",
default=(0, 2, 4, 8, 16, 32), # use (0, 2, 4, 8) for large batch size
)
parser.add_argument("--num-prompts", type=int, default=16)
parser.add_argument("--start", type=int, default=0)
parser.add_argument("--end", type=int)
parser.add_argument("--output", type=str, default="output.jsonl")
parser.add_argument("--is-multimodal", action="store_true", default=False)
args = parser.parse_args()
server_args: ServerArgs = ServerArgs.from_cli_args(args)
main(args, server_args)

View File

@@ -0,0 +1,22 @@
prompt = "The capital of france is "
import json
import requests
response = requests.post(
"http://0.0.0.0:8000/generate",
json={
"text": prompt,
"sampling_params": {"temperature": 0},
"return_logprob": True,
"return_input_logprob": True,
"logprob_start_len": 0,
},
)
j = response.json()
input_logprobs = j["meta_info"]["input_token_logprobs"]
output_logprobs = j["meta_info"]["output_token_logprobs"]
print(len(input_logprobs), len(output_logprobs))

View File

@@ -0,0 +1,34 @@
import json
import requests
port = 8000
json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
# JSON
response = requests.post(
f"http://localhost:{port}/generate",
json={
"text": "Here is the information of the capital of France in the JSON format.\n",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 64,
"json_schema": json_schema,
},
},
)
print(response.json())
# python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --trust-remote-code --disaggregation-mode prefill --tp 2 --disaggregation-ib-device mlx5_roce0,mlx5_roce1 --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --host 127.0.0.1 --port 8100

View File

@@ -0,0 +1,29 @@
import json
import requests
prompt = """
According to CNBC's Faber, the investors present on the call interpreted this statement as an indication of an upcoming funding round. While speculative, Faber believes the funding round could be as large as $25 billion, and bestow a valuation of between $150 billion and $200 billion on xAI.
For the benefit of those who might not be aware, xAI recently acquired the social media platform X in an all-stock deal that valued the former at $80 billion and the latter at $33 billion, inclusive of $12 billion in liabilities. This meant that the deal bestowed a gross valuation of $45 billion on X before factoring in its debt load of $12 billion.
Bear in mind that Elon Musk took X (then called Twitter) private back in 2022 in a $44 billion deal. Since then, Musk has managed to stem X's cash bleed, with the company reportedly generating $1.2 billion in adjusted EBITDA in 2024.
According to the investors present on the call, xAI is currently generating around $1 billion in annual revenue. This contrasts sharply with the erstwhile muted expectations of many investors, who did not expect the startup to generate any material revenue this year.
Elsewhere, Faber also alludes to the fact that xAI is already working on its next big training supercluster, officially dubbed the Colossus 2, which is expected to eventually house as many as 1 million NVIDIA GPUs at a cost of between $35 billion and $40 billion.
Even though xAI's Grok LLM is already largely comparable with OpenAI's cutting-edge models, the Colossus 2 would significantly up the ante, and could feasibly challenge OpenAI's apex position in the AI sphere.
Give your honest take on the above text:
"""
response = requests.post(
"http://0.0.0.0:8000/generate",
json={"text": prompt, "sampling_params": {"temperature": 0}},
)
response_json = response.json()
print(response_json["text"])

View File

@@ -0,0 +1,241 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Launch A Server\n",
"\n",
"Launch the server with a reasoning model (Qwen 3.5-4B) and reasoning parser."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sglang import separate_reasoning, assistant_begin, assistant_end\n",
"from sglang import assistant, function, gen, system, user\n",
"from sglang import image\n",
"from sglang import RuntimeEndpoint, set_default_backend\n",
"from sglang.srt.utils import load_image\n",
"from sglang.test.test_utils import is_in_ci\n",
"from sglang.utils import print_highlight, terminate_process, wait_for_server\n",
"\n",
"\n",
"if is_in_ci():\n",
" from patch import launch_server_cmd\n",
"else:\n",
" from sglang.utils import launch_server_cmd\n",
"\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen3-4B --reasoning-parser qwen3 --host 0.0.0.0\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")\n",
"print(f\"Server started on http://localhost:{port}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set the default backend. Note: you can set chat_template_name in RontimeEndpoint. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"set_default_backend(\n",
" RuntimeEndpoint(f\"http://localhost:{port}\", chat_template_name=\"qwen\")\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start with a basic question-answering task. And see how the reasoning content is generated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def basic_qa(s, question):\n",
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
" s += user(question)\n",
" s += assistant_begin()\n",
" s += gen(\"answer\", max_tokens=512)\n",
" s += assistant_end()\n",
"\n",
"\n",
"state = basic_qa(\"List 3 countries and their capitals.\")\n",
"print_highlight(state[\"answer\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With `separate_reasoning`, you can move the reasoning content to `{param_name}_reasoning_content` in the state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def basic_qa_separate_reasoning(s, question):\n",
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
" s += user(question)\n",
" s += assistant_begin()\n",
" s += separate_reasoning(gen(\"answer\", max_tokens=512), model_type=\"qwen3\")\n",
" s += assistant_end()\n",
"\n",
"\n",
"reasoning_state = basic_qa_separate_reasoning(\"List 3 countries and their capitals.\")\n",
"print_highlight(reasoning_state.stream_executor.variable_event.keys())\n",
"print_highlight(\n",
" f\"\\nSeparated Reasoning Content:\\n{reasoning_state['answer_reasoning_content']}\"\n",
")\n",
"\n",
"print_highlight(f\"\\n\\nContent:\\n{reasoning_state['answer']}\")\n",
"print_highlight(f\"\\n\\nMessages:\\n{reasoning_state.messages()[-1]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`separate_reasoning` can also be used in multi-turn conversations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def multi_turn_qa(s):\n",
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
" s += user(\"Please give me a list of 3 countries and their capitals.\")\n",
" s += assistant(\n",
" separate_reasoning(gen(\"first_answer\", max_tokens=512), model_type=\"qwen3\")\n",
" )\n",
" s += user(\"Please give me another list of 3 countries and their capitals.\")\n",
" s += assistant(\n",
" separate_reasoning(gen(\"second_answer\", max_tokens=512), model_type=\"qwen3\")\n",
" )\n",
" return s\n",
"\n",
"\n",
"reasoning_state = multi_turn_qa()\n",
"print_highlight(f\"\\n\\nfirst_answer:\\n{reasoning_state['first_answer']}\")\n",
"print_highlight(\n",
" f\"\\n\\nfirst_answer_reasoning_content:\\n{reasoning_state['first_answer_reasoning_content']}\"\n",
")\n",
"print_highlight(f\"\\n\\nsecond_answer:\\n{reasoning_state['second_answer']}\")\n",
"print_highlight(\n",
" f\"\\n\\nsecond_answer_reasoning_content:\\n{reasoning_state['second_answer_reasoning_content']}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using No thinking as Qwen 3's advanced feature \n",
"\n",
"sglang separate_reasoning is particularly useful when combined with Qwen 3's advanced feature.\n",
"\n",
"[Qwen 3's advanced usages](https://qwenlm.github.io/blog/qwen3/#advanced-usages)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"reasoning_state = basic_qa_separate_reasoning(\n",
" \"List 3 countries and their capitals. /no_think\"\n",
")\n",
"print_highlight(f\"Reasoning Content:\\n{reasoning_state['answer_reasoning_content']}\")\n",
"print_highlight(f\"Content:\\n{reasoning_state['answer']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`separate_reasoning` can also be used in regular expression generation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def regular_expression_gen(s):\n",
" s += user(\n",
" \"What is the IP address of the Google DNS servers? just provide the answer\"\n",
" )\n",
" s += assistant(\n",
" separate_reasoning(\n",
" gen(\n",
" \"answer\",\n",
" temperature=0,\n",
" regex=r\"((25[0-5]|2[0-4]\\d|[01]?\\d\\d?).){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\",\n",
" max_tokens=512,\n",
" ),\n",
" model_type=\"qwen3\",\n",
" ),\n",
" )\n",
"\n",
"\n",
"reasoning_state = regular_expression_gen()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_highlight(f\"Answer:\\n{reasoning_state['answer']}\")\n",
"print_highlight(\n",
" f\"\\n\\nReasoning Content:\\n{reasoning_state['answer_reasoning_content']}\"\n",
")"
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,7 @@
# Assuming the model is downdloaded at /home/ubuntu/model_weights/Llama-2-7b-chat-hf
docker run --name tgi --rm -ti --gpus all --network host \
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
ghcr.io/huggingface/text-generation-inference:1.1.0 \
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
--max-input-length 2048 --max-total-tokens 4096 \
--port 24000

View File

@@ -0,0 +1,14 @@
import argparse
import code
from sglang.srt.hf_transformers_utils import get_tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct"
)
args = parser.parse_args()
t = get_tokenizer(args.name)
code.interact(local=locals())

View File

@@ -0,0 +1,36 @@
from urllib.request import urlopen
from openai import OpenAI
test_cases = {
"64k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt",
"200k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt",
"600k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt",
"1m": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt",
}
client = OpenAI(api_key="EMPTY", base_url="http://127.0.0.1:30000/v1")
for name, url in test_cases.items():
print(f"\n==== Running test case: {name} ====")
try:
with urlopen(url, timeout=10) as response:
prompt = response.read().decode("utf-8")
except Exception as e:
print(f"Failed to load prompt for {name}: {e}")
continue
try:
response = client.chat.completions.create(
model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
messages=[{"role": "user", "content": prompt}],
stream=True,
max_tokens=128,
temperature=0,
)
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content is not None:
print(chunk.choices[0].delta.content, end="", flush=True)
except Exception as e:
print(f"\nError during completion for {name}: {e}")

View File

@@ -0,0 +1,77 @@
import glob
import json
import os
import re
import sys
from tqdm import tqdm
sys.path.append("../../")
from fix_corrupted_json import clean_json_file
dirpath = "/Users/ying"
output_file_prefix = "analyzed_log"
time = {}
tot_time = {}
size = {}
os.system(f"rm {output_file_prefix}*")
for dirname in glob.glob(os.path.join(dirpath, "trace*")):
print(dirname)
trace_name = dirname.split("/")[-1]
time[trace_name] = {}
size[trace_name] = {}
total_time = 0
for filename in tqdm(glob.glob(os.path.join(dirname, "*.json"))):
step_name = filename.split("/")[-1].split(".")[0]
step_name = "_".join(step_name.split("_")[1:])
if "prefill" not in filename and "decode" not in filename:
continue
match = re.search(r"(prefill|decode)_step_(\d+)\.json", filename)
if match:
phase = match.group(1)
step = match.group(2)
else:
raise Exception(f"Cannot parse {filename}")
try:
with open(filename, "r") as f:
trace = json.load(f)
except:
clean_json_file(filename, filename)
with open(filename, "r") as f:
trace = json.load(f)
for event in trace["traceEvents"]:
name = event["name"]
if name in ["profile_prefill_step", "profile_decode_step"]:
dur = event["dur"] / 1e3
time[trace_name][step_name] = dur
break
total_time += dur
step = int(step_name.split("_")[-1])
with open(os.path.join(dirname, f"size_{step}.json"), "r") as f:
size_info = json.load(f)
size[trace_name][step_name] = size_info["size"]
tot_time[trace_name] = total_time
time[trace_name] = dict(
sorted(time[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
)
size[trace_name] = dict(
sorted(size[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
)
with open(f"{output_file_prefix}_{trace_name}", "a") as f:
for k, v in time[trace_name].items():
size_v = size[trace_name][k]
print(f"{k:>15}{v:10.2f}\t{size_v}")
f.write(f"{k:>15}{v:10.2f}\t{size_v}\n")
with open(f"{output_file_prefix}_total_time", "w") as f:
print(tot_time)
json.dump(tot_time, f)

View File

@@ -0,0 +1,62 @@
import torch
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
# ADAPTER = "winddude/wizardLM-LlaMA-LoRA-7B"
ADAPTER = "/home/ying/test_lora"
HF_TOKEN = "..."
prompt = """
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:
The Transformers are large language models,
They're used to make predictions on text.
"""
tokenizer = LlamaTokenizer.from_pretrained(MODEL)
base_model = LlamaForCausalLM.from_pretrained(
MODEL,
device_map="auto",
# load_in_8bit=True,
torch_dtype=torch.float16,
# use_auth_token=HF_TOKEN,
).cuda()
# base model generate
with torch.no_grad():
output_tensors = base_model.generate(
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=32,
do_sample=False,
)[0]
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
print("======= base output ========")
print(output)
# peft model generate
model = PeftModel.from_pretrained(
base_model,
ADAPTER,
torch_dtype=torch.float16,
is_trainable=False,
)
with torch.no_grad():
output_tensors = model.generate(
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=32,
do_sample=False,
)[0]
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
print("======= peft output ========")
print(output)

View File

@@ -0,0 +1,30 @@
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
ADAPTER = "/home/ying/test_lora"
prompt = """
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:
The Transformers are large language models,
They're used to make predictions on text.
"""
llm = LLM(model=MODEL, enable_lora=True)
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
)
prompts = [prompt]
outputs = llm.generate(
prompts, sampling_params, lora_request=LoRARequest("test_lora", 1, ADAPTER)
)
print(outputs[0].prompt)
print(outputs[0].outputs[0].text)

View File

@@ -0,0 +1,197 @@
"""
Usage: python3 scripts/playground/reference_hf.py --model-path MODEL_PATH --model-type {text,vlm} [--max-new-tokens NUM] [--dtype DTYPE]
--model-path MODEL_PATH: Path to model (default: TinyLlama/TinyLlama-1.1B-Chat-v0.4)
--model-type {text,vlm}: Model type, text or vlm (default: text)
--max-new-tokens NUM: Max new tokens to generate (default: 16)
--dtype DTYPE: Data type for computation (default: float16)
Note: '--model' is deprecated; use '--model-path'. Runs normal_text() for text, vlm_text_with_image() for vlm.
Reference output:
========== Prompt 0 ==========
prefill logits (final) tensor([-8.3125, -7.1172, 3.3398, ..., -4.9531, -4.1328, -3.4141],
device='cuda:0')
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.
========== Prompt 1 ==========
prefill logits (final) tensor([-8.9062, -9.0156, 4.1484, ..., -4.9922, -4.4961, -4.0742],
device='cuda:0')
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
The capital of
========== Prompt 2 ==========
prefill logits (final) tensor([-9.6328, -9.0547, 4.0234, ..., -5.3047, -4.7148, -4.4609],
device='cuda:0')
<s> Today is a sunny day and I like to go for a walk in the park.
I'm going to the
"""
import argparse
import requests
import torch
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoProcessor,
)
from sglang.srt.hf_transformers_utils import get_tokenizer
@torch.no_grad()
def vlm_text_with_image(args):
# Load the processor and model for ImageTextToText tasks
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
args.model_path,
torch_dtype=args.dtype,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True,
)
torch.cuda.set_device(0)
# List of image URLs to process
image_urls = [
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
]
# Conversation template for the processor
conversation = [
{
"role": "user",
"content": [
{
"type": "image",
},
{"type": "text", "text": "Describe this image."},
],
}
]
max_new_tokens = args.max_new_tokens
for i, url in enumerate(image_urls):
# Load the image from the URL
image = Image.open(requests.get(url, stream=True).raw)
# Apply the chat template to the text prompt
# Notice that not all processors support chat templates.
# LLaVA and QWen are two processors that support chat templates.
if not hasattr(processor, "apply_chat_template"):
raise ValueError("The processor does not support chat templates.")
text_prompt = processor.apply_chat_template(
conversation, add_generation_prompt=True
)
# Prepare inputs for the model
inputs = processor(text=[text_prompt], images=[image], return_tensors="pt").to(
"cuda:0"
)
# Generate output from the model
output_ids = model.generate(
**inputs, do_sample=False, max_new_tokens=max_new_tokens
)
output_str = processor.decode(output_ids[0])
# Get the logits from the model's forward pass
outputs = model.forward(**inputs)
logits = outputs.logits[0, -1, :]
print(f"\n========== Image {i} ==========")
print("prefill logits (final)", logits)
# TODO(gaocegege): The output contains numerous <|image_pad|> tokens,
# making it cluttered and difficult to read.
# These tokens should be removed or cleaned up for better readability.
print(output_str)
@torch.no_grad()
def normal_text(args):
t = get_tokenizer(args.model_path, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=args.dtype,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True,
)
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
max_new_tokens = args.max_new_tokens
torch.cuda.set_device(0)
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = t.encode(p, return_tensors="pt").to("cuda:0")
else:
input_ids = torch.tensor([p], device="cuda:0")
output_ids = m.generate(
input_ids, do_sample=False, max_new_tokens=max_new_tokens
)
output_str = t.decode(output_ids[0])
prefill_logits = m.forward(input_ids).logits[0][-1]
print(f"\n========== Prompt {i} ==========")
print("prefill logits (final)", prefill_logits)
print(output_str)
@torch.no_grad()
def synthetic_tokens(args):
m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
m.cuda()
print(m)
input_len = 256
output_len = 8
prompts = [list(range(5, 5 + input_len))]
for p in prompts:
input_ids = p
for i in range(output_len + 1):
prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
0
][-1]
if i == 0:
print("prefill logits", prefill_logits)
else:
print("decode", i - 1, prefill_logits)
input_ids.append(torch.argmax(prefill_logits).item())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path",
type=str,
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
)
parser.add_argument("--max-new-tokens", type=int, default=16)
parser.add_argument("--dtype", type=str, default="float16")
parser.add_argument("--model-type", type=str, default="text")
args = parser.parse_args()
if args.model_type == "vlm":
vlm_text_with_image(args)
else:
normal_text(args)

View File

@@ -0,0 +1,151 @@
"""
Usage:
# replay from a folder
python3 replay_request_dump.py --file-number 100 --parallel 512 --input-folder /data/lianmin/sglang_request_dump/grok-mini-0220-engine-5756f8f94-28bm6/
# replay from a single file
python3 replay_request_dump.py --parallel 512 --input-file /data/sglang_crash_dump/memx-cti-34-sr1.xpop.twttr.net/crash_dump_2025-06-04_20-13-18.pkl
"""
import argparse
import glob
import json
import pickle
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from datetime import datetime
import requests
from sglang.bench_serving import set_ulimit
from sglang.utils import get_exception_traceback
def read_records(files):
records = []
for f in files:
tmp = pickle.load(open(f, "rb"))
if isinstance(tmp, dict) and "requests" in tmp:
records.extend(tmp["requests"])
else:
records.extend(tmp)
return records
def run_one_request_internal(record):
(req, output, replay_init_time, start_time, end_time, idx) = record
time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
if "completion_tokens" in output.get("meta_info", {}):
recorded_completion_tokens = output["meta_info"]["completion_tokens"]
else:
recorded_completion_tokens = ""
json_data = asdict(req)
stream = json_data["stream"]
if args.ignore_eos:
json_data["sampling_params"]["ignore_eos"] = True
if recorded_completion_tokens:
json_data["sampling_params"]["max_new_tokens"] = recorded_completion_tokens
response = requests.post(
f"http://{args.host}:{args.port}/generate",
json=json_data,
stream=stream,
)
if stream:
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
ret = json.loads(chunk[5:].strip("\n"))
else:
ret = response.json()
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
print(
f"{idx=}, {start_time=:.2f}, {prompt_tokens=}, "
f"{completion_tokens=}, {recorded_completion_tokens=}"
)
def run_one_request(record):
# global success_ct, error_ct
try:
run_one_request_internal(record)
# success_ct += 1
except Exception:
# error_ct += 1
traceback = get_exception_traceback()
print(f"Hit an exception: {traceback}")
def main(records):
if len(records) == 0:
return
base_time = records[0][-2]
base_time_str = datetime.fromtimestamp(base_time).strftime("%y-%m-%d %H:%M:%S")
print(f"{base_time_str=}")
replay_init_time = time.time()
for i in range(len(records)):
req, output, start_time, end_time = records[i]
start_time -= base_time
records[i] = (req, output, replay_init_time, start_time, end_time, i)
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(run_one_request, records)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument(
"--input-folder", type=str, default=None, help="Folder containing pickle files"
)
parser.add_argument(
"--input-file", type=str, default=None, help="Single pickle file to process"
)
parser.add_argument("--file-number", type=int, default=1)
parser.add_argument("--req-number", type=int, default=1000000)
parser.add_argument("--req-start", type=int, default=0)
parser.add_argument("--parallel", type=int, default=512)
parser.add_argument("--idx", type=int, default=None)
parser.add_argument("--ignore-eos", action="store_true")
parser.add_argument("--speed", type=float, default=1)
args = parser.parse_args()
set_ulimit()
files = []
if args.input_file:
files = [args.input_file]
if args.file_number > 1:
print("Warning: --file-number is ignored when --input-file is provided.")
elif args.input_folder:
files = glob.glob(f"{args.input_folder}/*.pkl")
files = files[: args.file_number]
else:
print("Error: Either --input-folder or --input-file must be provided.")
exit(1)
print(f"{files=}")
records = read_records(files)
# Sort by the receive time, before filtering
records.sort(key=lambda x: x[-2])
records = records[args.req_start :]
if args.idx:
records = [records[args.idx]]
print(f"testing {args.idx=}")
print(f"{records[0]}")
print(f"{len(records)=}")
main(records)

View File

@@ -0,0 +1,207 @@
import random
import string
import time
import unittest
from typing import Dict, List, Tuple
from tree import MultiTenantRadixTree
class TestMultiTenantRadixTree(unittest.TestCase):
def setUp(self):
self.tree = MultiTenantRadixTree()
def test_insert_exact_match(self):
"""Test 1: Basic insert and exact match operations"""
# Insert a single string for one tenant
self.tree.insert("hello", "tenant1")
matched, tenant = self.tree.prefix_match("hello")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")
# Insert same string for different tenant
self.tree.insert("hello", "tenant2")
matched, tenant = self.tree.prefix_match("hello")
self.assertIn(tenant, ["tenant1", "tenant2"])
# Insert different string for same tenant
self.tree.insert("world", "tenant1")
matched, tenant = self.tree.prefix_match("world")
self.assertEqual(matched, "world")
self.assertEqual(tenant, "tenant1")
print(self.tree.pretty_print())
def test_insert_partial_match(self):
"""Test 2: Insert with partial matching scenarios"""
# Test partial matches with common prefixes
self.tree.insert("hello", "tenant1")
print(self.tree.pretty_print())
self.tree.insert("help", "tenant2")
print(self.tree.pretty_print())
# Match exact strings
matched, tenant = self.tree.prefix_match("hello")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")
matched, tenant = self.tree.prefix_match("help")
self.assertEqual(matched, "help")
self.assertEqual(tenant, "tenant2")
# Match partial string
matched, tenant = self.tree.prefix_match("hel")
self.assertEqual(matched, "hel")
self.assertIn(tenant, ["tenant1", "tenant2"])
# Match longer string
matched, tenant = self.tree.prefix_match("hello_world")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")
def test_insert_edge_cases(self):
"""Test 3: Edge cases for insert and match operations"""
# Empty string
self.tree.insert("", "tenant1")
matched, tenant = self.tree.prefix_match("")
self.assertEqual(matched, "")
self.assertEqual(tenant, "tenant1")
# Single character
self.tree.insert("a", "tenant1")
matched, tenant = self.tree.prefix_match("a")
self.assertEqual(matched, "a")
self.assertEqual(tenant, "tenant1")
# Very long string
long_str = "a" * 1000
self.tree.insert(long_str, "tenant1")
matched, tenant = self.tree.prefix_match(long_str)
self.assertEqual(matched, long_str)
self.assertEqual(tenant, "tenant1")
# Unicode characters
self.tree.insert("你好", "tenant1")
matched, tenant = self.tree.prefix_match("你好")
self.assertEqual(matched, "你好")
self.assertEqual(tenant, "tenant1")
def test_simple_eviction(self):
"""Test 4: Simple eviction scenarios
Tenant1: limit 10 chars
Tenant2: limit 5 chars
Should demonstrate:
1. Basic eviction when size limit exceeded
2. Proper eviction based on last access time
3. Verification that shared nodes remain intact for other tenants
"""
# Set up size limits
max_size = {"tenant1": 10, "tenant2": 5}
# Insert strings for both tenants
self.tree.insert("hello", "tenant1") # size 5
self.tree.insert("hello", "tenant2") # size 5
self.tree.insert("world", "tenant2") # size 5, total for tenant2 = 10
# Verify initial sizes
sizes_before = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_before["tenant1"], 5) # "hello" = 5
self.assertEqual(sizes_before["tenant2"], 10) # "hello" + "world" = 10
# Evict - should remove "hello" from tenant2 as it's the oldest
self.tree.evict_tenant_data(max_size)
# Verify sizes after eviction
sizes_after = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_after["tenant1"], 5) # Should be unchanged
self.assertEqual(sizes_after["tenant2"], 5) # Only "world" remains
# Verify "world" remains for tenant2 (was accessed more recently)
matched, tenant = self.tree.prefix_match("world")
self.assertEqual(matched, "world")
self.assertEqual(tenant, "tenant2")
def test_medium_eviction(self):
"""Test 5: Medium complexity eviction scenarios with shared prefixes
Tenant1: limit 10 chars
Tenant2: limit 7 chars (forces one string to be evicted)
Tree structure after inserts:
└── 'h' [t1, t2]
├── 'i' [t1, t2] # Oldest for t2
└── 'e' [t1, t2]
├── 'llo' [t1, t2]
└── 'y' [t2] # Newest for t2
Size calculations:
tenant1: "h"(1) + "i"(1) + "e"(1) + "llo"(3) = 6 chars
tenant2: "h"(1) + "i"(1) + "e"(1) + "llo"(3) + "y"(1) = 7 chars
After eviction (tenant2 exceeds limit by 1 char):
"hi" should be removed from tenant2 as it's the oldest access
"""
max_size = {
"tenant1": 10,
"tenant2": 6,
} # tenant2 will need to evict one string
# Create a tree with overlapping prefixes
self.tree.insert("hi", "tenant1")
self.tree.insert("hi", "tenant2") # OLDEST for t2
self.tree.insert("hello", "tenant1")
self.tree.insert("hello", "tenant2")
self.tree.insert("hey", "tenant2") # NEWEST for t2
# Verify initial sizes
sizes_before = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_before["tenant1"], 6) # h(1) + i(1) + e(1) + llo(3) = 6
self.assertEqual(
sizes_before["tenant2"], 7
) # h(1) + i(1) + e(1) + llo(3) + y(1) = 7
print("\nTree before eviction:")
print(self.tree.pretty_print())
# Evict - should remove "hi" from tenant2 as it's the oldest
self.tree.evict_tenant_data(max_size)
print("\nTree after eviction:")
print(self.tree.pretty_print())
# Verify sizes after eviction
sizes_after = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_after["tenant1"], 6) # Should be unchanged
self.assertEqual(sizes_after["tenant2"], 6) # h(1) + e(1) + llo(3) + y(1) = 6
def test_advanced_eviction(self):
...
# Create 4 tenants
# Each tenants keeps adding strings with shared prefixes to thousands usage
# Set a strict limit for each tenant to only 100
# At the end, check whether all of the tenant is under 100 after eviction
max_size = {"tenant1": 100, "tenant2": 100, "tenant3": 100, "tenant4": 100}
prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"]
for i in range(100):
for j, prefix in enumerate(prefixes):
random_suffix = "".join(random.choices(string.ascii_letters, k=10))
self.tree.insert(prefix + random_suffix, f"tenant{j+1}")
sizes_before = self.tree.get_used_size_per_tenant()
print(sizes_before)
self.tree.evict_tenant_data(max_size)
sizes_after = self.tree.get_used_size_per_tenant()
print(sizes_after)
# ensure size_after is below max_size
for tenant, size in sizes_after.items():
self.assertLessEqual(size, max_size[tenant])
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,292 @@
import time
from collections import defaultdict
from typing import Dict, List
class Node:
def __init__(self):
self.children: Dict[str, Node] = dict()
# We choose to use text because most of the use cases are text-to-text,
# so we can save the tokenizing overhead.
self.text: str = ""
# Maps tenant_id to their last access timestamp
self.tenant_last_access_time: Dict[str, float] = dict()
self.parent = None
def shared_prefix_length(s1, s2):
min_length = min(len(s1), len(s2))
for i in range(min_length):
if s1[i] != s2[i]:
return i
return min_length
class MultiTenantRadixTree:
"""
Python Reference of Rust implementation of MultiTenantRadixTree
MultiTenantRadixTree is the overlap of multiple radix trees by different tenant
Each node in the tree can be owned by multiple tenants, allowing for efficient storage of common prefixes
while maintaining tenant isolation.
Key concepts:
- Tenant: An entity that owns a subset of the stored strings
- Each node tracks which tenants have access to it via tenant_last_access_time
- The tree structure is shared, but queries can be filtered by tenant_id
"""
def __init__(self):
self.root = Node()
def insert(self, s: str, tenant_id: str) -> None:
"""
Insert string 's' and associate it with the given tenant_id.
Args:
s: The string to insert
tenant_id: The identifier of the tenant who owns this string
"""
curr = self.root
curr_idx = 0
curr.tenant_last_access_time[tenant_id] = time.time()
while curr_idx < len(s):
matched_node = None
if s[curr_idx] in curr.children:
matched_node = curr.children[s[curr_idx]]
if matched_node is None:
# No match => create a new node
new_node = Node()
new_node.text = s[curr_idx:]
new_node.parent = curr
curr.children[s[curr_idx]] = new_node
curr_idx = len(s)
curr = new_node
curr.tenant_last_access_time[tenant_id] = time.time()
else:
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
# 1. If the matched text is shorter than the node text => split the node
if shared_len < len(matched_node.text):
# Split structure: [matched_node] => [new_node] -> [contracted_matched_node]
matched_text = matched_node.text[:shared_len]
unmatched_text = matched_node.text[shared_len:]
new_node = Node()
new_node.text = matched_text
new_node.children = {unmatched_text[0]: matched_node}
new_node.parent = curr
new_node.parent.children[matched_text[0]] = new_node
new_node.tenant_last_access_time = (
matched_node.tenant_last_access_time.copy()
)
# Contract matched node
matched_node.text = unmatched_text
matched_node.parent = new_node
curr_idx += shared_len
curr = new_node
curr.tenant_last_access_time[tenant_id] = time.time()
# 2. If the matched text is longer or equal to the node text => walk down the node
else:
curr_idx += shared_len
curr = matched_node
curr.tenant_last_access_time[tenant_id] = time.time()
def prefix_match(self, s: str) -> tuple[str, int]:
"""
Match string 's' with multiple tenants' trees in one operation.
Args:
s: The string to match
Returns:
Tuple(str, int): The longest prefix of 's' that matches the tree and the first tenant_id that own the matched prefix
"""
curr = self.root
curr_idx = 0
ret_text = ""
ret_tenant = None
while curr_idx < len(s):
matched_node = None
if s[curr_idx] in curr.children:
matched_node = curr.children[s[curr_idx]]
if matched_node is None:
break
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
if shared_len == len(matched_node.text):
curr_idx += shared_len
curr = matched_node
else:
curr_idx += shared_len
curr = matched_node
break
selected_tenant = list(curr.tenant_last_access_time.keys())[0]
# traverse back to the root to update last access time for the selected tenant
while curr != self.root:
curr.tenant_last_access_time[selected_tenant] = time.time()
curr = curr.parent
return s[:curr_idx], selected_tenant
def evict_tenant_data(self, max_size_per_tenant: Dict[str, int]) -> None:
"""
Evict data for tenants that have exceeded their storage limits.
Args:
max_size_per_tenant: Dictionary mapping tenant_id to their maximum allowed storage size
"""
def leaf_of(node):
"""
If the node is a leaf for a tenant, add tenant_id to the return list
This will return list of tenant ids
If not a leaf for all tenants, return []
"""
candidates = dict([(k, True) for k in node.tenant_last_access_time.keys()])
for n in node.children.values():
for c in n.tenant_last_access_time.keys():
candidates[c] = False
return [k for k, v in candidates.items() if v]
# maintain a heap with (time, tenant, node) as the value
import heapq
# 1. traverse the tree to
# a. add all the leaves into a heap (a node with N tenants will be added N times into the heap)
# b. calculate the used size for each tenant
# do a dfs with stack
stack = [self.root]
pq = []
used_size_per_tenant = defaultdict(int)
while stack:
curr = stack.pop()
for t in curr.tenant_last_access_time.keys():
used_size_per_tenant[t] += len(curr.text)
for c in curr.children.values():
stack.append(c)
# if the node is a leaf for a tenant, add the tenant to the heap
tenants = leaf_of(curr)
for t in tenants:
heapq.heappush(pq, (curr.tenant_last_access_time[t], t, curr))
# 2. pop the heap
# a. if the tenant's used size is less than the limit, continue
# b. if the tenant's used size is greater than the limit, remove the leaf and update the used size, and add its parent to the heap
while len(pq) > 0:
time, tenant, node = heapq.heappop(pq)
if used_size_per_tenant[tenant] <= max_size_per_tenant[tenant]:
continue
# remove the leaf
used_size_per_tenant[tenant] -= len(node.text)
del node.tenant_last_access_time[tenant]
# if no children and no tenants, remove the node
if len(node.children) == 0 and len(node.tenant_last_access_time) == 0:
del node.parent.children[node.text[0]]
# add its parent to the heap
if tenant in leaf_of(node.parent):
heapq.heappush(
pq,
(node.parent.tenant_last_access_time[tenant], tenant, node.parent),
)
def get_used_size_per_tenant(self) -> Dict[str, int]:
"""
Calculate the used storage size for each tenant.
Returns:
Dict[str, int]: A dictionary mapping tenant_id to their used storage size
"""
used_size_per_tenant = defaultdict(int)
stack = [self.root]
while stack:
curr = stack.pop()
for t in curr.tenant_last_access_time.keys():
used_size_per_tenant[t] += len(curr.text)
for c in curr.children.values():
stack.append(c)
return used_size_per_tenant
def remove_tenant(self, tenant_id: str) -> None:
"""
Remove all data associated with a specific tenant from the tree.
This operation maintains the integrity of the shared tree structure while
removing only the specified tenant's access information.
Args:
tenant_id: The identifier of the tenant whose data should be removed
"""
# TODO: Implementation needed
pass
def pretty_print(self) -> str:
"""
Returns a string representation of the tree showing the structure, tenant ownership,
and leaf status for each node.
Returns:
str: A formatted string showing the tree hierarchy with tenant information
"""
def _node_to_str(node: Node, prefix: str = "", is_last: bool = True) -> str:
# Current node representation
node_str = prefix
node_str += "└── " if is_last else "├── "
# Add node text
node_str += f"'{node.text}' ["
# Add tenant information including both timestamp and leaf status
tenant_info = []
for tid, ts in node.tenant_last_access_time.items():
time_str = (
time.strftime("%H:%M:%S.", time.localtime(ts))
+ f"{(ts % 1):0.3f}"[2:]
)
tenant_info.append(f"{tid} | {time_str}")
node_str += ", ".join(tenant_info)
node_str += "]\n"
# Handle children
children = list(node.children.items())
for i, (char, child) in enumerate(children):
is_last_child = i == len(children) - 1
# Adjust prefix for children based on whether this is the last child
new_prefix = prefix + (" " if is_last else "")
node_str += _node_to_str(child, new_prefix, is_last_child)
return node_str
if not self.root.children:
return "Empty tree"
# Start with root's children since root itself is just an empty node
result = ""
children = list(self.root.children.items())
for i, (char, child) in enumerate(children):
is_last = i == len(children) - 1
result += _node_to_str(child, "", is_last)
return result