[chore] Clean up redundant lora_weight_names concept to simplify code (#9131)
This commit is contained in:
@@ -32,8 +32,8 @@ from sglang.srt.lora.utils import (
|
||||
LoRABatchInfo,
|
||||
LoRAType,
|
||||
get_layer_id,
|
||||
get_normalized_lora_weight_names,
|
||||
get_weight_name,
|
||||
get_normalized_target_modules,
|
||||
get_target_module_name,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -350,12 +350,20 @@ class LoRAManager:
|
||||
"""
|
||||
for layer_id, layer_modules in enumerate(self.lora_modules):
|
||||
for module_name, module in layer_modules.items():
|
||||
weight_name = get_weight_name(
|
||||
module_name, self.memory_pool.lora_weight_names
|
||||
target_module = get_target_module_name(
|
||||
module_name, self.memory_pool.target_modules
|
||||
)
|
||||
module.set_lora_info(
|
||||
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
|
||||
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
|
||||
self.memory_pool.get_tensor(
|
||||
target_module=target_module,
|
||||
layer_id=layer_id,
|
||||
lora_type=LoRAType.LORA_A,
|
||||
),
|
||||
self.memory_pool.get_tensor(
|
||||
target_module=target_module,
|
||||
layer_id=layer_id,
|
||||
lora_type=LoRAType.LORA_B,
|
||||
),
|
||||
)
|
||||
|
||||
def init_state(
|
||||
@@ -380,7 +388,6 @@ class LoRAManager:
|
||||
max_lora_rank=max_lora_rank,
|
||||
target_modules=target_modules,
|
||||
)
|
||||
self.init_lora_weight_names()
|
||||
self.init_lora_modules()
|
||||
self.init_memory_pool()
|
||||
self.update_lora_info()
|
||||
@@ -426,6 +433,7 @@ class LoRAManager:
|
||||
"enable all support modules types. "
|
||||
)
|
||||
self.target_modules.update(config.target_modules)
|
||||
self.target_modules = get_normalized_target_modules(self.target_modules)
|
||||
|
||||
if max_lora_rank is not None:
|
||||
self.max_lora_rank = max_lora_rank
|
||||
@@ -435,15 +443,6 @@ class LoRAManager:
|
||||
default=0,
|
||||
)
|
||||
|
||||
def init_lora_weight_names(self):
|
||||
"""
|
||||
Add new LoRA weight names if needed based on the current `self.configs`.
|
||||
"""
|
||||
|
||||
self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
|
||||
self.target_modules
|
||||
)
|
||||
|
||||
def load_lora_weights(self, lora_ref: LoRARef):
|
||||
"""
|
||||
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
|
||||
@@ -467,7 +466,7 @@ class LoRAManager:
|
||||
tp_size=self.tp_size,
|
||||
tp_rank=self.tp_rank,
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
lora_weight_names=self.lora_weight_names,
|
||||
target_modules=self.target_modules,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
|
||||
@@ -494,7 +493,7 @@ class LoRAManager:
|
||||
continue
|
||||
|
||||
# The module should be converted if it is included in target_names
|
||||
if module_name.split(".")[-1] in self.lora_weight_names:
|
||||
if module_name.split(".")[-1] in self.target_modules:
|
||||
layer_id = get_layer_id(module_name)
|
||||
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
||||
module_name, module
|
||||
|
||||
@@ -13,9 +13,9 @@ from sglang.srt.lora.utils import (
|
||||
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
||||
LoRAType,
|
||||
get_hidden_dim,
|
||||
get_normalized_lora_weight_names,
|
||||
get_normalized_target_modules,
|
||||
get_stacked_multiply,
|
||||
get_weight_name,
|
||||
get_target_module_name,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -52,7 +52,7 @@ class LoRAMemoryPool:
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
max_lora_rank: int,
|
||||
lora_weight_names: Set[str],
|
||||
target_modules: Set[str],
|
||||
base_model: torch.nn.Module,
|
||||
):
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
@@ -62,7 +62,7 @@ class LoRAMemoryPool:
|
||||
self.tp_size: int = tp_size
|
||||
self.tp_rank: int = tp_rank
|
||||
self.max_lora_rank: int = max_lora_rank
|
||||
self.lora_weight_names: Set[str] = lora_weight_names
|
||||
self.target_modules: Set[str] = target_modules
|
||||
|
||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||
# A_buffer contains num_layer number of row-major tensors with shape
|
||||
@@ -95,8 +95,8 @@ class LoRAMemoryPool:
|
||||
"""
|
||||
if config.r > self.max_lora_rank:
|
||||
return False
|
||||
weights = get_normalized_lora_weight_names(config.target_modules)
|
||||
return weights.issubset(self.lora_weight_names)
|
||||
target_module_names = get_normalized_target_modules(config.target_modules)
|
||||
return target_module_names.issubset(self.target_modules)
|
||||
|
||||
if isinstance(config, LoRAConfig):
|
||||
return _can_support(config)
|
||||
@@ -139,10 +139,10 @@ class LoRAMemoryPool:
|
||||
|
||||
def init_buffer(
|
||||
buffer: Dict[str, List[torch.Tensor]],
|
||||
lora_weight_names: Set[str],
|
||||
target_modules: Set[str],
|
||||
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
||||
):
|
||||
for module_name in lora_weight_names:
|
||||
for module_name in target_modules:
|
||||
lora_shape = get_lora_shape_fn(
|
||||
module_name, base_model, self.max_lora_rank
|
||||
)
|
||||
@@ -157,13 +157,13 @@ class LoRAMemoryPool:
|
||||
|
||||
init_buffer(
|
||||
self.A_buffer,
|
||||
self.lora_weight_names,
|
||||
self.target_modules,
|
||||
self.get_lora_A_shape,
|
||||
)
|
||||
|
||||
init_buffer(
|
||||
self.B_buffer,
|
||||
self.lora_weight_names,
|
||||
self.target_modules,
|
||||
self.get_lora_B_shape,
|
||||
)
|
||||
|
||||
@@ -242,32 +242,34 @@ class LoRAMemoryPool:
|
||||
for layer_id in range(self.num_layer):
|
||||
layer_weights = lora_adapter.layers[layer_id].weights
|
||||
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
||||
weight_name: None for weight_name in self.A_buffer
|
||||
target_module: None for target_module in self.A_buffer
|
||||
}
|
||||
temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
|
||||
weight_name: None for weight_name in self.B_buffer
|
||||
target_module: None for target_module in self.B_buffer
|
||||
}
|
||||
for name, weights in layer_weights.items():
|
||||
lora_weight_name = get_weight_name(name, self.lora_weight_names)
|
||||
target_module = get_target_module_name(name, self.target_modules)
|
||||
if "lora_A" in name:
|
||||
temp_A_buffer[lora_weight_name] = weights
|
||||
temp_A_buffer[target_module] = weights
|
||||
else:
|
||||
temp_B_buffer[lora_weight_name] = weights
|
||||
temp_B_buffer[target_module] = weights
|
||||
|
||||
if self.tp_size > 1:
|
||||
cur_layer_modules = lora_modules[layer_id]
|
||||
for module_name, module in cur_layer_modules.items():
|
||||
weight_name = get_weight_name(module_name, self.lora_weight_names)
|
||||
target_module = get_target_module_name(
|
||||
module_name, self.target_modules
|
||||
)
|
||||
|
||||
if temp_A_buffer[weight_name] is None:
|
||||
if temp_A_buffer[target_module] is None:
|
||||
# Skip weight slicing if the weight is not present in the adapter
|
||||
continue
|
||||
|
||||
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
||||
temp_A_buffer[weight_name], self.tp_rank
|
||||
temp_A_buffer[target_module] = module.slice_lora_a_weights(
|
||||
temp_A_buffer[target_module], self.tp_rank
|
||||
)
|
||||
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
||||
temp_B_buffer[weight_name], self.tp_rank
|
||||
temp_B_buffer[target_module] = module.slice_lora_b_weights(
|
||||
temp_B_buffer[target_module], self.tp_rank
|
||||
)
|
||||
|
||||
for name, weights in temp_A_buffer.items():
|
||||
@@ -282,12 +284,12 @@ class LoRAMemoryPool:
|
||||
load_lora_weight_tensor(buffer_view, weights)
|
||||
|
||||
def get_tensor(
|
||||
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
||||
self, target_module: str, layer_id: int, lora_type: LoRAType
|
||||
) -> torch.Tensor:
|
||||
if lora_type == LoRAType.LORA_A:
|
||||
return self.A_buffer[weight_name][layer_id]
|
||||
return self.A_buffer[target_module][layer_id]
|
||||
|
||||
return self.B_buffer[weight_name][layer_id]
|
||||
return self.B_buffer[target_module][layer_id]
|
||||
|
||||
def get_buffer_id(self, lora_uid: str):
|
||||
return self.uid_to_buffer_id[lora_uid]
|
||||
|
||||
@@ -84,7 +84,7 @@ def get_hidden_dim(
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_normalized_lora_weight_names(
|
||||
def get_normalized_target_modules(
|
||||
target_modules: Iterable[str],
|
||||
) -> set[str]:
|
||||
"""
|
||||
@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names(
|
||||
|
||||
result = set()
|
||||
for name in target_modules:
|
||||
weight_name = params_mapping.get(name, name)
|
||||
result.add(weight_name)
|
||||
normalized_name = params_mapping.get(name, name)
|
||||
result.add(normalized_name)
|
||||
return result
|
||||
|
||||
|
||||
@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int:
|
||||
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
||||
|
||||
|
||||
def get_weight_name(
|
||||
target_name: str, lora_weight_names: Tuple[Set[str]]
|
||||
) -> Optional[str]:
|
||||
def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> str:
|
||||
"""
|
||||
Get the weight name in lora_weight_names that can match target_name.
|
||||
Get the target module name in target_modules that can match full_module_name.
|
||||
|
||||
If there is a weight name in lora_weight_names that can match target_name, return this name
|
||||
If there is a target module name in target_modules that can match full_module_name, return this name
|
||||
Else raise ValueError.
|
||||
"""
|
||||
for weight_name in lora_weight_names:
|
||||
if weight_name in target_name:
|
||||
return weight_name
|
||||
for target_module in target_modules:
|
||||
if target_module in full_module_name:
|
||||
return target_module
|
||||
raise ValueError(
|
||||
f"Cannot find weight name for {target_name} in {lora_weight_names}"
|
||||
f"Cannot find target module name for {full_module_name} in {target_modules}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2874,6 +2874,8 @@ SUPPORTED_LORA_TARGET_MODULES = [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
"qkv_proj",
|
||||
"gate_up_proj",
|
||||
]
|
||||
|
||||
LORA_TARGET_ALL_MODULES = "all"
|
||||
|
||||
Reference in New Issue
Block a user