[Fix] Fix accuracy bug and refactor codes for lora (#3413)
This commit is contained in:
@@ -1,8 +1,28 @@
|
||||
from .base_backend import BaseLoraBackend
|
||||
from .flashinfer_backend import FlashInferLoraBackend
|
||||
from .triton_backend import TritonLoraBackend
|
||||
from .base_backend import BaseLoRABackend
|
||||
from .flashinfer_backend import FlashInferLoRABackend
|
||||
from .triton_backend import TritonLoRABackend
|
||||
|
||||
|
||||
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
||||
"""
|
||||
Get corresponding backend class from backend's name
|
||||
"""
|
||||
backend_mapping = {
|
||||
"triton": TritonLoRABackend,
|
||||
"flashinfer": FlashInferLoRABackend,
|
||||
}
|
||||
|
||||
if name in backend_mapping:
|
||||
return backend_mapping[name]
|
||||
|
||||
raise Exception(
|
||||
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FlashInferLoraBackend",
|
||||
"TritonLoraBackend",
|
||||
"BaseLoRABackend",
|
||||
"FlashInferLoRABackend",
|
||||
"TritonLoRABackend",
|
||||
"get_backend_from_name",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.lora import LoraBatchInfo
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
|
||||
|
||||
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
|
||||
@@ -13,7 +13,7 @@ def get_fuse_output_scaling_add_from_name(name: str) -> bool:
|
||||
return mapping.get(name, False)
|
||||
|
||||
|
||||
def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
|
||||
def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
|
||||
mapping = {
|
||||
"triton": True,
|
||||
"flashinfer": False,
|
||||
@@ -21,7 +21,7 @@ def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
|
||||
return mapping.get(name, False)
|
||||
|
||||
|
||||
class BaseLoraBackend:
|
||||
class BaseLoRABackend:
|
||||
"""Base class for different Lora backends.
|
||||
Each backend has its own implementation of Lora kernels.
|
||||
|
||||
@@ -32,11 +32,11 @@ class BaseLoraBackend:
|
||||
and the operation of scaling and adding will be fused into kernel
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
||||
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
||||
self.name = name
|
||||
self.batch_info = batch_info
|
||||
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
|
||||
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
|
||||
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
|
||||
|
||||
def run_lora_a_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
@@ -46,10 +46,11 @@ class BaseLoraBackend:
|
||||
|
||||
Args:
|
||||
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
||||
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
|
||||
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
|
||||
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
|
||||
usually input_dim is much larger than r
|
||||
Returns:
|
||||
result with shape (s, r)
|
||||
result with shape (s, c * r)
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -83,7 +84,7 @@ class BaseLoraBackend:
|
||||
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
|
||||
qkv_lora_b: lora_b module for qkv.
|
||||
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
|
||||
If passed in as a tuple of two tensors containing:
|
||||
If passed in as a tuple of two tensors, it should contain:
|
||||
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
|
||||
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
|
||||
Returns:
|
||||
@@ -91,5 +92,26 @@ class BaseLoraBackend:
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_batch_info(self, batch_info: LoraBatchInfo):
|
||||
def run_gate_up_lora(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate_up_lora_a: torch.Tensor,
|
||||
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
*args,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
|
||||
|
||||
Args:
|
||||
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
||||
gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
|
||||
gate_up_lora_b: lora_b module for qkv.
|
||||
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
|
||||
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
|
||||
Returns:
|
||||
result with shape (s, 2 * output_dim)
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_batch_info(self, batch_info: LoRABatchInfo):
|
||||
self.batch_info = batch_info
|
||||
|
||||
@@ -2,17 +2,17 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.backend import BaseLoraBackend
|
||||
from sglang.srt.lora.lora import LoraBatchInfo
|
||||
from sglang.srt.lora.backend import BaseLoRABackend
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import SegmentGEMMWrapper
|
||||
|
||||
|
||||
class FlashInferLoraBackend(BaseLoraBackend):
|
||||
class FlashInferLoRABackend(BaseLoRABackend):
|
||||
|
||||
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
||||
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
||||
super().__init__(name, batch_info)
|
||||
|
||||
# Set up SGemm Wrapper from flashinfer
|
||||
@@ -55,6 +55,8 @@ class FlashInferLoraBackend(BaseLoraBackend):
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
|
||||
|
||||
# Shape of lora_a_output: (s, 3 * r)
|
||||
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
|
||||
|
||||
@@ -89,3 +91,38 @@ class FlashInferLoraBackend(BaseLoraBackend):
|
||||
)
|
||||
|
||||
return lora_output
|
||||
|
||||
def run_gate_up_lora(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate_up_lora_a: torch.Tensor,
|
||||
gate_up_lora_b: Tuple[torch.Tensor],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
|
||||
lora_rank = gate_up_lora_b[0].shape[-1]
|
||||
output_dim = gate_up_lora_b[0].shape[-2]
|
||||
|
||||
# Shape of lora_a_output: (s, 2 * r)
|
||||
lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
|
||||
|
||||
lora_output = torch.empty(
|
||||
(x.shape[0], 2 * output_dim),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
# Compute lora for gate and up proj respectively
|
||||
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, :lora_rank].contiguous(),
|
||||
weights=gate_up_lora_b[0],
|
||||
)
|
||||
|
||||
lora_output[:, output_dim:] = self.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, lora_rank:].contiguous(),
|
||||
weights=gate_up_lora_b[1],
|
||||
)
|
||||
|
||||
return lora_output
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.backend import BaseLoraBackend
|
||||
from sglang.srt.lora.lora import LoraBatchInfo
|
||||
from sglang.srt.lora.backend import BaseLoRABackend
|
||||
from sglang.srt.lora.triton_ops import (
|
||||
gate_up_lora_b_fwd,
|
||||
qkv_lora_b_fwd,
|
||||
sgemm_lora_a_fwd,
|
||||
sgemm_lora_b_fwd,
|
||||
)
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
|
||||
|
||||
class TritonLoraBackend(BaseLoraBackend):
|
||||
class TritonLoRABackend(BaseLoRABackend):
|
||||
|
||||
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
||||
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
||||
super().__init__(name, batch_info)
|
||||
|
||||
def run_lora_a_sgemm(
|
||||
@@ -59,3 +60,32 @@ class TritonLoraBackend(BaseLoraBackend):
|
||||
scaling,
|
||||
)
|
||||
return lora_output
|
||||
|
||||
def run_gate_up_lora(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate_up_lora_a: torch.Tensor,
|
||||
gate_up_lora_b: torch.Tensor,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
|
||||
# x: (s, input_dim)
|
||||
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
|
||||
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
|
||||
assert isinstance(gate_up_lora_b, torch.Tensor)
|
||||
output_dim = gate_up_lora_b.shape[-2] // 2
|
||||
|
||||
# lora_a_output: (s, 2 * r)
|
||||
lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info)
|
||||
lora_output = gate_up_lora_b_fwd(
|
||||
lora_a_output,
|
||||
gate_up_lora_b,
|
||||
self.batch_info,
|
||||
output_dim,
|
||||
base_output,
|
||||
scaling,
|
||||
)
|
||||
return lora_output
|
||||
|
||||
293
python/sglang/srt/lora/layers.py
Normal file
293
python/sglang/srt/lora/layers.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.lora.backend import BaseLoRABackend
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
lora_rank: int,
|
||||
scaling: float,
|
||||
lora_backend: BaseLoRABackend,
|
||||
):
|
||||
super().__init__()
|
||||
self.base_layer: nn.Module = base_layer
|
||||
self.lora_rank: int = lora_rank
|
||||
self.scaling: float = scaling
|
||||
self.set_lora: bool = False
|
||||
self.lora_backend: BaseLoRABackend = lora_backend
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.base_layer.forward(x)
|
||||
|
||||
def set_lora_info(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: VocabParallelEmbedding,
|
||||
lora_rank: int,
|
||||
scaling: float,
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
self.weight = base_layer.weight
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: ColumnParallelLinear,
|
||||
lora_rank: int,
|
||||
scaling: float,
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(
|
||||
self,
|
||||
A_buffer: torch.Tensor,
|
||||
B_buffer: torch.Tensor,
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||
lora_a_output,
|
||||
self.B_buffer[0],
|
||||
**backend_kwargs,
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor):
|
||||
# duplicate the logic in ColumnParallelLinear
|
||||
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
||||
output_parallel = self.base_layer.quant_method.apply(
|
||||
self.base_layer, input_, bias
|
||||
)
|
||||
|
||||
if self.set_lora:
|
||||
output_parallel = self.apply_lora(output_parallel, input_)
|
||||
|
||||
if self.base_layer.gather_output:
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: MergedColumnParallelLinear,
|
||||
lora_rank: int,
|
||||
scaling: float,
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(
|
||||
self,
|
||||
A_buffer: torch.Tensor,
|
||||
B_buffer: torch.Tensor,
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer_gate_up = A_buffer
|
||||
if self.lora_backend.fuse_stacked_lora_b:
|
||||
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
||||
self.B_buffer_gate_up = torch.cat(
|
||||
(B_buffer[0], B_buffer[1]), dim=-2
|
||||
).contiguous()
|
||||
else:
|
||||
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
|
||||
|
||||
lora_output = self.lora_backend.run_gate_up_lora(
|
||||
x,
|
||||
self.A_buffer_gate_up,
|
||||
self.B_buffer_gate_up,
|
||||
**backend_kwargs,
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
)
|
||||
|
||||
|
||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def init__(
|
||||
self,
|
||||
base_layer: QKVParallelLinear,
|
||||
lora_rank: int,
|
||||
scaling: float,
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(
|
||||
self,
|
||||
A_buffer_qkv: torch.Tensor,
|
||||
B_buffer_q: torch.Tensor,
|
||||
B_buffer_kv: torch.Tensor,
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer_qkv = A_buffer_qkv
|
||||
|
||||
if self.lora_backend.fuse_stacked_lora_b:
|
||||
assert (
|
||||
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
||||
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
||||
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
||||
|
||||
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||
self.B_buffer_qkv = torch.cat(
|
||||
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
|
||||
).contiguous()
|
||||
|
||||
# Offsets of q/k/v in output dimension
|
||||
self.output_offset = torch.tensor(
|
||||
[
|
||||
0,
|
||||
output_dim_q,
|
||||
output_dim_q + output_dim_kv,
|
||||
output_dim_q + 2 * output_dim_kv,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=B_buffer_q.device,
|
||||
)
|
||||
# For computing number of launched blocks
|
||||
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
||||
else:
|
||||
self.B_buffer_qkv = (
|
||||
B_buffer_q,
|
||||
B_buffer_kv,
|
||||
)
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
|
||||
if self.lora_backend.fuse_stacked_lora_b:
|
||||
backend_kwargs["output_offset"] = self.output_offset
|
||||
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
|
||||
|
||||
lora_output = self.lora_backend.run_qkv_lora(
|
||||
x,
|
||||
self.A_buffer_qkv,
|
||||
self.B_buffer_qkv,
|
||||
**backend_kwargs,
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
)
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: RowParallelLinear,
|
||||
lora_rank: int,
|
||||
scaling: float,
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||
lora_a_output,
|
||||
self.B_buffer[0],
|
||||
**backend_kwargs,
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor):
|
||||
# duplicate the logic in RowParallelLinear
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.base_layer.tp_size
|
||||
)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
output_parallel = self.base_layer.quant_method.apply(
|
||||
self.base_layer, input_parallel
|
||||
)
|
||||
|
||||
if self.set_lora:
|
||||
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
||||
|
||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.base_layer.skip_bias_add:
|
||||
output = (
|
||||
output_ + self.base_layer.bias
|
||||
if self.base_layer.bias is not None
|
||||
else output_
|
||||
)
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.base_layer.bias
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def get_lora_layer(
|
||||
layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend
|
||||
) -> BaseLayerWithLoRA:
|
||||
supported_layer_types = {
|
||||
# the order matters
|
||||
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
||||
QKVParallelLinear: QKVParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
||||
RowParallelLinear: RowParallelLinearWithLoRA,
|
||||
}
|
||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
||||
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
|
||||
return ret
|
||||
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
||||
@@ -19,282 +19,25 @@
|
||||
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
from sglang.srt.lora.backend import BaseLoRABackend
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraBatchInfo:
|
||||
# Batch size
|
||||
bs: int
|
||||
|
||||
# Lengths of each sequence in shape (bs,)
|
||||
seg_lens: torch.Tensor
|
||||
|
||||
# Indice pointers of each sequence in shape (bs + 1, )
|
||||
seg_indptr: torch.Tensor
|
||||
|
||||
# Maximum sequence length of current batch
|
||||
max_len: int
|
||||
|
||||
# The index of lora adapter used by each sequence, in shape (bs,)
|
||||
weight_indices: torch.Tensor
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
def __init__(self, base_layer, lora_rank, scaling, lora_backend):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.lora_rank = lora_rank
|
||||
self.scaling = scaling
|
||||
self.set_lora = False
|
||||
self.lora_backend = lora_backend
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.base_layer.forward(x)
|
||||
|
||||
def set_lora_info(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
self.weight = base_layer.weight
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
# TODO
|
||||
return output
|
||||
|
||||
def forward(self, input_: torch.Tensor):
|
||||
# duplicate the logic in ColumnParallelLinear
|
||||
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
||||
output_parallel = self.base_layer.quant_method.apply(
|
||||
self.base_layer, input_, bias
|
||||
)
|
||||
|
||||
if self.set_lora:
|
||||
output_parallel = self.apply_lora(output_parallel, input_)
|
||||
|
||||
if self.base_layer.gather_output:
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(
|
||||
self,
|
||||
A_buffer,
|
||||
B_buffer,
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
|
||||
|
||||
output_dim = base_output.shape[-1]
|
||||
lora_output = torch.empty_like(base_output)
|
||||
lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
|
||||
weights=self.B_buffer[0],
|
||||
)
|
||||
|
||||
lora_output[:, output_dim : 2 * output_dim] = (
|
||||
self.lora_backend.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
|
||||
weights=self.B_buffer[1],
|
||||
)
|
||||
)
|
||||
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
|
||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def init__(
|
||||
self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(
|
||||
self,
|
||||
A_buffer_qkv,
|
||||
B_buffer_q,
|
||||
B_buffer_kv,
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer_qkv = A_buffer_qkv
|
||||
|
||||
if self.lora_backend.fuse_qkv_lora_b:
|
||||
assert (
|
||||
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
||||
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
||||
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
||||
|
||||
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||
self.B_buffer_qkv = torch.cat(
|
||||
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
|
||||
).contiguous()
|
||||
|
||||
# Offsets of q/k/v in output dimension
|
||||
self.output_offset = torch.tensor(
|
||||
[
|
||||
0,
|
||||
output_dim_q,
|
||||
output_dim_q + output_dim_kv,
|
||||
output_dim_q + 2 * output_dim_kv,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=B_buffer_q.device,
|
||||
)
|
||||
# For computing number of launched blocks
|
||||
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
||||
else:
|
||||
self.B_buffer_qkv = (
|
||||
B_buffer_q,
|
||||
B_buffer_kv,
|
||||
)
|
||||
self.output_offset = None
|
||||
self.max_qkv_out_dim = None
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_output = self.lora_backend.run_qkv_lora(
|
||||
x,
|
||||
self.A_buffer_qkv,
|
||||
self.B_buffer_qkv,
|
||||
output_offset=self.output_offset,
|
||||
max_qkv_out_dim=self.max_qkv_out_dim,
|
||||
base_output=base_output,
|
||||
scaling=self.scaling,
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
)
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(self, A_buffer, B_buffer):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||
lora_a_output,
|
||||
self.B_buffer[0],
|
||||
base_output=base_output,
|
||||
scaling=self.scaling,
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
)
|
||||
|
||||
def forward(self, input_):
|
||||
# duplicate the logic in RowParallelLinear
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.base_layer.tp_size
|
||||
)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
output_parallel = self.base_layer.quant_method.apply(
|
||||
self.base_layer, input_parallel
|
||||
)
|
||||
|
||||
if self.set_lora:
|
||||
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
||||
|
||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.base_layer.skip_bias_add:
|
||||
output = (
|
||||
output_ + self.base_layer.bias
|
||||
if self.base_layer.bias is not None
|
||||
else output_
|
||||
)
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.base_layer.bias
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def get_lora_layer(
|
||||
layer: nn.Module, lora_rank, scaling, lora_backend
|
||||
) -> BaseLayerWithLoRA:
|
||||
supported_layer_types = {
|
||||
# the order matters
|
||||
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
||||
QKVParallelLinear: QKVParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
||||
RowParallelLinear: RowParallelLinearWithLoRA,
|
||||
}
|
||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
||||
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
|
||||
return ret
|
||||
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
||||
|
||||
|
||||
def get_mapped_params(module_names):
|
||||
ret = set()
|
||||
for module_name in module_names:
|
||||
ret.add(params_mapping(module_name))
|
||||
return list(ret)
|
||||
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
def __init__(self, config, base_hf_config):
|
||||
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.base_hf_config = base_hf_config
|
||||
self.weights = {}
|
||||
self.weight_gpu = {}
|
||||
self.config: LoRAConfig = config
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
self.weights: Dict[str, torch.Tensor] = {}
|
||||
self.weight_gpu: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def load_to_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
@@ -306,33 +49,32 @@ class LoRALayer(nn.Module):
|
||||
|
||||
|
||||
class LoRAAdapter(nn.Module):
|
||||
def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
|
||||
def __init__(
|
||||
self,
|
||||
uid: str,
|
||||
config: LoRAConfig,
|
||||
base_hf_config: AutoConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_backend: BaseLoRABackend,
|
||||
):
|
||||
super().__init__()
|
||||
self.uid = uid
|
||||
self.config = config
|
||||
self.uid: str = uid
|
||||
self.config: LoRAConfig = config
|
||||
assert self.config.hf_config["peft_type"].lower() == "lora"
|
||||
self.base_hf_config = base_hf_config
|
||||
self.load_config = load_config
|
||||
self.lora_backend = lora_backend
|
||||
self.scaling = self.config.lora_alpha / self.config.r
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
self.load_config: LoadConfig = load_config
|
||||
self.lora_backend: BaseLoRABackend = lora_backend
|
||||
self.scaling: float = self.config.lora_alpha / self.config.r
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
self.layers: List[LoRALayer] = nn.ModuleList(
|
||||
[
|
||||
LoRALayer(config, base_hf_config)
|
||||
for i in range(base_hf_config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.weights = {}
|
||||
self.weights_gpu = {}
|
||||
|
||||
def get_stacked_multiply(self, module_name):
|
||||
stacked_rank = {
|
||||
"qkv_proj": 3,
|
||||
"kv_proj": 2,
|
||||
"gate_up_proj": 2,
|
||||
}
|
||||
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
||||
self.weights: Dict[str, torch.Tensor] = {}
|
||||
self.weights_gpu: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def load_to_gpu(self):
|
||||
for name, weight in self.weights.items():
|
||||
@@ -367,44 +109,77 @@ class LoRAAdapter(nn.Module):
|
||||
for i in range(self.base_hf_config.num_hidden_layers):
|
||||
layer = self.layers[i]
|
||||
weight_names = [name for name, _ in layer.weights.items()]
|
||||
for weight_name in weight_names:
|
||||
if "k_proj" in weight_name:
|
||||
q_name = weight_name.replace("k_proj", "q_proj")
|
||||
v_name = weight_name.replace("k_proj", "v_proj")
|
||||
kv_name = weight_name.replace("k_proj", "kv_proj")
|
||||
qkv_name = weight_name.replace("k_proj", "qkv_proj")
|
||||
if "lora_A" in weight_name:
|
||||
layer.weights[qkv_name] = torch.cat(
|
||||
(
|
||||
layer.weights[q_name],
|
||||
layer.weights[weight_name],
|
||||
layer.weights[v_name],
|
||||
),
|
||||
0,
|
||||
)
|
||||
layer.weights.pop(q_name)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(v_name)
|
||||
else:
|
||||
layer.weights[kv_name] = torch.stack(
|
||||
[
|
||||
layer.weights[weight_name],
|
||||
layer.weights[v_name],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(v_name)
|
||||
elif "gate_proj" in weight_name:
|
||||
up_name = weight_name.replace("gate_proj", "up_proj")
|
||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||
if "lora_A" in weight_name:
|
||||
layer.weights[gate_up_name] = torch.cat(
|
||||
(layer.weights[weight_name], layer.weights[up_name]), 0
|
||||
)
|
||||
else:
|
||||
layer.weights[gate_up_name] = torch.stack(
|
||||
[layer.weights[weight_name], layer.weights[up_name]], dim=0
|
||||
)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(up_name)
|
||||
self.stack_qkv_proj(weight_names, layer.weights)
|
||||
self.stack_gate_up_proj(weight_names, layer.weights)
|
||||
|
||||
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
|
||||
|
||||
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
|
||||
target_module = set()
|
||||
for weight_name in weight_names:
|
||||
if "k_proj" in weight_name:
|
||||
target_module.add("k_proj")
|
||||
if "q_proj" in weight_name:
|
||||
target_module.add("q_proj")
|
||||
if "v_proj" in weight_name:
|
||||
target_module.add("v_proj")
|
||||
if len(target_module) == 0:
|
||||
return
|
||||
|
||||
for weight_name in weight_names:
|
||||
# We assume every lora adaptor should contain lora modules for q_proj
|
||||
if "q_proj" in weight_name:
|
||||
q_name = weight_name
|
||||
k_name = weight_name.replace("q_proj", "k_proj")
|
||||
v_name = weight_name.replace("q_proj", "v_proj")
|
||||
kv_name = weight_name.replace("q_proj", "kv_proj")
|
||||
qkv_name = weight_name.replace("q_proj", "qkv_proj")
|
||||
|
||||
# If k_proj doesn't have lora, initialize it to zero
|
||||
k_proj_weight = (
|
||||
weights[k_name]
|
||||
if "k_proj" in target_module
|
||||
else torch.zeros_like(weights[v_name])
|
||||
)
|
||||
if "lora_A" in weight_name:
|
||||
weights[qkv_name] = torch.cat(
|
||||
(
|
||||
weights[q_name],
|
||||
k_proj_weight,
|
||||
weights[v_name],
|
||||
),
|
||||
0,
|
||||
)
|
||||
weights.pop(q_name)
|
||||
if "k_proj" in target_module:
|
||||
weights.pop(k_name)
|
||||
weights.pop(v_name)
|
||||
else:
|
||||
weights[kv_name] = torch.stack(
|
||||
[
|
||||
k_proj_weight,
|
||||
weights[v_name],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
if "k_proj" in target_module:
|
||||
weights.pop(k_name)
|
||||
weights.pop(v_name)
|
||||
|
||||
def stack_gate_up_proj(
|
||||
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||
):
|
||||
for weight_name in weight_names:
|
||||
if "gate_proj" in weight_name:
|
||||
up_name = weight_name.replace("gate_proj", "up_proj")
|
||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||
if "lora_A" in weight_name:
|
||||
weights[gate_up_name] = torch.cat(
|
||||
(weights[weight_name], weights[up_name]), 0
|
||||
)
|
||||
else:
|
||||
weights[gate_up_name] = torch.stack(
|
||||
[weights[weight_name], weights[up_name]], dim=0
|
||||
)
|
||||
weights.pop(weight_name)
|
||||
weights.pop(up_name)
|
||||
|
||||
@@ -16,307 +16,115 @@
|
||||
# and "Punica: Multi-Tenant LoRA Serving"
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend
|
||||
from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
|
||||
from sglang.srt.lora.layers import get_lora_layer
|
||||
from sglang.srt.lora.lora import LoRAAdapter
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
||||
from sglang.srt.lora.utils import (
|
||||
LoRABatchInfo,
|
||||
LoRAType,
|
||||
get_customized_names_from_hf_names,
|
||||
get_layer_id,
|
||||
get_stacked_name,
|
||||
get_weight_name,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_flashinfer_available, replace_submodule
|
||||
from sglang.srt.utils import replace_submodule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_module_name(name):
|
||||
# Fallback solution of mapping from config module name to module name in model class.
|
||||
# Please check if it aligns with your base model.
|
||||
# Please implement the function in the model class if it is not.
|
||||
# You can reference this function in llama.py.
|
||||
params_mapping = {
|
||||
"q_proj": "qkv_proj",
|
||||
"k_proj": "qkv_proj",
|
||||
"v_proj": "qkv_proj",
|
||||
"gate_proj": "gate_up_proj",
|
||||
"up_proj": "gate_up_proj",
|
||||
}
|
||||
return params_mapping.get(name, name)
|
||||
|
||||
|
||||
def get_hidden_dim(module_name, config):
|
||||
# Fallback solution of get_hidden_dim for different modules
|
||||
# Please check if it aligns with your base model.
|
||||
# Please implement the function in the model class if it is not.
|
||||
# You can reference this function in llama.py.
|
||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
||||
return config.hidden_size, config.hidden_size
|
||||
elif module_name in ["kv_proj"]:
|
||||
return config.hidden_size, config.hidden_size // (
|
||||
config.num_attention_heads // config.num_key_value_heads
|
||||
)
|
||||
elif module_name == "gate_up_proj":
|
||||
return config.hidden_size, config.intermediate_size
|
||||
elif module_name == "down_proj":
|
||||
return config.intermediate_size, config.hidden_size
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_stacked_name(name):
|
||||
# origin name -> (name for A, name for B)
|
||||
params_mapping = {
|
||||
"q_proj": ("qkv_proj", "q_proj"),
|
||||
"k_proj": ("qkv_proj", "kv_proj"),
|
||||
"v_proj": ("qkv_proj", "kv_proj"),
|
||||
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
}
|
||||
return params_mapping.get(name, (name, name))
|
||||
|
||||
|
||||
def get_backend_from_name(name):
|
||||
backend_mapping = {
|
||||
"triton": TritonLoraBackend,
|
||||
"flashinfer": FlashInferLoraBackend,
|
||||
}
|
||||
|
||||
if name in backend_mapping:
|
||||
return backend_mapping[name]
|
||||
|
||||
raise Exception(
|
||||
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
|
||||
)
|
||||
|
||||
|
||||
def get_layer_id(name):
|
||||
match = re.search(r"layers\.(\d+)\.", name)
|
||||
if match is None:
|
||||
return None
|
||||
return int(match.group(1))
|
||||
|
||||
|
||||
class LoRAManager:
|
||||
def __init__(
|
||||
self,
|
||||
base_model,
|
||||
lora_paths,
|
||||
base_hf_config,
|
||||
max_loras_per_batch,
|
||||
load_config,
|
||||
dtype,
|
||||
lora_backend,
|
||||
base_model: torch.nn.Module,
|
||||
lora_paths: Dict[str, str],
|
||||
base_hf_config: AutoConfig,
|
||||
max_loras_per_batch: int,
|
||||
load_config: LoadConfig,
|
||||
dtype: torch.dtype,
|
||||
lora_backend: str = "triton",
|
||||
):
|
||||
self.base_model = base_model
|
||||
self.lora_paths = lora_paths
|
||||
self.base_hf_config = base_hf_config
|
||||
self.max_loras_per_batch = max_loras_per_batch
|
||||
self.load_config = load_config
|
||||
self.dtype = dtype
|
||||
self.base_model: torch.nn.Module = base_model
|
||||
self.lora_paths: Dict[str, str] = lora_paths
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
self.max_loras_per_batch: int = max_loras_per_batch
|
||||
self.load_config: LoadConfig = load_config
|
||||
self.dtype: torch.dtype = dtype
|
||||
|
||||
logger.info(f"Using {lora_backend} as backend of Lora kernels.")
|
||||
# LoRA backend for running sgemm kernels
|
||||
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
||||
backend_type = get_backend_from_name(lora_backend)
|
||||
self.lora_backend = backend_type(lora_backend)
|
||||
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
||||
|
||||
self.init_loras()
|
||||
self.init_lora_memory_pool()
|
||||
self.init_lora_batch()
|
||||
|
||||
def match_target_modules(self, module_name):
|
||||
for target_module in self.target_modules:
|
||||
if module_name.split(".")[-1] == target_module:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_target_modules(self):
|
||||
modules = []
|
||||
for module_name, module in self.base_model.named_modules():
|
||||
if self.match_target_modules(module_name):
|
||||
modules.append((module_name, module))
|
||||
return modules
|
||||
|
||||
def set_lora_module(self, module_name, module):
|
||||
lora_module = get_lora_layer(
|
||||
module, self.max_lora_dim, self.scaling, self.lora_backend
|
||||
)
|
||||
replace_submodule(self.base_model, module_name, lora_module)
|
||||
return lora_module
|
||||
|
||||
def init_loras(self):
|
||||
# get configs and target modules
|
||||
self.configs = {}
|
||||
self.origin_target_modules = set()
|
||||
# Config of each LoRA adapter
|
||||
self.configs: Dict[str, LoRAConfig] = {}
|
||||
|
||||
# Target module names in huggingface lora configs.
|
||||
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||
self.hf_target_names: Set[str] = set()
|
||||
for name, path in self.lora_paths.items():
|
||||
self.configs[name] = LoRAConfig(path)
|
||||
self.origin_target_modules = set(self.origin_target_modules) | set(
|
||||
self.hf_target_names = set(self.hf_target_names) | set(
|
||||
self.configs[name].target_modules
|
||||
)
|
||||
if hasattr(self.base_model, "get_module_name"):
|
||||
self.target_modules = {
|
||||
self.base_model.get_module_name(module)
|
||||
for module in self.origin_target_modules
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
"WARNING: get_module_name() is not defined, "
|
||||
"which is used to map config module name to model implementation module name."
|
||||
"Use the default one, but please check if it is correct for your model."
|
||||
)
|
||||
self.target_modules = {
|
||||
get_module_name(module) for module in self.origin_target_modules
|
||||
}
|
||||
self.target_weights = set(
|
||||
[get_stacked_name(module) for module in self.origin_target_modules]
|
||||
|
||||
# Target lora weight names for lora_a and lora_b modules repectively.
|
||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
|
||||
self.lora_weight_names: Set[Tuple[str]] = set(
|
||||
[get_stacked_name(module) for module in self.hf_target_names]
|
||||
)
|
||||
|
||||
# load all weights to cpu
|
||||
self.loras = []
|
||||
self.lora_id = {}
|
||||
self.loras: Dict[str, LoRAAdapter] = {}
|
||||
for name in self.lora_paths.keys():
|
||||
self.lora_id[name] = len(self.loras)
|
||||
self.loras.append(
|
||||
LoRAAdapter(
|
||||
name,
|
||||
self.configs[name],
|
||||
self.base_hf_config,
|
||||
self.load_config,
|
||||
self.lora_backend,
|
||||
)
|
||||
lora_adapter = LoRAAdapter(
|
||||
name,
|
||||
self.configs[name],
|
||||
self.base_hf_config,
|
||||
self.load_config,
|
||||
self.lora_backend,
|
||||
)
|
||||
self.loras[-1].initialize_weights()
|
||||
lora_adapter.initialize_weights()
|
||||
self.loras[name] = lora_adapter
|
||||
|
||||
# misc lora configs
|
||||
self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
self.scaling = self.loras[0].scaling
|
||||
# FIXME remove the restrictions
|
||||
# FIXME remove the restrictions after implementing unified paging
|
||||
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
self.scaling: float = list(self.loras.values())[0].scaling
|
||||
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
|
||||
assert all(x.scaling == self.scaling for x in self.loras)
|
||||
assert all(x.scaling == self.scaling for x in self.loras.values())
|
||||
|
||||
# monkey patch to use the LoRA version
|
||||
self.lora_modules = []
|
||||
for module_name, module in self.get_target_modules():
|
||||
self.lora_modules.append(
|
||||
(module_name, self.set_lora_module(module_name, module))
|
||||
)
|
||||
# Convert original model layers to layers with LoRA
|
||||
self.convert_to_lora_layers()
|
||||
|
||||
def init_lora_memory_pool(self):
|
||||
# preallocate lora memory pool
|
||||
self.A_buffer = {}
|
||||
self.B_buffer = {}
|
||||
num_layer = self.base_hf_config.num_hidden_layers
|
||||
for module_A, module_B in self.target_weights:
|
||||
# init A tensor, column_major=True
|
||||
if hasattr(self.base_model, "get_hidden_dim"):
|
||||
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
|
||||
else:
|
||||
logger.warning(
|
||||
"WARNING: get_hidden_dim() is not defined, "
|
||||
"which is used to get the hidden dim for different lora modules"
|
||||
"Use the default one, but please check if it is correct for your model."
|
||||
)
|
||||
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
|
||||
c = self.loras[-1].get_stacked_multiply(module_A)
|
||||
if module_A not in self.A_buffer:
|
||||
self.A_buffer[module_A] = [
|
||||
torch.empty(
|
||||
(
|
||||
self.max_loras_per_batch,
|
||||
self.max_lora_dim * c,
|
||||
hidden_dim_A,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(num_layer)
|
||||
]
|
||||
# init B tensor, column_major=True
|
||||
if hasattr(self.base_model, "get_hidden_dim"):
|
||||
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
|
||||
else:
|
||||
logger.warning(
|
||||
"WARNING: get_hidden_dim() is not defined, "
|
||||
"which is used to get the hidden dim for different lora modules"
|
||||
"Use the default one, but please check if it is correct for your model."
|
||||
)
|
||||
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
|
||||
c = self.loras[-1].get_stacked_multiply(module_B)
|
||||
if module_B not in self.B_buffer:
|
||||
self.B_buffer[module_B] = [
|
||||
torch.empty(
|
||||
(
|
||||
c,
|
||||
self.max_loras_per_batch,
|
||||
hidden_dim_B,
|
||||
self.max_lora_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(num_layer)
|
||||
]
|
||||
# Initialize memory pool
|
||||
self.memory_pool = LoRAMemoryPool(
|
||||
self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype
|
||||
)
|
||||
|
||||
def init_lora_batch(self):
|
||||
self.active_uids = set() # set of active loras
|
||||
self.buffer_id = {} # lora uid -> idx in memory pool
|
||||
|
||||
def get_weight_name(self, name, idx):
|
||||
for target_weight_name in self.target_weights:
|
||||
if target_weight_name[idx] in name:
|
||||
return target_weight_name[idx]
|
||||
|
||||
def load_lora(self, uid, buffer_id):
|
||||
num_layer = self.base_hf_config.num_hidden_layers
|
||||
if uid is None:
|
||||
for i in range(num_layer):
|
||||
for k in self.A_buffer.keys():
|
||||
self.A_buffer[k][i][buffer_id] *= 0
|
||||
return
|
||||
|
||||
for i in range(num_layer):
|
||||
layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
|
||||
for name, weights in layer_weights.items():
|
||||
if "lora_A" in name:
|
||||
lora_weight_name = self.get_weight_name(name, 0)
|
||||
if lora_weight_name:
|
||||
self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
||||
else:
|
||||
lora_weight_name = self.get_weight_name(name, 1)
|
||||
if lora_weight_name:
|
||||
c = self.loras[-1].get_stacked_multiply(lora_weight_name)
|
||||
if c > 1:
|
||||
for j in range(c):
|
||||
self.B_buffer[lora_weight_name][i][j][buffer_id].copy_(
|
||||
weights[j]
|
||||
)
|
||||
else:
|
||||
self.B_buffer[lora_weight_name][i][0][buffer_id].copy_(
|
||||
weights
|
||||
)
|
||||
# Initialize target lora modules in memory pool
|
||||
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
|
||||
|
||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||
# load active loras into lora memory pool
|
||||
cur_uids = set(forward_batch.lora_paths)
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
i = 0
|
||||
j = len(self.active_uids)
|
||||
evictable_uids = list(self.active_uids)
|
||||
for uid in cur_uids:
|
||||
if uid not in self.active_uids:
|
||||
if j < self.max_loras_per_batch:
|
||||
index = j
|
||||
j += 1
|
||||
else:
|
||||
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
|
||||
i += 1
|
||||
assert i < len(evictable_uids)
|
||||
self.active_uids.remove(evictable_uids[i])
|
||||
self.buffer_id.pop(evictable_uids[i])
|
||||
index = i
|
||||
i += 1
|
||||
self.load_lora(uid, index)
|
||||
self.active_uids.add(uid)
|
||||
self.buffer_id[uid] = index
|
||||
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
||||
|
||||
# FIXME: Handle lora uid with None more safely
|
||||
if cur_uids == set([None]):
|
||||
return
|
||||
|
||||
@@ -332,9 +140,9 @@ class LoRAManager:
|
||||
max_len = int(torch.max(seg_lens))
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
weight_indices[i] = self.buffer_id[lora_path]
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
||||
|
||||
batch_info = LoraBatchInfo(
|
||||
batch_info = LoRABatchInfo(
|
||||
bs=bs,
|
||||
seg_lens=seg_lens,
|
||||
seg_indptr=seg_indptr,
|
||||
@@ -346,16 +154,40 @@ class LoRAManager:
|
||||
# call set_lora_info for each lora modules
|
||||
for module_name, module in self.lora_modules:
|
||||
layer_id = get_layer_id(module_name)
|
||||
|
||||
if "qkv_proj" not in module_name:
|
||||
weight_name = self.get_weight_name(module_name, 0)
|
||||
weight_name = get_weight_name(
|
||||
module_name, self.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
module.set_lora_info(
|
||||
self.A_buffer[weight_name][layer_id],
|
||||
self.B_buffer[weight_name][layer_id],
|
||||
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
|
||||
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
|
||||
)
|
||||
else:
|
||||
module.set_lora_info(
|
||||
self.A_buffer["qkv_proj"][layer_id],
|
||||
self.B_buffer["q_proj"][layer_id],
|
||||
self.B_buffer["kv_proj"][layer_id],
|
||||
self.memory_pool.get_tensor("qkv_proj", layer_id, LoRAType.LORA_A),
|
||||
self.memory_pool.get_tensor("q_proj", layer_id, LoRAType.LORA_B),
|
||||
self.memory_pool.get_tensor("kv_proj", layer_id, LoRAType.LORA_B),
|
||||
)
|
||||
|
||||
def set_lora_module(self, module_name, module):
|
||||
lora_module = get_lora_layer(
|
||||
module, self.max_lora_dim, self.scaling, self.lora_backend
|
||||
)
|
||||
replace_submodule(self.base_model, module_name, lora_module)
|
||||
return lora_module
|
||||
|
||||
def convert_to_lora_layers(self):
|
||||
# Target module names of customized layers defined in python/sglang/srt/layers
|
||||
# e.g., {"qkv_proj", "o_proj"}
|
||||
customized_target_names = get_customized_names_from_hf_names(
|
||||
self.hf_target_names, self.base_model
|
||||
)
|
||||
|
||||
# Monkey patch to use the LoRA version layers
|
||||
self.lora_modules: List[Tuple[str, torch.nn.Module]] = []
|
||||
for module_name, module in self.base_model.named_modules():
|
||||
# The module should be converted if it is included in target_names
|
||||
if module_name.split(".")[-1] in customized_target_names:
|
||||
self.lora_modules.append(
|
||||
(module_name, self.set_lora_module(module_name, module))
|
||||
)
|
||||
|
||||
174
python/sglang/srt/lora/mem_pool.py
Normal file
174
python/sglang/srt/lora/mem_pool.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
from sglang.srt.lora.lora import LoRAAdapter
|
||||
from sglang.srt.lora.utils import (
|
||||
LoRAType,
|
||||
get_hidden_dim,
|
||||
get_stacked_multiply,
|
||||
get_weight_name,
|
||||
)
|
||||
|
||||
|
||||
class LoRAMemoryPool:
|
||||
"""Class for memory pool management of lora modules"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_hf_config: AutoConfig,
|
||||
max_loras_per_batch: int,
|
||||
max_lora_dim: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
self.num_layer: int = base_hf_config.num_hidden_layers
|
||||
self.max_loras_per_batch: int = max_loras_per_batch
|
||||
self.max_lora_dim: int = max_lora_dim
|
||||
self.dtype: torch.dtype = dtype
|
||||
|
||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||
# A_buffer contains num_layer number of row-major tensors with shape
|
||||
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
|
||||
# B_buffer contains num_layer number of column-major tensors with shape
|
||||
# (stacked_num, max_loras_per_batch, output_dim, max_lora_dim)
|
||||
self.A_buffer: Dict[str, List[torch.Tensor]] = {}
|
||||
self.B_buffer: Dict[str, List[torch.Tensor]] = {}
|
||||
|
||||
# Lora uid -> buffer idx in memory pool
|
||||
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
|
||||
|
||||
# Buffer idx -> lora uid in memory pool
|
||||
# All uids are initalized as empty strings for empty buffer slots
|
||||
# Here we don't initalize to None since None is a valid uid
|
||||
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
||||
|
||||
def init_buffers(
|
||||
self,
|
||||
lora_weight_names: Set[Tuple[str]],
|
||||
base_model: torch.nn.Module,
|
||||
):
|
||||
|
||||
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
||||
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
|
||||
|
||||
for module_A, module_B in lora_weight_names:
|
||||
# Init A tensor, column_major=False
|
||||
input_dim, _ = get_hidden_dim(module_A, self.base_hf_config, base_model)
|
||||
c = get_stacked_multiply(module_A)
|
||||
if module_A not in self.A_buffer:
|
||||
self.A_buffer[module_A] = [
|
||||
torch.empty(
|
||||
(
|
||||
self.max_loras_per_batch,
|
||||
self.max_lora_dim * c,
|
||||
input_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(self.num_layer)
|
||||
]
|
||||
|
||||
# Init B tensor, column_major=True
|
||||
_, output_dim = get_hidden_dim(module_B, self.base_hf_config, base_model)
|
||||
c = get_stacked_multiply(module_B)
|
||||
if module_B not in self.B_buffer:
|
||||
self.B_buffer[module_B] = [
|
||||
torch.empty(
|
||||
(
|
||||
c, # stacked lora_b modules might need separation
|
||||
self.max_loras_per_batch,
|
||||
output_dim,
|
||||
self.max_lora_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(self.num_layer)
|
||||
]
|
||||
|
||||
def prepare_lora_batch(
|
||||
self,
|
||||
cur_uids: Set[Optional[str]],
|
||||
lora_adapters: Dict[str, LoRAAdapter],
|
||||
):
|
||||
|
||||
def get_available_buffer_slot():
|
||||
for buffer_id in range(self.max_loras_per_batch):
|
||||
# Prioritize empty slots
|
||||
if self.buffer_id_to_uid[buffer_id] == "":
|
||||
return buffer_id, ""
|
||||
|
||||
for buffer_id in range(self.max_loras_per_batch):
|
||||
# Evict unneeded lora
|
||||
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
|
||||
return buffer_id, self.buffer_id_to_uid[buffer_id]
|
||||
|
||||
raise ValueError(
|
||||
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
|
||||
)
|
||||
|
||||
for uid in cur_uids:
|
||||
if uid not in self.uid_to_buffer_id:
|
||||
buffer_id, evicted_lora_uid = get_available_buffer_slot()
|
||||
if evicted_lora_uid != "":
|
||||
self.uid_to_buffer_id.pop(evicted_lora_uid)
|
||||
self.load_lora_weight_to_buffer(
|
||||
uid, buffer_id, lora_adapters.get(uid, None)
|
||||
)
|
||||
self.uid_to_buffer_id[uid] = buffer_id
|
||||
self.buffer_id_to_uid[buffer_id] = uid
|
||||
|
||||
def load_lora_weight_to_buffer(
|
||||
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
|
||||
):
|
||||
|
||||
if uid is None:
|
||||
for i in range(self.num_layer):
|
||||
for k in self.A_buffer.keys():
|
||||
self.A_buffer[k][i][buffer_id] *= 0
|
||||
return
|
||||
|
||||
assert lora_adapter is not None
|
||||
for layer_id in range(self.num_layer):
|
||||
layer_weights = lora_adapter.layers[layer_id].weights
|
||||
for name, weights in layer_weights.items():
|
||||
if "lora_A" in name:
|
||||
lora_weight_name = get_weight_name(
|
||||
name, self.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
if lora_weight_name:
|
||||
self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_(
|
||||
weights
|
||||
)
|
||||
else:
|
||||
lora_weight_name = get_weight_name(
|
||||
name, self.lora_weight_names, LoRAType.LORA_B
|
||||
)
|
||||
if lora_weight_name:
|
||||
c = get_stacked_multiply(lora_weight_name)
|
||||
if c > 1:
|
||||
for stacked_id in range(c):
|
||||
self.B_buffer[lora_weight_name][layer_id][stacked_id][
|
||||
buffer_id
|
||||
].copy_(weights[stacked_id])
|
||||
else:
|
||||
self.B_buffer[lora_weight_name][layer_id][0][
|
||||
buffer_id
|
||||
].copy_(weights)
|
||||
|
||||
def get_tensor(
|
||||
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
||||
) -> torch.Tensor:
|
||||
|
||||
if lora_type == LoRAType.LORA_A:
|
||||
return self.A_buffer[weight_name][layer_id]
|
||||
|
||||
return self.B_buffer[weight_name][layer_id]
|
||||
|
||||
def get_buffer_id(self, lora_uid: str):
|
||||
return self.uid_to_buffer_id[lora_uid]
|
||||
@@ -1,5 +1,11 @@
|
||||
from .gate_up_lora_b import gate_up_lora_b_fwd
|
||||
from .qkv_lora_b import qkv_lora_b_fwd
|
||||
from .sgemm_lora_a import sgemm_lora_a_fwd
|
||||
from .sgemm_lora_b import sgemm_lora_b_fwd
|
||||
|
||||
__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"]
|
||||
__all__ = [
|
||||
"gate_up_lora_b_fwd",
|
||||
"qkv_lora_b_fwd",
|
||||
"sgemm_lora_a_fwd",
|
||||
"sgemm_lora_b_fwd",
|
||||
]
|
||||
|
||||
170
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
Normal file
170
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gate_up_lora_b_kernel(
|
||||
# Pointers to matrices
|
||||
x,
|
||||
weights,
|
||||
output,
|
||||
# Parameters of size
|
||||
K, # K = R
|
||||
output_dim,
|
||||
# Strides
|
||||
x_stride_0,
|
||||
x_stride_1,
|
||||
w_stride_0,
|
||||
w_stride_1,
|
||||
w_stride_2,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
# Information on sequence lengths and weight id
|
||||
seg_lens,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
# Meta parameters
|
||||
BLOCK_S: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
# For fused output scaling and adding
|
||||
fuse_scaling_add,
|
||||
scaling,
|
||||
):
|
||||
# This kernel packs 2 sgemms (gate/up) into a single kernel.
|
||||
|
||||
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
|
||||
# weights: (num_lora, 2 * output_dim, K)
|
||||
# output: (s, 2 * output_dim)
|
||||
# output_dim >> K
|
||||
|
||||
# Current block computes sequence with batch_id,
|
||||
# which starts from row seg_start of x with length seg_len.
|
||||
# gate_up_id decides which of gate or up (0: gate, 1: up)
|
||||
batch_id = tl.program_id(axis=2)
|
||||
gate_up_id = tl.program_id(axis=1)
|
||||
pid = tl.program_id(axis=0)
|
||||
seg_len = tl.load(seg_lens + batch_id)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
n_start = gate_up_id * output_dim # offset on output dim
|
||||
|
||||
# The tile in output matrix will have (pid_s, pid_n) as id
|
||||
num_pid_n = tl.cdiv(output_dim, BLOCK_N)
|
||||
pid_s = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
# Create pointers for the first block of x and weights
|
||||
# The pointers will be advanced as we move in the K direction
|
||||
# and accumulate
|
||||
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
|
||||
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
k_offset = tl.arange(0, BLOCK_K)
|
||||
|
||||
x_ptrs = (x + seg_start * x_stride_0 + (gate_up_id * K) * x_stride_1) + (
|
||||
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
||||
)
|
||||
w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
|
||||
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
||||
)
|
||||
|
||||
# Iteate to compute the block in output matrix
|
||||
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x_tile = tl.load(
|
||||
x_ptrs,
|
||||
mask=(s_offset[:, None] < seg_len)
|
||||
and (k_offset[None, :] < K - k * BLOCK_K),
|
||||
other=0.0,
|
||||
)
|
||||
w_tile = tl.load(
|
||||
w_ptrs,
|
||||
mask=(k_offset[:, None] < K - k * BLOCK_K)
|
||||
and (n_offset[None, :] < output_dim),
|
||||
other=0.0,
|
||||
)
|
||||
partial_sum += tl.dot(x_tile, w_tile)
|
||||
|
||||
x_ptrs += BLOCK_K * x_stride_1
|
||||
w_ptrs += BLOCK_K * w_stride_2
|
||||
|
||||
# Store result to output matrix
|
||||
partial_sum *= scaling
|
||||
partial_sum = partial_sum.to(x.dtype.element_ty)
|
||||
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||
)
|
||||
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
|
||||
if fuse_scaling_add:
|
||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||
|
||||
|
||||
def gate_up_lora_b_fwd(
|
||||
x: torch.Tensor,
|
||||
gate_up_lora_b: torch.Tensor,
|
||||
batch_info: LoRABatchInfo,
|
||||
output_dim: int,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# x: (s, 2 * r)
|
||||
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
|
||||
# output: (s, 2 * output_dim)
|
||||
|
||||
# Compute lora_output with shape (s, output_dim) as follows:
|
||||
# lora_output[:, :output_dim] = sgemm(x[:, :r], gate_up_lora_b[:, :output_dim, :])
|
||||
# lora_output[:, output_dim:]
|
||||
# = sgemm(x[:, r:], gate_up_lora_b[:, output_dim:, :])
|
||||
|
||||
# Get dims
|
||||
s = x.shape[0]
|
||||
input_dim = x.shape[1]
|
||||
r = gate_up_lora_b.shape[-1]
|
||||
assert input_dim == 2 * r
|
||||
|
||||
BLOCK_S = 16
|
||||
BLOCK_R = 16
|
||||
BLOCK_OUT = 64
|
||||
|
||||
grid_b = (
|
||||
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(output_dim, BLOCK_OUT),
|
||||
2, # this dimension decides current block computes on gate or up proj
|
||||
batch_info.bs,
|
||||
)
|
||||
|
||||
if base_output is None:
|
||||
output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
|
||||
fuse_scaling_add = False
|
||||
else:
|
||||
output = base_output
|
||||
fuse_scaling_add = True
|
||||
|
||||
_gate_up_lora_b_kernel[grid_b](
|
||||
x,
|
||||
gate_up_lora_b,
|
||||
output,
|
||||
r,
|
||||
output_dim,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
gate_up_lora_b.stride(0),
|
||||
gate_up_lora_b.stride(1),
|
||||
gate_up_lora_b.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
batch_info.seg_lens,
|
||||
batch_info.seg_indptr,
|
||||
batch_info.weight_indices,
|
||||
BLOCK_S,
|
||||
BLOCK_OUT,
|
||||
BLOCK_R,
|
||||
fuse_scaling_add,
|
||||
scaling,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.lora.lora import LoraBatchInfo
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -108,7 +108,7 @@ def _qkv_lora_b_kernel(
|
||||
def qkv_lora_b_fwd(
|
||||
x: torch.Tensor,
|
||||
qkv_lora_b: torch.Tensor,
|
||||
batch_info: LoraBatchInfo,
|
||||
batch_info: LoRABatchInfo,
|
||||
output_offset: torch.Tensor,
|
||||
max_qkv_out_dim: int,
|
||||
base_output: torch.Tensor = None,
|
||||
@@ -123,11 +123,11 @@ def qkv_lora_b_fwd(
|
||||
# output: (s, output_dim_q + 2 * output_dim_kv)
|
||||
|
||||
# Compute lora_output with shape (s, output_dim) as follows:
|
||||
# lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
|
||||
# lora_output[:, :output_dim_q] = sgemm(x[:, :r], qkv_lora_b[:, :outptu_dim_q, :])
|
||||
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
|
||||
# = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
|
||||
# = sgemm(x[:, r: 2 * r], qkv_lora_b[:, outptu_dim_q: output_dim_q + output_dim_kv, :])
|
||||
# lora_output[:, output_dim_q + output_dim_kv: ]
|
||||
# = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
|
||||
# = sgemm(x[:, 2 * r: , qkv_lora_b[:, output_dim_q + output_dim_kv: , :])
|
||||
|
||||
# Get dims
|
||||
s = x.shape[0]
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.lora.lora import LoraBatchInfo
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel(
|
||||
|
||||
|
||||
def sgemm_lora_a_fwd(
|
||||
x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
|
||||
x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo
|
||||
) -> torch.Tensor:
|
||||
# x: (s, input_dim)
|
||||
# weights: (num_lora, r, input_dim)
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.lora.lora import LoraBatchInfo
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel(
|
||||
def sgemm_lora_b_fwd(
|
||||
x: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
batch_info: LoraBatchInfo,
|
||||
batch_info: LoRABatchInfo,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
|
||||
141
python/sglang/srt/lora/utils.py
Normal file
141
python/sglang/srt/lora/utils.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRABatchInfo:
|
||||
# Batch size
|
||||
bs: int
|
||||
|
||||
# Lengths of each sequence in shape (bs,)
|
||||
seg_lens: torch.Tensor
|
||||
|
||||
# Indice pointers of each sequence in shape (bs + 1, )
|
||||
seg_indptr: torch.Tensor
|
||||
|
||||
# Maximum sequence length of current batch
|
||||
max_len: int
|
||||
|
||||
# The index of lora adapter used by each sequence, in shape (bs,)
|
||||
weight_indices: torch.Tensor
|
||||
|
||||
|
||||
class LoRAType(Enum):
|
||||
LORA_A = 0
|
||||
LORA_B = 1
|
||||
|
||||
|
||||
def get_layer_id(name: str) -> int:
|
||||
"""
|
||||
Extract integer id of layer from its name in string.
|
||||
"""
|
||||
match = re.search(r"layers\.(\d+)\.", name)
|
||||
if match is None:
|
||||
return None
|
||||
return int(match.group(1))
|
||||
|
||||
|
||||
def get_customized_names_from_hf_names(
|
||||
hf_module_names: Set[str], base_model: torch.nn.Module
|
||||
) -> Set[str]:
|
||||
"""
|
||||
This function takes in a set of huggingface style module names:
|
||||
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||
and outputs a set of module names of customized sglang layers:
|
||||
e.g., {"qkv_proj", "o_proj"}
|
||||
"""
|
||||
if hasattr(base_model, "get_module_name"):
|
||||
return {base_model.get_module_name(name) for name in hf_module_names}
|
||||
else:
|
||||
"""
|
||||
Fallback solution of mapping from config module name to module name in model class.
|
||||
Please check if it aligns with your base model.
|
||||
Please implement the function in the model class if it is not.
|
||||
You can reference this function in llama.py.
|
||||
"""
|
||||
params_mapping = {
|
||||
"q_proj": "qkv_proj",
|
||||
"k_proj": "qkv_proj",
|
||||
"v_proj": "qkv_proj",
|
||||
"gate_proj": "gate_up_proj",
|
||||
"up_proj": "gate_up_proj",
|
||||
}
|
||||
return {params_mapping.get(name, name) for name in hf_module_names}
|
||||
|
||||
|
||||
def get_hidden_dim(
|
||||
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
||||
) -> Tuple[int]:
|
||||
"""
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
|
||||
"""
|
||||
|
||||
if hasattr(base_model, "get_hidden_dim"):
|
||||
return base_model.get_hidden_dim(module_name)
|
||||
else:
|
||||
"""
|
||||
WARNING: get_hidden_dim() is not defined,
|
||||
which is used to get the hidden dim for different lora modules
|
||||
Use the default one, but please check if it is correct for your model.
|
||||
Please implement the function in the model class if it is not.
|
||||
You can reference this function in llama.py.
|
||||
"""
|
||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
||||
return config.hidden_size, config.hidden_size
|
||||
elif module_name in ["kv_proj"]:
|
||||
return config.hidden_size, config.hidden_size // (
|
||||
config.num_attention_heads // config.num_key_value_heads
|
||||
)
|
||||
elif module_name == "gate_up_proj":
|
||||
return config.hidden_size, config.intermediate_size
|
||||
elif module_name == "down_proj":
|
||||
return config.intermediate_size, config.hidden_size
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_stacked_name(name: str) -> Tuple[str]:
|
||||
"""
|
||||
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
|
||||
"""
|
||||
params_mapping = {
|
||||
"q_proj": ("qkv_proj", "q_proj"),
|
||||
"k_proj": ("qkv_proj", "kv_proj"),
|
||||
"v_proj": ("qkv_proj", "kv_proj"),
|
||||
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
}
|
||||
return params_mapping.get(name, (name, name))
|
||||
|
||||
|
||||
def get_stacked_multiply(module_name: str) -> int:
|
||||
"""
|
||||
Mapping a lora module name to its magnification at output dimension
|
||||
"""
|
||||
stacked_rank = {
|
||||
"qkv_proj": 3,
|
||||
"kv_proj": 2,
|
||||
"gate_up_proj": 2,
|
||||
}
|
||||
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
||||
|
||||
|
||||
def get_weight_name(
|
||||
target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
target_name is name of a given module,
|
||||
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
|
||||
If there is a weight name in lora_weight_names that can match target_name, return this name
|
||||
Else return None
|
||||
"""
|
||||
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
||||
for weight_name_pair in lora_weight_names:
|
||||
if weight_name_pair[idx] in target_name:
|
||||
return weight_name_pair[idx]
|
||||
Reference in New Issue
Block a user