From 5ded39cab26116545de9ad6b8ce342c5a3615042 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Mon, 11 Aug 2025 22:59:29 -0700 Subject: [PATCH] Fix race condition in async lora unload (#9084) --- python/sglang/srt/lora/lora_registry.py | 29 +++++++++++++------ python/sglang/srt/managers/io_struct.py | 1 + .../sglang/srt/managers/tokenizer_manager.py | 22 ++++++++------ python/sglang/srt/utils.py | 2 +- 4 files changed, 35 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/lora/lora_registry.py b/python/sglang/srt/lora/lora_registry.py index 082f9a2d3..535ab47b4 100644 --- a/python/sglang/srt/lora/lora_registry.py +++ b/python/sglang/srt/lora/lora_registry.py @@ -14,7 +14,6 @@ import asyncio -from collections import defaultdict from dataclasses import dataclass, field, fields from typing import Dict, List, Optional, Union from uuid import uuid4 @@ -106,7 +105,6 @@ class LoRARegistry: f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}" ) del self._registry[lora_name] - del self._counters[lora_ref.lora_id] return lora_ref.lora_id @@ -117,6 +115,9 @@ class LoRARegistry: """ def _lookup(name: str) -> str: + if name is None: + return None + lora_ref = self._registry.get(name, None) if lora_ref is None: raise ValueError( @@ -135,7 +136,11 @@ class LoRARegistry: # Increment the counters only after all IDs are looked up. await asyncio.gather( - *[self._counters[id].increment(notify_all=False) for id in lora_ids] + *[ + self._counters[id].increment(notify_all=False) + for id in lora_ids + if id is not None + ] ) return lora_ids else: @@ -153,7 +158,11 @@ class LoRARegistry: await self._counters[lora_id].decrement() elif isinstance(lora_id, list): await asyncio.gather( - *[self._counters[id].decrement() for id in lora_id] + *[ + self._counters[id].decrement() + for id in lora_id + if id is not None + ] ) else: raise TypeError("lora_id must be either a string or a list of strings.") @@ -169,11 +178,13 @@ class LoRARegistry: assert ( lora_id not in self._registry ), "wait_for_unload should only be called after the LoRA adapter has been unregistered. " - counter = self._counters.get(lora_id) - if counter: - # Wait until no requests are using this LoRA adapter. - await counter.wait_for_zero() - del self._counters[lora_id] + assert ( + lora_id in self._counters + ), "The LoRA ID should still have a counter if it has been registered before." + + # Wait until no requests are using this LoRA adapter. + await self._counters[lora_id].wait_for_zero() + del self._counters[lora_id] def _register_adapter(self, lora_ref: LoRARef): """ diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 314339a8b..3c7c0069e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -455,6 +455,7 @@ class GenerateReqInput: log_metrics=self.log_metrics, modalities=self.modalities[i] if self.modalities else None, lora_path=self.lora_path[i] if self.lora_path is not None else None, + lora_id=self.lora_id[i] if self.lora_id is not None else None, custom_logit_processor=( self.custom_logit_processor[i] if self.custom_logit_processor is not None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a1a81a87f..00dd6a065 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -485,6 +485,10 @@ class TokenizerManager: await self.is_pause_cond.wait_for(lambda: not self.is_pause) async with self.model_update_lock.reader_lock: + if self.server_args.enable_lora and obj.lora_path: + # Look up the LoRA ID from the registry and start tracking ongoing LoRA requests. + obj.lora_id = await self.lora_registry.acquire(obj.lora_path) + if obj.is_single: tokenized_obj = await self._tokenize_one_request(obj) state = self._send_one_request(obj, tokenized_obj, created_time) @@ -552,11 +556,6 @@ class TokenizerManager: else: mm_inputs = None - if self.server_args.enable_lora and obj.lora_path: - # Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in - # `lora_path` with their corresponding unique LoRA IDs, as required for internal processing. - obj.lora_id = await self.lora_registry.acquire(obj.lora_path) - self._validate_one_request(obj, input_ids) return self._create_tokenized_object( obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids @@ -774,10 +773,6 @@ class TokenizerManager: msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}" logger.info(msg) - # Mark ongoing LoRA request as finished. - if self.server_args.enable_lora and obj.lora_path: - await self.lora_registry.release(obj.lora_id) - # Check if this was an abort/error created by scheduler if isinstance(out["meta_info"].get("finish_reason"), dict): finish_reason = out["meta_info"]["finish_reason"] @@ -796,6 +791,11 @@ class TokenizerManager: # Delete the key to prevent resending abort request to the scheduler and # to ensure aborted request state is cleaned up. del self.rid_to_state[state.obj.rid] + + # Mark ongoing LoRA request as finished. + if self.server_args.enable_lora and state.obj.lora_path: + await self.lora_registry.release(state.obj.lora_id) + raise fastapi.HTTPException( status_code=finish_reason["status_code"], detail=finish_reason["message"], @@ -1599,6 +1599,10 @@ class TokenizerManager: meta_info["e2e_latency"] = state.finished_time - state.created_time del self.rid_to_state[rid] + # Mark ongoing LoRA request as finished. + if self.server_args.enable_lora and state.obj.lora_path: + asyncio.create_task(self.lora_registry.release(state.obj.lora_id)) + state.out_list.append(out_dict) state.event.set() diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index a234e7547..f8be9c8d8 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2960,7 +2960,7 @@ class ConcurrentCounter: This suspends the calling coroutine without blocking the thread, allowing other tasks to run while waiting. When the counter becomes zero, the coroutine resumes. """ - self.wait_for(lambda count: count == 0) + await self.wait_for(lambda count: count == 0) @lru_cache(maxsize=1)