diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 3ab93c73b..c2a3eaabc 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -32,8 +32,8 @@ from sglang.srt.lora.utils import ( LoRABatchInfo, LoRAType, get_layer_id, - get_normalized_lora_weight_names, - get_weight_name, + get_normalized_target_modules, + get_target_module_name, ) from sglang.srt.managers.io_struct import LoRAUpdateResult from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -350,12 +350,20 @@ class LoRAManager: """ for layer_id, layer_modules in enumerate(self.lora_modules): for module_name, module in layer_modules.items(): - weight_name = get_weight_name( - module_name, self.memory_pool.lora_weight_names + target_module = get_target_module_name( + module_name, self.memory_pool.target_modules ) module.set_lora_info( - self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A), - self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B), + self.memory_pool.get_tensor( + target_module=target_module, + layer_id=layer_id, + lora_type=LoRAType.LORA_A, + ), + self.memory_pool.get_tensor( + target_module=target_module, + layer_id=layer_id, + lora_type=LoRAType.LORA_B, + ), ) def init_state( @@ -380,7 +388,6 @@ class LoRAManager: max_lora_rank=max_lora_rank, target_modules=target_modules, ) - self.init_lora_weight_names() self.init_lora_modules() self.init_memory_pool() self.update_lora_info() @@ -426,6 +433,7 @@ class LoRAManager: "enable all support modules types. " ) self.target_modules.update(config.target_modules) + self.target_modules = get_normalized_target_modules(self.target_modules) if max_lora_rank is not None: self.max_lora_rank = max_lora_rank @@ -435,15 +443,6 @@ class LoRAManager: default=0, ) - def init_lora_weight_names(self): - """ - Add new LoRA weight names if needed based on the current `self.configs`. - """ - - self.lora_weight_names: Set[str] = get_normalized_lora_weight_names( - self.target_modules - ) - def load_lora_weights(self, lora_ref: LoRARef): """ Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation. @@ -467,7 +466,7 @@ class LoRAManager: tp_size=self.tp_size, tp_rank=self.tp_rank, max_lora_rank=self.max_lora_rank, - lora_weight_names=self.lora_weight_names, + target_modules=self.target_modules, base_model=self.base_model, ) @@ -494,7 +493,7 @@ class LoRAManager: continue # The module should be converted if it is included in target_names - if module_name.split(".")[-1] in self.lora_weight_names: + if module_name.split(".")[-1] in self.target_modules: layer_id = get_layer_id(module_name) self.lora_modules[layer_id][module_name] = self.set_lora_module( module_name, module diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 56cd39d67..94955f414 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -13,9 +13,9 @@ from sglang.srt.lora.utils import ( ROW_PARALLELISM_LINEAR_LORA_NAMES, LoRAType, get_hidden_dim, - get_normalized_lora_weight_names, + get_normalized_target_modules, get_stacked_multiply, - get_weight_name, + get_target_module_name, ) logger = logging.getLogger(__name__) @@ -52,7 +52,7 @@ class LoRAMemoryPool: tp_size: int, tp_rank: int, max_lora_rank: int, - lora_weight_names: Set[str], + target_modules: Set[str], base_model: torch.nn.Module, ): self.base_hf_config: AutoConfig = base_hf_config @@ -62,7 +62,7 @@ class LoRAMemoryPool: self.tp_size: int = tp_size self.tp_rank: int = tp_rank self.max_lora_rank: int = max_lora_rank - self.lora_weight_names: Set[str] = lora_weight_names + self.target_modules: Set[str] = target_modules # Both A_buffer and B_buffer maps lora weight names to its buffer space. # A_buffer contains num_layer number of row-major tensors with shape @@ -95,8 +95,8 @@ class LoRAMemoryPool: """ if config.r > self.max_lora_rank: return False - weights = get_normalized_lora_weight_names(config.target_modules) - return weights.issubset(self.lora_weight_names) + target_module_names = get_normalized_target_modules(config.target_modules) + return target_module_names.issubset(self.target_modules) if isinstance(config, LoRAConfig): return _can_support(config) @@ -139,10 +139,10 @@ class LoRAMemoryPool: def init_buffer( buffer: Dict[str, List[torch.Tensor]], - lora_weight_names: Set[str], + target_modules: Set[str], get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]], ): - for module_name in lora_weight_names: + for module_name in target_modules: lora_shape = get_lora_shape_fn( module_name, base_model, self.max_lora_rank ) @@ -157,13 +157,13 @@ class LoRAMemoryPool: init_buffer( self.A_buffer, - self.lora_weight_names, + self.target_modules, self.get_lora_A_shape, ) init_buffer( self.B_buffer, - self.lora_weight_names, + self.target_modules, self.get_lora_B_shape, ) @@ -242,32 +242,34 @@ class LoRAMemoryPool: for layer_id in range(self.num_layer): layer_weights = lora_adapter.layers[layer_id].weights temp_A_buffer: Dict[str, Optional[torch.Tensor]] = { - weight_name: None for weight_name in self.A_buffer + target_module: None for target_module in self.A_buffer } temp_B_buffer: Dict[str, Optional[torch.Tensor]] = { - weight_name: None for weight_name in self.B_buffer + target_module: None for target_module in self.B_buffer } for name, weights in layer_weights.items(): - lora_weight_name = get_weight_name(name, self.lora_weight_names) + target_module = get_target_module_name(name, self.target_modules) if "lora_A" in name: - temp_A_buffer[lora_weight_name] = weights + temp_A_buffer[target_module] = weights else: - temp_B_buffer[lora_weight_name] = weights + temp_B_buffer[target_module] = weights if self.tp_size > 1: cur_layer_modules = lora_modules[layer_id] for module_name, module in cur_layer_modules.items(): - weight_name = get_weight_name(module_name, self.lora_weight_names) + target_module = get_target_module_name( + module_name, self.target_modules + ) - if temp_A_buffer[weight_name] is None: + if temp_A_buffer[target_module] is None: # Skip weight slicing if the weight is not present in the adapter continue - temp_A_buffer[weight_name] = module.slice_lora_a_weights( - temp_A_buffer[weight_name], self.tp_rank + temp_A_buffer[target_module] = module.slice_lora_a_weights( + temp_A_buffer[target_module], self.tp_rank ) - temp_B_buffer[weight_name] = module.slice_lora_b_weights( - temp_B_buffer[weight_name], self.tp_rank + temp_B_buffer[target_module] = module.slice_lora_b_weights( + temp_B_buffer[target_module], self.tp_rank ) for name, weights in temp_A_buffer.items(): @@ -282,12 +284,12 @@ class LoRAMemoryPool: load_lora_weight_tensor(buffer_view, weights) def get_tensor( - self, weight_name: str, layer_id: int, lora_type: LoRAType + self, target_module: str, layer_id: int, lora_type: LoRAType ) -> torch.Tensor: if lora_type == LoRAType.LORA_A: - return self.A_buffer[weight_name][layer_id] + return self.A_buffer[target_module][layer_id] - return self.B_buffer[weight_name][layer_id] + return self.B_buffer[target_module][layer_id] def get_buffer_id(self, lora_uid: str): return self.uid_to_buffer_id[lora_uid] diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index e5aa43eff..1067b40b0 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -84,7 +84,7 @@ def get_hidden_dim( raise NotImplementedError() -def get_normalized_lora_weight_names( +def get_normalized_target_modules( target_modules: Iterable[str], ) -> set[str]: """ @@ -100,8 +100,8 @@ def get_normalized_lora_weight_names( result = set() for name in target_modules: - weight_name = params_mapping.get(name, name) - result.add(weight_name) + normalized_name = params_mapping.get(name, name) + result.add(normalized_name) return result @@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int: return stacked_rank[module_name] if module_name in stacked_rank else 1 -def get_weight_name( - target_name: str, lora_weight_names: Tuple[Set[str]] -) -> Optional[str]: +def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> str: """ - Get the weight name in lora_weight_names that can match target_name. + Get the target module name in target_modules that can match full_module_name. - If there is a weight name in lora_weight_names that can match target_name, return this name + If there is a target module name in target_modules that can match full_module_name, return this name Else raise ValueError. """ - for weight_name in lora_weight_names: - if weight_name in target_name: - return weight_name + for target_module in target_modules: + if target_module in full_module_name: + return target_module raise ValueError( - f"Cannot find weight name for {target_name} in {lora_weight_names}" + f"Cannot find target module name for {full_module_name} in {target_modules}" ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index bcba0503b..b31f2a5ec 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2874,6 +2874,8 @@ SUPPORTED_LORA_TARGET_MODULES = [ "gate_proj", "up_proj", "down_proj", + "qkv_proj", + "gate_up_proj", ] LORA_TARGET_ALL_MODULES = "all"