Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261)
This commit is contained in:
@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
get_tokenizer,
|
||||
get_tokenizer_from_processor,
|
||||
)
|
||||
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
@@ -242,11 +243,11 @@ class TokenizerManager:
|
||||
revision=server_args.revision,
|
||||
)
|
||||
|
||||
# Initialize loaded loRA adapters with the initial lora paths in the server_args.
|
||||
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
|
||||
self.loaded_lora_adapters: Dict[str, str] = dict(
|
||||
self.server_args.lora_paths or {}
|
||||
)
|
||||
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
||||
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
||||
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
||||
# to internally used unique LoRA IDs.
|
||||
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
|
||||
|
||||
# Store states
|
||||
self.no_create_loop = False
|
||||
@@ -523,6 +524,10 @@ class TokenizerManager:
|
||||
else:
|
||||
mm_inputs = None
|
||||
|
||||
if self.server_args.enable_lora and obj.lora_path:
|
||||
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
|
||||
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
||||
|
||||
self._validate_one_request(obj, input_ids)
|
||||
return self._create_tokenized_object(
|
||||
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
||||
@@ -574,8 +579,6 @@ class TokenizerManager:
|
||||
"The server is not configured to enable custom logit processor. "
|
||||
"Please set `--enable-custom-logits-processor` to enable this feature."
|
||||
)
|
||||
if self.server_args.enable_lora and obj.lora_path:
|
||||
self._validate_lora_adapters(obj)
|
||||
|
||||
def _validate_input_ids_in_vocab(
|
||||
self, input_ids: List[int], vocab_size: int
|
||||
@@ -689,21 +692,6 @@ class TokenizerManager:
|
||||
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
||||
)
|
||||
|
||||
def _validate_lora_adapters(self, obj: GenerateReqInput):
|
||||
"""Validate that the requested LoRA adapters are loaded."""
|
||||
requested_adapters = (
|
||||
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
|
||||
)
|
||||
loaded_adapters = (
|
||||
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
|
||||
)
|
||||
unloaded_adapters = requested_adapters - loaded_adapters
|
||||
if unloaded_adapters:
|
||||
raise ValueError(
|
||||
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
|
||||
f"Loaded adapters: {loaded_adapters}."
|
||||
)
|
||||
|
||||
def _send_one_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -1054,8 +1042,18 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
async with self.model_update_lock.writer_lock:
|
||||
# Generate new uniquely identifiable LoRARef object.
|
||||
new_adapter = LoRARef(
|
||||
lora_name=obj.lora_name,
|
||||
lora_path=obj.lora_path,
|
||||
)
|
||||
|
||||
# Register the new adapter in the registry.
|
||||
obj.lora_id = new_adapter.lora_id
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
self.loaded_lora_adapters = result.loaded_adapters
|
||||
if result.success:
|
||||
await self.lora_registry.register(new_adapter)
|
||||
|
||||
return result
|
||||
|
||||
async def unload_lora_adapter(
|
||||
@@ -1069,6 +1067,10 @@ class TokenizerManager:
|
||||
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
||||
)
|
||||
|
||||
assert (
|
||||
obj.lora_name is not None
|
||||
), "lora_name must be provided to unload LoRA adapter"
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
@@ -1080,8 +1082,9 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
async with self.model_update_lock.writer_lock:
|
||||
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
self.loaded_lora_adapters = result.loaded_adapters
|
||||
|
||||
return result
|
||||
|
||||
async def get_weights_by_name(
|
||||
@@ -1309,7 +1312,7 @@ class TokenizerManager:
|
||||
filename = os.path.join(
|
||||
self.crash_dump_folder,
|
||||
os.getenv("HOSTNAME", None),
|
||||
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
|
||||
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user