[chore] Clean up redundant lora_weight_names concept to simplify code (#9131)
This commit is contained in:
@@ -32,8 +32,8 @@ from sglang.srt.lora.utils import (
|
|||||||
LoRABatchInfo,
|
LoRABatchInfo,
|
||||||
LoRAType,
|
LoRAType,
|
||||||
get_layer_id,
|
get_layer_id,
|
||||||
get_normalized_lora_weight_names,
|
get_normalized_target_modules,
|
||||||
get_weight_name,
|
get_target_module_name,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
@@ -350,12 +350,20 @@ class LoRAManager:
|
|||||||
"""
|
"""
|
||||||
for layer_id, layer_modules in enumerate(self.lora_modules):
|
for layer_id, layer_modules in enumerate(self.lora_modules):
|
||||||
for module_name, module in layer_modules.items():
|
for module_name, module in layer_modules.items():
|
||||||
weight_name = get_weight_name(
|
target_module = get_target_module_name(
|
||||||
module_name, self.memory_pool.lora_weight_names
|
module_name, self.memory_pool.target_modules
|
||||||
)
|
)
|
||||||
module.set_lora_info(
|
module.set_lora_info(
|
||||||
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
|
self.memory_pool.get_tensor(
|
||||||
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
|
target_module=target_module,
|
||||||
|
layer_id=layer_id,
|
||||||
|
lora_type=LoRAType.LORA_A,
|
||||||
|
),
|
||||||
|
self.memory_pool.get_tensor(
|
||||||
|
target_module=target_module,
|
||||||
|
layer_id=layer_id,
|
||||||
|
lora_type=LoRAType.LORA_B,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_state(
|
def init_state(
|
||||||
@@ -380,7 +388,6 @@ class LoRAManager:
|
|||||||
max_lora_rank=max_lora_rank,
|
max_lora_rank=max_lora_rank,
|
||||||
target_modules=target_modules,
|
target_modules=target_modules,
|
||||||
)
|
)
|
||||||
self.init_lora_weight_names()
|
|
||||||
self.init_lora_modules()
|
self.init_lora_modules()
|
||||||
self.init_memory_pool()
|
self.init_memory_pool()
|
||||||
self.update_lora_info()
|
self.update_lora_info()
|
||||||
@@ -426,6 +433,7 @@ class LoRAManager:
|
|||||||
"enable all support modules types. "
|
"enable all support modules types. "
|
||||||
)
|
)
|
||||||
self.target_modules.update(config.target_modules)
|
self.target_modules.update(config.target_modules)
|
||||||
|
self.target_modules = get_normalized_target_modules(self.target_modules)
|
||||||
|
|
||||||
if max_lora_rank is not None:
|
if max_lora_rank is not None:
|
||||||
self.max_lora_rank = max_lora_rank
|
self.max_lora_rank = max_lora_rank
|
||||||
@@ -435,15 +443,6 @@ class LoRAManager:
|
|||||||
default=0,
|
default=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_lora_weight_names(self):
|
|
||||||
"""
|
|
||||||
Add new LoRA weight names if needed based on the current `self.configs`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
|
|
||||||
self.target_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_lora_weights(self, lora_ref: LoRARef):
|
def load_lora_weights(self, lora_ref: LoRARef):
|
||||||
"""
|
"""
|
||||||
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
|
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
|
||||||
@@ -467,7 +466,7 @@ class LoRAManager:
|
|||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
max_lora_rank=self.max_lora_rank,
|
max_lora_rank=self.max_lora_rank,
|
||||||
lora_weight_names=self.lora_weight_names,
|
target_modules=self.target_modules,
|
||||||
base_model=self.base_model,
|
base_model=self.base_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -494,7 +493,7 @@ class LoRAManager:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# The module should be converted if it is included in target_names
|
# The module should be converted if it is included in target_names
|
||||||
if module_name.split(".")[-1] in self.lora_weight_names:
|
if module_name.split(".")[-1] in self.target_modules:
|
||||||
layer_id = get_layer_id(module_name)
|
layer_id = get_layer_id(module_name)
|
||||||
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
||||||
module_name, module
|
module_name, module
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ from sglang.srt.lora.utils import (
|
|||||||
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
||||||
LoRAType,
|
LoRAType,
|
||||||
get_hidden_dim,
|
get_hidden_dim,
|
||||||
get_normalized_lora_weight_names,
|
get_normalized_target_modules,
|
||||||
get_stacked_multiply,
|
get_stacked_multiply,
|
||||||
get_weight_name,
|
get_target_module_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -52,7 +52,7 @@ class LoRAMemoryPool:
|
|||||||
tp_size: int,
|
tp_size: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
max_lora_rank: int,
|
max_lora_rank: int,
|
||||||
lora_weight_names: Set[str],
|
target_modules: Set[str],
|
||||||
base_model: torch.nn.Module,
|
base_model: torch.nn.Module,
|
||||||
):
|
):
|
||||||
self.base_hf_config: AutoConfig = base_hf_config
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
@@ -62,7 +62,7 @@ class LoRAMemoryPool:
|
|||||||
self.tp_size: int = tp_size
|
self.tp_size: int = tp_size
|
||||||
self.tp_rank: int = tp_rank
|
self.tp_rank: int = tp_rank
|
||||||
self.max_lora_rank: int = max_lora_rank
|
self.max_lora_rank: int = max_lora_rank
|
||||||
self.lora_weight_names: Set[str] = lora_weight_names
|
self.target_modules: Set[str] = target_modules
|
||||||
|
|
||||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||||
# A_buffer contains num_layer number of row-major tensors with shape
|
# A_buffer contains num_layer number of row-major tensors with shape
|
||||||
@@ -95,8 +95,8 @@ class LoRAMemoryPool:
|
|||||||
"""
|
"""
|
||||||
if config.r > self.max_lora_rank:
|
if config.r > self.max_lora_rank:
|
||||||
return False
|
return False
|
||||||
weights = get_normalized_lora_weight_names(config.target_modules)
|
target_module_names = get_normalized_target_modules(config.target_modules)
|
||||||
return weights.issubset(self.lora_weight_names)
|
return target_module_names.issubset(self.target_modules)
|
||||||
|
|
||||||
if isinstance(config, LoRAConfig):
|
if isinstance(config, LoRAConfig):
|
||||||
return _can_support(config)
|
return _can_support(config)
|
||||||
@@ -139,10 +139,10 @@ class LoRAMemoryPool:
|
|||||||
|
|
||||||
def init_buffer(
|
def init_buffer(
|
||||||
buffer: Dict[str, List[torch.Tensor]],
|
buffer: Dict[str, List[torch.Tensor]],
|
||||||
lora_weight_names: Set[str],
|
target_modules: Set[str],
|
||||||
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
||||||
):
|
):
|
||||||
for module_name in lora_weight_names:
|
for module_name in target_modules:
|
||||||
lora_shape = get_lora_shape_fn(
|
lora_shape = get_lora_shape_fn(
|
||||||
module_name, base_model, self.max_lora_rank
|
module_name, base_model, self.max_lora_rank
|
||||||
)
|
)
|
||||||
@@ -157,13 +157,13 @@ class LoRAMemoryPool:
|
|||||||
|
|
||||||
init_buffer(
|
init_buffer(
|
||||||
self.A_buffer,
|
self.A_buffer,
|
||||||
self.lora_weight_names,
|
self.target_modules,
|
||||||
self.get_lora_A_shape,
|
self.get_lora_A_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
init_buffer(
|
init_buffer(
|
||||||
self.B_buffer,
|
self.B_buffer,
|
||||||
self.lora_weight_names,
|
self.target_modules,
|
||||||
self.get_lora_B_shape,
|
self.get_lora_B_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -242,32 +242,34 @@ class LoRAMemoryPool:
|
|||||||
for layer_id in range(self.num_layer):
|
for layer_id in range(self.num_layer):
|
||||||
layer_weights = lora_adapter.layers[layer_id].weights
|
layer_weights = lora_adapter.layers[layer_id].weights
|
||||||
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
||||||
weight_name: None for weight_name in self.A_buffer
|
target_module: None for target_module in self.A_buffer
|
||||||
}
|
}
|
||||||
temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
|
temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
|
||||||
weight_name: None for weight_name in self.B_buffer
|
target_module: None for target_module in self.B_buffer
|
||||||
}
|
}
|
||||||
for name, weights in layer_weights.items():
|
for name, weights in layer_weights.items():
|
||||||
lora_weight_name = get_weight_name(name, self.lora_weight_names)
|
target_module = get_target_module_name(name, self.target_modules)
|
||||||
if "lora_A" in name:
|
if "lora_A" in name:
|
||||||
temp_A_buffer[lora_weight_name] = weights
|
temp_A_buffer[target_module] = weights
|
||||||
else:
|
else:
|
||||||
temp_B_buffer[lora_weight_name] = weights
|
temp_B_buffer[target_module] = weights
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
cur_layer_modules = lora_modules[layer_id]
|
cur_layer_modules = lora_modules[layer_id]
|
||||||
for module_name, module in cur_layer_modules.items():
|
for module_name, module in cur_layer_modules.items():
|
||||||
weight_name = get_weight_name(module_name, self.lora_weight_names)
|
target_module = get_target_module_name(
|
||||||
|
module_name, self.target_modules
|
||||||
|
)
|
||||||
|
|
||||||
if temp_A_buffer[weight_name] is None:
|
if temp_A_buffer[target_module] is None:
|
||||||
# Skip weight slicing if the weight is not present in the adapter
|
# Skip weight slicing if the weight is not present in the adapter
|
||||||
continue
|
continue
|
||||||
|
|
||||||
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
temp_A_buffer[target_module] = module.slice_lora_a_weights(
|
||||||
temp_A_buffer[weight_name], self.tp_rank
|
temp_A_buffer[target_module], self.tp_rank
|
||||||
)
|
)
|
||||||
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
temp_B_buffer[target_module] = module.slice_lora_b_weights(
|
||||||
temp_B_buffer[weight_name], self.tp_rank
|
temp_B_buffer[target_module], self.tp_rank
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, weights in temp_A_buffer.items():
|
for name, weights in temp_A_buffer.items():
|
||||||
@@ -282,12 +284,12 @@ class LoRAMemoryPool:
|
|||||||
load_lora_weight_tensor(buffer_view, weights)
|
load_lora_weight_tensor(buffer_view, weights)
|
||||||
|
|
||||||
def get_tensor(
|
def get_tensor(
|
||||||
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
self, target_module: str, layer_id: int, lora_type: LoRAType
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if lora_type == LoRAType.LORA_A:
|
if lora_type == LoRAType.LORA_A:
|
||||||
return self.A_buffer[weight_name][layer_id]
|
return self.A_buffer[target_module][layer_id]
|
||||||
|
|
||||||
return self.B_buffer[weight_name][layer_id]
|
return self.B_buffer[target_module][layer_id]
|
||||||
|
|
||||||
def get_buffer_id(self, lora_uid: str):
|
def get_buffer_id(self, lora_uid: str):
|
||||||
return self.uid_to_buffer_id[lora_uid]
|
return self.uid_to_buffer_id[lora_uid]
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ def get_hidden_dim(
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def get_normalized_lora_weight_names(
|
def get_normalized_target_modules(
|
||||||
target_modules: Iterable[str],
|
target_modules: Iterable[str],
|
||||||
) -> set[str]:
|
) -> set[str]:
|
||||||
"""
|
"""
|
||||||
@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names(
|
|||||||
|
|
||||||
result = set()
|
result = set()
|
||||||
for name in target_modules:
|
for name in target_modules:
|
||||||
weight_name = params_mapping.get(name, name)
|
normalized_name = params_mapping.get(name, name)
|
||||||
result.add(weight_name)
|
result.add(normalized_name)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int:
|
|||||||
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
||||||
|
|
||||||
|
|
||||||
def get_weight_name(
|
def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> str:
|
||||||
target_name: str, lora_weight_names: Tuple[Set[str]]
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
"""
|
||||||
Get the weight name in lora_weight_names that can match target_name.
|
Get the target module name in target_modules that can match full_module_name.
|
||||||
|
|
||||||
If there is a weight name in lora_weight_names that can match target_name, return this name
|
If there is a target module name in target_modules that can match full_module_name, return this name
|
||||||
Else raise ValueError.
|
Else raise ValueError.
|
||||||
"""
|
"""
|
||||||
for weight_name in lora_weight_names:
|
for target_module in target_modules:
|
||||||
if weight_name in target_name:
|
if target_module in full_module_name:
|
||||||
return weight_name
|
return target_module
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot find weight name for {target_name} in {lora_weight_names}"
|
f"Cannot find target module name for {full_module_name} in {target_modules}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2874,6 +2874,8 @@ SUPPORTED_LORA_TARGET_MODULES = [
|
|||||||
"gate_proj",
|
"gate_proj",
|
||||||
"up_proj",
|
"up_proj",
|
||||||
"down_proj",
|
"down_proj",
|
||||||
|
"qkv_proj",
|
||||||
|
"gate_up_proj",
|
||||||
]
|
]
|
||||||
|
|
||||||
LORA_TARGET_ALL_MODULES = "all"
|
LORA_TARGET_ALL_MODULES = "all"
|
||||||
|
|||||||
Reference in New Issue
Block a user