[Fix] Fix accuracy bug and refactor codes for lora (#3413)
This commit is contained in:
@@ -1,8 +1,28 @@
|
|||||||
from .base_backend import BaseLoraBackend
|
from .base_backend import BaseLoRABackend
|
||||||
from .flashinfer_backend import FlashInferLoraBackend
|
from .flashinfer_backend import FlashInferLoRABackend
|
||||||
from .triton_backend import TritonLoraBackend
|
from .triton_backend import TritonLoRABackend
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
||||||
|
"""
|
||||||
|
Get corresponding backend class from backend's name
|
||||||
|
"""
|
||||||
|
backend_mapping = {
|
||||||
|
"triton": TritonLoRABackend,
|
||||||
|
"flashinfer": FlashInferLoRABackend,
|
||||||
|
}
|
||||||
|
|
||||||
|
if name in backend_mapping:
|
||||||
|
return backend_mapping[name]
|
||||||
|
|
||||||
|
raise Exception(
|
||||||
|
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FlashInferLoraBackend",
|
"BaseLoRABackend",
|
||||||
"TritonLoraBackend",
|
"FlashInferLoRABackend",
|
||||||
|
"TritonLoRABackend",
|
||||||
|
"get_backend_from_name",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.lora.lora import LoraBatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
|
||||||
|
|
||||||
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
|
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
|
||||||
@@ -13,7 +13,7 @@ def get_fuse_output_scaling_add_from_name(name: str) -> bool:
|
|||||||
return mapping.get(name, False)
|
return mapping.get(name, False)
|
||||||
|
|
||||||
|
|
||||||
def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
|
def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
|
||||||
mapping = {
|
mapping = {
|
||||||
"triton": True,
|
"triton": True,
|
||||||
"flashinfer": False,
|
"flashinfer": False,
|
||||||
@@ -21,7 +21,7 @@ def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
|
|||||||
return mapping.get(name, False)
|
return mapping.get(name, False)
|
||||||
|
|
||||||
|
|
||||||
class BaseLoraBackend:
|
class BaseLoRABackend:
|
||||||
"""Base class for different Lora backends.
|
"""Base class for different Lora backends.
|
||||||
Each backend has its own implementation of Lora kernels.
|
Each backend has its own implementation of Lora kernels.
|
||||||
|
|
||||||
@@ -32,11 +32,11 @@ class BaseLoraBackend:
|
|||||||
and the operation of scaling and adding will be fused into kernel
|
and the operation of scaling and adding will be fused into kernel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.batch_info = batch_info
|
self.batch_info = batch_info
|
||||||
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
|
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
|
||||||
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
|
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
|
||||||
|
|
||||||
def run_lora_a_sgemm(
|
def run_lora_a_sgemm(
|
||||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||||
@@ -46,10 +46,11 @@ class BaseLoraBackend:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
||||||
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
|
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
|
||||||
|
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
|
||||||
usually input_dim is much larger than r
|
usually input_dim is much larger than r
|
||||||
Returns:
|
Returns:
|
||||||
result with shape (s, r)
|
result with shape (s, c * r)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -83,7 +84,7 @@ class BaseLoraBackend:
|
|||||||
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
|
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
|
||||||
qkv_lora_b: lora_b module for qkv.
|
qkv_lora_b: lora_b module for qkv.
|
||||||
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
|
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
|
||||||
If passed in as a tuple of two tensors containing:
|
If passed in as a tuple of two tensors, it should contain:
|
||||||
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
|
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
|
||||||
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
|
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
|
||||||
Returns:
|
Returns:
|
||||||
@@ -91,5 +92,26 @@ class BaseLoraBackend:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def set_batch_info(self, batch_info: LoraBatchInfo):
|
def run_gate_up_lora(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
gate_up_lora_a: torch.Tensor,
|
||||||
|
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
||||||
|
gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
|
||||||
|
gate_up_lora_b: lora_b module for qkv.
|
||||||
|
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
|
||||||
|
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
|
||||||
|
Returns:
|
||||||
|
result with shape (s, 2 * output_dim)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_batch_info(self, batch_info: LoRABatchInfo):
|
||||||
self.batch_info = batch_info
|
self.batch_info = batch_info
|
||||||
|
|||||||
@@ -2,17 +2,17 @@ from typing import Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.lora.backend import BaseLoraBackend
|
from sglang.srt.lora.backend import BaseLoRABackend
|
||||||
from sglang.srt.lora.lora import LoraBatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
from flashinfer import SegmentGEMMWrapper
|
from flashinfer import SegmentGEMMWrapper
|
||||||
|
|
||||||
|
|
||||||
class FlashInferLoraBackend(BaseLoraBackend):
|
class FlashInferLoRABackend(BaseLoRABackend):
|
||||||
|
|
||||||
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
||||||
super().__init__(name, batch_info)
|
super().__init__(name, batch_info)
|
||||||
|
|
||||||
# Set up SGemm Wrapper from flashinfer
|
# Set up SGemm Wrapper from flashinfer
|
||||||
@@ -55,6 +55,8 @@ class FlashInferLoraBackend(BaseLoraBackend):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
|
||||||
|
|
||||||
# Shape of lora_a_output: (s, 3 * r)
|
# Shape of lora_a_output: (s, 3 * r)
|
||||||
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
|
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
|
||||||
|
|
||||||
@@ -89,3 +91,38 @@ class FlashInferLoraBackend(BaseLoraBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return lora_output
|
return lora_output
|
||||||
|
|
||||||
|
def run_gate_up_lora(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
gate_up_lora_a: torch.Tensor,
|
||||||
|
gate_up_lora_b: Tuple[torch.Tensor],
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
|
||||||
|
lora_rank = gate_up_lora_b[0].shape[-1]
|
||||||
|
output_dim = gate_up_lora_b[0].shape[-2]
|
||||||
|
|
||||||
|
# Shape of lora_a_output: (s, 2 * r)
|
||||||
|
lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
|
||||||
|
|
||||||
|
lora_output = torch.empty(
|
||||||
|
(x.shape[0], 2 * output_dim),
|
||||||
|
device=x.device,
|
||||||
|
dtype=x.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute lora for gate and up proj respectively
|
||||||
|
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
|
||||||
|
x=lora_a_output[:, :lora_rank].contiguous(),
|
||||||
|
weights=gate_up_lora_b[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_output[:, output_dim:] = self.run_lora_b_sgemm(
|
||||||
|
x=lora_a_output[:, lora_rank:].contiguous(),
|
||||||
|
weights=gate_up_lora_b[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
return lora_output
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.lora.backend import BaseLoraBackend
|
from sglang.srt.lora.backend import BaseLoRABackend
|
||||||
from sglang.srt.lora.lora import LoraBatchInfo
|
|
||||||
from sglang.srt.lora.triton_ops import (
|
from sglang.srt.lora.triton_ops import (
|
||||||
|
gate_up_lora_b_fwd,
|
||||||
qkv_lora_b_fwd,
|
qkv_lora_b_fwd,
|
||||||
sgemm_lora_a_fwd,
|
sgemm_lora_a_fwd,
|
||||||
sgemm_lora_b_fwd,
|
sgemm_lora_b_fwd,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
|
||||||
|
|
||||||
class TritonLoraBackend(BaseLoraBackend):
|
class TritonLoRABackend(BaseLoRABackend):
|
||||||
|
|
||||||
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
||||||
super().__init__(name, batch_info)
|
super().__init__(name, batch_info)
|
||||||
|
|
||||||
def run_lora_a_sgemm(
|
def run_lora_a_sgemm(
|
||||||
@@ -59,3 +60,32 @@ class TritonLoraBackend(BaseLoraBackend):
|
|||||||
scaling,
|
scaling,
|
||||||
)
|
)
|
||||||
return lora_output
|
return lora_output
|
||||||
|
|
||||||
|
def run_gate_up_lora(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
gate_up_lora_a: torch.Tensor,
|
||||||
|
gate_up_lora_b: torch.Tensor,
|
||||||
|
base_output: torch.Tensor = None,
|
||||||
|
scaling: float = 1.0,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# x: (s, input_dim)
|
||||||
|
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
|
||||||
|
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
|
||||||
|
assert isinstance(gate_up_lora_b, torch.Tensor)
|
||||||
|
output_dim = gate_up_lora_b.shape[-2] // 2
|
||||||
|
|
||||||
|
# lora_a_output: (s, 2 * r)
|
||||||
|
lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info)
|
||||||
|
lora_output = gate_up_lora_b_fwd(
|
||||||
|
lora_a_output,
|
||||||
|
gate_up_lora_b,
|
||||||
|
self.batch_info,
|
||||||
|
output_dim,
|
||||||
|
base_output,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
return lora_output
|
||||||
|
|||||||
293
python/sglang/srt/lora/layers.py
Normal file
293
python/sglang/srt/lora/layers.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from sglang.srt.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
split_tensor_along_last_dim,
|
||||||
|
tensor_model_parallel_all_gather,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.linear import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
|
from sglang.srt.lora.backend import BaseLoRABackend
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLayerWithLoRA(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_layer: nn.Module,
|
||||||
|
lora_rank: int,
|
||||||
|
scaling: float,
|
||||||
|
lora_backend: BaseLoRABackend,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.base_layer: nn.Module = base_layer
|
||||||
|
self.lora_rank: int = lora_rank
|
||||||
|
self.scaling: float = scaling
|
||||||
|
self.set_lora: bool = False
|
||||||
|
self.lora_backend: BaseLoRABackend = lora_backend
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return self.base_layer.forward(x)
|
||||||
|
|
||||||
|
def set_lora_info(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_layer: VocabParallelEmbedding,
|
||||||
|
lora_rank: int,
|
||||||
|
scaling: float,
|
||||||
|
lora_backend: BaseLoRABackend,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||||
|
self.weight = base_layer.weight
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_layer: ColumnParallelLinear,
|
||||||
|
lora_rank: int,
|
||||||
|
scaling: float,
|
||||||
|
lora_backend: BaseLoRABackend,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||||
|
|
||||||
|
def set_lora_info(
|
||||||
|
self,
|
||||||
|
A_buffer: torch.Tensor,
|
||||||
|
B_buffer: torch.Tensor,
|
||||||
|
):
|
||||||
|
self.set_lora = True
|
||||||
|
self.A_buffer = A_buffer
|
||||||
|
self.B_buffer = B_buffer
|
||||||
|
|
||||||
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
|
||||||
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||||
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||||
|
lora_a_output,
|
||||||
|
self.B_buffer[0],
|
||||||
|
**backend_kwargs,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
lora_output
|
||||||
|
if self.lora_backend.fuse_output_scaling_add
|
||||||
|
else base_output + lora_output * self.scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input_: torch.Tensor):
|
||||||
|
# duplicate the logic in ColumnParallelLinear
|
||||||
|
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
||||||
|
output_parallel = self.base_layer.quant_method.apply(
|
||||||
|
self.base_layer, input_, bias
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.set_lora:
|
||||||
|
output_parallel = self.apply_lora(output_parallel, input_)
|
||||||
|
|
||||||
|
if self.base_layer.gather_output:
|
||||||
|
output = tensor_model_parallel_all_gather(output_parallel)
|
||||||
|
else:
|
||||||
|
output = output_parallel
|
||||||
|
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
|
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_layer: MergedColumnParallelLinear,
|
||||||
|
lora_rank: int,
|
||||||
|
scaling: float,
|
||||||
|
lora_backend: BaseLoRABackend,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||||
|
|
||||||
|
def set_lora_info(
|
||||||
|
self,
|
||||||
|
A_buffer: torch.Tensor,
|
||||||
|
B_buffer: torch.Tensor,
|
||||||
|
):
|
||||||
|
self.set_lora = True
|
||||||
|
self.A_buffer_gate_up = A_buffer
|
||||||
|
if self.lora_backend.fuse_stacked_lora_b:
|
||||||
|
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
||||||
|
self.B_buffer_gate_up = torch.cat(
|
||||||
|
(B_buffer[0], B_buffer[1]), dim=-2
|
||||||
|
).contiguous()
|
||||||
|
else:
|
||||||
|
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
||||||
|
|
||||||
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
|
||||||
|
|
||||||
|
lora_output = self.lora_backend.run_gate_up_lora(
|
||||||
|
x,
|
||||||
|
self.A_buffer_gate_up,
|
||||||
|
self.B_buffer_gate_up,
|
||||||
|
**backend_kwargs,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
lora_output
|
||||||
|
if self.lora_backend.fuse_output_scaling_add
|
||||||
|
else base_output + lora_output * self.scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||||
|
def init__(
|
||||||
|
self,
|
||||||
|
base_layer: QKVParallelLinear,
|
||||||
|
lora_rank: int,
|
||||||
|
scaling: float,
|
||||||
|
lora_backend: BaseLoRABackend,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||||
|
|
||||||
|
def set_lora_info(
|
||||||
|
self,
|
||||||
|
A_buffer_qkv: torch.Tensor,
|
||||||
|
B_buffer_q: torch.Tensor,
|
||||||
|
B_buffer_kv: torch.Tensor,
|
||||||
|
):
|
||||||
|
self.set_lora = True
|
||||||
|
self.A_buffer_qkv = A_buffer_qkv
|
||||||
|
|
||||||
|
if self.lora_backend.fuse_stacked_lora_b:
|
||||||
|
assert (
|
||||||
|
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
||||||
|
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
||||||
|
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
||||||
|
|
||||||
|
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||||
|
self.B_buffer_qkv = torch.cat(
|
||||||
|
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
# Offsets of q/k/v in output dimension
|
||||||
|
self.output_offset = torch.tensor(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
output_dim_q,
|
||||||
|
output_dim_q + output_dim_kv,
|
||||||
|
output_dim_q + 2 * output_dim_kv,
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=B_buffer_q.device,
|
||||||
|
)
|
||||||
|
# For computing number of launched blocks
|
||||||
|
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
||||||
|
else:
|
||||||
|
self.B_buffer_qkv = (
|
||||||
|
B_buffer_q,
|
||||||
|
B_buffer_kv,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
|
||||||
|
if self.lora_backend.fuse_stacked_lora_b:
|
||||||
|
backend_kwargs["output_offset"] = self.output_offset
|
||||||
|
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
|
||||||
|
|
||||||
|
lora_output = self.lora_backend.run_qkv_lora(
|
||||||
|
x,
|
||||||
|
self.A_buffer_qkv,
|
||||||
|
self.B_buffer_qkv,
|
||||||
|
**backend_kwargs,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
lora_output
|
||||||
|
if self.lora_backend.fuse_output_scaling_add
|
||||||
|
else base_output + lora_output * self.scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_layer: RowParallelLinear,
|
||||||
|
lora_rank: int,
|
||||||
|
scaling: float,
|
||||||
|
lora_backend: BaseLoRABackend,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||||
|
|
||||||
|
def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
|
||||||
|
self.set_lora = True
|
||||||
|
self.A_buffer = A_buffer
|
||||||
|
self.B_buffer = B_buffer
|
||||||
|
|
||||||
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
|
||||||
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||||
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||||
|
lora_a_output,
|
||||||
|
self.B_buffer[0],
|
||||||
|
**backend_kwargs,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
lora_output
|
||||||
|
if self.lora_backend.fuse_output_scaling_add
|
||||||
|
else base_output + lora_output * self.scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input_: torch.Tensor):
|
||||||
|
# duplicate the logic in RowParallelLinear
|
||||||
|
if self.base_layer.input_is_parallel:
|
||||||
|
input_parallel = input_
|
||||||
|
else:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
splitted_input = split_tensor_along_last_dim(
|
||||||
|
input_, num_partitions=self.base_layer.tp_size
|
||||||
|
)
|
||||||
|
input_parallel = splitted_input[tp_rank].contiguous()
|
||||||
|
output_parallel = self.base_layer.quant_method.apply(
|
||||||
|
self.base_layer, input_parallel
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.set_lora:
|
||||||
|
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
||||||
|
|
||||||
|
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||||
|
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
|
else:
|
||||||
|
output_ = output_parallel
|
||||||
|
|
||||||
|
if not self.base_layer.skip_bias_add:
|
||||||
|
output = (
|
||||||
|
output_ + self.base_layer.bias
|
||||||
|
if self.base_layer.bias is not None
|
||||||
|
else output_
|
||||||
|
)
|
||||||
|
output_bias = None
|
||||||
|
else:
|
||||||
|
output = output_
|
||||||
|
output_bias = self.base_layer.bias
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_layer(
|
||||||
|
layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend
|
||||||
|
) -> BaseLayerWithLoRA:
|
||||||
|
supported_layer_types = {
|
||||||
|
# the order matters
|
||||||
|
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
||||||
|
QKVParallelLinear: QKVParallelLinearWithLoRA,
|
||||||
|
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
||||||
|
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
||||||
|
RowParallelLinear: RowParallelLinearWithLoRA,
|
||||||
|
}
|
||||||
|
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||||
|
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
||||||
|
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
|
||||||
|
return ret
|
||||||
|
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
||||||
@@ -19,282 +19,25 @@
|
|||||||
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
ColumnParallelLinear,
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||||
MergedColumnParallelLinear,
|
from sglang.srt.lora.backend import BaseLoRABackend
|
||||||
QKVParallelLinear,
|
from sglang.srt.lora.lora_config import LoRAConfig
|
||||||
RowParallelLinear,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
||||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoraBatchInfo:
|
|
||||||
# Batch size
|
|
||||||
bs: int
|
|
||||||
|
|
||||||
# Lengths of each sequence in shape (bs,)
|
|
||||||
seg_lens: torch.Tensor
|
|
||||||
|
|
||||||
# Indice pointers of each sequence in shape (bs + 1, )
|
|
||||||
seg_indptr: torch.Tensor
|
|
||||||
|
|
||||||
# Maximum sequence length of current batch
|
|
||||||
max_len: int
|
|
||||||
|
|
||||||
# The index of lora adapter used by each sequence, in shape (bs,)
|
|
||||||
weight_indices: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLayerWithLoRA(nn.Module):
|
|
||||||
def __init__(self, base_layer, lora_rank, scaling, lora_backend):
|
|
||||||
super().__init__()
|
|
||||||
self.base_layer = base_layer
|
|
||||||
self.lora_rank = lora_rank
|
|
||||||
self.scaling = scaling
|
|
||||||
self.set_lora = False
|
|
||||||
self.lora_backend = lora_backend
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
return self.base_layer.forward(x)
|
|
||||||
|
|
||||||
def set_lora_info(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|
||||||
def __init__(
|
|
||||||
self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
|
|
||||||
) -> None:
|
|
||||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
|
||||||
self.weight = base_layer.weight
|
|
||||||
|
|
||||||
|
|
||||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
||||||
def __init__(
|
|
||||||
self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
|
|
||||||
) -> None:
|
|
||||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
|
||||||
|
|
||||||
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
# TODO
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(self, input_: torch.Tensor):
|
|
||||||
# duplicate the logic in ColumnParallelLinear
|
|
||||||
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
|
||||||
output_parallel = self.base_layer.quant_method.apply(
|
|
||||||
self.base_layer, input_, bias
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.set_lora:
|
|
||||||
output_parallel = self.apply_lora(output_parallel, input_)
|
|
||||||
|
|
||||||
if self.base_layer.gather_output:
|
|
||||||
output = tensor_model_parallel_all_gather(output_parallel)
|
|
||||||
else:
|
|
||||||
output = output_parallel
|
|
||||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
|
||||||
return output, output_bias
|
|
||||||
|
|
||||||
|
|
||||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
||||||
def __init__(
|
|
||||||
self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
|
|
||||||
) -> None:
|
|
||||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
|
||||||
|
|
||||||
def set_lora_info(
|
|
||||||
self,
|
|
||||||
A_buffer,
|
|
||||||
B_buffer,
|
|
||||||
):
|
|
||||||
self.set_lora = True
|
|
||||||
self.A_buffer = A_buffer
|
|
||||||
self.B_buffer = B_buffer
|
|
||||||
|
|
||||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
|
|
||||||
|
|
||||||
output_dim = base_output.shape[-1]
|
|
||||||
lora_output = torch.empty_like(base_output)
|
|
||||||
lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
|
|
||||||
x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
|
|
||||||
weights=self.B_buffer[0],
|
|
||||||
)
|
|
||||||
|
|
||||||
lora_output[:, output_dim : 2 * output_dim] = (
|
|
||||||
self.lora_backend.run_lora_b_sgemm(
|
|
||||||
x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
|
|
||||||
weights=self.B_buffer[1],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return base_output + lora_output * self.scaling
|
|
||||||
|
|
||||||
|
|
||||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
||||||
def init__(
|
|
||||||
self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
|
|
||||||
) -> None:
|
|
||||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
|
||||||
|
|
||||||
def set_lora_info(
|
|
||||||
self,
|
|
||||||
A_buffer_qkv,
|
|
||||||
B_buffer_q,
|
|
||||||
B_buffer_kv,
|
|
||||||
):
|
|
||||||
self.set_lora = True
|
|
||||||
self.A_buffer_qkv = A_buffer_qkv
|
|
||||||
|
|
||||||
if self.lora_backend.fuse_qkv_lora_b:
|
|
||||||
assert (
|
|
||||||
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
|
||||||
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
|
||||||
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
|
||||||
|
|
||||||
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
|
||||||
self.B_buffer_qkv = torch.cat(
|
|
||||||
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
|
|
||||||
).contiguous()
|
|
||||||
|
|
||||||
# Offsets of q/k/v in output dimension
|
|
||||||
self.output_offset = torch.tensor(
|
|
||||||
[
|
|
||||||
0,
|
|
||||||
output_dim_q,
|
|
||||||
output_dim_q + output_dim_kv,
|
|
||||||
output_dim_q + 2 * output_dim_kv,
|
|
||||||
],
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=B_buffer_q.device,
|
|
||||||
)
|
|
||||||
# For computing number of launched blocks
|
|
||||||
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
|
||||||
else:
|
|
||||||
self.B_buffer_qkv = (
|
|
||||||
B_buffer_q,
|
|
||||||
B_buffer_kv,
|
|
||||||
)
|
|
||||||
self.output_offset = None
|
|
||||||
self.max_qkv_out_dim = None
|
|
||||||
|
|
||||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
lora_output = self.lora_backend.run_qkv_lora(
|
|
||||||
x,
|
|
||||||
self.A_buffer_qkv,
|
|
||||||
self.B_buffer_qkv,
|
|
||||||
output_offset=self.output_offset,
|
|
||||||
max_qkv_out_dim=self.max_qkv_out_dim,
|
|
||||||
base_output=base_output,
|
|
||||||
scaling=self.scaling,
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
lora_output
|
|
||||||
if self.lora_backend.fuse_output_scaling_add
|
|
||||||
else base_output + lora_output * self.scaling
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
||||||
def __init__(
|
|
||||||
self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
|
|
||||||
) -> None:
|
|
||||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
|
||||||
|
|
||||||
def set_lora_info(self, A_buffer, B_buffer):
|
|
||||||
self.set_lora = True
|
|
||||||
self.A_buffer = A_buffer
|
|
||||||
self.B_buffer = B_buffer
|
|
||||||
|
|
||||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
|
||||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
|
||||||
lora_a_output,
|
|
||||||
self.B_buffer[0],
|
|
||||||
base_output=base_output,
|
|
||||||
scaling=self.scaling,
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
lora_output
|
|
||||||
if self.lora_backend.fuse_output_scaling_add
|
|
||||||
else base_output + lora_output * self.scaling
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input_):
|
|
||||||
# duplicate the logic in RowParallelLinear
|
|
||||||
if self.base_layer.input_is_parallel:
|
|
||||||
input_parallel = input_
|
|
||||||
else:
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
splitted_input = split_tensor_along_last_dim(
|
|
||||||
input_, num_partitions=self.base_layer.tp_size
|
|
||||||
)
|
|
||||||
input_parallel = splitted_input[tp_rank].contiguous()
|
|
||||||
output_parallel = self.base_layer.quant_method.apply(
|
|
||||||
self.base_layer, input_parallel
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.set_lora:
|
|
||||||
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
|
||||||
|
|
||||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
|
||||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
|
||||||
else:
|
|
||||||
output_ = output_parallel
|
|
||||||
|
|
||||||
if not self.base_layer.skip_bias_add:
|
|
||||||
output = (
|
|
||||||
output_ + self.base_layer.bias
|
|
||||||
if self.base_layer.bias is not None
|
|
||||||
else output_
|
|
||||||
)
|
|
||||||
output_bias = None
|
|
||||||
else:
|
|
||||||
output = output_
|
|
||||||
output_bias = self.base_layer.bias
|
|
||||||
return output, output_bias
|
|
||||||
|
|
||||||
|
|
||||||
def get_lora_layer(
|
|
||||||
layer: nn.Module, lora_rank, scaling, lora_backend
|
|
||||||
) -> BaseLayerWithLoRA:
|
|
||||||
supported_layer_types = {
|
|
||||||
# the order matters
|
|
||||||
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
|
||||||
QKVParallelLinear: QKVParallelLinearWithLoRA,
|
|
||||||
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
|
||||||
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
|
||||||
RowParallelLinear: RowParallelLinearWithLoRA,
|
|
||||||
}
|
|
||||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
|
||||||
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
|
||||||
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
|
|
||||||
return ret
|
|
||||||
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
|
||||||
|
|
||||||
|
|
||||||
def get_mapped_params(module_names):
|
|
||||||
ret = set()
|
|
||||||
for module_name in module_names:
|
|
||||||
ret.add(params_mapping(module_name))
|
|
||||||
return list(ret)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayer(nn.Module):
|
class LoRALayer(nn.Module):
|
||||||
def __init__(self, config, base_hf_config):
|
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config: LoRAConfig = config
|
||||||
self.base_hf_config = base_hf_config
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
self.weights = {}
|
self.weights: Dict[str, torch.Tensor] = {}
|
||||||
self.weight_gpu = {}
|
self.weight_gpu: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
def load_to_gpu(self):
|
def load_to_gpu(self):
|
||||||
for name, weight in self.weights.items():
|
for name, weight in self.weights.items():
|
||||||
@@ -306,33 +49,32 @@ class LoRALayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LoRAAdapter(nn.Module):
|
class LoRAAdapter(nn.Module):
|
||||||
def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
|
def __init__(
|
||||||
|
self,
|
||||||
|
uid: str,
|
||||||
|
config: LoRAConfig,
|
||||||
|
base_hf_config: AutoConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
lora_backend: BaseLoRABackend,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.uid = uid
|
self.uid: str = uid
|
||||||
self.config = config
|
self.config: LoRAConfig = config
|
||||||
assert self.config.hf_config["peft_type"].lower() == "lora"
|
assert self.config.hf_config["peft_type"].lower() == "lora"
|
||||||
self.base_hf_config = base_hf_config
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
self.load_config = load_config
|
self.load_config: LoadConfig = load_config
|
||||||
self.lora_backend = lora_backend
|
self.lora_backend: BaseLoRABackend = lora_backend
|
||||||
self.scaling = self.config.lora_alpha / self.config.r
|
self.scaling: float = self.config.lora_alpha / self.config.r
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers: List[LoRALayer] = nn.ModuleList(
|
||||||
[
|
[
|
||||||
LoRALayer(config, base_hf_config)
|
LoRALayer(config, base_hf_config)
|
||||||
for i in range(base_hf_config.num_hidden_layers)
|
for i in range(base_hf_config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.weights = {}
|
self.weights: Dict[str, torch.Tensor] = {}
|
||||||
self.weights_gpu = {}
|
self.weights_gpu: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
def get_stacked_multiply(self, module_name):
|
|
||||||
stacked_rank = {
|
|
||||||
"qkv_proj": 3,
|
|
||||||
"kv_proj": 2,
|
|
||||||
"gate_up_proj": 2,
|
|
||||||
}
|
|
||||||
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
|
||||||
|
|
||||||
def load_to_gpu(self):
|
def load_to_gpu(self):
|
||||||
for name, weight in self.weights.items():
|
for name, weight in self.weights.items():
|
||||||
@@ -367,44 +109,77 @@ class LoRAAdapter(nn.Module):
|
|||||||
for i in range(self.base_hf_config.num_hidden_layers):
|
for i in range(self.base_hf_config.num_hidden_layers):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
weight_names = [name for name, _ in layer.weights.items()]
|
weight_names = [name for name, _ in layer.weights.items()]
|
||||||
for weight_name in weight_names:
|
self.stack_qkv_proj(weight_names, layer.weights)
|
||||||
if "k_proj" in weight_name:
|
self.stack_gate_up_proj(weight_names, layer.weights)
|
||||||
q_name = weight_name.replace("k_proj", "q_proj")
|
|
||||||
v_name = weight_name.replace("k_proj", "v_proj")
|
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
|
||||||
kv_name = weight_name.replace("k_proj", "kv_proj")
|
|
||||||
qkv_name = weight_name.replace("k_proj", "qkv_proj")
|
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
|
||||||
if "lora_A" in weight_name:
|
target_module = set()
|
||||||
layer.weights[qkv_name] = torch.cat(
|
for weight_name in weight_names:
|
||||||
(
|
if "k_proj" in weight_name:
|
||||||
layer.weights[q_name],
|
target_module.add("k_proj")
|
||||||
layer.weights[weight_name],
|
if "q_proj" in weight_name:
|
||||||
layer.weights[v_name],
|
target_module.add("q_proj")
|
||||||
),
|
if "v_proj" in weight_name:
|
||||||
0,
|
target_module.add("v_proj")
|
||||||
)
|
if len(target_module) == 0:
|
||||||
layer.weights.pop(q_name)
|
return
|
||||||
layer.weights.pop(weight_name)
|
|
||||||
layer.weights.pop(v_name)
|
for weight_name in weight_names:
|
||||||
else:
|
# We assume every lora adaptor should contain lora modules for q_proj
|
||||||
layer.weights[kv_name] = torch.stack(
|
if "q_proj" in weight_name:
|
||||||
[
|
q_name = weight_name
|
||||||
layer.weights[weight_name],
|
k_name = weight_name.replace("q_proj", "k_proj")
|
||||||
layer.weights[v_name],
|
v_name = weight_name.replace("q_proj", "v_proj")
|
||||||
],
|
kv_name = weight_name.replace("q_proj", "kv_proj")
|
||||||
dim=0,
|
qkv_name = weight_name.replace("q_proj", "qkv_proj")
|
||||||
)
|
|
||||||
layer.weights.pop(weight_name)
|
# If k_proj doesn't have lora, initialize it to zero
|
||||||
layer.weights.pop(v_name)
|
k_proj_weight = (
|
||||||
elif "gate_proj" in weight_name:
|
weights[k_name]
|
||||||
up_name = weight_name.replace("gate_proj", "up_proj")
|
if "k_proj" in target_module
|
||||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
else torch.zeros_like(weights[v_name])
|
||||||
if "lora_A" in weight_name:
|
)
|
||||||
layer.weights[gate_up_name] = torch.cat(
|
if "lora_A" in weight_name:
|
||||||
(layer.weights[weight_name], layer.weights[up_name]), 0
|
weights[qkv_name] = torch.cat(
|
||||||
)
|
(
|
||||||
else:
|
weights[q_name],
|
||||||
layer.weights[gate_up_name] = torch.stack(
|
k_proj_weight,
|
||||||
[layer.weights[weight_name], layer.weights[up_name]], dim=0
|
weights[v_name],
|
||||||
)
|
),
|
||||||
layer.weights.pop(weight_name)
|
0,
|
||||||
layer.weights.pop(up_name)
|
)
|
||||||
|
weights.pop(q_name)
|
||||||
|
if "k_proj" in target_module:
|
||||||
|
weights.pop(k_name)
|
||||||
|
weights.pop(v_name)
|
||||||
|
else:
|
||||||
|
weights[kv_name] = torch.stack(
|
||||||
|
[
|
||||||
|
k_proj_weight,
|
||||||
|
weights[v_name],
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
if "k_proj" in target_module:
|
||||||
|
weights.pop(k_name)
|
||||||
|
weights.pop(v_name)
|
||||||
|
|
||||||
|
def stack_gate_up_proj(
|
||||||
|
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||||
|
):
|
||||||
|
for weight_name in weight_names:
|
||||||
|
if "gate_proj" in weight_name:
|
||||||
|
up_name = weight_name.replace("gate_proj", "up_proj")
|
||||||
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||||
|
if "lora_A" in weight_name:
|
||||||
|
weights[gate_up_name] = torch.cat(
|
||||||
|
(weights[weight_name], weights[up_name]), 0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weights[gate_up_name] = torch.stack(
|
||||||
|
[weights[weight_name], weights[up_name]], dim=0
|
||||||
|
)
|
||||||
|
weights.pop(weight_name)
|
||||||
|
weights.pop(up_name)
|
||||||
|
|||||||
@@ -16,307 +16,115 @@
|
|||||||
# and "Punica: Multi-Tenant LoRA Serving"
|
# and "Punica: Multi-Tenant LoRA Serving"
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||||
|
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
|
||||||
|
from sglang.srt.lora.layers import get_lora_layer
|
||||||
|
from sglang.srt.lora.lora import LoRAAdapter
|
||||||
from sglang.srt.lora.lora_config import LoRAConfig
|
from sglang.srt.lora.lora_config import LoRAConfig
|
||||||
|
from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
||||||
|
from sglang.srt.lora.utils import (
|
||||||
|
LoRABatchInfo,
|
||||||
|
LoRAType,
|
||||||
|
get_customized_names_from_hf_names,
|
||||||
|
get_layer_id,
|
||||||
|
get_stacked_name,
|
||||||
|
get_weight_name,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import is_flashinfer_available, replace_submodule
|
from sglang.srt.utils import replace_submodule
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_module_name(name):
|
|
||||||
# Fallback solution of mapping from config module name to module name in model class.
|
|
||||||
# Please check if it aligns with your base model.
|
|
||||||
# Please implement the function in the model class if it is not.
|
|
||||||
# You can reference this function in llama.py.
|
|
||||||
params_mapping = {
|
|
||||||
"q_proj": "qkv_proj",
|
|
||||||
"k_proj": "qkv_proj",
|
|
||||||
"v_proj": "qkv_proj",
|
|
||||||
"gate_proj": "gate_up_proj",
|
|
||||||
"up_proj": "gate_up_proj",
|
|
||||||
}
|
|
||||||
return params_mapping.get(name, name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_hidden_dim(module_name, config):
|
|
||||||
# Fallback solution of get_hidden_dim for different modules
|
|
||||||
# Please check if it aligns with your base model.
|
|
||||||
# Please implement the function in the model class if it is not.
|
|
||||||
# You can reference this function in llama.py.
|
|
||||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
|
||||||
return config.hidden_size, config.hidden_size
|
|
||||||
elif module_name in ["kv_proj"]:
|
|
||||||
return config.hidden_size, config.hidden_size // (
|
|
||||||
config.num_attention_heads // config.num_key_value_heads
|
|
||||||
)
|
|
||||||
elif module_name == "gate_up_proj":
|
|
||||||
return config.hidden_size, config.intermediate_size
|
|
||||||
elif module_name == "down_proj":
|
|
||||||
return config.intermediate_size, config.hidden_size
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
def get_stacked_name(name):
|
|
||||||
# origin name -> (name for A, name for B)
|
|
||||||
params_mapping = {
|
|
||||||
"q_proj": ("qkv_proj", "q_proj"),
|
|
||||||
"k_proj": ("qkv_proj", "kv_proj"),
|
|
||||||
"v_proj": ("qkv_proj", "kv_proj"),
|
|
||||||
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
|
||||||
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
|
||||||
}
|
|
||||||
return params_mapping.get(name, (name, name))
|
|
||||||
|
|
||||||
|
|
||||||
def get_backend_from_name(name):
|
|
||||||
backend_mapping = {
|
|
||||||
"triton": TritonLoraBackend,
|
|
||||||
"flashinfer": FlashInferLoraBackend,
|
|
||||||
}
|
|
||||||
|
|
||||||
if name in backend_mapping:
|
|
||||||
return backend_mapping[name]
|
|
||||||
|
|
||||||
raise Exception(
|
|
||||||
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_layer_id(name):
|
|
||||||
match = re.search(r"layers\.(\d+)\.", name)
|
|
||||||
if match is None:
|
|
||||||
return None
|
|
||||||
return int(match.group(1))
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAManager:
|
class LoRAManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_model,
|
base_model: torch.nn.Module,
|
||||||
lora_paths,
|
lora_paths: Dict[str, str],
|
||||||
base_hf_config,
|
base_hf_config: AutoConfig,
|
||||||
max_loras_per_batch,
|
max_loras_per_batch: int,
|
||||||
load_config,
|
load_config: LoadConfig,
|
||||||
dtype,
|
dtype: torch.dtype,
|
||||||
lora_backend,
|
lora_backend: str = "triton",
|
||||||
):
|
):
|
||||||
self.base_model = base_model
|
self.base_model: torch.nn.Module = base_model
|
||||||
self.lora_paths = lora_paths
|
self.lora_paths: Dict[str, str] = lora_paths
|
||||||
self.base_hf_config = base_hf_config
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
self.max_loras_per_batch = max_loras_per_batch
|
self.max_loras_per_batch: int = max_loras_per_batch
|
||||||
self.load_config = load_config
|
self.load_config: LoadConfig = load_config
|
||||||
self.dtype = dtype
|
self.dtype: torch.dtype = dtype
|
||||||
|
|
||||||
logger.info(f"Using {lora_backend} as backend of Lora kernels.")
|
# LoRA backend for running sgemm kernels
|
||||||
|
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
||||||
backend_type = get_backend_from_name(lora_backend)
|
backend_type = get_backend_from_name(lora_backend)
|
||||||
self.lora_backend = backend_type(lora_backend)
|
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
||||||
|
|
||||||
self.init_loras()
|
self.init_loras()
|
||||||
self.init_lora_memory_pool()
|
self.init_lora_memory_pool()
|
||||||
self.init_lora_batch()
|
|
||||||
|
|
||||||
def match_target_modules(self, module_name):
|
|
||||||
for target_module in self.target_modules:
|
|
||||||
if module_name.split(".")[-1] == target_module:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_target_modules(self):
|
|
||||||
modules = []
|
|
||||||
for module_name, module in self.base_model.named_modules():
|
|
||||||
if self.match_target_modules(module_name):
|
|
||||||
modules.append((module_name, module))
|
|
||||||
return modules
|
|
||||||
|
|
||||||
def set_lora_module(self, module_name, module):
|
|
||||||
lora_module = get_lora_layer(
|
|
||||||
module, self.max_lora_dim, self.scaling, self.lora_backend
|
|
||||||
)
|
|
||||||
replace_submodule(self.base_model, module_name, lora_module)
|
|
||||||
return lora_module
|
|
||||||
|
|
||||||
def init_loras(self):
|
def init_loras(self):
|
||||||
# get configs and target modules
|
# Config of each LoRA adapter
|
||||||
self.configs = {}
|
self.configs: Dict[str, LoRAConfig] = {}
|
||||||
self.origin_target_modules = set()
|
|
||||||
|
# Target module names in huggingface lora configs.
|
||||||
|
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||||
|
self.hf_target_names: Set[str] = set()
|
||||||
for name, path in self.lora_paths.items():
|
for name, path in self.lora_paths.items():
|
||||||
self.configs[name] = LoRAConfig(path)
|
self.configs[name] = LoRAConfig(path)
|
||||||
self.origin_target_modules = set(self.origin_target_modules) | set(
|
self.hf_target_names = set(self.hf_target_names) | set(
|
||||||
self.configs[name].target_modules
|
self.configs[name].target_modules
|
||||||
)
|
)
|
||||||
if hasattr(self.base_model, "get_module_name"):
|
|
||||||
self.target_modules = {
|
# Target lora weight names for lora_a and lora_b modules repectively.
|
||||||
self.base_model.get_module_name(module)
|
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
|
||||||
for module in self.origin_target_modules
|
self.lora_weight_names: Set[Tuple[str]] = set(
|
||||||
}
|
[get_stacked_name(module) for module in self.hf_target_names]
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"WARNING: get_module_name() is not defined, "
|
|
||||||
"which is used to map config module name to model implementation module name."
|
|
||||||
"Use the default one, but please check if it is correct for your model."
|
|
||||||
)
|
|
||||||
self.target_modules = {
|
|
||||||
get_module_name(module) for module in self.origin_target_modules
|
|
||||||
}
|
|
||||||
self.target_weights = set(
|
|
||||||
[get_stacked_name(module) for module in self.origin_target_modules]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# load all weights to cpu
|
# load all weights to cpu
|
||||||
self.loras = []
|
self.loras: Dict[str, LoRAAdapter] = {}
|
||||||
self.lora_id = {}
|
|
||||||
for name in self.lora_paths.keys():
|
for name in self.lora_paths.keys():
|
||||||
self.lora_id[name] = len(self.loras)
|
lora_adapter = LoRAAdapter(
|
||||||
self.loras.append(
|
name,
|
||||||
LoRAAdapter(
|
self.configs[name],
|
||||||
name,
|
self.base_hf_config,
|
||||||
self.configs[name],
|
self.load_config,
|
||||||
self.base_hf_config,
|
self.lora_backend,
|
||||||
self.load_config,
|
|
||||||
self.lora_backend,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.loras[-1].initialize_weights()
|
lora_adapter.initialize_weights()
|
||||||
|
self.loras[name] = lora_adapter
|
||||||
|
|
||||||
# misc lora configs
|
# misc lora configs
|
||||||
self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
# FIXME remove the restrictions after implementing unified paging
|
||||||
self.scaling = self.loras[0].scaling
|
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
||||||
# FIXME remove the restrictions
|
self.scaling: float = list(self.loras.values())[0].scaling
|
||||||
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
|
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
|
||||||
assert all(x.scaling == self.scaling for x in self.loras)
|
assert all(x.scaling == self.scaling for x in self.loras.values())
|
||||||
|
|
||||||
# monkey patch to use the LoRA version
|
# Convert original model layers to layers with LoRA
|
||||||
self.lora_modules = []
|
self.convert_to_lora_layers()
|
||||||
for module_name, module in self.get_target_modules():
|
|
||||||
self.lora_modules.append(
|
|
||||||
(module_name, self.set_lora_module(module_name, module))
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_lora_memory_pool(self):
|
def init_lora_memory_pool(self):
|
||||||
# preallocate lora memory pool
|
# Initialize memory pool
|
||||||
self.A_buffer = {}
|
self.memory_pool = LoRAMemoryPool(
|
||||||
self.B_buffer = {}
|
self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype
|
||||||
num_layer = self.base_hf_config.num_hidden_layers
|
)
|
||||||
for module_A, module_B in self.target_weights:
|
|
||||||
# init A tensor, column_major=True
|
|
||||||
if hasattr(self.base_model, "get_hidden_dim"):
|
|
||||||
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"WARNING: get_hidden_dim() is not defined, "
|
|
||||||
"which is used to get the hidden dim for different lora modules"
|
|
||||||
"Use the default one, but please check if it is correct for your model."
|
|
||||||
)
|
|
||||||
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
|
|
||||||
c = self.loras[-1].get_stacked_multiply(module_A)
|
|
||||||
if module_A not in self.A_buffer:
|
|
||||||
self.A_buffer[module_A] = [
|
|
||||||
torch.empty(
|
|
||||||
(
|
|
||||||
self.max_loras_per_batch,
|
|
||||||
self.max_lora_dim * c,
|
|
||||||
hidden_dim_A,
|
|
||||||
),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
for i in range(num_layer)
|
|
||||||
]
|
|
||||||
# init B tensor, column_major=True
|
|
||||||
if hasattr(self.base_model, "get_hidden_dim"):
|
|
||||||
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"WARNING: get_hidden_dim() is not defined, "
|
|
||||||
"which is used to get the hidden dim for different lora modules"
|
|
||||||
"Use the default one, but please check if it is correct for your model."
|
|
||||||
)
|
|
||||||
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
|
|
||||||
c = self.loras[-1].get_stacked_multiply(module_B)
|
|
||||||
if module_B not in self.B_buffer:
|
|
||||||
self.B_buffer[module_B] = [
|
|
||||||
torch.empty(
|
|
||||||
(
|
|
||||||
c,
|
|
||||||
self.max_loras_per_batch,
|
|
||||||
hidden_dim_B,
|
|
||||||
self.max_lora_dim,
|
|
||||||
),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
for i in range(num_layer)
|
|
||||||
]
|
|
||||||
|
|
||||||
def init_lora_batch(self):
|
# Initialize target lora modules in memory pool
|
||||||
self.active_uids = set() # set of active loras
|
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
|
||||||
self.buffer_id = {} # lora uid -> idx in memory pool
|
|
||||||
|
|
||||||
def get_weight_name(self, name, idx):
|
|
||||||
for target_weight_name in self.target_weights:
|
|
||||||
if target_weight_name[idx] in name:
|
|
||||||
return target_weight_name[idx]
|
|
||||||
|
|
||||||
def load_lora(self, uid, buffer_id):
|
|
||||||
num_layer = self.base_hf_config.num_hidden_layers
|
|
||||||
if uid is None:
|
|
||||||
for i in range(num_layer):
|
|
||||||
for k in self.A_buffer.keys():
|
|
||||||
self.A_buffer[k][i][buffer_id] *= 0
|
|
||||||
return
|
|
||||||
|
|
||||||
for i in range(num_layer):
|
|
||||||
layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
|
|
||||||
for name, weights in layer_weights.items():
|
|
||||||
if "lora_A" in name:
|
|
||||||
lora_weight_name = self.get_weight_name(name, 0)
|
|
||||||
if lora_weight_name:
|
|
||||||
self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
|
||||||
else:
|
|
||||||
lora_weight_name = self.get_weight_name(name, 1)
|
|
||||||
if lora_weight_name:
|
|
||||||
c = self.loras[-1].get_stacked_multiply(lora_weight_name)
|
|
||||||
if c > 1:
|
|
||||||
for j in range(c):
|
|
||||||
self.B_buffer[lora_weight_name][i][j][buffer_id].copy_(
|
|
||||||
weights[j]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.B_buffer[lora_weight_name][i][0][buffer_id].copy_(
|
|
||||||
weights
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||||
# load active loras into lora memory pool
|
# load active loras into lora memory pool
|
||||||
cur_uids = set(forward_batch.lora_paths)
|
cur_uids = set(forward_batch.lora_paths)
|
||||||
assert len(cur_uids) <= self.max_loras_per_batch
|
assert len(cur_uids) <= self.max_loras_per_batch
|
||||||
i = 0
|
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
||||||
j = len(self.active_uids)
|
|
||||||
evictable_uids = list(self.active_uids)
|
|
||||||
for uid in cur_uids:
|
|
||||||
if uid not in self.active_uids:
|
|
||||||
if j < self.max_loras_per_batch:
|
|
||||||
index = j
|
|
||||||
j += 1
|
|
||||||
else:
|
|
||||||
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
|
|
||||||
i += 1
|
|
||||||
assert i < len(evictable_uids)
|
|
||||||
self.active_uids.remove(evictable_uids[i])
|
|
||||||
self.buffer_id.pop(evictable_uids[i])
|
|
||||||
index = i
|
|
||||||
i += 1
|
|
||||||
self.load_lora(uid, index)
|
|
||||||
self.active_uids.add(uid)
|
|
||||||
self.buffer_id[uid] = index
|
|
||||||
|
|
||||||
|
# FIXME: Handle lora uid with None more safely
|
||||||
if cur_uids == set([None]):
|
if cur_uids == set([None]):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -332,9 +140,9 @@ class LoRAManager:
|
|||||||
max_len = int(torch.max(seg_lens))
|
max_len = int(torch.max(seg_lens))
|
||||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
||||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||||
weight_indices[i] = self.buffer_id[lora_path]
|
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
||||||
|
|
||||||
batch_info = LoraBatchInfo(
|
batch_info = LoRABatchInfo(
|
||||||
bs=bs,
|
bs=bs,
|
||||||
seg_lens=seg_lens,
|
seg_lens=seg_lens,
|
||||||
seg_indptr=seg_indptr,
|
seg_indptr=seg_indptr,
|
||||||
@@ -346,16 +154,40 @@ class LoRAManager:
|
|||||||
# call set_lora_info for each lora modules
|
# call set_lora_info for each lora modules
|
||||||
for module_name, module in self.lora_modules:
|
for module_name, module in self.lora_modules:
|
||||||
layer_id = get_layer_id(module_name)
|
layer_id = get_layer_id(module_name)
|
||||||
|
|
||||||
if "qkv_proj" not in module_name:
|
if "qkv_proj" not in module_name:
|
||||||
weight_name = self.get_weight_name(module_name, 0)
|
weight_name = get_weight_name(
|
||||||
|
module_name, self.lora_weight_names, LoRAType.LORA_A
|
||||||
|
)
|
||||||
module.set_lora_info(
|
module.set_lora_info(
|
||||||
self.A_buffer[weight_name][layer_id],
|
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
|
||||||
self.B_buffer[weight_name][layer_id],
|
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
module.set_lora_info(
|
module.set_lora_info(
|
||||||
self.A_buffer["qkv_proj"][layer_id],
|
self.memory_pool.get_tensor("qkv_proj", layer_id, LoRAType.LORA_A),
|
||||||
self.B_buffer["q_proj"][layer_id],
|
self.memory_pool.get_tensor("q_proj", layer_id, LoRAType.LORA_B),
|
||||||
self.B_buffer["kv_proj"][layer_id],
|
self.memory_pool.get_tensor("kv_proj", layer_id, LoRAType.LORA_B),
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_lora_module(self, module_name, module):
|
||||||
|
lora_module = get_lora_layer(
|
||||||
|
module, self.max_lora_dim, self.scaling, self.lora_backend
|
||||||
|
)
|
||||||
|
replace_submodule(self.base_model, module_name, lora_module)
|
||||||
|
return lora_module
|
||||||
|
|
||||||
|
def convert_to_lora_layers(self):
|
||||||
|
# Target module names of customized layers defined in python/sglang/srt/layers
|
||||||
|
# e.g., {"qkv_proj", "o_proj"}
|
||||||
|
customized_target_names = get_customized_names_from_hf_names(
|
||||||
|
self.hf_target_names, self.base_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monkey patch to use the LoRA version layers
|
||||||
|
self.lora_modules: List[Tuple[str, torch.nn.Module]] = []
|
||||||
|
for module_name, module in self.base_model.named_modules():
|
||||||
|
# The module should be converted if it is included in target_names
|
||||||
|
if module_name.split(".")[-1] in customized_target_names:
|
||||||
|
self.lora_modules.append(
|
||||||
|
(module_name, self.set_lora_module(module_name, module))
|
||||||
)
|
)
|
||||||
|
|||||||
174
python/sglang/srt/lora/mem_pool.py
Normal file
174
python/sglang/srt/lora/mem_pool.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||||
|
from sglang.srt.lora.lora import LoRAAdapter
|
||||||
|
from sglang.srt.lora.utils import (
|
||||||
|
LoRAType,
|
||||||
|
get_hidden_dim,
|
||||||
|
get_stacked_multiply,
|
||||||
|
get_weight_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAMemoryPool:
|
||||||
|
"""Class for memory pool management of lora modules"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_hf_config: AutoConfig,
|
||||||
|
max_loras_per_batch: int,
|
||||||
|
max_lora_dim: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
|
self.num_layer: int = base_hf_config.num_hidden_layers
|
||||||
|
self.max_loras_per_batch: int = max_loras_per_batch
|
||||||
|
self.max_lora_dim: int = max_lora_dim
|
||||||
|
self.dtype: torch.dtype = dtype
|
||||||
|
|
||||||
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||||
|
# A_buffer contains num_layer number of row-major tensors with shape
|
||||||
|
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
|
||||||
|
# B_buffer contains num_layer number of column-major tensors with shape
|
||||||
|
# (stacked_num, max_loras_per_batch, output_dim, max_lora_dim)
|
||||||
|
self.A_buffer: Dict[str, List[torch.Tensor]] = {}
|
||||||
|
self.B_buffer: Dict[str, List[torch.Tensor]] = {}
|
||||||
|
|
||||||
|
# Lora uid -> buffer idx in memory pool
|
||||||
|
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
|
||||||
|
|
||||||
|
# Buffer idx -> lora uid in memory pool
|
||||||
|
# All uids are initalized as empty strings for empty buffer slots
|
||||||
|
# Here we don't initalize to None since None is a valid uid
|
||||||
|
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
||||||
|
|
||||||
|
def init_buffers(
|
||||||
|
self,
|
||||||
|
lora_weight_names: Set[Tuple[str]],
|
||||||
|
base_model: torch.nn.Module,
|
||||||
|
):
|
||||||
|
|
||||||
|
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
||||||
|
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
||||||
|
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
|
||||||
|
|
||||||
|
for module_A, module_B in lora_weight_names:
|
||||||
|
# Init A tensor, column_major=False
|
||||||
|
input_dim, _ = get_hidden_dim(module_A, self.base_hf_config, base_model)
|
||||||
|
c = get_stacked_multiply(module_A)
|
||||||
|
if module_A not in self.A_buffer:
|
||||||
|
self.A_buffer[module_A] = [
|
||||||
|
torch.empty(
|
||||||
|
(
|
||||||
|
self.max_loras_per_batch,
|
||||||
|
self.max_lora_dim * c,
|
||||||
|
input_dim,
|
||||||
|
),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
for i in range(self.num_layer)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Init B tensor, column_major=True
|
||||||
|
_, output_dim = get_hidden_dim(module_B, self.base_hf_config, base_model)
|
||||||
|
c = get_stacked_multiply(module_B)
|
||||||
|
if module_B not in self.B_buffer:
|
||||||
|
self.B_buffer[module_B] = [
|
||||||
|
torch.empty(
|
||||||
|
(
|
||||||
|
c, # stacked lora_b modules might need separation
|
||||||
|
self.max_loras_per_batch,
|
||||||
|
output_dim,
|
||||||
|
self.max_lora_dim,
|
||||||
|
),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
for i in range(self.num_layer)
|
||||||
|
]
|
||||||
|
|
||||||
|
def prepare_lora_batch(
|
||||||
|
self,
|
||||||
|
cur_uids: Set[Optional[str]],
|
||||||
|
lora_adapters: Dict[str, LoRAAdapter],
|
||||||
|
):
|
||||||
|
|
||||||
|
def get_available_buffer_slot():
|
||||||
|
for buffer_id in range(self.max_loras_per_batch):
|
||||||
|
# Prioritize empty slots
|
||||||
|
if self.buffer_id_to_uid[buffer_id] == "":
|
||||||
|
return buffer_id, ""
|
||||||
|
|
||||||
|
for buffer_id in range(self.max_loras_per_batch):
|
||||||
|
# Evict unneeded lora
|
||||||
|
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
|
||||||
|
return buffer_id, self.buffer_id_to_uid[buffer_id]
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
|
||||||
|
)
|
||||||
|
|
||||||
|
for uid in cur_uids:
|
||||||
|
if uid not in self.uid_to_buffer_id:
|
||||||
|
buffer_id, evicted_lora_uid = get_available_buffer_slot()
|
||||||
|
if evicted_lora_uid != "":
|
||||||
|
self.uid_to_buffer_id.pop(evicted_lora_uid)
|
||||||
|
self.load_lora_weight_to_buffer(
|
||||||
|
uid, buffer_id, lora_adapters.get(uid, None)
|
||||||
|
)
|
||||||
|
self.uid_to_buffer_id[uid] = buffer_id
|
||||||
|
self.buffer_id_to_uid[buffer_id] = uid
|
||||||
|
|
||||||
|
def load_lora_weight_to_buffer(
|
||||||
|
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
|
||||||
|
):
|
||||||
|
|
||||||
|
if uid is None:
|
||||||
|
for i in range(self.num_layer):
|
||||||
|
for k in self.A_buffer.keys():
|
||||||
|
self.A_buffer[k][i][buffer_id] *= 0
|
||||||
|
return
|
||||||
|
|
||||||
|
assert lora_adapter is not None
|
||||||
|
for layer_id in range(self.num_layer):
|
||||||
|
layer_weights = lora_adapter.layers[layer_id].weights
|
||||||
|
for name, weights in layer_weights.items():
|
||||||
|
if "lora_A" in name:
|
||||||
|
lora_weight_name = get_weight_name(
|
||||||
|
name, self.lora_weight_names, LoRAType.LORA_A
|
||||||
|
)
|
||||||
|
if lora_weight_name:
|
||||||
|
self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_(
|
||||||
|
weights
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lora_weight_name = get_weight_name(
|
||||||
|
name, self.lora_weight_names, LoRAType.LORA_B
|
||||||
|
)
|
||||||
|
if lora_weight_name:
|
||||||
|
c = get_stacked_multiply(lora_weight_name)
|
||||||
|
if c > 1:
|
||||||
|
for stacked_id in range(c):
|
||||||
|
self.B_buffer[lora_weight_name][layer_id][stacked_id][
|
||||||
|
buffer_id
|
||||||
|
].copy_(weights[stacked_id])
|
||||||
|
else:
|
||||||
|
self.B_buffer[lora_weight_name][layer_id][0][
|
||||||
|
buffer_id
|
||||||
|
].copy_(weights)
|
||||||
|
|
||||||
|
def get_tensor(
|
||||||
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if lora_type == LoRAType.LORA_A:
|
||||||
|
return self.A_buffer[weight_name][layer_id]
|
||||||
|
|
||||||
|
return self.B_buffer[weight_name][layer_id]
|
||||||
|
|
||||||
|
def get_buffer_id(self, lora_uid: str):
|
||||||
|
return self.uid_to_buffer_id[lora_uid]
|
||||||
@@ -1,5 +1,11 @@
|
|||||||
|
from .gate_up_lora_b import gate_up_lora_b_fwd
|
||||||
from .qkv_lora_b import qkv_lora_b_fwd
|
from .qkv_lora_b import qkv_lora_b_fwd
|
||||||
from .sgemm_lora_a import sgemm_lora_a_fwd
|
from .sgemm_lora_a import sgemm_lora_a_fwd
|
||||||
from .sgemm_lora_b import sgemm_lora_b_fwd
|
from .sgemm_lora_b import sgemm_lora_b_fwd
|
||||||
|
|
||||||
__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"]
|
__all__ = [
|
||||||
|
"gate_up_lora_b_fwd",
|
||||||
|
"qkv_lora_b_fwd",
|
||||||
|
"sgemm_lora_a_fwd",
|
||||||
|
"sgemm_lora_b_fwd",
|
||||||
|
]
|
||||||
|
|||||||
170
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
Normal file
170
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _gate_up_lora_b_kernel(
|
||||||
|
# Pointers to matrices
|
||||||
|
x,
|
||||||
|
weights,
|
||||||
|
output,
|
||||||
|
# Parameters of size
|
||||||
|
K, # K = R
|
||||||
|
output_dim,
|
||||||
|
# Strides
|
||||||
|
x_stride_0,
|
||||||
|
x_stride_1,
|
||||||
|
w_stride_0,
|
||||||
|
w_stride_1,
|
||||||
|
w_stride_2,
|
||||||
|
output_stride_0,
|
||||||
|
output_stride_1,
|
||||||
|
# Information on sequence lengths and weight id
|
||||||
|
seg_lens,
|
||||||
|
seg_indptr,
|
||||||
|
weight_indices,
|
||||||
|
# Meta parameters
|
||||||
|
BLOCK_S: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_K: tl.constexpr,
|
||||||
|
# For fused output scaling and adding
|
||||||
|
fuse_scaling_add,
|
||||||
|
scaling,
|
||||||
|
):
|
||||||
|
# This kernel packs 2 sgemms (gate/up) into a single kernel.
|
||||||
|
|
||||||
|
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
|
||||||
|
# weights: (num_lora, 2 * output_dim, K)
|
||||||
|
# output: (s, 2 * output_dim)
|
||||||
|
# output_dim >> K
|
||||||
|
|
||||||
|
# Current block computes sequence with batch_id,
|
||||||
|
# which starts from row seg_start of x with length seg_len.
|
||||||
|
# gate_up_id decides which of gate or up (0: gate, 1: up)
|
||||||
|
batch_id = tl.program_id(axis=2)
|
||||||
|
gate_up_id = tl.program_id(axis=1)
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
seg_len = tl.load(seg_lens + batch_id)
|
||||||
|
w_index = tl.load(weight_indices + batch_id)
|
||||||
|
seg_start = tl.load(seg_indptr + batch_id)
|
||||||
|
n_start = gate_up_id * output_dim # offset on output dim
|
||||||
|
|
||||||
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
||||||
|
num_pid_n = tl.cdiv(output_dim, BLOCK_N)
|
||||||
|
pid_s = pid // num_pid_n
|
||||||
|
pid_n = pid % num_pid_n
|
||||||
|
|
||||||
|
# Create pointers for the first block of x and weights
|
||||||
|
# The pointers will be advanced as we move in the K direction
|
||||||
|
# and accumulate
|
||||||
|
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
|
||||||
|
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||||
|
k_offset = tl.arange(0, BLOCK_K)
|
||||||
|
|
||||||
|
x_ptrs = (x + seg_start * x_stride_0 + (gate_up_id * K) * x_stride_1) + (
|
||||||
|
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
||||||
|
)
|
||||||
|
w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
|
||||||
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Iteate to compute the block in output matrix
|
||||||
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
||||||
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||||
|
x_tile = tl.load(
|
||||||
|
x_ptrs,
|
||||||
|
mask=(s_offset[:, None] < seg_len)
|
||||||
|
and (k_offset[None, :] < K - k * BLOCK_K),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
w_tile = tl.load(
|
||||||
|
w_ptrs,
|
||||||
|
mask=(k_offset[:, None] < K - k * BLOCK_K)
|
||||||
|
and (n_offset[None, :] < output_dim),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
partial_sum += tl.dot(x_tile, w_tile)
|
||||||
|
|
||||||
|
x_ptrs += BLOCK_K * x_stride_1
|
||||||
|
w_ptrs += BLOCK_K * w_stride_2
|
||||||
|
|
||||||
|
# Store result to output matrix
|
||||||
|
partial_sum *= scaling
|
||||||
|
partial_sum = partial_sum.to(x.dtype.element_ty)
|
||||||
|
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
||||||
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||||
|
)
|
||||||
|
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
|
||||||
|
if fuse_scaling_add:
|
||||||
|
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||||
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def gate_up_lora_b_fwd(
|
||||||
|
x: torch.Tensor,
|
||||||
|
gate_up_lora_b: torch.Tensor,
|
||||||
|
batch_info: LoRABatchInfo,
|
||||||
|
output_dim: int,
|
||||||
|
base_output: torch.Tensor = None,
|
||||||
|
scaling: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# x: (s, 2 * r)
|
||||||
|
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
|
||||||
|
# output: (s, 2 * output_dim)
|
||||||
|
|
||||||
|
# Compute lora_output with shape (s, output_dim) as follows:
|
||||||
|
# lora_output[:, :output_dim] = sgemm(x[:, :r], gate_up_lora_b[:, :output_dim, :])
|
||||||
|
# lora_output[:, output_dim:]
|
||||||
|
# = sgemm(x[:, r:], gate_up_lora_b[:, output_dim:, :])
|
||||||
|
|
||||||
|
# Get dims
|
||||||
|
s = x.shape[0]
|
||||||
|
input_dim = x.shape[1]
|
||||||
|
r = gate_up_lora_b.shape[-1]
|
||||||
|
assert input_dim == 2 * r
|
||||||
|
|
||||||
|
BLOCK_S = 16
|
||||||
|
BLOCK_R = 16
|
||||||
|
BLOCK_OUT = 64
|
||||||
|
|
||||||
|
grid_b = (
|
||||||
|
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(output_dim, BLOCK_OUT),
|
||||||
|
2, # this dimension decides current block computes on gate or up proj
|
||||||
|
batch_info.bs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if base_output is None:
|
||||||
|
output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
|
||||||
|
fuse_scaling_add = False
|
||||||
|
else:
|
||||||
|
output = base_output
|
||||||
|
fuse_scaling_add = True
|
||||||
|
|
||||||
|
_gate_up_lora_b_kernel[grid_b](
|
||||||
|
x,
|
||||||
|
gate_up_lora_b,
|
||||||
|
output,
|
||||||
|
r,
|
||||||
|
output_dim,
|
||||||
|
x.stride(0),
|
||||||
|
x.stride(1),
|
||||||
|
gate_up_lora_b.stride(0),
|
||||||
|
gate_up_lora_b.stride(1),
|
||||||
|
gate_up_lora_b.stride(2),
|
||||||
|
output.stride(0),
|
||||||
|
output.stride(1),
|
||||||
|
batch_info.seg_lens,
|
||||||
|
batch_info.seg_indptr,
|
||||||
|
batch_info.weight_indices,
|
||||||
|
BLOCK_S,
|
||||||
|
BLOCK_OUT,
|
||||||
|
BLOCK_R,
|
||||||
|
fuse_scaling_add,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
@@ -2,7 +2,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.lora.lora import LoraBatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -108,7 +108,7 @@ def _qkv_lora_b_kernel(
|
|||||||
def qkv_lora_b_fwd(
|
def qkv_lora_b_fwd(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
qkv_lora_b: torch.Tensor,
|
qkv_lora_b: torch.Tensor,
|
||||||
batch_info: LoraBatchInfo,
|
batch_info: LoRABatchInfo,
|
||||||
output_offset: torch.Tensor,
|
output_offset: torch.Tensor,
|
||||||
max_qkv_out_dim: int,
|
max_qkv_out_dim: int,
|
||||||
base_output: torch.Tensor = None,
|
base_output: torch.Tensor = None,
|
||||||
@@ -123,11 +123,11 @@ def qkv_lora_b_fwd(
|
|||||||
# output: (s, output_dim_q + 2 * output_dim_kv)
|
# output: (s, output_dim_q + 2 * output_dim_kv)
|
||||||
|
|
||||||
# Compute lora_output with shape (s, output_dim) as follows:
|
# Compute lora_output with shape (s, output_dim) as follows:
|
||||||
# lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
|
# lora_output[:, :output_dim_q] = sgemm(x[:, :r], qkv_lora_b[:, :outptu_dim_q, :])
|
||||||
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
|
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
|
||||||
# = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
|
# = sgemm(x[:, r: 2 * r], qkv_lora_b[:, outptu_dim_q: output_dim_q + output_dim_kv, :])
|
||||||
# lora_output[:, output_dim_q + output_dim_kv: ]
|
# lora_output[:, output_dim_q + output_dim_kv: ]
|
||||||
# = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
|
# = sgemm(x[:, 2 * r: , qkv_lora_b[:, output_dim_q + output_dim_kv: , :])
|
||||||
|
|
||||||
# Get dims
|
# Get dims
|
||||||
s = x.shape[0]
|
s = x.shape[0]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.lora.lora import LoraBatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel(
|
|||||||
|
|
||||||
|
|
||||||
def sgemm_lora_a_fwd(
|
def sgemm_lora_a_fwd(
|
||||||
x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
|
x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# x: (s, input_dim)
|
# x: (s, input_dim)
|
||||||
# weights: (num_lora, r, input_dim)
|
# weights: (num_lora, r, input_dim)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.lora.lora import LoraBatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel(
|
|||||||
def sgemm_lora_b_fwd(
|
def sgemm_lora_b_fwd(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
weights: torch.Tensor,
|
weights: torch.Tensor,
|
||||||
batch_info: LoraBatchInfo,
|
batch_info: LoRABatchInfo,
|
||||||
base_output: torch.Tensor = None,
|
base_output: torch.Tensor = None,
|
||||||
scaling: float = 1.0,
|
scaling: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|||||||
141
python/sglang/srt/lora/utils.py
Normal file
141
python/sglang/srt/lora/utils.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRABatchInfo:
|
||||||
|
# Batch size
|
||||||
|
bs: int
|
||||||
|
|
||||||
|
# Lengths of each sequence in shape (bs,)
|
||||||
|
seg_lens: torch.Tensor
|
||||||
|
|
||||||
|
# Indice pointers of each sequence in shape (bs + 1, )
|
||||||
|
seg_indptr: torch.Tensor
|
||||||
|
|
||||||
|
# Maximum sequence length of current batch
|
||||||
|
max_len: int
|
||||||
|
|
||||||
|
# The index of lora adapter used by each sequence, in shape (bs,)
|
||||||
|
weight_indices: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAType(Enum):
|
||||||
|
LORA_A = 0
|
||||||
|
LORA_B = 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_layer_id(name: str) -> int:
|
||||||
|
"""
|
||||||
|
Extract integer id of layer from its name in string.
|
||||||
|
"""
|
||||||
|
match = re.search(r"layers\.(\d+)\.", name)
|
||||||
|
if match is None:
|
||||||
|
return None
|
||||||
|
return int(match.group(1))
|
||||||
|
|
||||||
|
|
||||||
|
def get_customized_names_from_hf_names(
|
||||||
|
hf_module_names: Set[str], base_model: torch.nn.Module
|
||||||
|
) -> Set[str]:
|
||||||
|
"""
|
||||||
|
This function takes in a set of huggingface style module names:
|
||||||
|
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||||
|
and outputs a set of module names of customized sglang layers:
|
||||||
|
e.g., {"qkv_proj", "o_proj"}
|
||||||
|
"""
|
||||||
|
if hasattr(base_model, "get_module_name"):
|
||||||
|
return {base_model.get_module_name(name) for name in hf_module_names}
|
||||||
|
else:
|
||||||
|
"""
|
||||||
|
Fallback solution of mapping from config module name to module name in model class.
|
||||||
|
Please check if it aligns with your base model.
|
||||||
|
Please implement the function in the model class if it is not.
|
||||||
|
You can reference this function in llama.py.
|
||||||
|
"""
|
||||||
|
params_mapping = {
|
||||||
|
"q_proj": "qkv_proj",
|
||||||
|
"k_proj": "qkv_proj",
|
||||||
|
"v_proj": "qkv_proj",
|
||||||
|
"gate_proj": "gate_up_proj",
|
||||||
|
"up_proj": "gate_up_proj",
|
||||||
|
}
|
||||||
|
return {params_mapping.get(name, name) for name in hf_module_names}
|
||||||
|
|
||||||
|
|
||||||
|
def get_hidden_dim(
|
||||||
|
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
||||||
|
) -> Tuple[int]:
|
||||||
|
"""
|
||||||
|
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if hasattr(base_model, "get_hidden_dim"):
|
||||||
|
return base_model.get_hidden_dim(module_name)
|
||||||
|
else:
|
||||||
|
"""
|
||||||
|
WARNING: get_hidden_dim() is not defined,
|
||||||
|
which is used to get the hidden dim for different lora modules
|
||||||
|
Use the default one, but please check if it is correct for your model.
|
||||||
|
Please implement the function in the model class if it is not.
|
||||||
|
You can reference this function in llama.py.
|
||||||
|
"""
|
||||||
|
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
||||||
|
return config.hidden_size, config.hidden_size
|
||||||
|
elif module_name in ["kv_proj"]:
|
||||||
|
return config.hidden_size, config.hidden_size // (
|
||||||
|
config.num_attention_heads // config.num_key_value_heads
|
||||||
|
)
|
||||||
|
elif module_name == "gate_up_proj":
|
||||||
|
return config.hidden_size, config.intermediate_size
|
||||||
|
elif module_name == "down_proj":
|
||||||
|
return config.intermediate_size, config.hidden_size
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def get_stacked_name(name: str) -> Tuple[str]:
|
||||||
|
"""
|
||||||
|
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
|
||||||
|
"""
|
||||||
|
params_mapping = {
|
||||||
|
"q_proj": ("qkv_proj", "q_proj"),
|
||||||
|
"k_proj": ("qkv_proj", "kv_proj"),
|
||||||
|
"v_proj": ("qkv_proj", "kv_proj"),
|
||||||
|
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
||||||
|
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
||||||
|
}
|
||||||
|
return params_mapping.get(name, (name, name))
|
||||||
|
|
||||||
|
|
||||||
|
def get_stacked_multiply(module_name: str) -> int:
|
||||||
|
"""
|
||||||
|
Mapping a lora module name to its magnification at output dimension
|
||||||
|
"""
|
||||||
|
stacked_rank = {
|
||||||
|
"qkv_proj": 3,
|
||||||
|
"kv_proj": 2,
|
||||||
|
"gate_up_proj": 2,
|
||||||
|
}
|
||||||
|
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_name(
|
||||||
|
target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
target_name is name of a given module,
|
||||||
|
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
|
||||||
|
If there is a weight name in lora_weight_names that can match target_name, return this name
|
||||||
|
Else return None
|
||||||
|
"""
|
||||||
|
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
||||||
|
for weight_name_pair in lora_weight_names:
|
||||||
|
if weight_name_pair[idx] in target_name:
|
||||||
|
return weight_name_pair[idx]
|
||||||
@@ -22,7 +22,11 @@ from sglang.test.test_utils import calculate_rouge_l
|
|||||||
|
|
||||||
LORA_SETS = [
|
LORA_SETS = [
|
||||||
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
|
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
|
||||||
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]}
|
{
|
||||||
|
"base": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"loras": ["reissbaker/llama-3.1-8b-abliterated-lora"],
|
||||||
|
"decode_tolerance": 8e-2,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
TORCH_DTYPES = [torch.float16]
|
TORCH_DTYPES = [torch.float16]
|
||||||
|
|
||||||
@@ -128,7 +132,8 @@ class TestLoRABackend(unittest.TestCase):
|
|||||||
torch.max(abs(hf_logprobs - hf_no_lora_logprobs)),
|
torch.max(abs(hf_logprobs - hf_no_lora_logprobs)),
|
||||||
)
|
)
|
||||||
if hf_logprobs.shape[0] <= 100:
|
if hf_logprobs.shape[0] <= 100:
|
||||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
|
tol = lora_set.get("prefill_tolerance", prefill_tolerance)
|
||||||
|
assert torch.all(abs(hf_logprobs - srt_logprobs) < tol), (
|
||||||
f"prefill logprobs are not all close with model_path={base_path},"
|
f"prefill logprobs are not all close with model_path={base_path},"
|
||||||
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
|
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
|
||||||
f"prefill_tolerance={prefill_tolerance}."
|
f"prefill_tolerance={prefill_tolerance}."
|
||||||
@@ -144,7 +149,8 @@ class TestLoRABackend(unittest.TestCase):
|
|||||||
"\n",
|
"\n",
|
||||||
)
|
)
|
||||||
if hf_logprobs.shape[0] <= 100:
|
if hf_logprobs.shape[0] <= 100:
|
||||||
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
|
tol = lora_set.get("decode_tolerance", decode_tolerance)
|
||||||
|
assert torch.all(abs(hf_logprobs - srt_logprobs) < tol), (
|
||||||
f"decode logprobs are not all close with model_path={base_path},"
|
f"decode logprobs are not all close with model_path={base_path},"
|
||||||
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
|
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
|
||||||
f"decode_tolerance={decode_tolerance}."
|
f"decode_tolerance={decode_tolerance}."
|
||||||
@@ -153,7 +159,7 @@ class TestLoRABackend(unittest.TestCase):
|
|||||||
|
|
||||||
# compare output strings
|
# compare output strings
|
||||||
srt_output_str = srt_outputs.output_strs[i].strip(" ")
|
srt_output_str = srt_outputs.output_strs[i].strip(" ")
|
||||||
hf_output_str = hf_outputs.output_strs[i]
|
hf_output_str = hf_outputs.output_strs[i].strip(" ")
|
||||||
print(f"srt_output_str={srt_output_str}")
|
print(f"srt_output_str={srt_output_str}")
|
||||||
print(f"hf_output_str={hf_output_str}")
|
print(f"hf_output_str={hf_output_str}")
|
||||||
rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str])
|
rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str])
|
||||||
|
|||||||
Reference in New Issue
Block a user