Improve LoRA Perf by Deprecating FlashInfer and Eliminating Redundant Tensor Ops (#8940)
This commit is contained in:
@@ -5,22 +5,6 @@ import torch
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
|
||||
|
||||
def get_fuse_output_add_from_name(name: str) -> bool:
|
||||
mapping = {
|
||||
"triton": True,
|
||||
"flashinfer": False,
|
||||
}
|
||||
return mapping.get(name, False)
|
||||
|
||||
|
||||
def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
|
||||
mapping = {
|
||||
"triton": True,
|
||||
"flashinfer": False,
|
||||
}
|
||||
return mapping.get(name, False)
|
||||
|
||||
|
||||
class BaseLoRABackend:
|
||||
"""Base class for different Lora backends.
|
||||
Each backend has its own implementation of Lora kernels.
|
||||
@@ -28,15 +12,11 @@ class BaseLoRABackend:
|
||||
Args:
|
||||
name: name of backend
|
||||
batch_info: information of current batch for use
|
||||
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
||||
and the operation of adding will be fused into kernel
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
||||
self.name = name
|
||||
self.batch_info = batch_info
|
||||
self.fuse_output_add = get_fuse_output_add_from_name(name)
|
||||
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
|
||||
|
||||
def run_lora_a_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
@@ -126,8 +106,8 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
|
||||
|
||||
return TritonLoRABackend
|
||||
elif name == "flashinfer":
|
||||
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
|
||||
|
||||
return FlashInferLoRABackend
|
||||
raise ValueError(
|
||||
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {name}")
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import SegmentGEMMWrapper
|
||||
|
||||
|
||||
class FlashInferLoRABackend(BaseLoRABackend):
|
||||
|
||||
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
||||
super().__init__(name, batch_info)
|
||||
|
||||
# Set up SGemm Wrapper from flashinfer
|
||||
# FIXME wait for flashinfer segment gemm update
|
||||
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
|
||||
|
||||
def run_lora_a_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=weights,
|
||||
batch_size=self.batch_info.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.batch_info.seg_indptr,
|
||||
weight_indices=self.batch_info.weight_indices,
|
||||
)
|
||||
|
||||
def run_lora_b_sgemm(
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
|
||||
return (
|
||||
self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=weights,
|
||||
batch_size=self.batch_info.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.batch_info.seg_indptr,
|
||||
weight_indices=self.batch_info.weight_indices,
|
||||
)
|
||||
* self.batch_info.scalings[0]
|
||||
)
|
||||
|
||||
def run_qkv_lora(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
qkv_lora_a: torch.Tensor,
|
||||
qkv_lora_b: Tuple[torch.Tensor],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
|
||||
|
||||
# Shape of lora_a_output: (s, 3 * r)
|
||||
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
|
||||
|
||||
q_lora_b, kv_lora_b = qkv_lora_b
|
||||
lora_rank = kv_lora_b.shape[-1]
|
||||
output_dim_q = q_lora_b.shape[-2]
|
||||
output_dim_kv = kv_lora_b.shape[-2]
|
||||
lora_output = torch.empty(
|
||||
(x.shape[0], output_dim_q + 2 * output_dim_kv),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
# q
|
||||
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
|
||||
)
|
||||
|
||||
# kv
|
||||
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
|
||||
self.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
|
||||
weights=kv_lora_b[0],
|
||||
)
|
||||
)
|
||||
|
||||
lora_output[
|
||||
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
|
||||
] = self.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
|
||||
weights=kv_lora_b[1],
|
||||
)
|
||||
|
||||
return lora_output * self.batch_info.scalings[0]
|
||||
|
||||
def run_gate_up_lora(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate_up_lora_a: torch.Tensor,
|
||||
gate_up_lora_b: Tuple[torch.Tensor],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
|
||||
lora_rank = gate_up_lora_b[0].shape[-1]
|
||||
output_dim = gate_up_lora_b[0].shape[-2]
|
||||
|
||||
# Shape of lora_a_output: (s, 2 * r)
|
||||
lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
|
||||
|
||||
lora_output = torch.empty(
|
||||
(x.shape[0], 2 * output_dim),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
# Compute lora for gate and up proj respectively
|
||||
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, :lora_rank].contiguous(),
|
||||
weights=gate_up_lora_b[0],
|
||||
)
|
||||
|
||||
lora_output[:, output_dim:] = self.run_lora_b_sgemm(
|
||||
x=lora_a_output[:, lora_rank:].contiguous(),
|
||||
weights=gate_up_lora_b[1],
|
||||
)
|
||||
|
||||
return lora_output * self.batch_info.scalings[0]
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -79,18 +77,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.B_buffer = B_buffer
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
backend_kwargs = {"base_output": base_output}
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||
lora_a_output,
|
||||
self.B_buffer[0],
|
||||
**backend_kwargs,
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_add
|
||||
else base_output + lora_output
|
||||
x=lora_a_output,
|
||||
weights=self.B_buffer,
|
||||
base_output=base_output,
|
||||
)
|
||||
return lora_output
|
||||
|
||||
def forward(self, input_: torch.Tensor):
|
||||
# duplicate the logic in ColumnParallelLinear
|
||||
@@ -135,37 +128,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer_gate_up = A_buffer
|
||||
if self.lora_backend.fuse_stacked_lora_b:
|
||||
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
||||
if getattr(self, "B_buffer_gate_up", None) is None:
|
||||
self.B_buffer_gate_up = torch.empty(
|
||||
(
|
||||
B_buffer[0].shape[0],
|
||||
2 * B_buffer[0].shape[1],
|
||||
B_buffer[0].shape[2],
|
||||
),
|
||||
dtype=B_buffer[0].dtype,
|
||||
device=B_buffer[0].device,
|
||||
)
|
||||
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
|
||||
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
|
||||
else:
|
||||
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
||||
self.B_buffer_gate_up = B_buffer
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
backend_kwargs = {"base_output": base_output}
|
||||
|
||||
lora_output = self.lora_backend.run_gate_up_lora(
|
||||
x,
|
||||
self.A_buffer_gate_up,
|
||||
self.B_buffer_gate_up,
|
||||
**backend_kwargs,
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_add
|
||||
else base_output + lora_output
|
||||
x=x,
|
||||
gate_up_lora_a=self.A_buffer_gate_up,
|
||||
gate_up_lora_b=self.B_buffer_gate_up,
|
||||
base_output=base_output,
|
||||
)
|
||||
return lora_output
|
||||
|
||||
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
||||
return A
|
||||
@@ -173,9 +145,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
||||
# Since the outputs for both gate and up are identical, we use a random one.
|
||||
shard_size = self.base_layer.output_partition_sizes[0]
|
||||
gate_size = self.base_layer.output_sizes[0]
|
||||
start_idx = tp_rank * shard_size
|
||||
end_idx = (tp_rank + 1) * shard_size
|
||||
return B[:, start_idx:end_idx, :]
|
||||
return torch.concat(
|
||||
(
|
||||
B[start_idx:end_idx, :],
|
||||
B[gate_size + start_idx : gate_size + end_idx],
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
@@ -185,86 +164,46 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_backend)
|
||||
q_proj_shard_size = self.base_layer.q_proj_shard_size
|
||||
kv_proj_shard_size = self.base_layer.kv_proj_shard_size
|
||||
self.output_offset = torch.tensor(
|
||||
[
|
||||
0,
|
||||
q_proj_shard_size,
|
||||
q_proj_shard_size + kv_proj_shard_size,
|
||||
q_proj_shard_size + 2 * kv_proj_shard_size,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=next(self.base_layer.parameters()).device,
|
||||
)
|
||||
|
||||
# For computing number of launched blocks
|
||||
self.max_qkv_out_dim = max(q_proj_shard_size, kv_proj_shard_size)
|
||||
|
||||
def set_lora_info(
|
||||
self,
|
||||
A_buffer_qkv: torch.Tensor,
|
||||
B_buffer_q: torch.Tensor,
|
||||
B_buffer_kv: torch.Tensor,
|
||||
B_buffer_qkv: torch.Tensor,
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer_qkv = A_buffer_qkv
|
||||
|
||||
if self.lora_backend.fuse_stacked_lora_b:
|
||||
assert (
|
||||
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
||||
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
||||
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
||||
|
||||
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||
if getattr(self, "B_buffer_qkv", None) is None:
|
||||
self.B_buffer_qkv = torch.empty(
|
||||
(
|
||||
B_buffer_q[0].shape[0],
|
||||
output_dim_q + 2 * output_dim_kv,
|
||||
B_buffer_q[0].shape[2],
|
||||
),
|
||||
dtype=B_buffer_q[0].dtype,
|
||||
device=B_buffer_q[0].device,
|
||||
)
|
||||
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
|
||||
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
|
||||
B_buffer_kv[0]
|
||||
)
|
||||
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
|
||||
B_buffer_kv[1]
|
||||
)
|
||||
|
||||
# Offsets of q/k/v in output dimension
|
||||
if getattr(self, "output_offset", None) is None:
|
||||
self.output_offset = torch.tensor(
|
||||
[
|
||||
0,
|
||||
output_dim_q,
|
||||
output_dim_q + output_dim_kv,
|
||||
output_dim_q + 2 * output_dim_kv,
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=B_buffer_q.device,
|
||||
)
|
||||
# For computing number of launched blocks
|
||||
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
||||
else:
|
||||
self.B_buffer_qkv = (
|
||||
B_buffer_q,
|
||||
B_buffer_kv,
|
||||
)
|
||||
self.B_buffer_qkv = B_buffer_qkv
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
backend_kwargs = {"base_output": base_output}
|
||||
if self.lora_backend.fuse_stacked_lora_b:
|
||||
backend_kwargs["output_offset"] = self.output_offset
|
||||
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
|
||||
|
||||
lora_output = self.lora_backend.run_qkv_lora(
|
||||
x,
|
||||
self.A_buffer_qkv,
|
||||
self.B_buffer_qkv,
|
||||
**backend_kwargs,
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_add
|
||||
else base_output + lora_output
|
||||
x=x,
|
||||
qkv_lora_a=self.A_buffer_qkv,
|
||||
qkv_lora_b=self.B_buffer_qkv,
|
||||
base_output=base_output,
|
||||
output_offset=self.output_offset,
|
||||
max_qkv_out_dim=self.max_qkv_out_dim,
|
||||
)
|
||||
return lora_output
|
||||
|
||||
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
||||
return A
|
||||
|
||||
def slice_lora_b_weights(
|
||||
self, B: List[torch.Tensor], tp_rank: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
B_q, B_kv = B
|
||||
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int) -> torch.Tensor:
|
||||
base_layer = self.base_layer
|
||||
q_proj_shard_size = base_layer.q_proj_shard_size
|
||||
kv_proj_shard_size = base_layer.kv_proj_shard_size
|
||||
@@ -277,7 +216,19 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
kv_start_idx = kv_proj_shard_size * kv_shard_id
|
||||
kv_end_idx = kv_start_idx + kv_proj_shard_size
|
||||
|
||||
return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
|
||||
q_size, k_size, _ = base_layer.output_sizes
|
||||
B_q_shard = B[q_start_idx:q_end_idx, :]
|
||||
B_k_shard = B[q_size + kv_start_idx : q_size + kv_end_idx, :]
|
||||
B_v_shard = B[q_size + k_size + kv_start_idx : q_size + k_size + kv_end_idx, :]
|
||||
|
||||
return torch.concat(
|
||||
(
|
||||
B_q_shard,
|
||||
B_k_shard,
|
||||
B_v_shard,
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
@@ -294,18 +245,13 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.B_buffer = B_buffer
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
backend_kwargs = {"base_output": base_output}
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||
lora_a_output,
|
||||
self.B_buffer[0],
|
||||
**backend_kwargs,
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_add
|
||||
else base_output + lora_output
|
||||
x=lora_a_output,
|
||||
weights=self.B_buffer,
|
||||
base_output=base_output,
|
||||
)
|
||||
return lora_output
|
||||
|
||||
def forward(self, input_: torch.Tensor):
|
||||
# duplicate the logic in RowParallelLinear
|
||||
|
||||
@@ -117,7 +117,6 @@ class LoRAAdapter(nn.Module):
|
||||
q_name = weight_name
|
||||
k_name = weight_name.replace("q_proj", "k_proj")
|
||||
v_name = weight_name.replace("q_proj", "v_proj")
|
||||
kv_name = weight_name.replace("q_proj", "kv_proj")
|
||||
qkv_name = weight_name.replace("q_proj", "qkv_proj")
|
||||
|
||||
# If k_proj doesn't have lora, initialize it to zero
|
||||
@@ -126,57 +125,27 @@ class LoRAAdapter(nn.Module):
|
||||
if "k_proj" in target_module
|
||||
else torch.zeros_like(weights[v_name])
|
||||
)
|
||||
if "lora_A" in weight_name:
|
||||
weights[qkv_name] = torch.cat(
|
||||
(
|
||||
weights[q_name],
|
||||
k_proj_weight,
|
||||
weights[v_name],
|
||||
),
|
||||
0,
|
||||
)
|
||||
weights.pop(q_name)
|
||||
if "k_proj" in target_module:
|
||||
weights.pop(k_name)
|
||||
weights.pop(v_name)
|
||||
else:
|
||||
weights[kv_name] = torch.stack(
|
||||
[
|
||||
k_proj_weight,
|
||||
weights[v_name],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
if "k_proj" in target_module:
|
||||
weights.pop(k_name)
|
||||
weights.pop(v_name)
|
||||
weights[qkv_name] = torch.cat(
|
||||
(
|
||||
weights[q_name],
|
||||
k_proj_weight,
|
||||
weights[v_name],
|
||||
),
|
||||
0,
|
||||
)
|
||||
weights.pop(q_name)
|
||||
if "k_proj" in target_module:
|
||||
weights.pop(k_name)
|
||||
weights.pop(v_name)
|
||||
elif "qkv_proj" in weight_name:
|
||||
# If qkv_proj is already stacked, we normalize it following the SGL convention.
|
||||
qkv_name = weight_name
|
||||
q_name = weight_name.replace("qkv_proj", "q_proj")
|
||||
k_name = weight_name.replace("qkv_proj", "k_proj")
|
||||
v_name = weight_name.replace("qkv_proj", "v_proj")
|
||||
kv_name = weight_name.replace("qkv_proj", "kv_proj")
|
||||
if "lora_A" in weight_name:
|
||||
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
|
||||
else:
|
||||
head_size = (
|
||||
self.base_hf_config.hidden_size
|
||||
// self.base_hf_config.num_attention_heads
|
||||
)
|
||||
weights[q_name], k_proj_weight, v_proj_weight = torch.split(
|
||||
weights[qkv_name],
|
||||
[
|
||||
head_size * self.base_hf_config.num_attention_heads,
|
||||
head_size * self.base_hf_config.num_key_value_heads,
|
||||
head_size * self.base_hf_config.num_key_value_heads,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
weights[kv_name] = torch.stack(
|
||||
[k_proj_weight, v_proj_weight],
|
||||
dim=0,
|
||||
)
|
||||
# else: no-op as LoRA B weight is already stacked.
|
||||
|
||||
def normalize_gate_up_proj(
|
||||
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||
@@ -187,20 +156,14 @@ class LoRAAdapter(nn.Module):
|
||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||
if up_name not in weights:
|
||||
weights[up_name] = torch.zeros_like(weights[weight_name])
|
||||
# FIXME: Add gate-only support for flashinfer in future implementations
|
||||
assert self.lora_backend.name == "triton", (
|
||||
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
||||
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
||||
f"or consider implementing custom initialization logic for other backends."
|
||||
)
|
||||
if "lora_A" in weight_name:
|
||||
weights[gate_up_name] = torch.cat(
|
||||
(weights[weight_name], weights[up_name]), 0
|
||||
)
|
||||
else:
|
||||
weights[gate_up_name] = torch.stack(
|
||||
[weights[weight_name], weights[up_name]], dim=0
|
||||
)
|
||||
weights[gate_up_name] = torch.cat(
|
||||
(weights[weight_name], weights[up_name]), 0
|
||||
)
|
||||
weights.pop(weight_name)
|
||||
if up_name in weights:
|
||||
weights.pop(up_name)
|
||||
@@ -209,12 +172,4 @@ class LoRAAdapter(nn.Module):
|
||||
gate_up_name = weight_name
|
||||
if "lora_A" in weight_name:
|
||||
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
|
||||
else:
|
||||
output_dim = weights[gate_up_name].shape[0] // 2
|
||||
weights[gate_up_name] = torch.stack(
|
||||
[
|
||||
weights[gate_up_name][:output_dim, :],
|
||||
weights[gate_up_name][output_dim:, :],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
# else: no-op as LoRA B weight is already stacked.
|
||||
|
||||
@@ -31,7 +31,6 @@ from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
||||
from sglang.srt.lora.utils import (
|
||||
LoRABatchInfo,
|
||||
LoRAType,
|
||||
get_customized_names_from_hf_names,
|
||||
get_layer_id,
|
||||
get_normalized_lora_weight_names,
|
||||
get_weight_name,
|
||||
@@ -345,40 +344,19 @@ class LoRAManager:
|
||||
)
|
||||
self.lora_backend.set_batch_info(batch_info)
|
||||
|
||||
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
|
||||
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
|
||||
self.update_lora_info()
|
||||
|
||||
def update_lora_info(self):
|
||||
"""
|
||||
Update all LoRA modules to associate them with the latest memory buffer.
|
||||
"""
|
||||
for layer_id, layer_modules in enumerate(self.lora_modules):
|
||||
for module_name, module in layer_modules.items():
|
||||
if "qkv_proj" in module_name:
|
||||
module.set_lora_info(
|
||||
self.memory_pool.get_tensor(
|
||||
"qkv_proj", layer_id, LoRAType.LORA_A
|
||||
),
|
||||
self.memory_pool.get_tensor(
|
||||
"q_proj", layer_id, LoRAType.LORA_B
|
||||
),
|
||||
self.memory_pool.get_tensor(
|
||||
"kv_proj", layer_id, LoRAType.LORA_B
|
||||
),
|
||||
)
|
||||
else:
|
||||
weight_name = get_weight_name(
|
||||
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
module.set_lora_info(
|
||||
self.memory_pool.get_tensor(
|
||||
weight_name, layer_id, LoRAType.LORA_A
|
||||
),
|
||||
self.memory_pool.get_tensor(
|
||||
weight_name, layer_id, LoRAType.LORA_B
|
||||
),
|
||||
)
|
||||
weight_name = get_weight_name(
|
||||
module_name, self.memory_pool.lora_weight_names
|
||||
)
|
||||
module.set_lora_info(
|
||||
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
|
||||
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
|
||||
)
|
||||
|
||||
def init_state(
|
||||
self,
|
||||
@@ -405,6 +383,7 @@ class LoRAManager:
|
||||
self.init_lora_weight_names()
|
||||
self.init_lora_modules()
|
||||
self.init_memory_pool()
|
||||
self.update_lora_info()
|
||||
|
||||
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
||||
@@ -461,9 +440,9 @@ class LoRAManager:
|
||||
Add new LoRA weight names if needed based on the current `self.configs`.
|
||||
"""
|
||||
|
||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
|
||||
self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
|
||||
self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
|
||||
self.target_modules
|
||||
)
|
||||
|
||||
def load_lora_weights(self, lora_ref: LoRARef):
|
||||
"""
|
||||
@@ -479,15 +458,6 @@ class LoRAManager:
|
||||
lora_adapter.initialize_weights()
|
||||
self.loras[lora_ref.lora_id] = lora_adapter
|
||||
|
||||
# Additional checks for flashinfer backend
|
||||
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
||||
if self.lora_backend == "flashinfer":
|
||||
lora_dims = set(x.r for x in self.configs.values())
|
||||
scalings = set(x.scaling for x in self.loras.values())
|
||||
assert (
|
||||
len(lora_dims) == 1 and len(scalings) == 1
|
||||
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
||||
|
||||
def init_memory_pool(self):
|
||||
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
||||
self.memory_pool = LoRAMemoryPool(
|
||||
@@ -512,12 +482,6 @@ class LoRAManager:
|
||||
{} for _ in range(self.base_hf_config.num_hidden_layers)
|
||||
]
|
||||
|
||||
# Target module names of customized layers defined in python/sglang/srt/layers
|
||||
# e.g., {"qkv_proj", "o_proj"}
|
||||
customized_target_names = get_customized_names_from_hf_names(
|
||||
self.target_modules, self.base_model
|
||||
)
|
||||
|
||||
for module_name, module in self.base_model.named_modules():
|
||||
# TODO (lifuhuang): in the future, we should consider generalizing the
|
||||
# should_apply_lora function to support mapping by full module name instead
|
||||
@@ -530,7 +494,7 @@ class LoRAManager:
|
||||
continue
|
||||
|
||||
# The module should be converted if it is included in target_names
|
||||
if module_name.split(".")[-1] in customized_target_names:
|
||||
if module_name.split(".")[-1] in self.lora_weight_names:
|
||||
layer_id = get_layer_id(module_name)
|
||||
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
||||
module_name, module
|
||||
|
||||
@@ -52,7 +52,7 @@ class LoRAMemoryPool:
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
max_lora_rank: int,
|
||||
lora_weight_names: Tuple[Set[str], Set[str]],
|
||||
lora_weight_names: Set[str],
|
||||
base_model: torch.nn.Module,
|
||||
):
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
@@ -62,9 +62,7 @@ class LoRAMemoryPool:
|
||||
self.tp_size: int = tp_size
|
||||
self.tp_rank: int = tp_rank
|
||||
self.max_lora_rank: int = max_lora_rank
|
||||
|
||||
# lora weight names for LoRA A and B respectively.
|
||||
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
|
||||
self.lora_weight_names: Set[str] = lora_weight_names
|
||||
|
||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||
# A_buffer contains num_layer number of row-major tensors with shape
|
||||
@@ -97,12 +95,8 @@ class LoRAMemoryPool:
|
||||
"""
|
||||
if config.r > self.max_lora_rank:
|
||||
return False
|
||||
weights_a, weights_b = get_normalized_lora_weight_names(
|
||||
config.target_modules
|
||||
)
|
||||
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
|
||||
self.lora_weight_names[1]
|
||||
)
|
||||
weights = get_normalized_lora_weight_names(config.target_modules)
|
||||
return weights.issubset(self.lora_weight_names)
|
||||
|
||||
if isinstance(config, LoRAConfig):
|
||||
return _can_support(config)
|
||||
@@ -132,11 +126,9 @@ class LoRAMemoryPool:
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||
"""
|
||||
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
||||
c = get_stacked_multiply(module_name)
|
||||
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||
output_dim = divide(output_dim, self.tp_size)
|
||||
return (
|
||||
c,
|
||||
self.max_loras_per_batch,
|
||||
output_dim,
|
||||
max_lora_dim,
|
||||
@@ -165,13 +157,13 @@ class LoRAMemoryPool:
|
||||
|
||||
init_buffer(
|
||||
self.A_buffer,
|
||||
self.lora_weight_names[0],
|
||||
self.lora_weight_names,
|
||||
self.get_lora_A_shape,
|
||||
)
|
||||
|
||||
init_buffer(
|
||||
self.B_buffer,
|
||||
self.lora_weight_names[1],
|
||||
self.lora_weight_names,
|
||||
self.get_lora_B_shape,
|
||||
)
|
||||
|
||||
@@ -246,7 +238,7 @@ class LoRAMemoryPool:
|
||||
return
|
||||
|
||||
assert lora_adapter is not None
|
||||
lora_rank = lora_adapter.config.hf_config["r"]
|
||||
lora_rank = lora_adapter.config.r
|
||||
for layer_id in range(self.num_layer):
|
||||
layer_weights = lora_adapter.layers[layer_id].weights
|
||||
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
||||
@@ -256,73 +248,38 @@ class LoRAMemoryPool:
|
||||
weight_name: None for weight_name in self.B_buffer
|
||||
}
|
||||
for name, weights in layer_weights.items():
|
||||
lora_weight_name = get_weight_name(name, self.lora_weight_names)
|
||||
if "lora_A" in name:
|
||||
lora_weight_name = get_weight_name(
|
||||
name, self.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
temp_A_buffer[lora_weight_name] = weights
|
||||
else:
|
||||
lora_weight_name = get_weight_name(
|
||||
name, self.lora_weight_names, LoRAType.LORA_B
|
||||
)
|
||||
temp_B_buffer[lora_weight_name] = weights
|
||||
|
||||
if self.tp_size > 1:
|
||||
cur_layer_modules = lora_modules[layer_id]
|
||||
for module_name, module in cur_layer_modules.items():
|
||||
weight_name = get_weight_name(
|
||||
module_name, self.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
weight_name = get_weight_name(module_name, self.lora_weight_names)
|
||||
|
||||
if temp_A_buffer[weight_name] is None:
|
||||
# Skip weight slicing if the weight is not present in the adapter
|
||||
continue
|
||||
|
||||
if "qkv_proj" in module_name:
|
||||
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
||||
temp_A_buffer["qkv_proj"], self.tp_rank
|
||||
)
|
||||
temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
|
||||
module.slice_lora_b_weights(
|
||||
[temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
|
||||
self.tp_rank,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
|
||||
# Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
|
||||
# B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
|
||||
# FlashInfer LoRA backend.
|
||||
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
||||
temp_A_buffer[weight_name], self.tp_rank
|
||||
)
|
||||
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
||||
temp_B_buffer[weight_name], self.tp_rank
|
||||
)
|
||||
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
||||
temp_A_buffer[weight_name], self.tp_rank
|
||||
)
|
||||
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
||||
temp_B_buffer[weight_name], self.tp_rank
|
||||
)
|
||||
|
||||
for name, weights in temp_A_buffer.items():
|
||||
c = get_stacked_multiply(name)
|
||||
buffer_view = self.A_buffer[name][layer_id][buffer_id][
|
||||
: lora_rank * c, :
|
||||
]
|
||||
target_buffer = self.A_buffer[name][layer_id]
|
||||
buffer_view = target_buffer[buffer_id, : lora_rank * c, :]
|
||||
load_lora_weight_tensor(buffer_view, weights)
|
||||
|
||||
for name, weights in temp_B_buffer.items():
|
||||
c = get_stacked_multiply(name)
|
||||
if c > 1:
|
||||
for stacked_id in range(c):
|
||||
buffer_view = self.B_buffer[name][layer_id][stacked_id][
|
||||
buffer_id
|
||||
][:, :lora_rank]
|
||||
weight_slice = (
|
||||
weights[stacked_id] if weights is not None else None
|
||||
)
|
||||
load_lora_weight_tensor(buffer_view, weight_slice)
|
||||
else:
|
||||
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
|
||||
:, :lora_rank
|
||||
]
|
||||
load_lora_weight_tensor(buffer_view, weights)
|
||||
target_buffer = self.B_buffer[name][layer_id]
|
||||
buffer_view = target_buffer[buffer_id, :, :lora_rank]
|
||||
load_lora_weight_tensor(buffer_view, weights)
|
||||
|
||||
def get_tensor(
|
||||
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
||||
|
||||
@@ -119,7 +119,7 @@ def _qkv_lora_b_kernel(
|
||||
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||
)
|
||||
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
|
||||
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < n_size)
|
||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||
|
||||
|
||||
@@ -47,34 +47,6 @@ def get_layer_id(name: str) -> int:
|
||||
return int(match.group(1))
|
||||
|
||||
|
||||
def get_customized_names_from_hf_names(
|
||||
hf_module_names: Set[str], base_model: torch.nn.Module
|
||||
) -> Set[str]:
|
||||
"""
|
||||
This function takes in a set of huggingface style module names:
|
||||
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||
and outputs a set of module names of customized sglang layers:
|
||||
e.g., {"qkv_proj", "o_proj"}
|
||||
"""
|
||||
if hasattr(base_model, "get_module_name"):
|
||||
return {base_model.get_module_name(name) for name in hf_module_names}
|
||||
else:
|
||||
"""
|
||||
Fallback solution of mapping from config module name to module name in model class.
|
||||
Please check if it aligns with your base model.
|
||||
Please implement the function in the model class if it is not.
|
||||
You can reference this function in llama.py.
|
||||
"""
|
||||
params_mapping = {
|
||||
"q_proj": "qkv_proj",
|
||||
"k_proj": "qkv_proj",
|
||||
"v_proj": "qkv_proj",
|
||||
"gate_proj": "gate_up_proj",
|
||||
"up_proj": "gate_up_proj",
|
||||
}
|
||||
return {params_mapping.get(name, name) for name in hf_module_names}
|
||||
|
||||
|
||||
def get_hidden_dim(
|
||||
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
||||
) -> Tuple[int]:
|
||||
@@ -95,22 +67,9 @@ def get_hidden_dim(
|
||||
head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
# TODO: the special handling of qkv will be addressed in #8940.
|
||||
if module_name == "qkv_proj":
|
||||
return (
|
||||
config.hidden_size,
|
||||
None, # qkv_proj is only used in LoRA A
|
||||
)
|
||||
elif module_name == "kv_proj":
|
||||
return (
|
||||
None, # kv_proj is only used in LoRA B
|
||||
head_dim * config.num_key_value_heads,
|
||||
)
|
||||
elif module_name == "q_proj":
|
||||
return (
|
||||
None, # q_proj is only used in LoRA B
|
||||
head_dim * config.num_attention_heads,
|
||||
return config.hidden_size, head_dim * (
|
||||
config.num_attention_heads + config.num_key_value_heads * 2
|
||||
)
|
||||
elif module_name == "o_proj":
|
||||
return (
|
||||
@@ -118,7 +77,7 @@ def get_hidden_dim(
|
||||
config.hidden_size,
|
||||
)
|
||||
elif module_name == "gate_up_proj":
|
||||
return config.hidden_size, config.intermediate_size
|
||||
return config.hidden_size, config.intermediate_size * 2
|
||||
elif module_name == "down_proj":
|
||||
return config.intermediate_size, config.hidden_size
|
||||
else:
|
||||
@@ -127,26 +86,22 @@ def get_hidden_dim(
|
||||
|
||||
def get_normalized_lora_weight_names(
|
||||
target_modules: Iterable[str],
|
||||
) -> Tuple[set[str], set[str]]:
|
||||
) -> set[str]:
|
||||
"""
|
||||
Mapping a list of target module name to names of the normalized LoRA weights.
|
||||
Returned tuple contains (name for Lora A, name for Lora B)
|
||||
"""
|
||||
params_mapping = {
|
||||
"q_proj": (["qkv_proj"], ["q_proj"]),
|
||||
"k_proj": (["qkv_proj"], ["kv_proj"]),
|
||||
"v_proj": (["qkv_proj"], ["kv_proj"]),
|
||||
"gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
"up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
||||
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
"q_proj": "qkv_proj",
|
||||
"k_proj": "qkv_proj",
|
||||
"v_proj": "qkv_proj",
|
||||
"gate_proj": "gate_up_proj",
|
||||
"up_proj": "gate_up_proj",
|
||||
}
|
||||
|
||||
result = (set(), set())
|
||||
result = set()
|
||||
for name in target_modules:
|
||||
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
|
||||
result[0].update(lora_a)
|
||||
result[1].update(lora_b)
|
||||
weight_name = params_mapping.get(name, name)
|
||||
result.add(weight_name)
|
||||
return result
|
||||
|
||||
|
||||
@@ -156,23 +111,21 @@ def get_stacked_multiply(module_name: str) -> int:
|
||||
"""
|
||||
stacked_rank = {
|
||||
"qkv_proj": 3,
|
||||
"kv_proj": 2,
|
||||
"gate_up_proj": 2,
|
||||
}
|
||||
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
||||
|
||||
|
||||
def get_weight_name(
|
||||
target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
|
||||
target_name: str, lora_weight_names: Tuple[Set[str]]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
target_name is name of a given module,
|
||||
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
|
||||
Get the weight name in lora_weight_names that can match target_name.
|
||||
|
||||
If there is a weight name in lora_weight_names that can match target_name, return this name
|
||||
Else raise ValueError.
|
||||
"""
|
||||
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
||||
for weight_name in lora_weight_names[idx]:
|
||||
for weight_name in lora_weight_names:
|
||||
if weight_name in target_name:
|
||||
return weight_name
|
||||
raise ValueError(
|
||||
@@ -180,9 +133,4 @@ def get_weight_name(
|
||||
)
|
||||
|
||||
|
||||
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
|
||||
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
|
||||
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
|
||||
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
|
||||
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
|
||||
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
|
||||
|
||||
@@ -501,23 +501,16 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
||||
|
||||
def get_hidden_dim(self, module_name):
|
||||
# return input_dim, output_dim
|
||||
# TODO: the special handling of qkv will be addressed in #8940.
|
||||
if module_name == "qkv_proj":
|
||||
return (
|
||||
self.config.hidden_size,
|
||||
None, # qkv_proj is only used in LoRA A
|
||||
self.config.head_dim
|
||||
* (
|
||||
self.config.num_attention_heads
|
||||
+ self.config.num_key_value_heads * 2
|
||||
),
|
||||
)
|
||||
elif module_name == "kv_proj":
|
||||
return (
|
||||
None, # kv_proj is only used in LoRA B
|
||||
self.config.head_dim * self.config.num_key_value_heads,
|
||||
)
|
||||
elif module_name == "q_proj":
|
||||
return (
|
||||
None, # q_proj is only used in LoRA B
|
||||
self.config.head_dim * self.config.num_attention_heads,
|
||||
)
|
||||
elif module_name in ["o_proj"]:
|
||||
elif module_name == "o_proj":
|
||||
return (
|
||||
self.config.head_dim * self.config.num_attention_heads,
|
||||
self.config.hidden_size,
|
||||
@@ -527,7 +520,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
||||
"Currently SGLang requires uniform intermediate size for all layers. "
|
||||
"Please file an issue if you need support for non-uniform intermediate sizes."
|
||||
)
|
||||
return self.config.hidden_size, self.config.intermediate_size[0]
|
||||
return self.config.hidden_size, self.config.intermediate_size[0] * 2
|
||||
elif module_name == "down_proj":
|
||||
assert len(set(self.config.intermediate_size)) == 1, (
|
||||
"Currently SGLang requires uniform intermediate size for all layers. "
|
||||
|
||||
Reference in New Issue
Block a user