[router] add base_gpu_id server args & merged radix tree python reference (#2115)
This commit is contained in:
@@ -156,7 +156,7 @@ class DataParallelController:
|
||||
)
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
|
||||
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
||||
|
||||
@@ -1380,6 +1380,10 @@ def run_scheduler_process(
|
||||
dp_rank: Optional[int],
|
||||
pipe_writer,
|
||||
):
|
||||
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
|
||||
if dp_rank is None:
|
||||
dp_rank = int(os.getenv("DP_RANK", -1))
|
||||
|
||||
if dp_rank is None:
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
else:
|
||||
|
||||
@@ -418,7 +418,7 @@ def launch_engine(
|
||||
)
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = tp_rank % tp_size_per_node
|
||||
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
||||
|
||||
@@ -72,6 +72,7 @@ class ServerArgs:
|
||||
constrained_json_whitespace_pattern: Optional[str] = None
|
||||
watchdog_timeout: float = 300
|
||||
download_dir: Optional[str] = None
|
||||
base_gpu_id: int = 0
|
||||
|
||||
# Logging
|
||||
log_level: str = "info"
|
||||
@@ -412,6 +413,12 @@ class ServerArgs:
|
||||
default=ServerArgs.download_dir,
|
||||
help="Model download directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-gpu-id",
|
||||
type=int,
|
||||
default=ServerArgs.base_gpu_id,
|
||||
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
|
||||
)
|
||||
|
||||
# Logging
|
||||
parser.add_argument(
|
||||
@@ -736,6 +743,7 @@ class ServerArgs:
|
||||
and (self.lora_paths is None or self.disable_cuda_graph)
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
||||
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
||||
|
||||
if isinstance(self.lora_paths, list):
|
||||
lora_paths = self.lora_paths
|
||||
|
||||
207
scripts/playground/router/test_tree.py
Normal file
207
scripts/playground/router/test_tree.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from tree import MultiTenantRadixTree
|
||||
|
||||
|
||||
class TestMultiTenantRadixTree(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tree = MultiTenantRadixTree()
|
||||
|
||||
def test_insert_exact_match(self):
|
||||
"""Test 1: Basic insert and exact match operations"""
|
||||
# Insert a single string for one tenant
|
||||
self.tree.insert("hello", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("hello")
|
||||
self.assertEqual(matched, "hello")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Insert same string for different tenant
|
||||
self.tree.insert("hello", "tenant2")
|
||||
matched, tenant = self.tree.prefix_match("hello")
|
||||
self.assertIn(tenant, ["tenant1", "tenant2"])
|
||||
|
||||
# Insert different string for same tenant
|
||||
self.tree.insert("world", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("world")
|
||||
self.assertEqual(matched, "world")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
def test_insert_partial_match(self):
|
||||
"""Test 2: Insert with partial matching scenarios"""
|
||||
# Test partial matches with common prefixes
|
||||
self.tree.insert("hello", "tenant1")
|
||||
print(self.tree.pretty_print())
|
||||
self.tree.insert("help", "tenant2")
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
# Match exact strings
|
||||
matched, tenant = self.tree.prefix_match("hello")
|
||||
self.assertEqual(matched, "hello")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
matched, tenant = self.tree.prefix_match("help")
|
||||
self.assertEqual(matched, "help")
|
||||
self.assertEqual(tenant, "tenant2")
|
||||
|
||||
# Match partial string
|
||||
matched, tenant = self.tree.prefix_match("hel")
|
||||
self.assertEqual(matched, "hel")
|
||||
self.assertIn(tenant, ["tenant1", "tenant2"])
|
||||
|
||||
# Match longer string
|
||||
matched, tenant = self.tree.prefix_match("hello_world")
|
||||
self.assertEqual(matched, "hello")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
def test_insert_edge_cases(self):
|
||||
"""Test 3: Edge cases for insert and match operations"""
|
||||
# Empty string
|
||||
self.tree.insert("", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("")
|
||||
self.assertEqual(matched, "")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Single character
|
||||
self.tree.insert("a", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("a")
|
||||
self.assertEqual(matched, "a")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Very long string
|
||||
long_str = "a" * 1000
|
||||
self.tree.insert(long_str, "tenant1")
|
||||
matched, tenant = self.tree.prefix_match(long_str)
|
||||
self.assertEqual(matched, long_str)
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Unicode characters
|
||||
self.tree.insert("你好", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("你好")
|
||||
self.assertEqual(matched, "你好")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
def test_simple_eviction(self):
|
||||
"""Test 4: Simple eviction scenarios
|
||||
Tenant1: limit 10 chars
|
||||
Tenant2: limit 5 chars
|
||||
|
||||
Should demonstrate:
|
||||
1. Basic eviction when size limit exceeded
|
||||
2. Proper eviction based on last access time
|
||||
3. Verification that shared nodes remain intact for other tenants
|
||||
"""
|
||||
# Set up size limits
|
||||
max_size = {"tenant1": 10, "tenant2": 5}
|
||||
|
||||
# Insert strings for both tenants
|
||||
self.tree.insert("hello", "tenant1") # size 5
|
||||
self.tree.insert("hello", "tenant2") # size 5
|
||||
self.tree.insert("world", "tenant2") # size 5, total for tenant2 = 10
|
||||
|
||||
# Verify initial sizes
|
||||
sizes_before = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_before["tenant1"], 5) # "hello" = 5
|
||||
self.assertEqual(sizes_before["tenant2"], 10) # "hello" + "world" = 10
|
||||
|
||||
# Evict - should remove "hello" from tenant2 as it's the oldest
|
||||
self.tree.evict_tenant_data(max_size)
|
||||
|
||||
# Verify sizes after eviction
|
||||
sizes_after = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_after["tenant1"], 5) # Should be unchanged
|
||||
self.assertEqual(sizes_after["tenant2"], 5) # Only "world" remains
|
||||
|
||||
# Verify "world" remains for tenant2 (was accessed more recently)
|
||||
matched, tenant = self.tree.prefix_match("world")
|
||||
self.assertEqual(matched, "world")
|
||||
self.assertEqual(tenant, "tenant2")
|
||||
|
||||
def test_medium_eviction(self):
|
||||
"""Test 5: Medium complexity eviction scenarios with shared prefixes
|
||||
Tenant1: limit 10 chars
|
||||
Tenant2: limit 7 chars (forces one string to be evicted)
|
||||
|
||||
Tree structure after inserts:
|
||||
└── 'h' [t1, t2]
|
||||
├── 'i' [t1, t2] # Oldest for t2
|
||||
└── 'e' [t1, t2]
|
||||
├── 'llo' [t1, t2]
|
||||
└── 'y' [t2] # Newest for t2
|
||||
|
||||
Size calculations:
|
||||
tenant1: "h"(1) + "i"(1) + "e"(1) + "llo"(3) = 6 chars
|
||||
tenant2: "h"(1) + "i"(1) + "e"(1) + "llo"(3) + "y"(1) = 7 chars
|
||||
|
||||
After eviction (tenant2 exceeds limit by 1 char):
|
||||
"hi" should be removed from tenant2 as it's the oldest access
|
||||
"""
|
||||
max_size = {
|
||||
"tenant1": 10,
|
||||
"tenant2": 6,
|
||||
} # tenant2 will need to evict one string
|
||||
|
||||
# Create a tree with overlapping prefixes
|
||||
self.tree.insert("hi", "tenant1")
|
||||
self.tree.insert("hi", "tenant2") # OLDEST for t2
|
||||
|
||||
self.tree.insert("hello", "tenant1")
|
||||
self.tree.insert("hello", "tenant2")
|
||||
|
||||
self.tree.insert("hey", "tenant2") # NEWEST for t2
|
||||
|
||||
# Verify initial sizes
|
||||
sizes_before = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_before["tenant1"], 6) # h(1) + i(1) + e(1) + llo(3) = 6
|
||||
self.assertEqual(
|
||||
sizes_before["tenant2"], 7
|
||||
) # h(1) + i(1) + e(1) + llo(3) + y(1) = 7
|
||||
|
||||
print("\nTree before eviction:")
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
# Evict - should remove "hi" from tenant2 as it's the oldest
|
||||
self.tree.evict_tenant_data(max_size)
|
||||
|
||||
print("\nTree after eviction:")
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
# Verify sizes after eviction
|
||||
sizes_after = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_after["tenant1"], 6) # Should be unchanged
|
||||
self.assertEqual(sizes_after["tenant2"], 6) # h(1) + e(1) + llo(3) + y(1) = 6
|
||||
|
||||
def test_advanced_eviction(self):
|
||||
...
|
||||
# Create 4 tenants
|
||||
# Each tenants keeps adding strings with shared prefixes to thousands usage
|
||||
# Set a strict limit for each tenant to only 100
|
||||
# At the end, check whether all of the tenant is under 100 after eviction
|
||||
|
||||
max_size = {"tenant1": 100, "tenant2": 100, "tenant3": 100, "tenant4": 100}
|
||||
|
||||
prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"]
|
||||
for i in range(100):
|
||||
for j, prefix in enumerate(prefixes):
|
||||
random_suffix = "".join(random.choices(string.ascii_letters, k=10))
|
||||
self.tree.insert(prefix + random_suffix, f"tenant{j+1}")
|
||||
|
||||
sizes_before = self.tree.get_used_size_per_tenant()
|
||||
print(sizes_before)
|
||||
|
||||
self.tree.evict_tenant_data(max_size)
|
||||
|
||||
sizes_after = self.tree.get_used_size_per_tenant()
|
||||
print(sizes_after)
|
||||
# ensure size_after is below max_size
|
||||
for tenant, size in sizes_after.items():
|
||||
self.assertLessEqual(size, max_size[tenant])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
292
scripts/playground/router/tree.py
Normal file
292
scripts/playground/router/tree.py
Normal file
@@ -0,0 +1,292 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self):
|
||||
self.children: Dict[str, Node] = dict()
|
||||
# We choose to use text because most of the use cases are text-to-text,
|
||||
# so we can save the tokenizing overhead.
|
||||
self.text: str = ""
|
||||
# Maps tenant_id to their last access timestamp
|
||||
self.tenant_last_access_time: Dict[str, float] = dict()
|
||||
self.parent = None
|
||||
|
||||
|
||||
def shared_prefix_length(s1, s2):
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(min_length):
|
||||
if s1[i] != s2[i]:
|
||||
return i
|
||||
return min_length
|
||||
|
||||
|
||||
class MultiTenantRadixTree:
|
||||
"""
|
||||
Python Reference of Rust implementation of MultiTenantRadixTree
|
||||
|
||||
MultiTenantRadixTree is the overlap of multiple radix trees by different tenant
|
||||
Each node in the tree can be owned by multiple tenants, allowing for efficient storage of common prefixes
|
||||
while maintaining tenant isolation.
|
||||
|
||||
Key concepts:
|
||||
- Tenant: An entity that owns a subset of the stored strings
|
||||
- Each node tracks which tenants have access to it via tenant_last_access_time
|
||||
- The tree structure is shared, but queries can be filtered by tenant_id
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.root = Node()
|
||||
|
||||
def insert(self, s: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Insert string 's' and associate it with the given tenant_id.
|
||||
|
||||
Args:
|
||||
s: The string to insert
|
||||
tenant_id: The identifier of the tenant who owns this string
|
||||
"""
|
||||
curr = self.root
|
||||
curr_idx = 0
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
|
||||
while curr_idx < len(s):
|
||||
matched_node = None
|
||||
if s[curr_idx] in curr.children:
|
||||
matched_node = curr.children[s[curr_idx]]
|
||||
|
||||
if matched_node is None:
|
||||
# No match => create a new node
|
||||
new_node = Node()
|
||||
new_node.text = s[curr_idx:]
|
||||
new_node.parent = curr
|
||||
|
||||
curr.children[s[curr_idx]] = new_node
|
||||
curr_idx = len(s)
|
||||
curr = new_node
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
else:
|
||||
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
|
||||
|
||||
# 1. If the matched text is shorter than the node text => split the node
|
||||
if shared_len < len(matched_node.text):
|
||||
# Split structure: [matched_node] => [new_node] -> [contracted_matched_node]
|
||||
|
||||
matched_text = matched_node.text[:shared_len]
|
||||
unmatched_text = matched_node.text[shared_len:]
|
||||
|
||||
new_node = Node()
|
||||
new_node.text = matched_text
|
||||
new_node.children = {unmatched_text[0]: matched_node}
|
||||
new_node.parent = curr
|
||||
new_node.parent.children[matched_text[0]] = new_node
|
||||
new_node.tenant_last_access_time = (
|
||||
matched_node.tenant_last_access_time.copy()
|
||||
)
|
||||
|
||||
# Contract matched node
|
||||
matched_node.text = unmatched_text
|
||||
matched_node.parent = new_node
|
||||
|
||||
curr_idx += shared_len
|
||||
curr = new_node
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
# 2. If the matched text is longer or equal to the node text => walk down the node
|
||||
else:
|
||||
curr_idx += shared_len
|
||||
curr = matched_node
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
|
||||
def prefix_match(self, s: str) -> tuple[str, int]:
|
||||
"""
|
||||
Match string 's' with multiple tenants' trees in one operation.
|
||||
|
||||
Args:
|
||||
s: The string to match
|
||||
|
||||
Returns:
|
||||
Tuple(str, int): The longest prefix of 's' that matches the tree and the first tenant_id that own the matched prefix
|
||||
"""
|
||||
curr = self.root
|
||||
curr_idx = 0
|
||||
|
||||
ret_text = ""
|
||||
ret_tenant = None
|
||||
|
||||
while curr_idx < len(s):
|
||||
matched_node = None
|
||||
if s[curr_idx] in curr.children:
|
||||
matched_node = curr.children[s[curr_idx]]
|
||||
|
||||
if matched_node is None:
|
||||
break
|
||||
|
||||
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
|
||||
if shared_len == len(matched_node.text):
|
||||
curr_idx += shared_len
|
||||
curr = matched_node
|
||||
else:
|
||||
curr_idx += shared_len
|
||||
curr = matched_node
|
||||
break
|
||||
|
||||
selected_tenant = list(curr.tenant_last_access_time.keys())[0]
|
||||
|
||||
# traverse back to the root to update last access time for the selected tenant
|
||||
while curr != self.root:
|
||||
curr.tenant_last_access_time[selected_tenant] = time.time()
|
||||
curr = curr.parent
|
||||
|
||||
return s[:curr_idx], selected_tenant
|
||||
|
||||
def evict_tenant_data(self, max_size_per_tenant: Dict[str, int]) -> None:
|
||||
"""
|
||||
Evict data for tenants that have exceeded their storage limits.
|
||||
|
||||
Args:
|
||||
max_size_per_tenant: Dictionary mapping tenant_id to their maximum allowed storage size
|
||||
"""
|
||||
|
||||
def leaf_of(node):
|
||||
"""
|
||||
If the node is a leaf for a tenant, add tenant_id to the return list
|
||||
This will return list of tenant ids
|
||||
If not a leaf for all tenants, return []
|
||||
"""
|
||||
candidates = dict([(k, True) for k in node.tenant_last_access_time.keys()])
|
||||
|
||||
for n in node.children.values():
|
||||
for c in n.tenant_last_access_time.keys():
|
||||
candidates[c] = False
|
||||
|
||||
return [k for k, v in candidates.items() if v]
|
||||
|
||||
# maintain a heap with (time, tenant, node) as the value
|
||||
import heapq
|
||||
|
||||
# 1. traverse the tree to
|
||||
# a. add all the leaves into a heap (a node with N tenants will be added N times into the heap)
|
||||
# b. calculate the used size for each tenant
|
||||
# do a dfs with stack
|
||||
stack = [self.root]
|
||||
pq = []
|
||||
used_size_per_tenant = defaultdict(int)
|
||||
|
||||
while stack:
|
||||
curr = stack.pop()
|
||||
for t in curr.tenant_last_access_time.keys():
|
||||
used_size_per_tenant[t] += len(curr.text)
|
||||
|
||||
for c in curr.children.values():
|
||||
stack.append(c)
|
||||
|
||||
# if the node is a leaf for a tenant, add the tenant to the heap
|
||||
tenants = leaf_of(curr)
|
||||
for t in tenants:
|
||||
heapq.heappush(pq, (curr.tenant_last_access_time[t], t, curr))
|
||||
|
||||
# 2. pop the heap
|
||||
# a. if the tenant's used size is less than the limit, continue
|
||||
# b. if the tenant's used size is greater than the limit, remove the leaf and update the used size, and add its parent to the heap
|
||||
while len(pq) > 0:
|
||||
time, tenant, node = heapq.heappop(pq)
|
||||
if used_size_per_tenant[tenant] <= max_size_per_tenant[tenant]:
|
||||
continue
|
||||
|
||||
# remove the leaf
|
||||
used_size_per_tenant[tenant] -= len(node.text)
|
||||
del node.tenant_last_access_time[tenant]
|
||||
# if no children and no tenants, remove the node
|
||||
if len(node.children) == 0 and len(node.tenant_last_access_time) == 0:
|
||||
del node.parent.children[node.text[0]]
|
||||
|
||||
# add its parent to the heap
|
||||
if tenant in leaf_of(node.parent):
|
||||
heapq.heappush(
|
||||
pq,
|
||||
(node.parent.tenant_last_access_time[tenant], tenant, node.parent),
|
||||
)
|
||||
|
||||
def get_used_size_per_tenant(self) -> Dict[str, int]:
|
||||
"""
|
||||
Calculate the used storage size for each tenant.
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping tenant_id to their used storage size
|
||||
"""
|
||||
used_size_per_tenant = defaultdict(int)
|
||||
|
||||
stack = [self.root]
|
||||
while stack:
|
||||
curr = stack.pop()
|
||||
for t in curr.tenant_last_access_time.keys():
|
||||
used_size_per_tenant[t] += len(curr.text)
|
||||
|
||||
for c in curr.children.values():
|
||||
stack.append(c)
|
||||
|
||||
return used_size_per_tenant
|
||||
|
||||
def remove_tenant(self, tenant_id: str) -> None:
|
||||
"""
|
||||
Remove all data associated with a specific tenant from the tree.
|
||||
This operation maintains the integrity of the shared tree structure while
|
||||
removing only the specified tenant's access information.
|
||||
|
||||
Args:
|
||||
tenant_id: The identifier of the tenant whose data should be removed
|
||||
"""
|
||||
# TODO: Implementation needed
|
||||
pass
|
||||
|
||||
def pretty_print(self) -> str:
|
||||
"""
|
||||
Returns a string representation of the tree showing the structure, tenant ownership,
|
||||
and leaf status for each node.
|
||||
|
||||
Returns:
|
||||
str: A formatted string showing the tree hierarchy with tenant information
|
||||
"""
|
||||
|
||||
def _node_to_str(node: Node, prefix: str = "", is_last: bool = True) -> str:
|
||||
# Current node representation
|
||||
node_str = prefix
|
||||
node_str += "└── " if is_last else "├── "
|
||||
|
||||
# Add node text
|
||||
node_str += f"'{node.text}' ["
|
||||
|
||||
# Add tenant information including both timestamp and leaf status
|
||||
tenant_info = []
|
||||
for tid, ts in node.tenant_last_access_time.items():
|
||||
time_str = (
|
||||
time.strftime("%H:%M:%S.", time.localtime(ts))
|
||||
+ f"{(ts % 1):0.3f}"[2:]
|
||||
)
|
||||
tenant_info.append(f"{tid} | {time_str}")
|
||||
|
||||
node_str += ", ".join(tenant_info)
|
||||
node_str += "]\n"
|
||||
|
||||
# Handle children
|
||||
children = list(node.children.items())
|
||||
for i, (char, child) in enumerate(children):
|
||||
is_last_child = i == len(children) - 1
|
||||
# Adjust prefix for children based on whether this is the last child
|
||||
new_prefix = prefix + (" " if is_last else "│ ")
|
||||
node_str += _node_to_str(child, new_prefix, is_last_child)
|
||||
|
||||
return node_str
|
||||
|
||||
if not self.root.children:
|
||||
return "Empty tree"
|
||||
|
||||
# Start with root's children since root itself is just an empty node
|
||||
result = ""
|
||||
children = list(self.root.children.items())
|
||||
for i, (char, child) in enumerate(children):
|
||||
is_last = i == len(children) - 1
|
||||
result += _node_to_str(child, "", is_last)
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user