Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261)

This commit is contained in:
Lifu Huang
2025-07-23 00:32:16 -07:00
committed by GitHub
parent e885bfdc6a
commit 8abd3e77fe
11 changed files with 400 additions and 261 deletions

View File

@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer,
get_tokenizer_from_processor,
)
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
@@ -242,11 +243,11 @@ class TokenizerManager:
revision=server_args.revision,
)
# Initialize loaded loRA adapters with the initial lora paths in the server_args.
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
self.loaded_lora_adapters: Dict[str, str] = dict(
self.server_args.lora_paths or {}
)
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs.
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
# Store states
self.no_create_loop = False
@@ -523,6 +524,10 @@ class TokenizerManager:
else:
mm_inputs = None
if self.server_args.enable_lora and obj.lora_path:
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
self._validate_one_request(obj, input_ids)
return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
@@ -574,8 +579,6 @@ class TokenizerManager:
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
if self.server_args.enable_lora and obj.lora_path:
self._validate_lora_adapters(obj)
def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int
@@ -689,21 +692,6 @@ class TokenizerManager:
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
)
def _validate_lora_adapters(self, obj: GenerateReqInput):
"""Validate that the requested LoRA adapters are loaded."""
requested_adapters = (
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
)
loaded_adapters = (
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
)
unloaded_adapters = requested_adapters - loaded_adapters
if unloaded_adapters:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
f"Loaded adapters: {loaded_adapters}."
)
def _send_one_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -1054,8 +1042,18 @@ class TokenizerManager:
)
async with self.model_update_lock.writer_lock:
# Generate new uniquely identifiable LoRARef object.
new_adapter = LoRARef(
lora_name=obj.lora_name,
lora_path=obj.lora_path,
)
# Register the new adapter in the registry.
obj.lora_id = new_adapter.lora_id
result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
if result.success:
await self.lora_registry.register(new_adapter)
return result
async def unload_lora_adapter(
@@ -1069,6 +1067,10 @@ class TokenizerManager:
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
assert (
obj.lora_name is not None
), "lora_name must be provided to unload LoRA adapter"
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
@@ -1080,8 +1082,9 @@ class TokenizerManager:
)
async with self.model_update_lock.writer_lock:
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
return result
async def get_weights_by_name(
@@ -1309,7 +1312,7 @@ class TokenizerManager:
filename = os.path.join(
self.crash_dump_folder,
os.getenv("HOSTNAME", None),
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
)
os.makedirs(os.path.dirname(filename), exist_ok=True)