Files
sglang/python/sglang/srt/lora/utils.py
2025-09-04 15:56:33 +08:00

135 lines
3.8 KiB
Python

import re
from dataclasses import dataclass
from enum import Enum
from typing import Iterable, Optional, Set, Tuple
import torch
from sglang.srt.hf_transformers_utils import AutoConfig
@dataclass
class LoRABatchInfo:
# Batch size
bs: int
# Lengths of each sequence in shape (bs,)
seg_lens: torch.Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr: torch.Tensor
# Maximum sequence length of current batch
max_len: int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices: torch.Tensor
# ranks of each lora adapter, in shape (lora_num,)
lora_ranks: torch.Tensor
# scaling of each lora adapter, in shape (lora_num,)
scalings: torch.Tensor
class LoRAType(Enum):
LORA_A = 0
LORA_B = 1
def get_layer_id(name: str) -> int:
"""
Extract integer id of layer from its name in string.
"""
match = re.search(r"layers\.(\d+)\.", name)
if match is None:
return None
return int(match.group(1))
def get_hidden_dim(
module_name: str, config: AutoConfig, base_model: torch.nn.Module
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
"""
if hasattr(base_model, "get_hidden_dim"):
return base_model.get_hidden_dim(module_name)
else:
"""
WARNING: get_hidden_dim() is not defined,
which is used to get the hidden dim for different lora modules
Use the default one, but please check if it is correct for your model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
if module_name == "qkv_proj":
return config.hidden_size, head_dim * (
config.num_attention_heads + config.num_key_value_heads * 2
)
elif module_name == "o_proj":
return (
head_dim * config.num_attention_heads,
config.hidden_size,
)
elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size * 2
elif module_name == "down_proj":
return config.intermediate_size, config.hidden_size
else:
raise NotImplementedError()
def get_normalized_target_modules(
target_modules: Iterable[str],
) -> set[str]:
"""
Mapping a list of target module name to names of the normalized LoRA weights.
"""
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
result = set()
for name in target_modules:
normalized_name = params_mapping.get(name, name)
result.add(normalized_name)
return result
def get_stacked_multiply(module_name: str) -> int:
"""
Mapping a lora module name to its magnification at output dimension
"""
stacked_rank = {
"qkv_proj": 3,
"gate_up_proj": 2,
}
return stacked_rank[module_name] if module_name in stacked_rank else 1
def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> str:
"""
Get the target module name in target_modules that can match full_module_name.
If there is a target module name in target_modules that can match full_module_name, return this name
Else raise ValueError.
"""
for target_module in target_modules:
if target_module in full_module_name:
return target_module
raise ValueError(
f"Cannot find target module name for {full_module_name} in {target_modules}"
)
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]