[Feature] Initial support for multi-LoRA serving (#1307)
This commit is contained in:
403
python/sglang/srt/lora/lora.py
Normal file
403
python/sglang/srt/lora/lora.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
|
||||
# and "Punica: Multi-Tenant LoRA Serving"
|
||||
|
||||
# LoRA layers class inheritance adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.segment_gemm = segment_gemm
|
||||
self.lora_rank = lora_rank
|
||||
self.scaling = scaling
|
||||
self.set_lora = False
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.base_layer.forward(x)
|
||||
|
||||
def set_lora_info(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
self.weight = base_layer.weight
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
|
||||
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
# TODO
|
||||
return output
|
||||
|
||||
def forward(self, input_: torch.Tensor):
|
||||
# duplicate the logic in ColumnParallelLinear
|
||||
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
||||
output_parallel = self.base_layer.quant_method.apply(
|
||||
self.base_layer, input_, bias
|
||||
)
|
||||
|
||||
if self.set_lora:
|
||||
output_parallel = self.apply_lora(output_parallel, input_)
|
||||
|
||||
if self.base_layer.gather_output:
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
|
||||
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
self.bs = bs
|
||||
self.seq_lens = seq_lens
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
# FIXME
|
||||
assert lora_a_output.shape[-1] == self.lora_rank * 2
|
||||
lora_output = torch.empty_like(base_output)
|
||||
output_dim = lora_output.shape[-1] // 2
|
||||
for i in range(2):
|
||||
left = output_dim * i
|
||||
right = left + output_dim
|
||||
lora_output[:, left:right] = self.segment_gemm.run(
|
||||
x=lora_a_output[
|
||||
:, self.lora_rank * i : self.lora_rank * (i + 1)
|
||||
].contiguous(),
|
||||
weights=self.B_buffer[:, left:right, :].contiguous(),
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
|
||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
|
||||
def set_lora_info(
|
||||
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer_qkv = A_buffer_qkv
|
||||
self.B_buffer_q = B_buffer_q
|
||||
self.B_buffer_kv = B_buffer_kv
|
||||
self.bs = bs
|
||||
self.seq_lens = seq_lens
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer_qkv,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
# FIXME parallelize qkv
|
||||
lora_output = torch.empty_like(base_output)
|
||||
# q
|
||||
output_dim_q = self.B_buffer_q.shape[-2]
|
||||
lora_output[:, :output_dim_q] = self.segment_gemm.run(
|
||||
x=lora_a_output[:, : self.lora_rank].contiguous(),
|
||||
weights=self.B_buffer_q,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
# kv
|
||||
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
|
||||
for i in range(2):
|
||||
left = output_dim_kv * i
|
||||
right = left + output_dim_kv
|
||||
lora_output[:, output_dim_q + left : output_dim_q + right] = (
|
||||
self.segment_gemm.run(
|
||||
x=lora_a_output[
|
||||
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
|
||||
].contiguous(),
|
||||
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
|
||||
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
self.bs = bs
|
||||
self.seq_lens = seq_lens
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
lora_output = self.segment_gemm.run(
|
||||
x=lora_output,
|
||||
weights=self.B_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_lens=self.seq_lens,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
def forward(self, input_):
|
||||
# duplicate the logic in RowParallelLinear
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.base_layer.tp_size
|
||||
)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
output_parallel = self.base_layer.quant_method.apply(
|
||||
self.base_layer, input_parallel
|
||||
)
|
||||
|
||||
if self.set_lora:
|
||||
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
||||
|
||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.base_layer.skip_bias_add:
|
||||
output = (
|
||||
output_ + self.base_layer.bias
|
||||
if self.base_layer.bias is not None
|
||||
else output_
|
||||
)
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.base_layer.bias
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def get_lora_layer(
|
||||
layer: nn.Module, segment_gemm, lora_rank, scaling
|
||||
) -> BaseLayerWithLoRA:
|
||||
supported_layer_types = {
|
||||
# the order matters
|
||||
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
||||
QKVParallelLinear: QKVParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
||||
RowParallelLinear: RowParallelLinearWithLoRA,
|
||||
}
|
||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
||||
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
|
||||
return ret
|
||||
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
||||
|
||||
|
||||
def get_mapped_params(module_names):
|
||||
ret = set()
|
||||
for module_name in module_names:
|
||||
ret.add(params_mapping(module_name))
|
||||
return list(ret)
|
||||
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
def __init__(self, config, base_hf_config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.base_hf_config = base_hf_config
|
||||
self.weights = {}
|
||||
self.weight_gpu = {}
|
||||
|
||||
def load_to_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
|
||||
|
||||
def offload_from_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weight_gpu[name] = None
|
||||
|
||||
|
||||
class LoRAAdapter(nn.Module):
|
||||
def __init__(self, uid, config, base_hf_config, load_config):
|
||||
super().__init__()
|
||||
self.uid = uid
|
||||
self.config = config
|
||||
assert self.config.hf_config["peft_type"].lower() == "lora"
|
||||
self.base_hf_config = base_hf_config
|
||||
self.load_config = load_config
|
||||
self.scaling = self.config.lora_alpha / self.config.r
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LoRALayer(config, base_hf_config)
|
||||
for i in range(base_hf_config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.weights = {}
|
||||
self.weights_gpu = {}
|
||||
|
||||
def get_stacked_multiply(self, module_name):
|
||||
stacked_rank = {
|
||||
"qkv_proj": 3,
|
||||
"kv_proj": 2,
|
||||
"gate_up_proj": 2,
|
||||
}
|
||||
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
||||
|
||||
def load_to_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
|
||||
for layer in self.layers:
|
||||
layer.load_to_gpu()
|
||||
|
||||
def offload_from_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
self.weights_gpu[name] = None
|
||||
for layer in self.layers:
|
||||
layer.offload_from_gpu()
|
||||
|
||||
# initialize the LoRA weights to cpu
|
||||
def initialize_weights(self):
|
||||
model_path = self.config.path
|
||||
loader = DefaultModelLoader(self.load_config)
|
||||
revision = getattr(self.config.hf_config, "revision", None)
|
||||
for name, loaded_weight in loader._get_weights_iterator(
|
||||
model_path, revision=revision, fall_back_to_pt=True
|
||||
):
|
||||
match = re.search(r"layers\.(\d+)\.", name)
|
||||
if match is not None:
|
||||
layer_id = int(match.group(1))
|
||||
self.layers[layer_id].weights[name] = loaded_weight.cpu()
|
||||
else:
|
||||
self.weights[name] = loaded_weight.cpu()
|
||||
|
||||
# stack kv_proj and gate_up_proj
|
||||
for i in range(self.base_hf_config.num_hidden_layers):
|
||||
layer = self.layers[i]
|
||||
weight_names = [name for name, _ in layer.weights.items()]
|
||||
for weight_name in weight_names:
|
||||
if "k_proj" in weight_name:
|
||||
q_name = weight_name.replace("k_proj", "q_proj")
|
||||
v_name = weight_name.replace("k_proj", "v_proj")
|
||||
kv_name = weight_name.replace("k_proj", "kv_proj")
|
||||
qkv_name = weight_name.replace("k_proj", "qkv_proj")
|
||||
if "lora_A" in weight_name:
|
||||
layer.weights[qkv_name] = torch.cat(
|
||||
(
|
||||
layer.weights[q_name],
|
||||
layer.weights[weight_name],
|
||||
layer.weights[v_name],
|
||||
),
|
||||
0,
|
||||
)
|
||||
layer.weights.pop(q_name)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(v_name)
|
||||
else:
|
||||
layer.weights[kv_name] = torch.cat(
|
||||
(
|
||||
layer.weights[weight_name],
|
||||
layer.weights[v_name],
|
||||
),
|
||||
0,
|
||||
)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(v_name)
|
||||
elif "gate_proj" in weight_name:
|
||||
up_name = weight_name.replace("gate_proj", "up_proj")
|
||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||
layer.weights[gate_up_name] = torch.cat(
|
||||
(layer.weights[weight_name], layer.weights[up_name]), 0
|
||||
)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(up_name)
|
||||
43
python/sglang/srt/lora/lora_config.py
Normal file
43
python/sglang/srt/lora/lora_config.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
class LoRAConfig:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.hf_config = self.get_lora_config()
|
||||
self.target_modules = self.hf_config["target_modules"]
|
||||
self.r = self.hf_config["r"]
|
||||
self.lora_alpha = self.hf_config["lora_alpha"]
|
||||
|
||||
def get_lora_config(self, dummy=False):
|
||||
if dummy:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
if not os.path.isdir(self.path):
|
||||
weights_dir = snapshot_download(self.path, allow_patterns=["*.json"])
|
||||
else:
|
||||
weights_dir = self.path
|
||||
config_name = "adapter_config.json"
|
||||
with open(os.path.join(weights_dir, config_name), "r") as f:
|
||||
return json.load(f)
|
||||
256
python/sglang/srt/lora/lora_manager.py
Normal file
256
python/sglang/srt/lora/lora_manager.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
|
||||
# and "Punica: Multi-Tenant LoRA Serving"
|
||||
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from flashinfer import SegmentGEMMWrapper
|
||||
|
||||
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import replace_submodule
|
||||
|
||||
|
||||
def get_stacked_name(name):
|
||||
# origin name -> (name for A, name for B)
|
||||
params_mapping = {
|
||||
"q_proj": ("qkv_proj", "q_proj"),
|
||||
"k_proj": ("qkv_proj", "kv_proj"),
|
||||
"v_proj": ("qkv_proj", "kv_proj"),
|
||||
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
}
|
||||
return params_mapping.get(name, (name, name))
|
||||
|
||||
|
||||
def get_layer_id(name):
|
||||
match = re.search(r"layers\.(\d+)\.", name)
|
||||
if match is None:
|
||||
return None
|
||||
return int(match.group(1))
|
||||
|
||||
|
||||
class LoRAManager:
|
||||
def __init__(
|
||||
self,
|
||||
base_model,
|
||||
lora_paths,
|
||||
base_hf_config,
|
||||
max_loras_per_batch,
|
||||
load_config,
|
||||
dtype,
|
||||
):
|
||||
self.base_model = base_model
|
||||
self.lora_paths = lora_paths
|
||||
self.base_hf_config = base_hf_config
|
||||
self.max_loras_per_batch = max_loras_per_batch
|
||||
self.load_config = load_config
|
||||
self.dtype = dtype
|
||||
|
||||
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
|
||||
|
||||
self.init_loras()
|
||||
self.init_lora_memory_pool()
|
||||
self.init_lora_batch()
|
||||
|
||||
def match_target_modules(self, module_name):
|
||||
for target_module in self.target_modules:
|
||||
if module_name.split(".")[-1] == target_module:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_target_modules(self):
|
||||
modules = []
|
||||
for module_name, module in self.base_model.named_modules():
|
||||
if self.match_target_modules(module_name):
|
||||
modules.append((module_name, module))
|
||||
return modules
|
||||
|
||||
def set_lora_module(self, module_name, module):
|
||||
lora_module = get_lora_layer(
|
||||
module, self.segment_gemm, self.max_lora_dim, self.scaling
|
||||
)
|
||||
replace_submodule(self.base_model, module_name, lora_module)
|
||||
return lora_module
|
||||
|
||||
def init_loras(self):
|
||||
# get configs and target modules
|
||||
self.configs = {}
|
||||
self.origin_target_modules = set()
|
||||
for path in self.lora_paths:
|
||||
self.configs[path] = LoRAConfig(path)
|
||||
self.origin_target_modules = set(self.origin_target_modules) | set(
|
||||
self.configs[path].target_modules
|
||||
)
|
||||
self.target_modules = set(
|
||||
[
|
||||
self.base_model.get_module_name(module)
|
||||
for module in self.origin_target_modules
|
||||
]
|
||||
)
|
||||
self.target_weights = set(
|
||||
[get_stacked_name(module) for module in self.origin_target_modules]
|
||||
)
|
||||
|
||||
# load all weights to cpu
|
||||
self.loras = []
|
||||
self.lora_id = {}
|
||||
for path in self.lora_paths:
|
||||
self.lora_id[path] = len(self.loras)
|
||||
self.loras.append(
|
||||
LoRAAdapter(
|
||||
path, self.configs[path], self.base_hf_config, self.load_config
|
||||
)
|
||||
)
|
||||
self.loras[-1].initialize_weights()
|
||||
|
||||
# misc lora configs
|
||||
self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
self.scaling = self.loras[0].scaling
|
||||
# FIXME remove the restrictions
|
||||
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
|
||||
assert all(x.scaling == self.scaling for x in self.loras)
|
||||
|
||||
# monkey patch to use the LoRA version
|
||||
self.lora_modules = []
|
||||
for module_name, module in self.get_target_modules():
|
||||
self.lora_modules.append(
|
||||
(module_name, self.set_lora_module(module_name, module))
|
||||
)
|
||||
|
||||
def init_lora_memory_pool(self):
|
||||
# preallocate lora memory pool
|
||||
self.A_buffer = {}
|
||||
self.B_buffer = {}
|
||||
num_layer = self.base_hf_config.num_hidden_layers
|
||||
for module_A, module_B in self.target_weights:
|
||||
# init A tensor, column_major=True
|
||||
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
|
||||
c = self.loras[-1].get_stacked_multiply(module_A)
|
||||
if module_A not in self.A_buffer:
|
||||
self.A_buffer[module_A] = [
|
||||
torch.empty(
|
||||
(
|
||||
self.max_loras_per_batch,
|
||||
self.max_lora_dim * c,
|
||||
hidden_dim_A,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(num_layer)
|
||||
]
|
||||
# init B tensor, column_major=True
|
||||
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
|
||||
c = self.loras[-1].get_stacked_multiply(module_B)
|
||||
if module_B not in self.B_buffer:
|
||||
self.B_buffer[module_B] = [
|
||||
torch.empty(
|
||||
(
|
||||
self.max_loras_per_batch,
|
||||
hidden_dim_B * c,
|
||||
self.max_lora_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(num_layer)
|
||||
]
|
||||
|
||||
def init_lora_batch(self):
|
||||
self.active_uids = set() # set of active loras
|
||||
self.buffer_id = {} # lora uid -> idx in memory pool
|
||||
|
||||
def get_weight_name(self, name, idx):
|
||||
for target_weight_name in self.target_weights:
|
||||
if target_weight_name[idx] in name:
|
||||
return target_weight_name[idx]
|
||||
|
||||
def load_lora(self, uid, buffer_id):
|
||||
num_layer = self.base_hf_config.num_hidden_layers
|
||||
if uid is None:
|
||||
for i in range(num_layer):
|
||||
for k in self.A_buffer.keys():
|
||||
self.A_buffer[k][i][buffer_id] *= 0
|
||||
return
|
||||
|
||||
for i in range(num_layer):
|
||||
layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
|
||||
for name, weights in layer_weights.items():
|
||||
if "lora_A" in name:
|
||||
lora_weight_name = self.get_weight_name(name, 0)
|
||||
if lora_weight_name:
|
||||
self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
||||
else:
|
||||
lora_weight_name = self.get_weight_name(name, 1)
|
||||
if lora_weight_name:
|
||||
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
||||
|
||||
def prepare_lora_batch(self, batch, extend_seq_lens=None):
|
||||
# load active loras into lora memory pool
|
||||
cur_uids = set([req.lora_path for req in batch.reqs])
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
i = 0
|
||||
evictable_uids = list(self.active_uids)
|
||||
for uid in cur_uids:
|
||||
if uid not in self.active_uids:
|
||||
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
|
||||
i += 1
|
||||
if i < len(evictable_uids):
|
||||
self.active_uids.remove(evictable_uids[i])
|
||||
self.buffer_id.pop(evictable_uids[i])
|
||||
self.load_lora(uid, i)
|
||||
self.active_uids.add(uid)
|
||||
self.buffer_id[uid] = i
|
||||
i += 1
|
||||
|
||||
if cur_uids == set([None]):
|
||||
return
|
||||
|
||||
# setup lora in forward modules
|
||||
bs = len(batch.reqs)
|
||||
seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
||||
for i, req in enumerate(batch.reqs):
|
||||
weight_indices[i] = self.buffer_id[req.lora_path]
|
||||
|
||||
for module_name, module in self.lora_modules:
|
||||
layer_id = get_layer_id(module_name)
|
||||
|
||||
if "qkv_proj" not in module_name:
|
||||
weight_name = self.get_weight_name(module_name, 0)
|
||||
module.set_lora_info(
|
||||
self.A_buffer[weight_name][layer_id],
|
||||
self.B_buffer[weight_name][layer_id],
|
||||
bs,
|
||||
seg_lens,
|
||||
weight_indices,
|
||||
)
|
||||
else:
|
||||
module.set_lora_info(
|
||||
self.A_buffer["qkv_proj"][layer_id],
|
||||
self.B_buffer["q_proj"][layer_id],
|
||||
self.B_buffer["kv_proj"][layer_id],
|
||||
bs,
|
||||
seg_lens,
|
||||
weight_indices,
|
||||
)
|
||||
Reference in New Issue
Block a user