[Feature] Initial support for multi-LoRA serving (#1307)

This commit is contained in:
Ying Sheng
2024-09-12 16:46:14 -07:00
committed by GitHub
parent c33d82a211
commit 712216928f
21 changed files with 1435 additions and 22 deletions

View File

@@ -0,0 +1,403 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
# and "Punica: Multi-Tenant LoRA Serving"
# LoRA layers class inheritance adapted from:
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
import json
import os
import re
from typing import Any, Dict, List, Optional, Tuple
import safetensors.torch
import torch
from torch import nn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
class BaseLayerWithLoRA(nn.Module):
def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
super().__init__()
self.base_layer = base_layer
self.segment_gemm = segment_gemm
self.lora_rank = lora_rank
self.scaling = scaling
self.set_lora = False
def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
def set_lora_info(self, *args):
pass
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
self.weight = base_layer.weight
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
# TODO
return output
def forward(self, input_: torch.Tensor):
# duplicate the logic in ColumnParallelLinear
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_, bias
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_)
if self.base_layer.gather_output:
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output, output_bias
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seq_lens = seq_lens
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
# FIXME
assert lora_a_output.shape[-1] == self.lora_rank * 2
lora_output = torch.empty_like(base_output)
output_dim = lora_output.shape[-1] // 2
for i in range(2):
left = output_dim * i
right = left + output_dim
lora_output[:, left:right] = self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * i : self.lora_rank * (i + 1)
].contiguous(),
weights=self.B_buffer[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
return base_output + lora_output * self.scaling
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def set_lora_info(
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_q = B_buffer_q
self.B_buffer_kv = B_buffer_kv
self.bs = bs
self.seq_lens = seq_lens
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer_qkv,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
# FIXME parallelize qkv
lora_output = torch.empty_like(base_output)
# q
output_dim_q = self.B_buffer_q.shape[-2]
lora_output[:, :output_dim_q] = self.segment_gemm.run(
x=lora_a_output[:, : self.lora_rank].contiguous(),
weights=self.B_buffer_q,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
# kv
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
for i in range(2):
left = output_dim_kv * i
right = left + output_dim_kv
lora_output[:, output_dim_q + left : output_dim_q + right] = (
self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
].contiguous(),
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
)
return base_output + lora_output * self.scaling
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seq_lens = seq_lens
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
lora_output = self.segment_gemm.run(
x=lora_output,
weights=self.B_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
return base_output + lora_output * self.scaling
def forward(self, input_):
# duplicate the logic in RowParallelLinear
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size
)
input_parallel = splitted_input[tp_rank].contiguous()
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_parallel
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (
output_ + self.base_layer.bias
if self.base_layer.bias is not None
else output_
)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
return output, output_bias
def get_lora_layer(
layer: nn.Module, segment_gemm, lora_rank, scaling
) -> BaseLayerWithLoRA:
supported_layer_types = {
# the order matters
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLoRA,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
return ret
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
def get_mapped_params(module_names):
ret = set()
for module_name in module_names:
ret.add(params_mapping(module_name))
return list(ret)
class LoRALayer(nn.Module):
def __init__(self, config, base_hf_config):
super().__init__()
self.config = config
self.base_hf_config = base_hf_config
self.weights = {}
self.weight_gpu = {}
def load_to_gpu(self):
for name, weight in self.weights.items():
self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
def offload_from_gpu(self):
for name, weight in self.weights.items():
self.weight_gpu[name] = None
class LoRAAdapter(nn.Module):
def __init__(self, uid, config, base_hf_config, load_config):
super().__init__()
self.uid = uid
self.config = config
assert self.config.hf_config["peft_type"].lower() == "lora"
self.base_hf_config = base_hf_config
self.load_config = load_config
self.scaling = self.config.lora_alpha / self.config.r
self.layers = nn.ModuleList(
[
LoRALayer(config, base_hf_config)
for i in range(base_hf_config.num_hidden_layers)
]
)
self.weights = {}
self.weights_gpu = {}
def get_stacked_multiply(self, module_name):
stacked_rank = {
"qkv_proj": 3,
"kv_proj": 2,
"gate_up_proj": 2,
}
return stacked_rank[module_name] if module_name in stacked_rank else 1
def load_to_gpu(self):
for name, weight in self.weights.items():
self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
for layer in self.layers:
layer.load_to_gpu()
def offload_from_gpu(self):
for name, weight in self.weights.items():
self.weights_gpu[name] = None
for layer in self.layers:
layer.offload_from_gpu()
# initialize the LoRA weights to cpu
def initialize_weights(self):
model_path = self.config.path
loader = DefaultModelLoader(self.load_config)
revision = getattr(self.config.hf_config, "revision", None)
for name, loaded_weight in loader._get_weights_iterator(
model_path, revision=revision, fall_back_to_pt=True
):
match = re.search(r"layers\.(\d+)\.", name)
if match is not None:
layer_id = int(match.group(1))
self.layers[layer_id].weights[name] = loaded_weight.cpu()
else:
self.weights[name] = loaded_weight.cpu()
# stack kv_proj and gate_up_proj
for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()]
for weight_name in weight_names:
if "k_proj" in weight_name:
q_name = weight_name.replace("k_proj", "q_proj")
v_name = weight_name.replace("k_proj", "v_proj")
kv_name = weight_name.replace("k_proj", "kv_proj")
qkv_name = weight_name.replace("k_proj", "qkv_proj")
if "lora_A" in weight_name:
layer.weights[qkv_name] = torch.cat(
(
layer.weights[q_name],
layer.weights[weight_name],
layer.weights[v_name],
),
0,
)
layer.weights.pop(q_name)
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
else:
layer.weights[kv_name] = torch.cat(
(
layer.weights[weight_name],
layer.weights[v_name],
),
0,
)
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
elif "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
layer.weights[gate_up_name] = torch.cat(
(layer.weights[weight_name], layer.weights[up_name]), 0
)
layer.weights.pop(weight_name)
layer.weights.pop(up_name)

View File

@@ -0,0 +1,43 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import os
from huggingface_hub import snapshot_download
class LoRAConfig:
def __init__(
self,
path: str,
) -> None:
self.path = path
self.hf_config = self.get_lora_config()
self.target_modules = self.hf_config["target_modules"]
self.r = self.hf_config["r"]
self.lora_alpha = self.hf_config["lora_alpha"]
def get_lora_config(self, dummy=False):
if dummy:
raise NotImplementedError()
else:
if not os.path.isdir(self.path):
weights_dir = snapshot_download(self.path, allow_patterns=["*.json"])
else:
weights_dir = self.path
config_name = "adapter_config.json"
with open(os.path.join(weights_dir, config_name), "r") as f:
return json.load(f)

View File

@@ -0,0 +1,256 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
# and "Punica: Multi-Tenant LoRA Serving"
import re
from dataclasses import dataclass
import torch
from flashinfer import SegmentGEMMWrapper
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import replace_submodule
def get_stacked_name(name):
# origin name -> (name for A, name for B)
params_mapping = {
"q_proj": ("qkv_proj", "q_proj"),
"k_proj": ("qkv_proj", "kv_proj"),
"v_proj": ("qkv_proj", "kv_proj"),
"gate_proj": ("gate_up_proj", "gate_up_proj"),
"up_proj": ("gate_up_proj", "gate_up_proj"),
}
return params_mapping.get(name, (name, name))
def get_layer_id(name):
match = re.search(r"layers\.(\d+)\.", name)
if match is None:
return None
return int(match.group(1))
class LoRAManager:
def __init__(
self,
base_model,
lora_paths,
base_hf_config,
max_loras_per_batch,
load_config,
dtype,
):
self.base_model = base_model
self.lora_paths = lora_paths
self.base_hf_config = base_hf_config
self.max_loras_per_batch = max_loras_per_batch
self.load_config = load_config
self.dtype = dtype
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
self.init_loras()
self.init_lora_memory_pool()
self.init_lora_batch()
def match_target_modules(self, module_name):
for target_module in self.target_modules:
if module_name.split(".")[-1] == target_module:
return True
return False
def get_target_modules(self):
modules = []
for module_name, module in self.base_model.named_modules():
if self.match_target_modules(module_name):
modules.append((module_name, module))
return modules
def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(
module, self.segment_gemm, self.max_lora_dim, self.scaling
)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def init_loras(self):
# get configs and target modules
self.configs = {}
self.origin_target_modules = set()
for path in self.lora_paths:
self.configs[path] = LoRAConfig(path)
self.origin_target_modules = set(self.origin_target_modules) | set(
self.configs[path].target_modules
)
self.target_modules = set(
[
self.base_model.get_module_name(module)
for module in self.origin_target_modules
]
)
self.target_weights = set(
[get_stacked_name(module) for module in self.origin_target_modules]
)
# load all weights to cpu
self.loras = []
self.lora_id = {}
for path in self.lora_paths:
self.lora_id[path] = len(self.loras)
self.loras.append(
LoRAAdapter(
path, self.configs[path], self.base_hf_config, self.load_config
)
)
self.loras[-1].initialize_weights()
# misc lora configs
self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
self.scaling = self.loras[0].scaling
# FIXME remove the restrictions
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
assert all(x.scaling == self.scaling for x in self.loras)
# monkey patch to use the LoRA version
self.lora_modules = []
for module_name, module in self.get_target_modules():
self.lora_modules.append(
(module_name, self.set_lora_module(module_name, module))
)
def init_lora_memory_pool(self):
# preallocate lora memory pool
self.A_buffer = {}
self.B_buffer = {}
num_layer = self.base_hf_config.num_hidden_layers
for module_A, module_B in self.target_weights:
# init A tensor, column_major=True
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
c = self.loras[-1].get_stacked_multiply(module_A)
if module_A not in self.A_buffer:
self.A_buffer[module_A] = [
torch.empty(
(
self.max_loras_per_batch,
self.max_lora_dim * c,
hidden_dim_A,
),
dtype=self.dtype,
device="cuda",
)
for i in range(num_layer)
]
# init B tensor, column_major=True
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
c = self.loras[-1].get_stacked_multiply(module_B)
if module_B not in self.B_buffer:
self.B_buffer[module_B] = [
torch.empty(
(
self.max_loras_per_batch,
hidden_dim_B * c,
self.max_lora_dim,
),
dtype=self.dtype,
device="cuda",
)
for i in range(num_layer)
]
def init_lora_batch(self):
self.active_uids = set() # set of active loras
self.buffer_id = {} # lora uid -> idx in memory pool
def get_weight_name(self, name, idx):
for target_weight_name in self.target_weights:
if target_weight_name[idx] in name:
return target_weight_name[idx]
def load_lora(self, uid, buffer_id):
num_layer = self.base_hf_config.num_hidden_layers
if uid is None:
for i in range(num_layer):
for k in self.A_buffer.keys():
self.A_buffer[k][i][buffer_id] *= 0
return
for i in range(num_layer):
layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
for name, weights in layer_weights.items():
if "lora_A" in name:
lora_weight_name = self.get_weight_name(name, 0)
if lora_weight_name:
self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
else:
lora_weight_name = self.get_weight_name(name, 1)
if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
def prepare_lora_batch(self, batch, extend_seq_lens=None):
# load active loras into lora memory pool
cur_uids = set([req.lora_path for req in batch.reqs])
assert len(cur_uids) <= self.max_loras_per_batch
i = 0
evictable_uids = list(self.active_uids)
for uid in cur_uids:
if uid not in self.active_uids:
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
i += 1
if i < len(evictable_uids):
self.active_uids.remove(evictable_uids[i])
self.buffer_id.pop(evictable_uids[i])
self.load_lora(uid, i)
self.active_uids.add(uid)
self.buffer_id[uid] = i
i += 1
if cur_uids == set([None]):
return
# setup lora in forward modules
bs = len(batch.reqs)
seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, req in enumerate(batch.reqs):
weight_indices[i] = self.buffer_id[req.lora_path]
for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name)
if "qkv_proj" not in module_name:
weight_name = self.get_weight_name(module_name, 0)
module.set_lora_info(
self.A_buffer[weight_name][layer_id],
self.B_buffer[weight_name][layer_id],
bs,
seg_lens,
weight_indices,
)
else:
module.set_lora_info(
self.A_buffer["qkv_proj"][layer_id],
self.B_buffer["q_proj"][layer_id],
self.B_buffer["kv_proj"][layer_id],
bs,
seg_lens,
weight_indices,
)