diff --git a/python/sglang/srt/lora/backend/__init__.py b/python/sglang/srt/lora/backend/__init__.py index ed377b4b4..07fe11d23 100644 --- a/python/sglang/srt/lora/backend/__init__.py +++ b/python/sglang/srt/lora/backend/__init__.py @@ -1,8 +1,28 @@ -from .base_backend import BaseLoraBackend -from .flashinfer_backend import FlashInferLoraBackend -from .triton_backend import TritonLoraBackend +from .base_backend import BaseLoRABackend +from .flashinfer_backend import FlashInferLoRABackend +from .triton_backend import TritonLoRABackend + + +def get_backend_from_name(name: str) -> BaseLoRABackend: + """ + Get corresponding backend class from backend's name + """ + backend_mapping = { + "triton": TritonLoRABackend, + "flashinfer": FlashInferLoRABackend, + } + + if name in backend_mapping: + return backend_mapping[name] + + raise Exception( + f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}" + ) + __all__ = [ - "FlashInferLoraBackend", - "TritonLoraBackend", + "BaseLoRABackend", + "FlashInferLoRABackend", + "TritonLoRABackend", + "get_backend_from_name", ] diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index d6c72a14e..e09f3dfd9 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import torch -from sglang.srt.lora.lora import LoraBatchInfo +from sglang.srt.lora.utils import LoRABatchInfo def get_fuse_output_scaling_add_from_name(name: str) -> bool: @@ -13,7 +13,7 @@ def get_fuse_output_scaling_add_from_name(name: str) -> bool: return mapping.get(name, False) -def get_fuse_qkv_lora_b_from_name(name: str) -> bool: +def get_fuse_stacked_lora_b_from_name(name: str) -> bool: mapping = { "triton": True, "flashinfer": False, @@ -21,7 +21,7 @@ def get_fuse_qkv_lora_b_from_name(name: str) -> bool: return mapping.get(name, False) -class BaseLoraBackend: +class BaseLoRABackend: """Base class for different Lora backends. Each backend has its own implementation of Lora kernels. @@ -32,11 +32,11 @@ class BaseLoraBackend: and the operation of scaling and adding will be fused into kernel """ - def __init__(self, name: str, batch_info: LoraBatchInfo = None): + def __init__(self, name: str, batch_info: LoRABatchInfo = None): self.name = name self.batch_info = batch_info self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name) - self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name) + self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name) def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs @@ -46,10 +46,11 @@ class BaseLoraBackend: Args: x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths - weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank + weights: a set of lora weights with shape (num_lora, c * r, input_dim), + here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj) usually input_dim is much larger than r Returns: - result with shape (s, r) + result with shape (s, c * r) """ pass @@ -83,7 +84,7 @@ class BaseLoraBackend: qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim) qkv_lora_b: lora_b module for qkv. If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r) - If passed in as a tuple of two tensors containing: + If passed in as a tuple of two tensors, it should contain: a lora_b module for q, with shape (1, num_lora, output_dim_q, r) and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r) Returns: @@ -91,5 +92,26 @@ class BaseLoraBackend: """ pass - def set_batch_info(self, batch_info: LoraBatchInfo): + def run_gate_up_lora( + self, + x: torch.Tensor, + gate_up_lora_a: torch.Tensor, + gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]], + *args, + **kwargs + ) -> torch.Tensor: + """Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer. + + Args: + x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths + gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim) + gate_up_lora_b: lora_b module for qkv. + If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r) + If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r) + Returns: + result with shape (s, 2 * output_dim) + """ + pass + + def set_batch_info(self, batch_info: LoRABatchInfo): self.batch_info = batch_info diff --git a/python/sglang/srt/lora/backend/flashinfer_backend.py b/python/sglang/srt/lora/backend/flashinfer_backend.py index 91c15be3c..9f7218312 100644 --- a/python/sglang/srt/lora/backend/flashinfer_backend.py +++ b/python/sglang/srt/lora/backend/flashinfer_backend.py @@ -2,17 +2,17 @@ from typing import Tuple import torch -from sglang.srt.lora.backend import BaseLoraBackend -from sglang.srt.lora.lora import LoraBatchInfo +from sglang.srt.lora.backend import BaseLoRABackend +from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.utils import is_flashinfer_available if is_flashinfer_available(): from flashinfer import SegmentGEMMWrapper -class FlashInferLoraBackend(BaseLoraBackend): +class FlashInferLoRABackend(BaseLoRABackend): - def __init__(self, name: str, batch_info: LoraBatchInfo = None): + def __init__(self, name: str, batch_info: LoRABatchInfo = None): super().__init__(name, batch_info) # Set up SGemm Wrapper from flashinfer @@ -55,6 +55,8 @@ class FlashInferLoraBackend(BaseLoraBackend): **kwargs, ) -> torch.Tensor: + assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2 + # Shape of lora_a_output: (s, 3 * r) lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a) @@ -89,3 +91,38 @@ class FlashInferLoraBackend(BaseLoraBackend): ) return lora_output + + def run_gate_up_lora( + self, + x: torch.Tensor, + gate_up_lora_a: torch.Tensor, + gate_up_lora_b: Tuple[torch.Tensor], + *args, + **kwargs, + ) -> torch.Tensor: + + assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2 + lora_rank = gate_up_lora_b[0].shape[-1] + output_dim = gate_up_lora_b[0].shape[-2] + + # Shape of lora_a_output: (s, 2 * r) + lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a) + + lora_output = torch.empty( + (x.shape[0], 2 * output_dim), + device=x.device, + dtype=x.dtype, + ) + + # Compute lora for gate and up proj respectively + lora_output[:, :output_dim] = self.run_lora_b_sgemm( + x=lora_a_output[:, :lora_rank].contiguous(), + weights=gate_up_lora_b[0], + ) + + lora_output[:, output_dim:] = self.run_lora_b_sgemm( + x=lora_a_output[:, lora_rank:].contiguous(), + weights=gate_up_lora_b[1], + ) + + return lora_output diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index 357040bf9..1ae9dcb2d 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -1,17 +1,18 @@ import torch -from sglang.srt.lora.backend import BaseLoraBackend -from sglang.srt.lora.lora import LoraBatchInfo +from sglang.srt.lora.backend import BaseLoRABackend from sglang.srt.lora.triton_ops import ( + gate_up_lora_b_fwd, qkv_lora_b_fwd, sgemm_lora_a_fwd, sgemm_lora_b_fwd, ) +from sglang.srt.lora.utils import LoRABatchInfo -class TritonLoraBackend(BaseLoraBackend): +class TritonLoRABackend(BaseLoRABackend): - def __init__(self, name: str, batch_info: LoraBatchInfo = None): + def __init__(self, name: str, batch_info: LoRABatchInfo = None): super().__init__(name, batch_info) def run_lora_a_sgemm( @@ -59,3 +60,32 @@ class TritonLoraBackend(BaseLoraBackend): scaling, ) return lora_output + + def run_gate_up_lora( + self, + x: torch.Tensor, + gate_up_lora_a: torch.Tensor, + gate_up_lora_b: torch.Tensor, + base_output: torch.Tensor = None, + scaling: float = 1.0, + *args, + **kwargs + ) -> torch.Tensor: + + # x: (s, input_dim) + # gate_up_lora_a: (num_lora, 2 * r, input_dim) + # gate_up_lora_b: (num_lora, 2 * output_dim, r) + assert isinstance(gate_up_lora_b, torch.Tensor) + output_dim = gate_up_lora_b.shape[-2] // 2 + + # lora_a_output: (s, 2 * r) + lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info) + lora_output = gate_up_lora_b_fwd( + lora_a_output, + gate_up_lora_b, + self.batch_info, + output_dim, + base_output, + scaling, + ) + return lora_output diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py new file mode 100644 index 000000000..558fffa90 --- /dev/null +++ b/python/sglang/srt/lora/layers.py @@ -0,0 +1,293 @@ +import torch +from torch import nn + +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.lora.backend import BaseLoRABackend + + +class BaseLayerWithLoRA(nn.Module): + def __init__( + self, + base_layer: nn.Module, + lora_rank: int, + scaling: float, + lora_backend: BaseLoRABackend, + ): + super().__init__() + self.base_layer: nn.Module = base_layer + self.lora_rank: int = lora_rank + self.scaling: float = scaling + self.set_lora: bool = False + self.lora_backend: BaseLoRABackend = lora_backend + + def forward(self, x: torch.Tensor): + return self.base_layer.forward(x) + + def set_lora_info(self, *args): + pass + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + def __init__( + self, + base_layer: VocabParallelEmbedding, + lora_rank: int, + scaling: float, + lora_backend: BaseLoRABackend, + ) -> None: + super().__init__(base_layer, lora_rank, scaling, lora_backend) + self.weight = base_layer.weight + + +class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + def __init__( + self, + base_layer: ColumnParallelLinear, + lora_rank: int, + scaling: float, + lora_backend: BaseLoRABackend, + ) -> None: + super().__init__(base_layer, lora_rank, scaling, lora_backend) + + def set_lora_info( + self, + A_buffer: torch.Tensor, + B_buffer: torch.Tensor, + ): + self.set_lora = True + self.A_buffer = A_buffer + self.B_buffer = B_buffer + + def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + backend_kwargs = {"base_output": base_output, "scaling": self.scaling} + lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) + lora_output = self.lora_backend.run_lora_b_sgemm( + lora_a_output, + self.B_buffer[0], + **backend_kwargs, + ) + return ( + lora_output + if self.lora_backend.fuse_output_scaling_add + else base_output + lora_output * self.scaling + ) + + def forward(self, input_: torch.Tensor): + # duplicate the logic in ColumnParallelLinear + bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None + output_parallel = self.base_layer.quant_method.apply( + self.base_layer, input_, bias + ) + + if self.set_lora: + output_parallel = self.apply_lora(output_parallel, input_) + + if self.base_layer.gather_output: + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None + return output, output_bias + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + def __init__( + self, + base_layer: MergedColumnParallelLinear, + lora_rank: int, + scaling: float, + lora_backend: BaseLoRABackend, + ) -> None: + super().__init__(base_layer, lora_rank, scaling, lora_backend) + + def set_lora_info( + self, + A_buffer: torch.Tensor, + B_buffer: torch.Tensor, + ): + self.set_lora = True + self.A_buffer_gate_up = A_buffer + if self.lora_backend.fuse_stacked_lora_b: + # B_buffer_gate_up: (num_lora, 2 * output_dim, r) + self.B_buffer_gate_up = torch.cat( + (B_buffer[0], B_buffer[1]), dim=-2 + ).contiguous() + else: + self.B_buffer_gate_up = (B_buffer[0], B_buffer[1]) + + def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + backend_kwargs = {"base_output": base_output, "scaling": self.scaling} + + lora_output = self.lora_backend.run_gate_up_lora( + x, + self.A_buffer_gate_up, + self.B_buffer_gate_up, + **backend_kwargs, + ) + return ( + lora_output + if self.lora_backend.fuse_output_scaling_add + else base_output + lora_output * self.scaling + ) + + +class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + def init__( + self, + base_layer: QKVParallelLinear, + lora_rank: int, + scaling: float, + lora_backend: BaseLoRABackend, + ) -> None: + super().__init__(base_layer, lora_rank, scaling, lora_backend) + + def set_lora_info( + self, + A_buffer_qkv: torch.Tensor, + B_buffer_q: torch.Tensor, + B_buffer_kv: torch.Tensor, + ): + self.set_lora = True + self.A_buffer_qkv = A_buffer_qkv + + if self.lora_backend.fuse_stacked_lora_b: + assert ( + B_buffer_q.shape[-1] == B_buffer_kv.shape[-1] + ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b" + output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2] + + # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r) + self.B_buffer_qkv = torch.cat( + (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2 + ).contiguous() + + # Offsets of q/k/v in output dimension + self.output_offset = torch.tensor( + [ + 0, + output_dim_q, + output_dim_q + output_dim_kv, + output_dim_q + 2 * output_dim_kv, + ], + dtype=torch.int32, + device=B_buffer_q.device, + ) + # For computing number of launched blocks + self.max_qkv_out_dim = max(output_dim_q, output_dim_kv) + else: + self.B_buffer_qkv = ( + B_buffer_q, + B_buffer_kv, + ) + + def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + backend_kwargs = {"base_output": base_output, "scaling": self.scaling} + if self.lora_backend.fuse_stacked_lora_b: + backend_kwargs["output_offset"] = self.output_offset + backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim + + lora_output = self.lora_backend.run_qkv_lora( + x, + self.A_buffer_qkv, + self.B_buffer_qkv, + **backend_kwargs, + ) + return ( + lora_output + if self.lora_backend.fuse_output_scaling_add + else base_output + lora_output * self.scaling + ) + + +class RowParallelLinearWithLoRA(BaseLayerWithLoRA): + def __init__( + self, + base_layer: RowParallelLinear, + lora_rank: int, + scaling: float, + lora_backend: BaseLoRABackend, + ) -> None: + super().__init__(base_layer, lora_rank, scaling, lora_backend) + + def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor): + self.set_lora = True + self.A_buffer = A_buffer + self.B_buffer = B_buffer + + def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + backend_kwargs = {"base_output": base_output, "scaling": self.scaling} + lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) + lora_output = self.lora_backend.run_lora_b_sgemm( + lora_a_output, + self.B_buffer[0], + **backend_kwargs, + ) + return ( + lora_output + if self.lora_backend.fuse_output_scaling_add + else base_output + lora_output * self.scaling + ) + + def forward(self, input_: torch.Tensor): + # duplicate the logic in RowParallelLinear + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size + ) + input_parallel = splitted_input[tp_rank].contiguous() + output_parallel = self.base_layer.quant_method.apply( + self.base_layer, input_parallel + ) + + if self.set_lora: + output_parallel = self.apply_lora(output_parallel, input_parallel) + + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = ( + output_ + self.base_layer.bias + if self.base_layer.bias is not None + else output_ + ) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + +def get_lora_layer( + layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend +) -> BaseLayerWithLoRA: + supported_layer_types = { + # the order matters + VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, + QKVParallelLinear: QKVParallelLinearWithLoRA, + MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, + ColumnParallelLinear: ColumnParallelLinearWithLoRA, + RowParallelLinear: RowParallelLinearWithLoRA, + } + for src_layer_type, lora_layer_type in supported_layer_types.items(): + if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck + ret = lora_layer_type(layer, lora_rank, scaling, lora_backend) + return ret + raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 9de3b9236..643946f43 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -19,282 +19,25 @@ # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py import re -from dataclasses import dataclass +from typing import Dict, List import torch from torch import nn -from sglang.srt.layers.linear import ( - ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.hf_transformers_utils import AutoConfig +from sglang.srt.lora.backend import BaseLoRABackend +from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.model_loader.loader import DefaultModelLoader -@dataclass -class LoraBatchInfo: - # Batch size - bs: int - - # Lengths of each sequence in shape (bs,) - seg_lens: torch.Tensor - - # Indice pointers of each sequence in shape (bs + 1, ) - seg_indptr: torch.Tensor - - # Maximum sequence length of current batch - max_len: int - - # The index of lora adapter used by each sequence, in shape (bs,) - weight_indices: torch.Tensor - - -class BaseLayerWithLoRA(nn.Module): - def __init__(self, base_layer, lora_rank, scaling, lora_backend): - super().__init__() - self.base_layer = base_layer - self.lora_rank = lora_rank - self.scaling = scaling - self.set_lora = False - self.lora_backend = lora_backend - - def forward(self, x: torch.Tensor): - return self.base_layer.forward(x) - - def set_lora_info(self, *args): - pass - - -class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): - def __init__( - self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend - ) -> None: - super().__init__(base_layer, lora_rank, scaling, lora_backend) - self.weight = base_layer.weight - - -class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): - def __init__( - self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend - ) -> None: - super().__init__(base_layer, lora_rank, scaling, lora_backend) - - def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - # TODO - return output - - def forward(self, input_: torch.Tensor): - # duplicate the logic in ColumnParallelLinear - bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None - output_parallel = self.base_layer.quant_method.apply( - self.base_layer, input_, bias - ) - - if self.set_lora: - output_parallel = self.apply_lora(output_parallel, input_) - - if self.base_layer.gather_output: - output = tensor_model_parallel_all_gather(output_parallel) - else: - output = output_parallel - output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None - return output, output_bias - - -class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - def __init__( - self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend - ) -> None: - super().__init__(base_layer, lora_rank, scaling, lora_backend) - - def set_lora_info( - self, - A_buffer, - B_buffer, - ): - self.set_lora = True - self.A_buffer = A_buffer - self.B_buffer = B_buffer - - def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer) - - output_dim = base_output.shape[-1] - lora_output = torch.empty_like(base_output) - lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm( - x=lora_a_output[:, 0 : self.lora_rank].contiguous(), - weights=self.B_buffer[0], - ) - - lora_output[:, output_dim : 2 * output_dim] = ( - self.lora_backend.run_lora_b_sgemm( - x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(), - weights=self.B_buffer[1], - ) - ) - - return base_output + lora_output * self.scaling - - -class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - def init__( - self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend - ) -> None: - super().__init__(base_layer, lora_rank, scaling, lora_backend) - - def set_lora_info( - self, - A_buffer_qkv, - B_buffer_q, - B_buffer_kv, - ): - self.set_lora = True - self.A_buffer_qkv = A_buffer_qkv - - if self.lora_backend.fuse_qkv_lora_b: - assert ( - B_buffer_q.shape[-1] == B_buffer_kv.shape[-1] - ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b" - output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2] - - # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r) - self.B_buffer_qkv = torch.cat( - (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2 - ).contiguous() - - # Offsets of q/k/v in output dimension - self.output_offset = torch.tensor( - [ - 0, - output_dim_q, - output_dim_q + output_dim_kv, - output_dim_q + 2 * output_dim_kv, - ], - dtype=torch.int32, - device=B_buffer_q.device, - ) - # For computing number of launched blocks - self.max_qkv_out_dim = max(output_dim_q, output_dim_kv) - else: - self.B_buffer_qkv = ( - B_buffer_q, - B_buffer_kv, - ) - self.output_offset = None - self.max_qkv_out_dim = None - - def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_output = self.lora_backend.run_qkv_lora( - x, - self.A_buffer_qkv, - self.B_buffer_qkv, - output_offset=self.output_offset, - max_qkv_out_dim=self.max_qkv_out_dim, - base_output=base_output, - scaling=self.scaling, - ) - return ( - lora_output - if self.lora_backend.fuse_output_scaling_add - else base_output + lora_output * self.scaling - ) - - -class RowParallelLinearWithLoRA(BaseLayerWithLoRA): - def __init__( - self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend - ) -> None: - super().__init__(base_layer, lora_rank, scaling, lora_backend) - - def set_lora_info(self, A_buffer, B_buffer): - self.set_lora = True - self.A_buffer = A_buffer - self.B_buffer = B_buffer - - def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) - lora_output = self.lora_backend.run_lora_b_sgemm( - lora_a_output, - self.B_buffer[0], - base_output=base_output, - scaling=self.scaling, - ) - return ( - lora_output - if self.lora_backend.fuse_output_scaling_add - else base_output + lora_output * self.scaling - ) - - def forward(self, input_): - # duplicate the logic in RowParallelLinear - if self.base_layer.input_is_parallel: - input_parallel = input_ - else: - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.base_layer.tp_size - ) - input_parallel = splitted_input[tp_rank].contiguous() - output_parallel = self.base_layer.quant_method.apply( - self.base_layer, input_parallel - ) - - if self.set_lora: - output_parallel = self.apply_lora(output_parallel, input_parallel) - - if self.base_layer.reduce_results and self.base_layer.tp_size > 1: - output_ = tensor_model_parallel_all_reduce(output_parallel) - else: - output_ = output_parallel - - if not self.base_layer.skip_bias_add: - output = ( - output_ + self.base_layer.bias - if self.base_layer.bias is not None - else output_ - ) - output_bias = None - else: - output = output_ - output_bias = self.base_layer.bias - return output, output_bias - - -def get_lora_layer( - layer: nn.Module, lora_rank, scaling, lora_backend -) -> BaseLayerWithLoRA: - supported_layer_types = { - # the order matters - VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, - QKVParallelLinear: QKVParallelLinearWithLoRA, - MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, - ColumnParallelLinear: ColumnParallelLinearWithLoRA, - RowParallelLinear: RowParallelLinearWithLoRA, - } - for src_layer_type, lora_layer_type in supported_layer_types.items(): - if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck - ret = lora_layer_type(layer, lora_rank, scaling, lora_backend) - return ret - raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") - - -def get_mapped_params(module_names): - ret = set() - for module_name in module_names: - ret.add(params_mapping(module_name)) - return list(ret) - - class LoRALayer(nn.Module): - def __init__(self, config, base_hf_config): + def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig): super().__init__() - self.config = config - self.base_hf_config = base_hf_config - self.weights = {} - self.weight_gpu = {} + self.config: LoRAConfig = config + self.base_hf_config: AutoConfig = base_hf_config + self.weights: Dict[str, torch.Tensor] = {} + self.weight_gpu: Dict[str, torch.Tensor] = {} def load_to_gpu(self): for name, weight in self.weights.items(): @@ -306,33 +49,32 @@ class LoRALayer(nn.Module): class LoRAAdapter(nn.Module): - def __init__(self, uid, config, base_hf_config, load_config, lora_backend): + def __init__( + self, + uid: str, + config: LoRAConfig, + base_hf_config: AutoConfig, + load_config: LoadConfig, + lora_backend: BaseLoRABackend, + ): super().__init__() - self.uid = uid - self.config = config + self.uid: str = uid + self.config: LoRAConfig = config assert self.config.hf_config["peft_type"].lower() == "lora" - self.base_hf_config = base_hf_config - self.load_config = load_config - self.lora_backend = lora_backend - self.scaling = self.config.lora_alpha / self.config.r + self.base_hf_config: AutoConfig = base_hf_config + self.load_config: LoadConfig = load_config + self.lora_backend: BaseLoRABackend = lora_backend + self.scaling: float = self.config.lora_alpha / self.config.r - self.layers = nn.ModuleList( + self.layers: List[LoRALayer] = nn.ModuleList( [ LoRALayer(config, base_hf_config) for i in range(base_hf_config.num_hidden_layers) ] ) - self.weights = {} - self.weights_gpu = {} - - def get_stacked_multiply(self, module_name): - stacked_rank = { - "qkv_proj": 3, - "kv_proj": 2, - "gate_up_proj": 2, - } - return stacked_rank[module_name] if module_name in stacked_rank else 1 + self.weights: Dict[str, torch.Tensor] = {} + self.weights_gpu: Dict[str, torch.Tensor] = {} def load_to_gpu(self): for name, weight in self.weights.items(): @@ -367,44 +109,77 @@ class LoRAAdapter(nn.Module): for i in range(self.base_hf_config.num_hidden_layers): layer = self.layers[i] weight_names = [name for name, _ in layer.weights.items()] - for weight_name in weight_names: - if "k_proj" in weight_name: - q_name = weight_name.replace("k_proj", "q_proj") - v_name = weight_name.replace("k_proj", "v_proj") - kv_name = weight_name.replace("k_proj", "kv_proj") - qkv_name = weight_name.replace("k_proj", "qkv_proj") - if "lora_A" in weight_name: - layer.weights[qkv_name] = torch.cat( - ( - layer.weights[q_name], - layer.weights[weight_name], - layer.weights[v_name], - ), - 0, - ) - layer.weights.pop(q_name) - layer.weights.pop(weight_name) - layer.weights.pop(v_name) - else: - layer.weights[kv_name] = torch.stack( - [ - layer.weights[weight_name], - layer.weights[v_name], - ], - dim=0, - ) - layer.weights.pop(weight_name) - layer.weights.pop(v_name) - elif "gate_proj" in weight_name: - up_name = weight_name.replace("gate_proj", "up_proj") - gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") - if "lora_A" in weight_name: - layer.weights[gate_up_name] = torch.cat( - (layer.weights[weight_name], layer.weights[up_name]), 0 - ) - else: - layer.weights[gate_up_name] = torch.stack( - [layer.weights[weight_name], layer.weights[up_name]], dim=0 - ) - layer.weights.pop(weight_name) - layer.weights.pop(up_name) + self.stack_qkv_proj(weight_names, layer.weights) + self.stack_gate_up_proj(weight_names, layer.weights) + + def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]): + + # Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj + target_module = set() + for weight_name in weight_names: + if "k_proj" in weight_name: + target_module.add("k_proj") + if "q_proj" in weight_name: + target_module.add("q_proj") + if "v_proj" in weight_name: + target_module.add("v_proj") + if len(target_module) == 0: + return + + for weight_name in weight_names: + # We assume every lora adaptor should contain lora modules for q_proj + if "q_proj" in weight_name: + q_name = weight_name + k_name = weight_name.replace("q_proj", "k_proj") + v_name = weight_name.replace("q_proj", "v_proj") + kv_name = weight_name.replace("q_proj", "kv_proj") + qkv_name = weight_name.replace("q_proj", "qkv_proj") + + # If k_proj doesn't have lora, initialize it to zero + k_proj_weight = ( + weights[k_name] + if "k_proj" in target_module + else torch.zeros_like(weights[v_name]) + ) + if "lora_A" in weight_name: + weights[qkv_name] = torch.cat( + ( + weights[q_name], + k_proj_weight, + weights[v_name], + ), + 0, + ) + weights.pop(q_name) + if "k_proj" in target_module: + weights.pop(k_name) + weights.pop(v_name) + else: + weights[kv_name] = torch.stack( + [ + k_proj_weight, + weights[v_name], + ], + dim=0, + ) + if "k_proj" in target_module: + weights.pop(k_name) + weights.pop(v_name) + + def stack_gate_up_proj( + self, weight_names: List[str], weights: Dict[str, torch.Tensor] + ): + for weight_name in weight_names: + if "gate_proj" in weight_name: + up_name = weight_name.replace("gate_proj", "up_proj") + gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") + if "lora_A" in weight_name: + weights[gate_up_name] = torch.cat( + (weights[weight_name], weights[up_name]), 0 + ) + else: + weights[gate_up_name] = torch.stack( + [weights[weight_name], weights[up_name]], dim=0 + ) + weights.pop(weight_name) + weights.pop(up_name) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 404f3f507..be9df347e 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -16,307 +16,115 @@ # and "Punica: Multi-Tenant LoRA Serving" import logging -import re +from typing import Dict, List, Set, Tuple import torch -from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend -from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer +from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.hf_transformers_utils import AutoConfig +from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name +from sglang.srt.lora.layers import get_lora_layer +from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora_config import LoRAConfig +from sglang.srt.lora.mem_pool import LoRAMemoryPool +from sglang.srt.lora.utils import ( + LoRABatchInfo, + LoRAType, + get_customized_names_from_hf_names, + get_layer_id, + get_stacked_name, + get_weight_name, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import is_flashinfer_available, replace_submodule +from sglang.srt.utils import replace_submodule logger = logging.getLogger(__name__) -def get_module_name(name): - # Fallback solution of mapping from config module name to module name in model class. - # Please check if it aligns with your base model. - # Please implement the function in the model class if it is not. - # You can reference this function in llama.py. - params_mapping = { - "q_proj": "qkv_proj", - "k_proj": "qkv_proj", - "v_proj": "qkv_proj", - "gate_proj": "gate_up_proj", - "up_proj": "gate_up_proj", - } - return params_mapping.get(name, name) - - -def get_hidden_dim(module_name, config): - # Fallback solution of get_hidden_dim for different modules - # Please check if it aligns with your base model. - # Please implement the function in the model class if it is not. - # You can reference this function in llama.py. - if module_name in ["q_proj", "o_proj", "qkv_proj"]: - return config.hidden_size, config.hidden_size - elif module_name in ["kv_proj"]: - return config.hidden_size, config.hidden_size // ( - config.num_attention_heads // config.num_key_value_heads - ) - elif module_name == "gate_up_proj": - return config.hidden_size, config.intermediate_size - elif module_name == "down_proj": - return config.intermediate_size, config.hidden_size - else: - raise NotImplementedError() - - -def get_stacked_name(name): - # origin name -> (name for A, name for B) - params_mapping = { - "q_proj": ("qkv_proj", "q_proj"), - "k_proj": ("qkv_proj", "kv_proj"), - "v_proj": ("qkv_proj", "kv_proj"), - "gate_proj": ("gate_up_proj", "gate_up_proj"), - "up_proj": ("gate_up_proj", "gate_up_proj"), - } - return params_mapping.get(name, (name, name)) - - -def get_backend_from_name(name): - backend_mapping = { - "triton": TritonLoraBackend, - "flashinfer": FlashInferLoraBackend, - } - - if name in backend_mapping: - return backend_mapping[name] - - raise Exception( - f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}" - ) - - -def get_layer_id(name): - match = re.search(r"layers\.(\d+)\.", name) - if match is None: - return None - return int(match.group(1)) - - class LoRAManager: def __init__( self, - base_model, - lora_paths, - base_hf_config, - max_loras_per_batch, - load_config, - dtype, - lora_backend, + base_model: torch.nn.Module, + lora_paths: Dict[str, str], + base_hf_config: AutoConfig, + max_loras_per_batch: int, + load_config: LoadConfig, + dtype: torch.dtype, + lora_backend: str = "triton", ): - self.base_model = base_model - self.lora_paths = lora_paths - self.base_hf_config = base_hf_config - self.max_loras_per_batch = max_loras_per_batch - self.load_config = load_config - self.dtype = dtype + self.base_model: torch.nn.Module = base_model + self.lora_paths: Dict[str, str] = lora_paths + self.base_hf_config: AutoConfig = base_hf_config + self.max_loras_per_batch: int = max_loras_per_batch + self.load_config: LoadConfig = load_config + self.dtype: torch.dtype = dtype - logger.info(f"Using {lora_backend} as backend of Lora kernels.") + # LoRA backend for running sgemm kernels + logger.info(f"Using {lora_backend} as backend of LoRA kernels.") backend_type = get_backend_from_name(lora_backend) - self.lora_backend = backend_type(lora_backend) + self.lora_backend: BaseLoRABackend = backend_type(lora_backend) self.init_loras() self.init_lora_memory_pool() - self.init_lora_batch() - - def match_target_modules(self, module_name): - for target_module in self.target_modules: - if module_name.split(".")[-1] == target_module: - return True - return False - - def get_target_modules(self): - modules = [] - for module_name, module in self.base_model.named_modules(): - if self.match_target_modules(module_name): - modules.append((module_name, module)) - return modules - - def set_lora_module(self, module_name, module): - lora_module = get_lora_layer( - module, self.max_lora_dim, self.scaling, self.lora_backend - ) - replace_submodule(self.base_model, module_name, lora_module) - return lora_module def init_loras(self): - # get configs and target modules - self.configs = {} - self.origin_target_modules = set() + # Config of each LoRA adapter + self.configs: Dict[str, LoRAConfig] = {} + + # Target module names in huggingface lora configs. + # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"} + self.hf_target_names: Set[str] = set() for name, path in self.lora_paths.items(): self.configs[name] = LoRAConfig(path) - self.origin_target_modules = set(self.origin_target_modules) | set( + self.hf_target_names = set(self.hf_target_names) | set( self.configs[name].target_modules ) - if hasattr(self.base_model, "get_module_name"): - self.target_modules = { - self.base_model.get_module_name(module) - for module in self.origin_target_modules - } - else: - logger.warning( - "WARNING: get_module_name() is not defined, " - "which is used to map config module name to model implementation module name." - "Use the default one, but please check if it is correct for your model." - ) - self.target_modules = { - get_module_name(module) for module in self.origin_target_modules - } - self.target_weights = set( - [get_stacked_name(module) for module in self.origin_target_modules] + + # Target lora weight names for lora_a and lora_b modules repectively. + # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")} + self.lora_weight_names: Set[Tuple[str]] = set( + [get_stacked_name(module) for module in self.hf_target_names] ) # load all weights to cpu - self.loras = [] - self.lora_id = {} + self.loras: Dict[str, LoRAAdapter] = {} for name in self.lora_paths.keys(): - self.lora_id[name] = len(self.loras) - self.loras.append( - LoRAAdapter( - name, - self.configs[name], - self.base_hf_config, - self.load_config, - self.lora_backend, - ) + lora_adapter = LoRAAdapter( + name, + self.configs[name], + self.base_hf_config, + self.load_config, + self.lora_backend, ) - self.loras[-1].initialize_weights() + lora_adapter.initialize_weights() + self.loras[name] = lora_adapter # misc lora configs - self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()]) - self.scaling = self.loras[0].scaling - # FIXME remove the restrictions + # FIXME remove the restrictions after implementing unified paging + self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()]) + self.scaling: float = list(self.loras.values())[0].scaling assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values()) - assert all(x.scaling == self.scaling for x in self.loras) + assert all(x.scaling == self.scaling for x in self.loras.values()) - # monkey patch to use the LoRA version - self.lora_modules = [] - for module_name, module in self.get_target_modules(): - self.lora_modules.append( - (module_name, self.set_lora_module(module_name, module)) - ) + # Convert original model layers to layers with LoRA + self.convert_to_lora_layers() def init_lora_memory_pool(self): - # preallocate lora memory pool - self.A_buffer = {} - self.B_buffer = {} - num_layer = self.base_hf_config.num_hidden_layers - for module_A, module_B in self.target_weights: - # init A tensor, column_major=True - if hasattr(self.base_model, "get_hidden_dim"): - hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A) - else: - logger.warning( - "WARNING: get_hidden_dim() is not defined, " - "which is used to get the hidden dim for different lora modules" - "Use the default one, but please check if it is correct for your model." - ) - hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config) - c = self.loras[-1].get_stacked_multiply(module_A) - if module_A not in self.A_buffer: - self.A_buffer[module_A] = [ - torch.empty( - ( - self.max_loras_per_batch, - self.max_lora_dim * c, - hidden_dim_A, - ), - dtype=self.dtype, - device="cuda", - ) - for i in range(num_layer) - ] - # init B tensor, column_major=True - if hasattr(self.base_model, "get_hidden_dim"): - _, hidden_dim_B = self.base_model.get_hidden_dim(module_B) - else: - logger.warning( - "WARNING: get_hidden_dim() is not defined, " - "which is used to get the hidden dim for different lora modules" - "Use the default one, but please check if it is correct for your model." - ) - _, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config) - c = self.loras[-1].get_stacked_multiply(module_B) - if module_B not in self.B_buffer: - self.B_buffer[module_B] = [ - torch.empty( - ( - c, - self.max_loras_per_batch, - hidden_dim_B, - self.max_lora_dim, - ), - dtype=self.dtype, - device="cuda", - ) - for i in range(num_layer) - ] + # Initialize memory pool + self.memory_pool = LoRAMemoryPool( + self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype + ) - def init_lora_batch(self): - self.active_uids = set() # set of active loras - self.buffer_id = {} # lora uid -> idx in memory pool - - def get_weight_name(self, name, idx): - for target_weight_name in self.target_weights: - if target_weight_name[idx] in name: - return target_weight_name[idx] - - def load_lora(self, uid, buffer_id): - num_layer = self.base_hf_config.num_hidden_layers - if uid is None: - for i in range(num_layer): - for k in self.A_buffer.keys(): - self.A_buffer[k][i][buffer_id] *= 0 - return - - for i in range(num_layer): - layer_weights = self.loras[self.lora_id[uid]].layers[i].weights - for name, weights in layer_weights.items(): - if "lora_A" in name: - lora_weight_name = self.get_weight_name(name, 0) - if lora_weight_name: - self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights) - else: - lora_weight_name = self.get_weight_name(name, 1) - if lora_weight_name: - c = self.loras[-1].get_stacked_multiply(lora_weight_name) - if c > 1: - for j in range(c): - self.B_buffer[lora_weight_name][i][j][buffer_id].copy_( - weights[j] - ) - else: - self.B_buffer[lora_weight_name][i][0][buffer_id].copy_( - weights - ) + # Initialize target lora modules in memory pool + self.memory_pool.init_buffers(self.lora_weight_names, self.base_model) def prepare_lora_batch(self, forward_batch: ForwardBatch): # load active loras into lora memory pool cur_uids = set(forward_batch.lora_paths) assert len(cur_uids) <= self.max_loras_per_batch - i = 0 - j = len(self.active_uids) - evictable_uids = list(self.active_uids) - for uid in cur_uids: - if uid not in self.active_uids: - if j < self.max_loras_per_batch: - index = j - j += 1 - else: - while i < len(evictable_uids) and evictable_uids[i] in cur_uids: - i += 1 - assert i < len(evictable_uids) - self.active_uids.remove(evictable_uids[i]) - self.buffer_id.pop(evictable_uids[i]) - index = i - i += 1 - self.load_lora(uid, index) - self.active_uids.add(uid) - self.buffer_id[uid] = index + self.memory_pool.prepare_lora_batch(cur_uids, self.loras) + # FIXME: Handle lora uid with None more safely if cur_uids == set([None]): return @@ -332,9 +140,9 @@ class LoRAManager: max_len = int(torch.max(seg_lens)) weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") for i, lora_path in enumerate(forward_batch.lora_paths): - weight_indices[i] = self.buffer_id[lora_path] + weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) - batch_info = LoraBatchInfo( + batch_info = LoRABatchInfo( bs=bs, seg_lens=seg_lens, seg_indptr=seg_indptr, @@ -346,16 +154,40 @@ class LoRAManager: # call set_lora_info for each lora modules for module_name, module in self.lora_modules: layer_id = get_layer_id(module_name) - if "qkv_proj" not in module_name: - weight_name = self.get_weight_name(module_name, 0) + weight_name = get_weight_name( + module_name, self.lora_weight_names, LoRAType.LORA_A + ) module.set_lora_info( - self.A_buffer[weight_name][layer_id], - self.B_buffer[weight_name][layer_id], + self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A), + self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B), ) else: module.set_lora_info( - self.A_buffer["qkv_proj"][layer_id], - self.B_buffer["q_proj"][layer_id], - self.B_buffer["kv_proj"][layer_id], + self.memory_pool.get_tensor("qkv_proj", layer_id, LoRAType.LORA_A), + self.memory_pool.get_tensor("q_proj", layer_id, LoRAType.LORA_B), + self.memory_pool.get_tensor("kv_proj", layer_id, LoRAType.LORA_B), + ) + + def set_lora_module(self, module_name, module): + lora_module = get_lora_layer( + module, self.max_lora_dim, self.scaling, self.lora_backend + ) + replace_submodule(self.base_model, module_name, lora_module) + return lora_module + + def convert_to_lora_layers(self): + # Target module names of customized layers defined in python/sglang/srt/layers + # e.g., {"qkv_proj", "o_proj"} + customized_target_names = get_customized_names_from_hf_names( + self.hf_target_names, self.base_model + ) + + # Monkey patch to use the LoRA version layers + self.lora_modules: List[Tuple[str, torch.nn.Module]] = [] + for module_name, module in self.base_model.named_modules(): + # The module should be converted if it is included in target_names + if module_name.split(".")[-1] in customized_target_names: + self.lora_modules.append( + (module_name, self.set_lora_module(module_name, module)) ) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py new file mode 100644 index 000000000..a1a6e549d --- /dev/null +++ b/python/sglang/srt/lora/mem_pool.py @@ -0,0 +1,174 @@ +from typing import Dict, List, Optional, Set, Tuple + +import torch + +from sglang.srt.hf_transformers_utils import AutoConfig +from sglang.srt.lora.lora import LoRAAdapter +from sglang.srt.lora.utils import ( + LoRAType, + get_hidden_dim, + get_stacked_multiply, + get_weight_name, +) + + +class LoRAMemoryPool: + """Class for memory pool management of lora modules""" + + def __init__( + self, + base_hf_config: AutoConfig, + max_loras_per_batch: int, + max_lora_dim: int, + dtype: torch.dtype, + ): + + self.base_hf_config: AutoConfig = base_hf_config + self.num_layer: int = base_hf_config.num_hidden_layers + self.max_loras_per_batch: int = max_loras_per_batch + self.max_lora_dim: int = max_lora_dim + self.dtype: torch.dtype = dtype + + # Both A_buffer and B_buffer maps lora weight names to its buffer space. + # A_buffer contains num_layer number of row-major tensors with shape + # (max_loras_per_batch, stacked_num * max_lora_dim, input_dim) + # B_buffer contains num_layer number of column-major tensors with shape + # (stacked_num, max_loras_per_batch, output_dim, max_lora_dim) + self.A_buffer: Dict[str, List[torch.Tensor]] = {} + self.B_buffer: Dict[str, List[torch.Tensor]] = {} + + # Lora uid -> buffer idx in memory pool + self.uid_to_buffer_id: Dict[Optional[str], int] = {} + + # Buffer idx -> lora uid in memory pool + # All uids are initalized as empty strings for empty buffer slots + # Here we don't initalize to None since None is a valid uid + self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch + + def init_buffers( + self, + lora_weight_names: Set[Tuple[str]], + base_model: torch.nn.Module, + ): + + # lora_weight_names is a set of name pairs indicating each pair of lora modules to load + # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")} + self.lora_weight_names: Set[Tuple[str]] = lora_weight_names + + for module_A, module_B in lora_weight_names: + # Init A tensor, column_major=False + input_dim, _ = get_hidden_dim(module_A, self.base_hf_config, base_model) + c = get_stacked_multiply(module_A) + if module_A not in self.A_buffer: + self.A_buffer[module_A] = [ + torch.empty( + ( + self.max_loras_per_batch, + self.max_lora_dim * c, + input_dim, + ), + dtype=self.dtype, + device="cuda", + ) + for i in range(self.num_layer) + ] + + # Init B tensor, column_major=True + _, output_dim = get_hidden_dim(module_B, self.base_hf_config, base_model) + c = get_stacked_multiply(module_B) + if module_B not in self.B_buffer: + self.B_buffer[module_B] = [ + torch.empty( + ( + c, # stacked lora_b modules might need separation + self.max_loras_per_batch, + output_dim, + self.max_lora_dim, + ), + dtype=self.dtype, + device="cuda", + ) + for i in range(self.num_layer) + ] + + def prepare_lora_batch( + self, + cur_uids: Set[Optional[str]], + lora_adapters: Dict[str, LoRAAdapter], + ): + + def get_available_buffer_slot(): + for buffer_id in range(self.max_loras_per_batch): + # Prioritize empty slots + if self.buffer_id_to_uid[buffer_id] == "": + return buffer_id, "" + + for buffer_id in range(self.max_loras_per_batch): + # Evict unneeded lora + if self.buffer_id_to_uid[buffer_id] not in cur_uids: + return buffer_id, self.buffer_id_to_uid[buffer_id] + + raise ValueError( + "No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch." + ) + + for uid in cur_uids: + if uid not in self.uid_to_buffer_id: + buffer_id, evicted_lora_uid = get_available_buffer_slot() + if evicted_lora_uid != "": + self.uid_to_buffer_id.pop(evicted_lora_uid) + self.load_lora_weight_to_buffer( + uid, buffer_id, lora_adapters.get(uid, None) + ) + self.uid_to_buffer_id[uid] = buffer_id + self.buffer_id_to_uid[buffer_id] = uid + + def load_lora_weight_to_buffer( + self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None + ): + + if uid is None: + for i in range(self.num_layer): + for k in self.A_buffer.keys(): + self.A_buffer[k][i][buffer_id] *= 0 + return + + assert lora_adapter is not None + for layer_id in range(self.num_layer): + layer_weights = lora_adapter.layers[layer_id].weights + for name, weights in layer_weights.items(): + if "lora_A" in name: + lora_weight_name = get_weight_name( + name, self.lora_weight_names, LoRAType.LORA_A + ) + if lora_weight_name: + self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_( + weights + ) + else: + lora_weight_name = get_weight_name( + name, self.lora_weight_names, LoRAType.LORA_B + ) + if lora_weight_name: + c = get_stacked_multiply(lora_weight_name) + if c > 1: + for stacked_id in range(c): + self.B_buffer[lora_weight_name][layer_id][stacked_id][ + buffer_id + ].copy_(weights[stacked_id]) + else: + self.B_buffer[lora_weight_name][layer_id][0][ + buffer_id + ].copy_(weights) + + def get_tensor( + self, weight_name: str, layer_id: int, lora_type: LoRAType + ) -> torch.Tensor: + + if lora_type == LoRAType.LORA_A: + return self.A_buffer[weight_name][layer_id] + + return self.B_buffer[weight_name][layer_id] + + def get_buffer_id(self, lora_uid: str): + return self.uid_to_buffer_id[lora_uid] diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index efc76bb8b..da55e8fd5 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -1,5 +1,11 @@ +from .gate_up_lora_b import gate_up_lora_b_fwd from .qkv_lora_b import qkv_lora_b_fwd from .sgemm_lora_a import sgemm_lora_a_fwd from .sgemm_lora_b import sgemm_lora_b_fwd -__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"] +__all__ = [ + "gate_up_lora_b_fwd", + "qkv_lora_b_fwd", + "sgemm_lora_a_fwd", + "sgemm_lora_b_fwd", +] diff --git a/python/sglang/srt/lora/triton_ops/gate_up_lora_b.py b/python/sglang/srt/lora/triton_ops/gate_up_lora_b.py new file mode 100644 index 000000000..ceaf8a6c7 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/gate_up_lora_b.py @@ -0,0 +1,170 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.utils import LoRABatchInfo + + +@triton.jit +def _gate_up_lora_b_kernel( + # Pointers to matrices + x, + weights, + output, + # Parameters of size + K, # K = R + output_dim, + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # For fused output scaling and adding + fuse_scaling_add, + scaling, +): + # This kernel packs 2 sgemms (gate/up) into a single kernel. + + # x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank + # weights: (num_lora, 2 * output_dim, K) + # output: (s, 2 * output_dim) + # output_dim >> K + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len. + # gate_up_id decides which of gate or up (0: gate, 1: up) + batch_id = tl.program_id(axis=2) + gate_up_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + n_start = gate_up_id * output_dim # offset on output dim + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(output_dim, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + + x_ptrs = (x + seg_start * x_stride_0 + (gate_up_id * K) * x_stride_1) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) + and (n_offset[None, :] < output_dim), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim) + if fuse_scaling_add: + partial_sum += tl.load(output_ptr, mask=output_mask) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def gate_up_lora_b_fwd( + x: torch.Tensor, + gate_up_lora_b: torch.Tensor, + batch_info: LoRABatchInfo, + output_dim: int, + base_output: torch.Tensor = None, + scaling: float = 1.0, +) -> torch.Tensor: + + # x: (s, 2 * r) + # gate_up_lora_b: (num_lora, 2 * output_dim, r) + # output: (s, 2 * output_dim) + + # Compute lora_output with shape (s, output_dim) as follows: + # lora_output[:, :output_dim] = sgemm(x[:, :r], gate_up_lora_b[:, :output_dim, :]) + # lora_output[:, output_dim:] + # = sgemm(x[:, r:], gate_up_lora_b[:, output_dim:, :]) + + # Get dims + s = x.shape[0] + input_dim = x.shape[1] + r = gate_up_lora_b.shape[-1] + assert input_dim == 2 * r + + BLOCK_S = 16 + BLOCK_R = 16 + BLOCK_OUT = 64 + + grid_b = ( + triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(output_dim, BLOCK_OUT), + 2, # this dimension decides current block computes on gate or up proj + batch_info.bs, + ) + + if base_output is None: + output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype) + fuse_scaling_add = False + else: + output = base_output + fuse_scaling_add = True + + _gate_up_lora_b_kernel[grid_b]( + x, + gate_up_lora_b, + output, + r, + output_dim, + x.stride(0), + x.stride(1), + gate_up_lora_b.stride(0), + gate_up_lora_b.stride(1), + gate_up_lora_b.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + BLOCK_S, + BLOCK_OUT, + BLOCK_R, + fuse_scaling_add, + scaling, + ) + + return output diff --git a/python/sglang/srt/lora/triton_ops/qkv_lora_b.py b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py index 3e090f4dc..bf56eef71 100644 --- a/python/sglang/srt/lora/triton_ops/qkv_lora_b.py +++ b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py @@ -2,7 +2,7 @@ import torch import triton import triton.language as tl -from sglang.srt.lora.lora import LoraBatchInfo +from sglang.srt.lora.utils import LoRABatchInfo @triton.jit @@ -108,7 +108,7 @@ def _qkv_lora_b_kernel( def qkv_lora_b_fwd( x: torch.Tensor, qkv_lora_b: torch.Tensor, - batch_info: LoraBatchInfo, + batch_info: LoRABatchInfo, output_offset: torch.Tensor, max_qkv_out_dim: int, base_output: torch.Tensor = None, @@ -123,11 +123,11 @@ def qkv_lora_b_fwd( # output: (s, output_dim_q + 2 * output_dim_kv) # Compute lora_output with shape (s, output_dim) as follows: - # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], ) + # lora_output[:, :output_dim_q] = sgemm(x[:, :r], qkv_lora_b[:, :outptu_dim_q, :]) # lora_output[:, output_dim_q: output_dim_q + output_dim_kv] - # = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0]) + # = sgemm(x[:, r: 2 * r], qkv_lora_b[:, outptu_dim_q: output_dim_q + output_dim_kv, :]) # lora_output[:, output_dim_q + output_dim_kv: ] - # = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1]) + # = sgemm(x[:, 2 * r: , qkv_lora_b[:, output_dim_q + output_dim_kv: , :]) # Get dims s = x.shape[0] diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py index 305bb8c5f..e2d24c3f4 100644 --- a/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py @@ -2,7 +2,7 @@ import torch import triton import triton.language as tl -from sglang.srt.lora.lora import LoraBatchInfo +from sglang.srt.lora.utils import LoRABatchInfo @triton.jit @@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel( def sgemm_lora_a_fwd( - x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo + x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo ) -> torch.Tensor: # x: (s, input_dim) # weights: (num_lora, r, input_dim) diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py index c0bc91363..2e2e3a04c 100644 --- a/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py @@ -2,7 +2,7 @@ import torch import triton import triton.language as tl -from sglang.srt.lora.lora import LoraBatchInfo +from sglang.srt.lora.utils import LoRABatchInfo @triton.jit @@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel( def sgemm_lora_b_fwd( x: torch.Tensor, weights: torch.Tensor, - batch_info: LoraBatchInfo, + batch_info: LoRABatchInfo, base_output: torch.Tensor = None, scaling: float = 1.0, ) -> torch.Tensor: diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py new file mode 100644 index 000000000..f5845373d --- /dev/null +++ b/python/sglang/srt/lora/utils.py @@ -0,0 +1,141 @@ +import re +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Set, Tuple + +import torch + +from sglang.srt.hf_transformers_utils import AutoConfig + + +@dataclass +class LoRABatchInfo: + # Batch size + bs: int + + # Lengths of each sequence in shape (bs,) + seg_lens: torch.Tensor + + # Indice pointers of each sequence in shape (bs + 1, ) + seg_indptr: torch.Tensor + + # Maximum sequence length of current batch + max_len: int + + # The index of lora adapter used by each sequence, in shape (bs,) + weight_indices: torch.Tensor + + +class LoRAType(Enum): + LORA_A = 0 + LORA_B = 1 + + +def get_layer_id(name: str) -> int: + """ + Extract integer id of layer from its name in string. + """ + match = re.search(r"layers\.(\d+)\.", name) + if match is None: + return None + return int(match.group(1)) + + +def get_customized_names_from_hf_names( + hf_module_names: Set[str], base_model: torch.nn.Module +) -> Set[str]: + """ + This function takes in a set of huggingface style module names: + e.g., {"k_proj", "q_proj", "v_proj", "o_proj"} + and outputs a set of module names of customized sglang layers: + e.g., {"qkv_proj", "o_proj"} + """ + if hasattr(base_model, "get_module_name"): + return {base_model.get_module_name(name) for name in hf_module_names} + else: + """ + Fallback solution of mapping from config module name to module name in model class. + Please check if it aligns with your base model. + Please implement the function in the model class if it is not. + You can reference this function in llama.py. + """ + params_mapping = { + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", + } + return {params_mapping.get(name, name) for name in hf_module_names} + + +def get_hidden_dim( + module_name: str, config: AutoConfig, base_model: torch.nn.Module +) -> Tuple[int]: + """ + Given a module_name (might be a stacked name), return the hidden dims of modules's input and output. + """ + + if hasattr(base_model, "get_hidden_dim"): + return base_model.get_hidden_dim(module_name) + else: + """ + WARNING: get_hidden_dim() is not defined, + which is used to get the hidden dim for different lora modules + Use the default one, but please check if it is correct for your model. + Please implement the function in the model class if it is not. + You can reference this function in llama.py. + """ + if module_name in ["q_proj", "o_proj", "qkv_proj"]: + return config.hidden_size, config.hidden_size + elif module_name in ["kv_proj"]: + return config.hidden_size, config.hidden_size // ( + config.num_attention_heads // config.num_key_value_heads + ) + elif module_name == "gate_up_proj": + return config.hidden_size, config.intermediate_size + elif module_name == "down_proj": + return config.intermediate_size, config.hidden_size + else: + raise NotImplementedError() + + +def get_stacked_name(name: str) -> Tuple[str]: + """ + Mapping a target module name to (stacked name for Lora A, stacked name for Lora B) + """ + params_mapping = { + "q_proj": ("qkv_proj", "q_proj"), + "k_proj": ("qkv_proj", "kv_proj"), + "v_proj": ("qkv_proj", "kv_proj"), + "gate_proj": ("gate_up_proj", "gate_up_proj"), + "up_proj": ("gate_up_proj", "gate_up_proj"), + } + return params_mapping.get(name, (name, name)) + + +def get_stacked_multiply(module_name: str) -> int: + """ + Mapping a lora module name to its magnification at output dimension + """ + stacked_rank = { + "qkv_proj": 3, + "kv_proj": 2, + "gate_up_proj": 2, + } + return stacked_rank[module_name] if module_name in stacked_rank else 1 + + +def get_weight_name( + target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType +) -> Optional[str]: + """ + target_name is name of a given module, + lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above) + If there is a weight name in lora_weight_names that can match target_name, return this name + Else return None + """ + idx = 0 if lora_type == LoRAType.LORA_A else 1 + for weight_name_pair in lora_weight_names: + if weight_name_pair[idx] in target_name: + return weight_name_pair[idx] diff --git a/test/srt/models/test_lora_backend.py b/test/srt/models/test_lora_backend.py index 6d6163300..82e3ff167 100644 --- a/test/srt/models/test_lora_backend.py +++ b/test/srt/models/test_lora_backend.py @@ -22,7 +22,11 @@ from sglang.test.test_utils import calculate_rouge_l LORA_SETS = [ {"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]}, - # {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]} + { + "base": "meta-llama/Llama-3.1-8B-Instruct", + "loras": ["reissbaker/llama-3.1-8b-abliterated-lora"], + "decode_tolerance": 8e-2, + }, ] TORCH_DTYPES = [torch.float16] @@ -128,7 +132,8 @@ class TestLoRABackend(unittest.TestCase): torch.max(abs(hf_logprobs - hf_no_lora_logprobs)), ) if hf_logprobs.shape[0] <= 100: - assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), ( + tol = lora_set.get("prefill_tolerance", prefill_tolerance) + assert torch.all(abs(hf_logprobs - srt_logprobs) < tol), ( f"prefill logprobs are not all close with model_path={base_path}," f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}" f"prefill_tolerance={prefill_tolerance}." @@ -144,7 +149,8 @@ class TestLoRABackend(unittest.TestCase): "\n", ) if hf_logprobs.shape[0] <= 100: - assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), ( + tol = lora_set.get("decode_tolerance", decode_tolerance) + assert torch.all(abs(hf_logprobs - srt_logprobs) < tol), ( f"decode logprobs are not all close with model_path={base_path}," f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}" f"decode_tolerance={decode_tolerance}." @@ -153,7 +159,7 @@ class TestLoRABackend(unittest.TestCase): # compare output strings srt_output_str = srt_outputs.output_strs[i].strip(" ") - hf_output_str = hf_outputs.output_strs[i] + hf_output_str = hf_outputs.output_strs[i].strip(" ") print(f"srt_output_str={srt_output_str}") print(f"hf_output_str={hf_output_str}") rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str])