From 75e6a7cde144fdb60683fcdcdbc3f4e0a8411ef9 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 11 Aug 2025 10:14:11 -0700 Subject: [PATCH] Support radix cache for Lora feature (#7216) --- docs/advanced_features/lora.ipynb | 5 +- python/sglang/srt/managers/schedule_batch.py | 29 +- python/sglang/srt/managers/scheduler.py | 15 +- .../sglang/srt/mem_cache/lora_radix_cache.py | 421 ++++++++++++++++++ python/sglang/srt/server_args.py | 6 +- test/srt/lora/test_lora.py | 1 - test/srt/lora/test_lora_eviction.py | 1 - test/srt/lora/test_lora_qwen3.py | 1 - test/srt/lora/test_lora_radix_cache.py | 83 ++++ test/srt/lora/test_lora_update.py | 2 - test/srt/lora/utils.py | 8 +- test/srt/run_suite.py | 1 + 12 files changed, 546 insertions(+), 27 deletions(-) create mode 100644 python/sglang/srt/mem_cache/lora_radix_cache.py create mode 100644 test/srt/lora/test_lora_radix_cache.py diff --git a/docs/advanced_features/lora.ipynb b/docs/advanced_features/lora.ipynb index 1a732cecc..708508134 100644 --- a/docs/advanced_features/lora.ipynb +++ b/docs/advanced_features/lora.ipynb @@ -80,7 +80,6 @@ " --enable-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " --max-loras-per-batch 1 --lora-backend triton \\\n", - " --disable-radix-cache\n", "\"\"\"\n", ")\n", "\n", @@ -140,7 +139,6 @@ " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n", " --max-loras-per-batch 2 --lora-backend triton \\\n", - " --disable-radix-cache\n", "\"\"\"\n", ")\n", "\n", @@ -215,7 +213,6 @@ " --enable-lora \\\n", " --cuda-graph-max-bs 2 \\\n", " --max-loras-per-batch 2 --lora-backend triton \\\n", - " --disable-radix-cache\n", " --max-lora-rank 256\n", " --lora-target-modules all\n", " \"\"\"\n", @@ -462,7 +459,7 @@ "source": [ "## Future Works\n", "\n", - "The development roadmap for LoRA-related features can be found in this [issue](https://github.com/sgl-project/sglang/issues/2929). Currently radix attention is incompatible with LoRA and must be manually disabled. Other features, including Unified Paging, Cutlass backend, and dynamic loading/unloadingm, are still under development." + "The development roadmap for LoRA-related features can be found in this [issue](https://github.com/sgl-project/sglang/issues/2929). Other features, including Embedding Layer, Unified Paging, Cutlass backend are still under development." ] } ], diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index e6b8d42ba..faa8a9b93 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -58,6 +58,7 @@ from sglang.srt.mem_cache.allocator import ( ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache +from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.metrics.collector import TimeStats @@ -639,14 +640,26 @@ class Req: ): self.fill_ids = self.origin_input_ids + self.output_ids if tree_cache is not None: - ( - self.prefix_indices, - self.last_node, - self.last_host_node, - self.host_hit_length, - ) = tree_cache.match_prefix( - key=self.adjust_max_prefix_ids(), - ) + if isinstance(tree_cache, LoRARadixCache): + ( + self.prefix_indices, + self.last_node, + self.last_host_node, + self.host_hit_length, + ) = tree_cache.match_prefix_with_lora_id( + key=LoRAKey( + lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids() + ), + ) + else: + ( + self.prefix_indices, + self.last_node, + self.last_host_node, + self.host_hit_length, + ) = tree_cache.match_prefix( + key=self.adjust_max_prefix_ids(), + ) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) def adjust_max_prefix_ids(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3856bf259..5b10eef59 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -130,6 +130,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache +from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors @@ -630,7 +631,19 @@ class Scheduler( page_size=self.page_size, disable=server_args.disable_radix_cache, ) - + elif self.enable_lora: + assert ( + not self.enable_hierarchical_cache + ), "LoRA radix cache doesn't support hierarchical cache" + assert ( + self.schedule_policy == "fcfs" + ), "LoRA radix cache only supports FCFS policy" + self.tree_cache = LoRARadixCache( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + page_size=self.page_size, + disable=server_args.disable_radix_cache, + ) else: self.tree_cache = RadixCache( req_to_token_pool=self.req_to_token_pool, diff --git a/python/sglang/srt/mem_cache/lora_radix_cache.py b/python/sglang/srt/mem_cache/lora_radix_cache.py new file mode 100644 index 000000000..fa5626012 --- /dev/null +++ b/python/sglang/srt/mem_cache/lora_radix_cache.py @@ -0,0 +1,421 @@ +"""Radix cache for LoRA. It's modified based on RadixCache with lora_id added to the key of nodes.""" + +import heapq +import time +from collections import defaultdict +from typing import TYPE_CHECKING, Any, List, Optional + +import torch + +from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req +else: + Req = Any # Placeholder for Req type when not type checking + + +class LoRAKey: + + def __init__(self, lora_id: str, token_ids: List[int]): + self.lora_id = ( + lora_id # lora_id of adaptor, should be hash value of adaptor path + ) + self.token_ids = token_ids # token_ids of the key + + def __len__(self): + return len(self.token_ids) + + +def get_child_key(key: LoRAKey): + # Here the key of children dict is the hash of lora_id + str(token_ids[0]) + # So the child key can be matched only when lora_id and token_ids[0] are the same + if key.lora_id is None: + return hash(str(key.token_ids[0])) + else: + return hash(key.lora_id + str(key.token_ids[0])) + + +class LoRATreeNode: + + counter = 0 + + def __init__(self, id: Optional[int] = None): + self.children = defaultdict(LoRATreeNode) + self.parent: LoRATreeNode = None + self.key: LoRAKey = None + self.value: Optional[torch.Tensor] = None + self.lock_ref = 0 + self.last_access_time = time.monotonic() + + self.id = LoRATreeNode.counter if id is None else id + LoRATreeNode.counter += 1 + + @property + def evicted(self): + return self.value is None + + def __lt__(self, other: "LoRATreeNode"): + return self.last_access_time < other.last_access_time + + +def _key_match(key0: LoRAKey, key1: LoRAKey): + if key0.lora_id != key1.lora_id: + raise ValueError( + f"_key_match should be run on the same lora_id, but got key0.lora_id={key0.lora_id} != key1.lora_id={key1.lora_id}" + ) + i = 0 + for k0, k1 in zip(key0.token_ids, key1.token_ids): + if k0 != k1: + break + i += 1 + return i + + +class LoRARadixCache(BasePrefixCache): + + def __init__( + self, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, + page_size: int, + disable: bool = False, + ): + if page_size > 1: + raise ValueError("LoRARadixCache currently only supports page_size = 1") + + if token_to_kv_pool_allocator is None: + raise ValueError( + "token_to_kv_pool_allocator is required to run LoraRadixCache" + ) + + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.page_size = page_size + self.disable = disable + self.device = self.token_to_kv_pool_allocator.device + + self.key_match_fn = _key_match + self.get_child_key_fn = get_child_key + self.reset() + + def reset(self): + self.root_node = LoRATreeNode() + self.root_node.key = LoRAKey(lora_id="", token_ids=[]) + self.root_node.value = None + self.evictable_size_ = 0 + self.protected_size_ = 0 + + def match_prefix(self, key: List[int], **kwargs) -> MatchResult: + raise ValueError( + "LoRARadixCache needs both token ids and lora id as inputs for matching. Please use match_prefix_with_lora_id instead." + ) + + def match_prefix_with_lora_id(self, key: LoRAKey, **kwargs) -> MatchResult: + """Find the matching prefix from the lora radix tree. + Args: + key: A LoRAKey to find a matching prefix. + Returns: + A tuple of a tensor of matching prefix token IDs and + the last node that contains the prefix values. Note that + this API can modify the internal state of the Radix tree. + The last node create a new child if the prefix is shorter + than the last node's value. + """ + if self.disable or len(key) == 0: + return MatchResult( + device_indices=torch.empty( + (0,), + dtype=torch.int64, + device=self.device, + ), + last_device_node=self.root_node, + last_host_node=self.root_node, + ) + + value, last_node = self._match_prefix_helper(self.root_node, key) + if value: + value = torch.cat(value) + else: + value = torch.empty((0,), dtype=torch.int64, device=self.device) + return MatchResult( + device_indices=value, + last_device_node=last_node, + last_host_node=last_node, + ) + + def insert(self, key: LoRAKey, value=None): + if self.disable: + return 0 + + if value is None: + value = [x for x in key.token_ids] + return self._insert_helper(self.root_node, key, value) + + def cache_finished_req(self, req: Req): + """Cache request when it finishes.""" + if self.disable: + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1 + ] + self.token_to_kv_pool_allocator.free(kv_indices) + self.req_to_token_pool.free(req.req_pool_idx) + return + + token_ids = (req.origin_input_ids + req.output_ids)[:-1] + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + + page_aligned_len = len(kv_indices) + page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) + + # Radix Cache takes one ref in memory pool + lora_key = LoRAKey(lora_id=req.lora_id, token_ids=token_ids[:page_aligned_len]) + new_prefix_len = self.insert(lora_key, page_aligned_kv_indices) + self.token_to_kv_pool_allocator.free( + kv_indices[len(req.prefix_indices) : new_prefix_len] + ) + + # Remove req slot release the cache lock + self.req_to_token_pool.free(req.req_pool_idx) + self.dec_lock_ref(req.last_node) + + def cache_unfinished_req(self, req: Req): + """Cache request when it is unfinished.""" + if self.disable: + return + + token_ids = req.fill_ids + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + + page_aligned_len = len(kv_indices) + page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) + page_aligned_token_ids = token_ids[:page_aligned_len] + + # Radix Cache takes one ref in memory pool + inserted_key = LoRAKey(lora_id=req.lora_id, token_ids=page_aligned_token_ids) + new_prefix_len = self.insert(inserted_key, page_aligned_kv_indices) + self.token_to_kv_pool_allocator.free( + kv_indices[len(req.prefix_indices) : new_prefix_len] + ) + + # The prefix indices could be updated, reuse it + new_indices, new_last_node, _, _ = self.match_prefix_with_lora_id(inserted_key) + self.req_to_token_pool.write( + (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), + new_indices[len(req.prefix_indices) :], + ) + + self.dec_lock_ref(req.last_node) + self.inc_lock_ref(new_last_node) + + # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + req.prefix_indices = new_indices + req.last_node = new_last_node + + def pretty_print(self): + self._print_helper(self.root_node, 0) + print(f"#tokens: {self.total_size()}") + + def total_size(self): + return self._total_size_helper() + + def evict(self, num_tokens: int): + if self.disable: + return + + leaves = self._collect_leaves() + heapq.heapify(leaves) + + num_evicted = 0 + while num_evicted < num_tokens and len(leaves): + x = heapq.heappop(leaves) + + if x == self.root_node: + break + if x.lock_ref > 0: + continue + + self.token_to_kv_pool_allocator.free(x.value) + num_evicted += len(x.value) + self._delete_leaf(x) + + if len(x.parent.children) == 0: + heapq.heappush(leaves, x.parent) + + def inc_lock_ref(self, node: LoRATreeNode): + if self.disable: + return 0 + + delta = 0 + while node != self.root_node: + if node.lock_ref == 0: + self.evictable_size_ -= len(node.value) + self.protected_size_ += len(node.value) + delta -= len(node.value) + node.lock_ref += 1 + node = node.parent + return delta + + def dec_lock_ref(self, node: LoRATreeNode): + if self.disable: + return 0 + + delta = 0 + while node != self.root_node: + if node.lock_ref == 1: + self.evictable_size_ += len(node.value) + self.protected_size_ -= len(node.value) + delta += len(node.value) + node.lock_ref -= 1 + node = node.parent + return delta + + def evictable_size(self): + return self.evictable_size_ + + def protected_size(self): + # protected size refers to the size of the cache that is locked + return self.protected_size_ + + def all_values_flatten(self): + values = [] + + def _dfs_helper(node: LoRATreeNode): + for _, child in node.children.items(): + values.append(child.value) + _dfs_helper(child) + + _dfs_helper(self.root_node) + return torch.cat(values) + + ##### Internal Helper Functions ##### + + def _match_prefix_helper(self, node: LoRATreeNode, key: LoRAKey): + node.last_access_time = time.monotonic() + + child_key = self.get_child_key_fn(key) + + value = [] + while len(key) > 0 and child_key in node.children.keys(): + child = node.children[child_key] + child.last_access_time = time.monotonic() + prefix_len = self.key_match_fn(child.key, key) + if prefix_len < len(child.key): + new_node = self._split_node(child.key, child, prefix_len) + value.append(new_node.value) + node = new_node + break + else: + value.append(child.value) + node = child + key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:]) + + if len(key): + child_key = self.get_child_key_fn(key) + + return value, node + + def _split_node(self, key: LoRAKey, child: LoRATreeNode, split_len: int): + # new_node -> child + new_node = LoRATreeNode() + key_split_1 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[:split_len]) + key_split_2 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[split_len:]) + new_node.children = {self.get_child_key_fn(key_split_2): child} + new_node.parent = child.parent + new_node.lock_ref = child.lock_ref + new_node.key = key_split_1 + new_node.value = child.value[:split_len] + child.parent = new_node + child.key = key_split_2 + child.value = child.value[split_len:] + new_node.parent.children[self.get_child_key_fn(key)] = new_node + + return new_node + + def _insert_helper(self, node: LoRATreeNode, key: LoRAKey, value): + node.last_access_time = time.monotonic() + if len(key) == 0: + return 0 + + child_key = self.get_child_key_fn(key) + + total_prefix_length = 0 + while len(key) > 0 and child_key in node.children.keys(): + node = node.children[child_key] + node.last_access_time = time.monotonic() + prefix_len = self.key_match_fn(node.key, key) + total_prefix_length += prefix_len + key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:]) + value = value[prefix_len:] + + if prefix_len < len(node.key): + new_node = self._split_node(node.key, node, prefix_len) + node = new_node + + if len(key): + child_key = self.get_child_key_fn(key) + + if len(key): + new_node = LoRATreeNode() + new_node.parent = node + new_node.key = key + new_node.value = value + node.children[child_key] = new_node + self.evictable_size_ += len(value) + return total_prefix_length + + def _print_helper(self, node: LoRATreeNode, indent: int): + """Prints the radix tree in a human-readable format.""" + stack = [(node, indent)] + while stack: + current_node, current_indent = stack.pop() + print( + " " * current_indent, + len(current_node.key), + current_node.key.token_ids[:10], + f"r={current_node.lock_ref}", + ) + for key, child in current_node.children.items(): + stack.append((child, current_indent + 2)) + + assert key == self.get_child_key_fn( + child.key + ), f"{key=}, {self.get_child_key_fn(child.key)=}" + + def _delete_leaf(self, node): + for k, v in node.parent.children.items(): + if v == node: + break + del node.parent.children[k] + self.evictable_size_ -= len(node.key) + + def _total_size_helper(self): + total_size = 0 + stack = [self.root_node] + while stack: + current_node = stack.pop() + total_size += len(current_node.value) + for child in current_node.children.values(): + if child.evicted: + continue + stack.append(child) + return total_size + + def _collect_leaves(self): + ret_list = [] + stack = [self.root_node] + + while stack: + cur_node = stack.pop() + if len(cur_node.children) == 0: + ret_list.append(cur_node) + else: + stack.extend(cur_node.children.values()) + + return ret_list diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b2d8901a7..93ceb6797 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2004,11 +2004,7 @@ class ServerArgs: ), "chunked_prefill_size must be divisible by page_size" def check_lora_server_args(self): - assert ( - self.max_loras_per_batch > 0 - # FIXME - and (self.lora_paths is None or self.disable_radix_cache) - ), "compatibility of lora and radix attention is in progress" + assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" # Enable LoRA if any LoRA paths are provided for backward compatibility. if self.lora_paths: diff --git a/test/srt/lora/test_lora.py b/test/srt/lora/test_lora.py index 17aa6f3b8..536cec71a 100644 --- a/test/srt/lora/test_lora.py +++ b/test/srt/lora/test_lora.py @@ -104,7 +104,6 @@ class TestLoRA(CustomTestCase): lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]], max_loras_per_batch=len(lora_adapter_paths) + 1, lora_backend=backend, - disable_radix_cache=True, sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch. attention_backend="torch_native", ) diff --git a/test/srt/lora/test_lora_eviction.py b/test/srt/lora/test_lora_eviction.py index b352da2d5..d27b11906 100644 --- a/test/srt/lora/test_lora_eviction.py +++ b/test/srt/lora/test_lora_eviction.py @@ -97,7 +97,6 @@ class TestLoRAEviction(CustomTestCase): lora_paths=initial_lora_paths, max_loras_per_batch=1, lora_backend=backend, - disable_radix_cache=True, enable_lora=True, max_lora_rank=256, lora_target_modules=["all"], diff --git a/test/srt/lora/test_lora_qwen3.py b/test/srt/lora/test_lora_qwen3.py index 4519c3c1f..d114e1ee8 100644 --- a/test/srt/lora/test_lora_qwen3.py +++ b/test/srt/lora/test_lora_qwen3.py @@ -140,7 +140,6 @@ class TestLoRA(CustomTestCase): lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]], max_loras_per_batch=len(lora_adapter_paths) + 1, lora_backend=backend, - disable_radix_cache=True, ) hf_runner = HFRunner( base_path, diff --git a/test/srt/lora/test_lora_radix_cache.py b/test/srt/lora/test_lora_radix_cache.py new file mode 100644 index 000000000..d3ecb219c --- /dev/null +++ b/test/srt/lora/test_lora_radix_cache.py @@ -0,0 +1,83 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import multiprocessing as mp +import random +import unittest + +import torch +from utils import CI_MULTI_LORA_MODELS, DEFAULT_PROMPTS, run_lora_test_one_by_one + +from sglang.test.runners import HFRunner, SRTRunner +from sglang.test.test_utils import CustomTestCase + +PROMPTS = [ + "AI is a field of computer science focused on", + """ + ### Instruction: + Tell me about llamas and alpacas + ### Response: + Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids. + ### Question: + What do you know about llamas? + ### Answer: + """, +] + + +class TestLoRARadixCache(CustomTestCase): + + def test_lora_radix_cache(self): + # Here we need a model case with multiple adaptors for testing correctness of radix cache + model_case = CI_MULTI_LORA_MODELS[0] + + torch_dtype = torch.float16 + max_new_tokens = 32 + backend = "triton" + batch_prompts = ( + PROMPTS + if not model_case.skip_long_prompt + else [p for p in PROMPTS if len(p) < 1000] + ) + + # Test lora with radix cache + run_lora_test_one_by_one( + batch_prompts, + model_case, + torch_dtype, + max_new_tokens=max_new_tokens, + backend=backend, + disable_radix_cache=False, + test_tag="lora-with-radix-cache", + ) + + # Test lora without radix cache + run_lora_test_one_by_one( + batch_prompts, + model_case, + torch_dtype, + max_new_tokens=max_new_tokens, + backend=backend, + disable_radix_cache=True, + test_tag="lora-without-radix-cache", + ) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main(warnings="ignore") diff --git a/test/srt/lora/test_lora_update.py b/test/srt/lora/test_lora_update.py index 9afbde79c..e33fccc02 100644 --- a/test/srt/lora/test_lora_update.py +++ b/test/srt/lora/test_lora_update.py @@ -787,7 +787,6 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): max_loaded_loras=self.max_loaded_loras, disable_cuda_graph=self.disable_cuda_graph, cuda_graph_max_bs=self.cuda_graph_max_bs, - disable_radix_cache=True, enable_lora=self.enable_lora, ) self.handle.__enter__() @@ -917,7 +916,6 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): str(self.max_loras_per_batch), "--lora-backend", self.lora_backend, - "--disable-radix-cache", "--random-seed", "42", "--max-running-request", diff --git a/test/srt/lora/utils.py b/test/srt/lora/utils.py index 642b8731e..705231965 100644 --- a/test/srt/lora/utils.py +++ b/test/srt/lora/utils.py @@ -136,7 +136,7 @@ def run_lora_test_one_by_one( max_new_tokens: int, backend: str, disable_cuda_graph: bool = False, - disable_radix_cache: bool = True, + disable_radix_cache: bool = False, mem_fraction_static: float = 0.88, test_tag: str = "", ): @@ -156,7 +156,7 @@ def run_lora_test_one_by_one( max_new_tokens (int): The maximum number of new tokens to generate. backend (str): The lora backend to use. disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False. - disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True. + disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to False. mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88. test_tag (str, optional): The tag to use for the test. Defaults to "". """ @@ -284,7 +284,7 @@ def run_lora_test_by_batch( max_new_tokens: int, backend: str, disable_cuda_graph: bool = False, - disable_radix_cache: bool = True, + disable_radix_cache: bool = False, mem_fraction_static: float = 0.88, test_tag: str = "", ): @@ -303,7 +303,7 @@ def run_lora_test_by_batch( max_new_tokens (int): The maximum number of new tokens to generate. backend (str): The lora backend to use. disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False. - disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True. + disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to False. mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88. test_tag (str, optional): The tag to use for the test. Defaults to "". """ diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index dab218994..dcf2c0efb 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -23,6 +23,7 @@ suites = { TestFile("lora/test_lora_cuda_graph.py", 250), TestFile("lora/test_lora_update.py", 400), TestFile("lora/test_lora_qwen3.py", 97), + TestFile("lora/test_lora_radix_cache.py", 100), TestFile("models/test_embedding_models.py", 73), # TestFile("models/test_clip_models.py", 52), TestFile("models/test_encoder_embedding_models.py", 100),