[Feature] Define backends and add Triton backend for Lora (#3161)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
8
python/sglang/srt/lora/backend/__init__.py
Normal file
8
python/sglang/srt/lora/backend/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .base_backend import BaseLoraBackend
|
||||
from .flashinfer_backend import FlashInferLoraBackend
|
||||
from .triton_backend import TritonLoraBackend
|
||||
|
||||
__all__ = [
|
||||
"FlashInferLoraBackend",
|
||||
"TritonLoraBackend",
|
||||
]
|
||||
95
python/sglang/srt/lora/backend/base_backend.py
Normal file
95
python/sglang/srt/lora/backend/base_backend.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.lora import LoraBatchInfo
|
||||
|
||||
|
||||
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
|
||||
mapping = {
|
||||
"triton": True,
|
||||
"flashinfer": False,
|
||||
}
|
||||
return mapping.get(name, False)
|
||||
|
||||
|
||||
def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
|
||||
mapping = {
|
||||
"triton": True,
|
||||
"flashinfer": False,
|
||||
}
|
||||
return mapping.get(name, False)
|
||||
|
||||
|
||||
class BaseLoraBackend:
|
||||
"""Base class for different Lora backends.
|
||||
Each backend has its own implementation of Lora kernels.
|
||||
|
||||
Args:
|
||||
name: name of backend
|
||||
batch_info: information of current batch for use
|
||||
fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
||||
and the operation of scaling and adding will be fused into kernel
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
||||
self.name = name
|
||||
self.batch_info = batch_info
|
||||
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
|
||||
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
|
||||
|
||||
def run_lora_a_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""Run segment Gemm of lora a modules with current backend.
|
||||
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
|
||||
|
||||
Args:
|
||||
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
||||
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
|
||||
usually input_dim is much larger than r
|
||||
Returns:
|
||||
result with shape (s, r)
|
||||
"""
|
||||
pass
|
||||
|
||||
def run_lora_b_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""Run segment Gemm of lora b modules with current backend.
|
||||
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
|
||||
|
||||
Args:
|
||||
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
|
||||
weights: a set of lora weights with shape (num_lora, output_dim, r)
|
||||
usually output_dim is much larger than r
|
||||
Returns:
|
||||
result with shape (s, output_dim)
|
||||
"""
|
||||
pass
|
||||
|
||||
def run_qkv_lora(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
qkv_lora_a: torch.Tensor,
|
||||
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
*args,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""Run the lora pass for QKV Layer.
|
||||
|
||||
Args:
|
||||
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
||||
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
|
||||
qkv_lora_b: lora_b module for qkv.
|
||||
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
|
||||
If passed in as a tuple of two tensors containing:
|
||||
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
|
||||
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
|
||||
Returns:
|
||||
result with shape (s, output_dim_q + 2 * output_dim_kv)
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_batch_info(self, batch_info: LoraBatchInfo):
|
||||
self.batch_info = batch_info
|
||||
88
python/sglang/srt/lora/backend/flashinfer_backend.py
Normal file
88
python/sglang/srt/lora/backend/flashinfer_backend.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from flashinfer import SegmentGEMMWrapper
|
||||
|
||||
from sglang.srt.lora.backend import BaseLoraBackend
|
||||
from sglang.srt.lora.lora import LoraBatchInfo
|
||||
|
||||
|
||||
class FlashInferLoraBackend(BaseLoraBackend):
|
||||
|
||||
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
||||
super().__init__(name, batch_info)
|
||||
|
||||
# Set up SGemm Wrapper from flashinfer
|
||||
# FIXME wait for flashinfer segment gemm update
|
||||
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
|
||||
|
||||
def run_lora_a_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=weights,
|
||||
batch_size=self.batch_info.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.batch_info.seg_indptr,
|
||||
weight_indices=self.batch_info.weight_indices,
|
||||
)
|
||||
|
||||
def run_lora_b_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=weights,
|
||||
batch_size=self.batch_info.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.batch_info.seg_indptr,
|
||||
weight_indices=self.batch_info.weight_indices,
|
||||
)
|
||||
|
||||
def run_qkv_lora(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
qkv_lora_a: torch.Tensor,
|
||||
qkv_lora_b: Tuple[torch.Tensor],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Shape of lora_a_output: (s, 3 * r)
|
||||
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
|
||||
|
||||
q_lora_b, kv_lora_b = qkv_lora_b
|
||||
lora_rank = kv_lora_b.shape[-1]
|
||||
output_dim_q = q_lora_b.shape[-2]
|
||||
output_dim_kv = kv_lora_b.shape[-2]
|
||||
lora_output = torch.empty(
|
||||
(x.shape[0], output_dim_q + 2 * output_dim_kv),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
# q
|
||||
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
|
||||
)
|
||||
|
||||
# kv
|
||||
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
|
||||
self.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
|
||||
weights=kv_lora_b[0],
|
||||
)
|
||||
)
|
||||
|
||||
lora_output[
|
||||
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
|
||||
] = self.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
|
||||
weights=kv_lora_b[1],
|
||||
)
|
||||
|
||||
return lora_output
|
||||
61
python/sglang/srt/lora/backend/triton_backend.py
Normal file
61
python/sglang/srt/lora/backend/triton_backend.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.backend import BaseLoraBackend
|
||||
from sglang.srt.lora.lora import LoraBatchInfo
|
||||
from sglang.srt.lora.triton_ops import (
|
||||
qkv_lora_b_fwd,
|
||||
sgemm_lora_a_fwd,
|
||||
sgemm_lora_b_fwd,
|
||||
)
|
||||
|
||||
|
||||
class TritonLoraBackend(BaseLoraBackend):
|
||||
|
||||
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
||||
super().__init__(name, batch_info)
|
||||
|
||||
def run_lora_a_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
return sgemm_lora_a_fwd(x, weights, self.batch_info)
|
||||
|
||||
def run_lora_b_sgemm(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)
|
||||
|
||||
def run_qkv_lora(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
qkv_lora_a: torch.Tensor,
|
||||
qkv_lora_b: torch.Tensor,
|
||||
output_offset: torch.Tensor,
|
||||
max_qkv_out_dim: int,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
|
||||
# x: (s, input_dim)
|
||||
# qkv_lora_a: (num_lora, 3 * r, input_dim)
|
||||
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||
assert isinstance(qkv_lora_b, torch.Tensor)
|
||||
|
||||
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
|
||||
lora_output = qkv_lora_b_fwd(
|
||||
lora_a_output,
|
||||
qkv_lora_b,
|
||||
self.batch_info,
|
||||
output_offset,
|
||||
max_qkv_out_dim,
|
||||
base_output,
|
||||
scaling,
|
||||
)
|
||||
return lora_output
|
||||
@@ -18,8 +18,8 @@
|
||||
# LoRA layers class inheritance adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
||||
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -34,14 +34,32 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraBatchInfo:
|
||||
# Batch size
|
||||
bs: int
|
||||
|
||||
# Lengths of each sequence in shape (bs,)
|
||||
seg_lens: torch.Tensor
|
||||
|
||||
# Indice pointers of each sequence in shape (bs + 1, )
|
||||
seg_indptr: torch.Tensor
|
||||
|
||||
# Maximum sequence length of current batch
|
||||
max_len: int
|
||||
|
||||
# The index of lora adapter used by each sequence, in shape (bs,)
|
||||
weight_indices: torch.Tensor
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
|
||||
def __init__(self, base_layer, lora_rank, scaling, lora_backend):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.segment_gemm = segment_gemm
|
||||
self.lora_rank = lora_rank
|
||||
self.scaling = scaling
|
||||
self.set_lora = False
|
||||
self.lora_backend = lora_backend
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.base_layer.forward(x)
|
||||
@@ -52,17 +70,17 @@ class BaseLayerWithLoRA(nn.Module):
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
|
||||
self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
self.weight = base_layer.weight
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
|
||||
self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
# TODO
|
||||
@@ -88,136 +106,127 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
|
||||
self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
|
||||
def set_lora_info(
|
||||
self,
|
||||
A_buffer,
|
||||
B_buffer,
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
self.bs = bs
|
||||
self.seg_indptr = seg_indptr
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
# FIXME
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
|
||||
|
||||
output_dim = base_output.shape[-1]
|
||||
lora_output = torch.empty_like(base_output)
|
||||
output_dim = lora_output.shape[-1] // 2
|
||||
for i in range(2):
|
||||
left = output_dim * i
|
||||
right = left + output_dim
|
||||
lora_output[:, left:right] = self.segment_gemm.run(
|
||||
x=lora_a_output[
|
||||
:, self.lora_rank * i : self.lora_rank * (i + 1)
|
||||
].contiguous(),
|
||||
weights=self.B_buffer[:, left:right, :].contiguous(),
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
|
||||
weights=self.B_buffer[0],
|
||||
)
|
||||
|
||||
lora_output[:, output_dim : 2 * output_dim] = (
|
||||
self.lora_backend.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
|
||||
weights=self.B_buffer[1],
|
||||
)
|
||||
)
|
||||
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
|
||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
|
||||
def init__(
|
||||
self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(
|
||||
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
|
||||
self,
|
||||
A_buffer_qkv,
|
||||
B_buffer_q,
|
||||
B_buffer_kv,
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer_qkv = A_buffer_qkv
|
||||
self.B_buffer_q = B_buffer_q
|
||||
self.B_buffer_kv = B_buffer_kv
|
||||
self.bs = bs
|
||||
self.seg_indptr = seg_indptr
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
if self.lora_backend.fuse_qkv_lora_b:
|
||||
assert (
|
||||
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
||||
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
||||
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
||||
|
||||
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||
self.B_buffer_qkv = torch.cat(
|
||||
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
|
||||
).contiguous()
|
||||
|
||||
# Offsets of q/k/v in output dimension
|
||||
self.output_offset = torch.tensor(
|
||||
[
|
||||
0,
|
||||
output_dim_q,
|
||||
output_dim_q + output_dim_kv,
|
||||
output_dim_q + 2 * output_dim_kv,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=B_buffer_q.device,
|
||||
)
|
||||
# For computing number of launched blocks
|
||||
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
||||
else:
|
||||
self.B_buffer_qkv = (
|
||||
B_buffer_q,
|
||||
B_buffer_kv,
|
||||
)
|
||||
self.output_offset = None
|
||||
self.max_qkv_out_dim = None
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer_qkv,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
lora_output = self.lora_backend.run_qkv_lora(
|
||||
x,
|
||||
self.A_buffer_qkv,
|
||||
self.B_buffer_qkv,
|
||||
output_offset=self.output_offset,
|
||||
max_qkv_out_dim=self.max_qkv_out_dim,
|
||||
base_output=base_output,
|
||||
scaling=self.scaling,
|
||||
)
|
||||
# FIXME parallelize qkv
|
||||
lora_output = torch.empty_like(base_output)
|
||||
# q
|
||||
output_dim_q = self.B_buffer_q.shape[-2]
|
||||
lora_output[:, :output_dim_q] = self.segment_gemm.run(
|
||||
x=lora_a_output[:, : self.lora_rank].contiguous(),
|
||||
weights=self.B_buffer_q,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
)
|
||||
# kv
|
||||
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
|
||||
for i in range(2):
|
||||
left = output_dim_kv * i
|
||||
right = left + output_dim_kv
|
||||
lora_output[:, output_dim_q + left : output_dim_q + right] = (
|
||||
self.segment_gemm.run(
|
||||
x=lora_a_output[
|
||||
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
|
||||
].contiguous(),
|
||||
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
)
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
|
||||
self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
|
||||
) -> None:
|
||||
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
|
||||
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
|
||||
def set_lora_info(self, A_buffer, B_buffer):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
self.bs = bs
|
||||
self.seg_indptr = seg_indptr
|
||||
self.weight_indices = weight_indices
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_output = self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=self.A_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||
lora_a_output,
|
||||
self.B_buffer[0],
|
||||
base_output=base_output,
|
||||
scaling=self.scaling,
|
||||
)
|
||||
lora_output = self.segment_gemm.run(
|
||||
x=lora_output,
|
||||
weights=self.B_buffer,
|
||||
batch_size=self.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.seg_indptr,
|
||||
weight_indices=self.weight_indices,
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
)
|
||||
return base_output + lora_output * self.scaling
|
||||
|
||||
def forward(self, input_):
|
||||
# duplicate the logic in RowParallelLinear
|
||||
@@ -255,7 +264,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
|
||||
def get_lora_layer(
|
||||
layer: nn.Module, segment_gemm, lora_rank, scaling
|
||||
layer: nn.Module, lora_rank, scaling, lora_backend
|
||||
) -> BaseLayerWithLoRA:
|
||||
supported_layer_types = {
|
||||
# the order matters
|
||||
@@ -267,7 +276,7 @@ def get_lora_layer(
|
||||
}
|
||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
||||
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
|
||||
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
|
||||
return ret
|
||||
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
||||
|
||||
@@ -297,13 +306,14 @@ class LoRALayer(nn.Module):
|
||||
|
||||
|
||||
class LoRAAdapter(nn.Module):
|
||||
def __init__(self, uid, config, base_hf_config, load_config):
|
||||
def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
|
||||
super().__init__()
|
||||
self.uid = uid
|
||||
self.config = config
|
||||
assert self.config.hf_config["peft_type"].lower() == "lora"
|
||||
self.base_hf_config = base_hf_config
|
||||
self.load_config = load_config
|
||||
self.lora_backend = lora_backend
|
||||
self.scaling = self.config.lora_alpha / self.config.r
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
@@ -376,20 +386,25 @@ class LoRAAdapter(nn.Module):
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(v_name)
|
||||
else:
|
||||
layer.weights[kv_name] = torch.cat(
|
||||
(
|
||||
layer.weights[kv_name] = torch.stack(
|
||||
[
|
||||
layer.weights[weight_name],
|
||||
layer.weights[v_name],
|
||||
),
|
||||
0,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(v_name)
|
||||
elif "gate_proj" in weight_name:
|
||||
up_name = weight_name.replace("gate_proj", "up_proj")
|
||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||
layer.weights[gate_up_name] = torch.cat(
|
||||
(layer.weights[weight_name], layer.weights[up_name]), 0
|
||||
)
|
||||
if "lora_A" in weight_name:
|
||||
layer.weights[gate_up_name] = torch.cat(
|
||||
(layer.weights[weight_name], layer.weights[up_name]), 0
|
||||
)
|
||||
else:
|
||||
layer.weights[gate_up_name] = torch.stack(
|
||||
[layer.weights[weight_name], layer.weights[up_name]], dim=0
|
||||
)
|
||||
layer.weights.pop(weight_name)
|
||||
layer.weights.pop(up_name)
|
||||
|
||||
@@ -20,16 +20,14 @@ import re
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
||||
from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend
|
||||
from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_flashinfer_available, replace_submodule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import SegmentGEMMWrapper
|
||||
|
||||
|
||||
def get_module_name(name):
|
||||
# Fallback solution of mapping from config module name to module name in model class.
|
||||
@@ -77,6 +75,20 @@ def get_stacked_name(name):
|
||||
return params_mapping.get(name, (name, name))
|
||||
|
||||
|
||||
def get_backend_from_name(name):
|
||||
backend_mapping = {
|
||||
"triton": TritonLoraBackend,
|
||||
"flashinfer": FlashInferLoraBackend,
|
||||
}
|
||||
|
||||
if name in backend_mapping:
|
||||
return backend_mapping[name]
|
||||
|
||||
raise Exception(
|
||||
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
|
||||
)
|
||||
|
||||
|
||||
def get_layer_id(name):
|
||||
match = re.search(r"layers\.(\d+)\.", name)
|
||||
if match is None:
|
||||
@@ -93,6 +105,7 @@ class LoRAManager:
|
||||
max_loras_per_batch,
|
||||
load_config,
|
||||
dtype,
|
||||
lora_backend,
|
||||
):
|
||||
self.base_model = base_model
|
||||
self.lora_paths = lora_paths
|
||||
@@ -101,8 +114,9 @@ class LoRAManager:
|
||||
self.load_config = load_config
|
||||
self.dtype = dtype
|
||||
|
||||
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
|
||||
logger.info(f"Using {lora_backend} as backend of Lora kernels.")
|
||||
backend_type = get_backend_from_name(lora_backend)
|
||||
self.lora_backend = backend_type(lora_backend)
|
||||
|
||||
self.init_loras()
|
||||
self.init_lora_memory_pool()
|
||||
@@ -123,7 +137,7 @@ class LoRAManager:
|
||||
|
||||
def set_lora_module(self, module_name, module):
|
||||
lora_module = get_lora_layer(
|
||||
module, self.segment_gemm, self.max_lora_dim, self.scaling
|
||||
module, self.max_lora_dim, self.scaling, self.lora_backend
|
||||
)
|
||||
replace_submodule(self.base_model, module_name, lora_module)
|
||||
return lora_module
|
||||
@@ -162,7 +176,11 @@ class LoRAManager:
|
||||
self.lora_id[name] = len(self.loras)
|
||||
self.loras.append(
|
||||
LoRAAdapter(
|
||||
name, self.configs[name], self.base_hf_config, self.load_config
|
||||
name,
|
||||
self.configs[name],
|
||||
self.base_hf_config,
|
||||
self.load_config,
|
||||
self.lora_backend,
|
||||
)
|
||||
)
|
||||
self.loras[-1].initialize_weights()
|
||||
@@ -226,8 +244,9 @@ class LoRAManager:
|
||||
self.B_buffer[module_B] = [
|
||||
torch.empty(
|
||||
(
|
||||
c,
|
||||
self.max_loras_per_batch,
|
||||
hidden_dim_B * c,
|
||||
hidden_dim_B,
|
||||
self.max_lora_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
@@ -263,7 +282,16 @@ class LoRAManager:
|
||||
else:
|
||||
lora_weight_name = self.get_weight_name(name, 1)
|
||||
if lora_weight_name:
|
||||
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
||||
c = self.loras[-1].get_stacked_multiply(lora_weight_name)
|
||||
if c > 1:
|
||||
for j in range(c):
|
||||
self.B_buffer[lora_weight_name][i][j][buffer_id].copy_(
|
||||
weights[j]
|
||||
)
|
||||
else:
|
||||
self.B_buffer[lora_weight_name][i][0][buffer_id].copy_(
|
||||
weights
|
||||
)
|
||||
|
||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||
# load active loras into lora memory pool
|
||||
@@ -292,20 +320,30 @@ class LoRAManager:
|
||||
if cur_uids == set([None]):
|
||||
return
|
||||
|
||||
# setup lora in forward modules
|
||||
# set up batch info shared by all lora moruldes
|
||||
bs = forward_batch.batch_size
|
||||
seg_lens = (
|
||||
forward_batch.extend_seq_lens
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else torch.ones(bs, device="cuda")
|
||||
)
|
||||
# FIXME: reuse the data rather than recompute
|
||||
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
||||
max_len = int(torch.max(seg_lens))
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
weight_indices[i] = self.buffer_id[lora_path]
|
||||
|
||||
batch_info = LoraBatchInfo(
|
||||
bs=bs,
|
||||
seg_lens=seg_lens,
|
||||
seg_indptr=seg_indptr,
|
||||
max_len=max_len,
|
||||
weight_indices=weight_indices,
|
||||
)
|
||||
self.lora_backend.set_batch_info(batch_info)
|
||||
|
||||
# call set_lora_info for each lora modules
|
||||
for module_name, module in self.lora_modules:
|
||||
layer_id = get_layer_id(module_name)
|
||||
|
||||
@@ -314,16 +352,10 @@ class LoRAManager:
|
||||
module.set_lora_info(
|
||||
self.A_buffer[weight_name][layer_id],
|
||||
self.B_buffer[weight_name][layer_id],
|
||||
bs,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
)
|
||||
else:
|
||||
module.set_lora_info(
|
||||
self.A_buffer["qkv_proj"][layer_id],
|
||||
self.B_buffer["q_proj"][layer_id],
|
||||
self.B_buffer["kv_proj"][layer_id],
|
||||
bs,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
)
|
||||
|
||||
5
python/sglang/srt/lora/triton_ops/__init__.py
Normal file
5
python/sglang/srt/lora/triton_ops/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .qkv_lora_b import qkv_lora_b_fwd
|
||||
from .sgemm_lora_a import sgemm_lora_a_fwd
|
||||
from .sgemm_lora_b import sgemm_lora_b_fwd
|
||||
|
||||
__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"]
|
||||
182
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
Normal file
182
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.lora.lora import LoraBatchInfo
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _qkv_lora_b_kernel(
|
||||
# Pointers to matrices
|
||||
x,
|
||||
weights,
|
||||
output,
|
||||
# Parameters of size
|
||||
K, # K = R
|
||||
max_qkv_out_dim, # max(output_q_dim, output_kv_dim)
|
||||
# Strides
|
||||
x_stride_0,
|
||||
x_stride_1,
|
||||
w_stride_0,
|
||||
w_stride_1,
|
||||
w_stride_2,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
# Information on sequence lengths and weight id
|
||||
seg_lens,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
# Offsets of q/k/v slice on output dimension
|
||||
n_offs,
|
||||
# Meta parameters
|
||||
BLOCK_S: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
# For fused output scaling and adding
|
||||
fuse_scaling_add,
|
||||
scaling,
|
||||
):
|
||||
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
|
||||
|
||||
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
|
||||
# weights: (num_lora, N_Q + 2 * N_KV, K)
|
||||
# output: (s, N_Q + 2 * N_KV)
|
||||
# N_Q >> K, N_KV >> K
|
||||
|
||||
# Current block computes sequence with batch_id,
|
||||
# which starts from row seg_start of x with length seg_len.
|
||||
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
|
||||
batch_id = tl.program_id(axis=2)
|
||||
qkv_id = tl.program_id(axis=1)
|
||||
pid = tl.program_id(axis=0)
|
||||
seg_len = tl.load(seg_lens + batch_id)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
n_start = tl.load(n_offs + qkv_id)
|
||||
n_size = tl.load(n_offs + qkv_id + 1) - n_start
|
||||
|
||||
# The tile in output matrix will have (pid_s, pid_n) as id
|
||||
num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
|
||||
pid_s = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
# Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
|
||||
# The pointers will be advanced as we move in the K direction
|
||||
# and accumulate
|
||||
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
|
||||
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
k_offset = tl.arange(0, BLOCK_K)
|
||||
|
||||
x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + (
|
||||
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
||||
)
|
||||
w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
|
||||
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
||||
)
|
||||
|
||||
# Iteate to compute the block in output matrix
|
||||
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x_tile = tl.load(
|
||||
x_ptrs,
|
||||
mask=(s_offset[:, None] < seg_len)
|
||||
and (k_offset[None, :] < K - k * BLOCK_K),
|
||||
other=0.0,
|
||||
)
|
||||
w_tile = tl.load(
|
||||
w_ptrs,
|
||||
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
|
||||
other=0.0,
|
||||
)
|
||||
partial_sum += tl.dot(x_tile, w_tile)
|
||||
|
||||
x_ptrs += BLOCK_K * x_stride_1
|
||||
w_ptrs += BLOCK_K * w_stride_2
|
||||
|
||||
# Store result to output matrix
|
||||
partial_sum *= scaling
|
||||
partial_sum = partial_sum.to(x.dtype.element_ty)
|
||||
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||
)
|
||||
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
|
||||
if fuse_scaling_add:
|
||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||
|
||||
|
||||
def qkv_lora_b_fwd(
|
||||
x: torch.Tensor,
|
||||
qkv_lora_b: torch.Tensor,
|
||||
batch_info: LoraBatchInfo,
|
||||
output_offset: torch.Tensor,
|
||||
max_qkv_out_dim: int,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# x: (s, 3 * r)
|
||||
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||
# output_offset = [0, output_dim_q, output_dim_q + output_dim_kv,
|
||||
# output_dim_q + 2 * output_dim_kv]
|
||||
# max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
||||
# output: (s, output_dim_q + 2 * output_dim_kv)
|
||||
|
||||
# Compute lora_output with shape (s, output_dim) as follows:
|
||||
# lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
|
||||
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
|
||||
# = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
|
||||
# lora_output[:, output_dim_q + output_dim_kv: ]
|
||||
# = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
|
||||
|
||||
# Get dims
|
||||
s = x.shape[0]
|
||||
input_dim = x.shape[1]
|
||||
r = qkv_lora_b.shape[-1]
|
||||
output_dim = qkv_lora_b.shape[-2]
|
||||
assert input_dim == 3 * r
|
||||
assert output_offset.shape[0] == 4
|
||||
|
||||
BLOCK_S = 16
|
||||
BLOCK_R = 16
|
||||
BLOCK_OUT = 64
|
||||
|
||||
grid_b = (
|
||||
triton.cdiv(batch_info.max_len, BLOCK_S)
|
||||
* triton.cdiv(max_qkv_out_dim, BLOCK_OUT),
|
||||
3, # this dimension decides current block computes on q, k or v
|
||||
batch_info.bs,
|
||||
)
|
||||
|
||||
if base_output is None:
|
||||
output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
|
||||
fuse_scaling_add = False
|
||||
else:
|
||||
output = base_output
|
||||
fuse_scaling_add = True
|
||||
|
||||
_qkv_lora_b_kernel[grid_b](
|
||||
x,
|
||||
qkv_lora_b,
|
||||
output,
|
||||
r,
|
||||
max_qkv_out_dim,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
qkv_lora_b.stride(0),
|
||||
qkv_lora_b.stride(1),
|
||||
qkv_lora_b.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
batch_info.seg_lens,
|
||||
batch_info.seg_indptr,
|
||||
batch_info.weight_indices,
|
||||
output_offset,
|
||||
BLOCK_S,
|
||||
BLOCK_OUT,
|
||||
BLOCK_R,
|
||||
fuse_scaling_add,
|
||||
scaling,
|
||||
)
|
||||
|
||||
return output
|
||||
143
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
Normal file
143
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.lora.lora import LoraBatchInfo
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgemm_lora_a_kernel(
|
||||
# Pointers to matrices
|
||||
x,
|
||||
weights,
|
||||
output,
|
||||
# Matrix dimensions
|
||||
N, # r
|
||||
K, # input_dim
|
||||
# Strides
|
||||
x_stride_0,
|
||||
x_stride_1,
|
||||
w_stride_0,
|
||||
w_stride_1,
|
||||
w_stride_2,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
# Information on sequence lengths and weight id
|
||||
seg_lens,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
# Meta parameters
|
||||
BLOCK_S: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
|
||||
# x: (s, K), s is the sum of sequence lengths
|
||||
# weights: (num_lora, N, K)
|
||||
# output: (s, N)
|
||||
|
||||
# Current block computes sequence with batch_id,
|
||||
# which starts from row seg_start of x with length seg_len
|
||||
batch_id = tl.program_id(axis=1)
|
||||
pid = tl.program_id(axis=0)
|
||||
seg_len = tl.load(seg_lens + batch_id)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
|
||||
# The tile in output matrix will have (pid_s, pid_n) as id
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
pid_s = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
# Create pointers for the first block of x and weights[batch_id]
|
||||
# The pointers will be advanced as we move in the K direction
|
||||
# and accumulate
|
||||
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
|
||||
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
k_offset = tl.arange(0, BLOCK_K)
|
||||
x_ptrs = (x + seg_start * x_stride_0) + (
|
||||
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
||||
)
|
||||
w_ptrs = (weights + w_index * w_stride_0) + (
|
||||
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
||||
)
|
||||
|
||||
# Iteate to compute the block in output matrix
|
||||
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x_tile = tl.load(
|
||||
x_ptrs,
|
||||
mask=(s_offset[:, None] < seg_len)
|
||||
and (k_offset[None, :] < K - k * BLOCK_K),
|
||||
other=0.0,
|
||||
)
|
||||
w_tile = tl.load(
|
||||
w_ptrs,
|
||||
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
|
||||
other=0.0,
|
||||
)
|
||||
partial_sum += tl.dot(x_tile, w_tile)
|
||||
|
||||
x_ptrs += BLOCK_K * x_stride_1
|
||||
w_ptrs += BLOCK_K * w_stride_2
|
||||
|
||||
# Store result to output matrix
|
||||
partial_sum = partial_sum.to(x.dtype.element_ty)
|
||||
output_ptr = (output + seg_start * output_stride_0) + (
|
||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||
)
|
||||
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
|
||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||
|
||||
|
||||
def sgemm_lora_a_fwd(
|
||||
x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
|
||||
) -> torch.Tensor:
|
||||
# x: (s, input_dim)
|
||||
# weights: (num_lora, r, input_dim)
|
||||
# output: (s, r)
|
||||
# when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
|
||||
# input_dim is much larger than r
|
||||
|
||||
assert x.is_contiguous()
|
||||
assert weights.is_contiguous()
|
||||
assert len(x.shape) == 2
|
||||
assert len(weights.shape) == 3
|
||||
|
||||
S = x.shape[0]
|
||||
R = weights.shape[-2]
|
||||
K = weights.shape[-1]
|
||||
assert x.shape[-1] == K
|
||||
|
||||
# Block shapes
|
||||
BLOCK_S = 16
|
||||
BLOCK_K = 256
|
||||
BLOCK_R = 16
|
||||
|
||||
grid = (
|
||||
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(R, BLOCK_R),
|
||||
batch_info.bs,
|
||||
)
|
||||
|
||||
output = torch.empty((S, R), device=x.device, dtype=x.dtype)
|
||||
_sgemm_lora_a_kernel[grid](
|
||||
x,
|
||||
weights,
|
||||
output,
|
||||
R,
|
||||
K,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
weights.stride(0),
|
||||
weights.stride(1),
|
||||
weights.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
batch_info.seg_lens,
|
||||
batch_info.seg_indptr,
|
||||
batch_info.weight_indices,
|
||||
BLOCK_S,
|
||||
BLOCK_R,
|
||||
BLOCK_K,
|
||||
)
|
||||
return output
|
||||
159
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
Normal file
159
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.lora.lora import LoraBatchInfo
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgemm_lora_b_kernel(
|
||||
# Pointers to matrices
|
||||
x,
|
||||
weights,
|
||||
output,
|
||||
# Matrix dimensions
|
||||
N, # output_dim
|
||||
K, # r
|
||||
# Strides
|
||||
x_stride_0,
|
||||
x_stride_1,
|
||||
w_stride_0,
|
||||
w_stride_1,
|
||||
w_stride_2,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
# Information on sequence lengths and weight id
|
||||
seg_lens,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
# Meta parameters
|
||||
BLOCK_S: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
# For fused output scaling and adding
|
||||
fuse_scaling_add,
|
||||
scaling,
|
||||
):
|
||||
# x: (s, K), s is the sum of sequence lengths
|
||||
# weights: (num_lora, N, K)
|
||||
# output: (s, N)
|
||||
|
||||
# Current block computes sequence with batch_id,
|
||||
# which starts from row seg_start of x with length seg_len
|
||||
batch_id = tl.program_id(axis=1)
|
||||
pid = tl.program_id(axis=0)
|
||||
seg_len = tl.load(seg_lens + batch_id)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
|
||||
# The tile in output matrix will have (pid_s, pid_n) as id
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
pid_s = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
# Create pointers for the first block of x and weights[batch_id]
|
||||
# The pointers will be advanced as we move in the K direction
|
||||
# and accumulate
|
||||
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
|
||||
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
k_offset = tl.arange(0, BLOCK_K)
|
||||
x_ptrs = (x + seg_start * x_stride_0) + (
|
||||
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
||||
)
|
||||
w_ptrs = (weights + w_index * w_stride_0) + (
|
||||
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
||||
)
|
||||
|
||||
# Iteate to compute the block in output matrix
|
||||
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x_tile = tl.load(
|
||||
x_ptrs,
|
||||
mask=(s_offset[:, None] < seg_len)
|
||||
and (k_offset[None, :] < K - k * BLOCK_K),
|
||||
other=0.0,
|
||||
)
|
||||
w_tile = tl.load(
|
||||
w_ptrs,
|
||||
mask=(k_offset[:, None] < K - k * BLOCK_K),
|
||||
other=0.0,
|
||||
)
|
||||
partial_sum += tl.dot(x_tile, w_tile)
|
||||
|
||||
x_ptrs += BLOCK_K * x_stride_1
|
||||
w_ptrs += BLOCK_K * w_stride_2
|
||||
|
||||
# Store result to output matrix
|
||||
partial_sum *= scaling
|
||||
partial_sum = partial_sum.to(x.dtype.element_ty)
|
||||
output_ptr = (output + seg_start * output_stride_0) + (
|
||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||
)
|
||||
output_mask = s_offset[:, None] < seg_len
|
||||
if fuse_scaling_add:
|
||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||
|
||||
|
||||
def sgemm_lora_b_fwd(
|
||||
x: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
batch_info: LoraBatchInfo,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
# x: (s, r)
|
||||
# weights: (num_lora, output_dim, r)
|
||||
# output: (s, output_dim)
|
||||
# output_dim is much larger than r
|
||||
|
||||
assert x.is_contiguous()
|
||||
assert weights.is_contiguous()
|
||||
assert len(x.shape) == 2
|
||||
assert len(weights.shape) == 3
|
||||
|
||||
S = x.shape[0]
|
||||
N = weights.shape[-2]
|
||||
R = weights.shape[-1]
|
||||
assert x.shape[-1] == R
|
||||
|
||||
# Block shapes
|
||||
BLOCK_S = 16
|
||||
BLOCK_R = 16
|
||||
BLOCK_N = 256
|
||||
|
||||
grid = (
|
||||
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N),
|
||||
batch_info.bs,
|
||||
)
|
||||
|
||||
if base_output is None:
|
||||
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
|
||||
fuse_scaling_add = False
|
||||
else:
|
||||
output = base_output
|
||||
fuse_scaling_add = True
|
||||
|
||||
_sgemm_lora_b_kernel[grid](
|
||||
x,
|
||||
weights,
|
||||
output,
|
||||
N,
|
||||
R,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
weights.stride(0),
|
||||
weights.stride(1),
|
||||
weights.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
batch_info.seg_lens,
|
||||
batch_info.seg_indptr,
|
||||
batch_info.weight_indices,
|
||||
BLOCK_S,
|
||||
BLOCK_N,
|
||||
BLOCK_R,
|
||||
fuse_scaling_add,
|
||||
scaling,
|
||||
)
|
||||
return output
|
||||
@@ -530,6 +530,7 @@ class ModelRunner:
|
||||
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
||||
load_config=self.load_config,
|
||||
dtype=self.dtype,
|
||||
lora_backend=self.server_args.lora_backend,
|
||||
)
|
||||
logger.info("LoRA manager ready.")
|
||||
|
||||
|
||||
@@ -113,6 +113,7 @@ class ServerArgs:
|
||||
# LoRA
|
||||
lora_paths: Optional[List[str]] = None
|
||||
max_loras_per_batch: int = 8
|
||||
lora_backend: str = "triton"
|
||||
|
||||
# Kernel backend
|
||||
attention_backend: Optional[str] = None
|
||||
@@ -653,13 +654,19 @@ class ServerArgs:
|
||||
nargs="*",
|
||||
default=None,
|
||||
action=LoRAPathAction,
|
||||
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
|
||||
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-loras-per-batch",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Maximum number of adapters for a running batch, include base-only request",
|
||||
help="Maximum number of adapters for a running batch, include base-only request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-backend",
|
||||
type=str,
|
||||
default="triton",
|
||||
help="Choose the kernel backend for multi-LoRA serving.",
|
||||
)
|
||||
|
||||
# Kernel backend
|
||||
|
||||
@@ -272,6 +272,7 @@ class SRTRunner:
|
||||
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
lora_paths: List[str] = None,
|
||||
max_loras_per_batch: int = 4,
|
||||
lora_backend: str = "triton",
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = False,
|
||||
):
|
||||
@@ -287,6 +288,7 @@ class SRTRunner:
|
||||
is_embedding=not self.is_generation,
|
||||
lora_paths=lora_paths,
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
lora_backend=lora_backend,
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user