Files
sglang/python/sglang/srt/lora/lora_manager.py

325 lines
12 KiB
Python

"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
# and "Punica: Multi-Tenant LoRA Serving"
import logging
import re
import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip, replace_submodule
logger = logging.getLogger(__name__)
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import SegmentGEMMWrapper
def get_module_name(name):
# Fallback solution of mapping from config module name to module name in model class.
# Please check if it aligns with your base model.
# Please implement the function in the model class if it is not.
# You can reference this function in llama.py.
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_hidden_dim(module_name, config):
# Fallback solution of get_hidden_dim for different modules
# Please check if it aligns with your base model.
# Please implement the function in the model class if it is not.
# You can reference this function in llama.py.
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return config.hidden_size, config.hidden_size
elif module_name in ["kv_proj"]:
return config.hidden_size, config.hidden_size // (
config.num_attention_heads // config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size
elif module_name == "down_proj":
return config.intermediate_size, config.hidden_size
else:
raise NotImplementedError()
def get_stacked_name(name):
# origin name -> (name for A, name for B)
params_mapping = {
"q_proj": ("qkv_proj", "q_proj"),
"k_proj": ("qkv_proj", "kv_proj"),
"v_proj": ("qkv_proj", "kv_proj"),
"gate_proj": ("gate_up_proj", "gate_up_proj"),
"up_proj": ("gate_up_proj", "gate_up_proj"),
}
return params_mapping.get(name, (name, name))
def get_layer_id(name):
match = re.search(r"layers\.(\d+)\.", name)
if match is None:
return None
return int(match.group(1))
class LoRAManager:
def __init__(
self,
base_model,
lora_paths,
base_hf_config,
max_loras_per_batch,
load_config,
dtype,
):
self.base_model = base_model
self.lora_paths = lora_paths
self.base_hf_config = base_hf_config
self.max_loras_per_batch = max_loras_per_batch
self.load_config = load_config
self.dtype = dtype
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
self.init_loras()
self.init_lora_memory_pool()
self.init_lora_batch()
def match_target_modules(self, module_name):
for target_module in self.target_modules:
if module_name.split(".")[-1] == target_module:
return True
return False
def get_target_modules(self):
modules = []
for module_name, module in self.base_model.named_modules():
if self.match_target_modules(module_name):
modules.append((module_name, module))
return modules
def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(
module, self.segment_gemm, self.max_lora_dim, self.scaling
)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def init_loras(self):
# get configs and target modules
self.configs = {}
self.origin_target_modules = set()
for name, path in self.lora_paths.items():
self.configs[name] = LoRAConfig(path)
self.origin_target_modules = set(self.origin_target_modules) | set(
self.configs[name].target_modules
)
if hasattr(self.base_model, "get_module_name"):
self.target_modules = {
self.base_model.get_module_name(module)
for module in self.origin_target_modules
}
else:
logger.warning(
f"WARNING: get_module_name() is not defined, "
f"which is used to map config module name to model implementation module name."
f"Use the default one, but please check if it is correct for your model."
)
self.target_modules = {
get_module_name(module) for module in self.origin_target_modules
}
self.target_weights = set(
[get_stacked_name(module) for module in self.origin_target_modules]
)
# load all weights to cpu
self.loras = []
self.lora_id = {}
for name in self.lora_paths.keys():
self.lora_id[name] = len(self.loras)
self.loras.append(
LoRAAdapter(
name, self.configs[name], self.base_hf_config, self.load_config
)
)
self.loras[-1].initialize_weights()
# misc lora configs
self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
self.scaling = self.loras[0].scaling
# FIXME remove the restrictions
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
assert all(x.scaling == self.scaling for x in self.loras)
# monkey patch to use the LoRA version
self.lora_modules = []
for module_name, module in self.get_target_modules():
self.lora_modules.append(
(module_name, self.set_lora_module(module_name, module))
)
def init_lora_memory_pool(self):
# preallocate lora memory pool
self.A_buffer = {}
self.B_buffer = {}
num_layer = self.base_hf_config.num_hidden_layers
for module_A, module_B in self.target_weights:
# init A tensor, column_major=True
if hasattr(self.base_model, "get_hidden_dim"):
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
else:
logger.warning(
f"WARNING: get_hidden_dim() is not defined, "
f"which is used to get the hidden dim for different lora modules"
f"Use the default one, but please check if it is correct for your model."
)
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
c = self.loras[-1].get_stacked_multiply(module_A)
if module_A not in self.A_buffer:
self.A_buffer[module_A] = [
torch.empty(
(
self.max_loras_per_batch,
self.max_lora_dim * c,
hidden_dim_A,
),
dtype=self.dtype,
device="cuda",
)
for i in range(num_layer)
]
# init B tensor, column_major=True
if hasattr(self.base_model, "get_hidden_dim"):
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
else:
logger.warning(
f"WARNING: get_hidden_dim() is not defined, "
f"which is used to get the hidden dim for different lora modules"
f"Use the default one, but please check if it is correct for your model."
)
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
c = self.loras[-1].get_stacked_multiply(module_B)
if module_B not in self.B_buffer:
self.B_buffer[module_B] = [
torch.empty(
(
self.max_loras_per_batch,
hidden_dim_B * c,
self.max_lora_dim,
),
dtype=self.dtype,
device="cuda",
)
for i in range(num_layer)
]
def init_lora_batch(self):
self.active_uids = set() # set of active loras
self.buffer_id = {} # lora uid -> idx in memory pool
def get_weight_name(self, name, idx):
for target_weight_name in self.target_weights:
if target_weight_name[idx] in name:
return target_weight_name[idx]
def load_lora(self, uid, buffer_id):
num_layer = self.base_hf_config.num_hidden_layers
if uid is None:
for i in range(num_layer):
for k in self.A_buffer.keys():
self.A_buffer[k][i][buffer_id] *= 0
return
for i in range(num_layer):
layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
for name, weights in layer_weights.items():
if "lora_A" in name:
lora_weight_name = self.get_weight_name(name, 0)
if lora_weight_name:
self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
else:
lora_weight_name = self.get_weight_name(name, 1)
if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
i = 0
evictable_uids = list(self.active_uids)
for uid in cur_uids:
if uid not in self.active_uids:
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
i += 1
if i < len(evictable_uids):
self.active_uids.remove(evictable_uids[i])
self.buffer_id.pop(evictable_uids[i])
self.load_lora(uid, i)
self.active_uids.add(uid)
self.buffer_id[uid] = i
i += 1
if cur_uids == set([None]):
return
# setup lora in forward modules
bs = forward_batch.batch_size
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs)
)
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.buffer_id[lora_path]
for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name)
if "qkv_proj" not in module_name:
weight_name = self.get_weight_name(module_name, 0)
module.set_lora_info(
self.A_buffer[weight_name][layer_id],
self.B_buffer[weight_name][layer_id],
bs,
seg_lens,
weight_indices,
)
else:
module.set_lora_info(
self.A_buffer["qkv_proj"][layer_id],
self.B_buffer["q_proj"][layer_id],
self.B_buffer["kv_proj"][layer_id],
bs,
seg_lens,
weight_indices,
)