Improve LoRA Perf by Deprecating FlashInfer and Eliminating Redundant Tensor Ops (#8940)
This commit is contained in:
@@ -35,7 +35,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
|
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
|
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we only support Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
|
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
@@ -5,22 +5,6 @@ import torch
|
|||||||
from sglang.srt.lora.utils import LoRABatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
|
||||||
|
|
||||||
def get_fuse_output_add_from_name(name: str) -> bool:
|
|
||||||
mapping = {
|
|
||||||
"triton": True,
|
|
||||||
"flashinfer": False,
|
|
||||||
}
|
|
||||||
return mapping.get(name, False)
|
|
||||||
|
|
||||||
|
|
||||||
def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
|
|
||||||
mapping = {
|
|
||||||
"triton": True,
|
|
||||||
"flashinfer": False,
|
|
||||||
}
|
|
||||||
return mapping.get(name, False)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLoRABackend:
|
class BaseLoRABackend:
|
||||||
"""Base class for different Lora backends.
|
"""Base class for different Lora backends.
|
||||||
Each backend has its own implementation of Lora kernels.
|
Each backend has its own implementation of Lora kernels.
|
||||||
@@ -28,15 +12,11 @@ class BaseLoRABackend:
|
|||||||
Args:
|
Args:
|
||||||
name: name of backend
|
name: name of backend
|
||||||
batch_info: information of current batch for use
|
batch_info: information of current batch for use
|
||||||
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
|
||||||
and the operation of adding will be fused into kernel
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.batch_info = batch_info
|
self.batch_info = batch_info
|
||||||
self.fuse_output_add = get_fuse_output_add_from_name(name)
|
|
||||||
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
|
|
||||||
|
|
||||||
def run_lora_a_sgemm(
|
def run_lora_a_sgemm(
|
||||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||||
@@ -126,8 +106,8 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
|
|||||||
|
|
||||||
return TritonLoRABackend
|
return TritonLoRABackend
|
||||||
elif name == "flashinfer":
|
elif name == "flashinfer":
|
||||||
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
|
raise ValueError(
|
||||||
|
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|
||||||
return FlashInferLoRABackend
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid backend: {name}")
|
raise ValueError(f"Invalid backend: {name}")
|
||||||
|
|||||||
@@ -1,131 +0,0 @@
|
|||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
|
||||||
from sglang.srt.lora.utils import LoRABatchInfo
|
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
|
||||||
|
|
||||||
if is_flashinfer_available():
|
|
||||||
from flashinfer import SegmentGEMMWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferLoRABackend(BaseLoRABackend):
|
|
||||||
|
|
||||||
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
|
||||||
super().__init__(name, batch_info)
|
|
||||||
|
|
||||||
# Set up SGemm Wrapper from flashinfer
|
|
||||||
# FIXME wait for flashinfer segment gemm update
|
|
||||||
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
|
||||||
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
|
|
||||||
|
|
||||||
def run_lora_a_sgemm(
|
|
||||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
return self.segment_gemm.run(
|
|
||||||
x=x,
|
|
||||||
weights=weights,
|
|
||||||
batch_size=self.batch_info.bs,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=self.batch_info.seg_indptr,
|
|
||||||
weight_indices=self.batch_info.weight_indices,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_lora_b_sgemm(
|
|
||||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
return (
|
|
||||||
self.segment_gemm.run(
|
|
||||||
x=x,
|
|
||||||
weights=weights,
|
|
||||||
batch_size=self.batch_info.bs,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=self.batch_info.seg_indptr,
|
|
||||||
weight_indices=self.batch_info.weight_indices,
|
|
||||||
)
|
|
||||||
* self.batch_info.scalings[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_qkv_lora(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
qkv_lora_a: torch.Tensor,
|
|
||||||
qkv_lora_b: Tuple[torch.Tensor],
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
|
|
||||||
|
|
||||||
# Shape of lora_a_output: (s, 3 * r)
|
|
||||||
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
|
|
||||||
|
|
||||||
q_lora_b, kv_lora_b = qkv_lora_b
|
|
||||||
lora_rank = kv_lora_b.shape[-1]
|
|
||||||
output_dim_q = q_lora_b.shape[-2]
|
|
||||||
output_dim_kv = kv_lora_b.shape[-2]
|
|
||||||
lora_output = torch.empty(
|
|
||||||
(x.shape[0], output_dim_q + 2 * output_dim_kv),
|
|
||||||
device=x.device,
|
|
||||||
dtype=x.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# q
|
|
||||||
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
|
|
||||||
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
# kv
|
|
||||||
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
|
|
||||||
self.run_lora_b_sgemm(
|
|
||||||
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
|
|
||||||
weights=kv_lora_b[0],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
lora_output[
|
|
||||||
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
|
|
||||||
] = self.run_lora_b_sgemm(
|
|
||||||
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
|
|
||||||
weights=kv_lora_b[1],
|
|
||||||
)
|
|
||||||
|
|
||||||
return lora_output * self.batch_info.scalings[0]
|
|
||||||
|
|
||||||
def run_gate_up_lora(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
gate_up_lora_a: torch.Tensor,
|
|
||||||
gate_up_lora_b: Tuple[torch.Tensor],
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
|
|
||||||
lora_rank = gate_up_lora_b[0].shape[-1]
|
|
||||||
output_dim = gate_up_lora_b[0].shape[-2]
|
|
||||||
|
|
||||||
# Shape of lora_a_output: (s, 2 * r)
|
|
||||||
lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
|
|
||||||
|
|
||||||
lora_output = torch.empty(
|
|
||||||
(x.shape[0], 2 * output_dim),
|
|
||||||
device=x.device,
|
|
||||||
dtype=x.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute lora for gate and up proj respectively
|
|
||||||
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
|
|
||||||
x=lora_a_output[:, :lora_rank].contiguous(),
|
|
||||||
weights=gate_up_lora_b[0],
|
|
||||||
)
|
|
||||||
|
|
||||||
lora_output[:, output_dim:] = self.run_lora_b_sgemm(
|
|
||||||
x=lora_a_output[:, lora_rank:].contiguous(),
|
|
||||||
weights=gate_up_lora_b[1],
|
|
||||||
)
|
|
||||||
|
|
||||||
return lora_output * self.batch_info.scalings[0]
|
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@@ -79,18 +77,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.B_buffer = B_buffer
|
self.B_buffer = B_buffer
|
||||||
|
|
||||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
backend_kwargs = {"base_output": base_output}
|
|
||||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||||
lora_a_output,
|
x=lora_a_output,
|
||||||
self.B_buffer[0],
|
weights=self.B_buffer,
|
||||||
**backend_kwargs,
|
base_output=base_output,
|
||||||
)
|
|
||||||
return (
|
|
||||||
lora_output
|
|
||||||
if self.lora_backend.fuse_output_add
|
|
||||||
else base_output + lora_output
|
|
||||||
)
|
)
|
||||||
|
return lora_output
|
||||||
|
|
||||||
def forward(self, input_: torch.Tensor):
|
def forward(self, input_: torch.Tensor):
|
||||||
# duplicate the logic in ColumnParallelLinear
|
# duplicate the logic in ColumnParallelLinear
|
||||||
@@ -135,37 +128,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
):
|
):
|
||||||
self.set_lora = True
|
self.set_lora = True
|
||||||
self.A_buffer_gate_up = A_buffer
|
self.A_buffer_gate_up = A_buffer
|
||||||
if self.lora_backend.fuse_stacked_lora_b:
|
self.B_buffer_gate_up = B_buffer
|
||||||
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
|
||||||
if getattr(self, "B_buffer_gate_up", None) is None:
|
|
||||||
self.B_buffer_gate_up = torch.empty(
|
|
||||||
(
|
|
||||||
B_buffer[0].shape[0],
|
|
||||||
2 * B_buffer[0].shape[1],
|
|
||||||
B_buffer[0].shape[2],
|
|
||||||
),
|
|
||||||
dtype=B_buffer[0].dtype,
|
|
||||||
device=B_buffer[0].device,
|
|
||||||
)
|
|
||||||
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
|
|
||||||
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
|
|
||||||
else:
|
|
||||||
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
|
||||||
|
|
||||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
backend_kwargs = {"base_output": base_output}
|
|
||||||
|
|
||||||
lora_output = self.lora_backend.run_gate_up_lora(
|
lora_output = self.lora_backend.run_gate_up_lora(
|
||||||
x,
|
x=x,
|
||||||
self.A_buffer_gate_up,
|
gate_up_lora_a=self.A_buffer_gate_up,
|
||||||
self.B_buffer_gate_up,
|
gate_up_lora_b=self.B_buffer_gate_up,
|
||||||
**backend_kwargs,
|
base_output=base_output,
|
||||||
)
|
|
||||||
return (
|
|
||||||
lora_output
|
|
||||||
if self.lora_backend.fuse_output_add
|
|
||||||
else base_output + lora_output
|
|
||||||
)
|
)
|
||||||
|
return lora_output
|
||||||
|
|
||||||
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
||||||
return A
|
return A
|
||||||
@@ -173,9 +145,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
||||||
# Since the outputs for both gate and up are identical, we use a random one.
|
# Since the outputs for both gate and up are identical, we use a random one.
|
||||||
shard_size = self.base_layer.output_partition_sizes[0]
|
shard_size = self.base_layer.output_partition_sizes[0]
|
||||||
|
gate_size = self.base_layer.output_sizes[0]
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
end_idx = (tp_rank + 1) * shard_size
|
end_idx = (tp_rank + 1) * shard_size
|
||||||
return B[:, start_idx:end_idx, :]
|
return torch.concat(
|
||||||
|
(
|
||||||
|
B[start_idx:end_idx, :],
|
||||||
|
B[gate_size + start_idx : gate_size + end_idx],
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||||
@@ -185,86 +164,46 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
lora_backend: BaseLoRABackend,
|
lora_backend: BaseLoRABackend,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(base_layer, lora_backend)
|
super().__init__(base_layer, lora_backend)
|
||||||
|
q_proj_shard_size = self.base_layer.q_proj_shard_size
|
||||||
|
kv_proj_shard_size = self.base_layer.kv_proj_shard_size
|
||||||
|
self.output_offset = torch.tensor(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
q_proj_shard_size,
|
||||||
|
q_proj_shard_size + kv_proj_shard_size,
|
||||||
|
q_proj_shard_size + 2 * kv_proj_shard_size,
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=next(self.base_layer.parameters()).device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For computing number of launched blocks
|
||||||
|
self.max_qkv_out_dim = max(q_proj_shard_size, kv_proj_shard_size)
|
||||||
|
|
||||||
def set_lora_info(
|
def set_lora_info(
|
||||||
self,
|
self,
|
||||||
A_buffer_qkv: torch.Tensor,
|
A_buffer_qkv: torch.Tensor,
|
||||||
B_buffer_q: torch.Tensor,
|
B_buffer_qkv: torch.Tensor,
|
||||||
B_buffer_kv: torch.Tensor,
|
|
||||||
):
|
):
|
||||||
self.set_lora = True
|
self.set_lora = True
|
||||||
self.A_buffer_qkv = A_buffer_qkv
|
self.A_buffer_qkv = A_buffer_qkv
|
||||||
|
self.B_buffer_qkv = B_buffer_qkv
|
||||||
if self.lora_backend.fuse_stacked_lora_b:
|
|
||||||
assert (
|
|
||||||
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
|
||||||
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
|
||||||
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
|
||||||
|
|
||||||
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
|
||||||
if getattr(self, "B_buffer_qkv", None) is None:
|
|
||||||
self.B_buffer_qkv = torch.empty(
|
|
||||||
(
|
|
||||||
B_buffer_q[0].shape[0],
|
|
||||||
output_dim_q + 2 * output_dim_kv,
|
|
||||||
B_buffer_q[0].shape[2],
|
|
||||||
),
|
|
||||||
dtype=B_buffer_q[0].dtype,
|
|
||||||
device=B_buffer_q[0].device,
|
|
||||||
)
|
|
||||||
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
|
|
||||||
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
|
|
||||||
B_buffer_kv[0]
|
|
||||||
)
|
|
||||||
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
|
|
||||||
B_buffer_kv[1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Offsets of q/k/v in output dimension
|
|
||||||
if getattr(self, "output_offset", None) is None:
|
|
||||||
self.output_offset = torch.tensor(
|
|
||||||
[
|
|
||||||
0,
|
|
||||||
output_dim_q,
|
|
||||||
output_dim_q + output_dim_kv,
|
|
||||||
output_dim_q + 2 * output_dim_kv,
|
|
||||||
],
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=B_buffer_q.device,
|
|
||||||
)
|
|
||||||
# For computing number of launched blocks
|
|
||||||
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
|
||||||
else:
|
|
||||||
self.B_buffer_qkv = (
|
|
||||||
B_buffer_q,
|
|
||||||
B_buffer_kv,
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
backend_kwargs = {"base_output": base_output}
|
|
||||||
if self.lora_backend.fuse_stacked_lora_b:
|
|
||||||
backend_kwargs["output_offset"] = self.output_offset
|
|
||||||
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
|
|
||||||
|
|
||||||
lora_output = self.lora_backend.run_qkv_lora(
|
lora_output = self.lora_backend.run_qkv_lora(
|
||||||
x,
|
x=x,
|
||||||
self.A_buffer_qkv,
|
qkv_lora_a=self.A_buffer_qkv,
|
||||||
self.B_buffer_qkv,
|
qkv_lora_b=self.B_buffer_qkv,
|
||||||
**backend_kwargs,
|
base_output=base_output,
|
||||||
)
|
output_offset=self.output_offset,
|
||||||
return (
|
max_qkv_out_dim=self.max_qkv_out_dim,
|
||||||
lora_output
|
|
||||||
if self.lora_backend.fuse_output_add
|
|
||||||
else base_output + lora_output
|
|
||||||
)
|
)
|
||||||
|
return lora_output
|
||||||
|
|
||||||
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
||||||
return A
|
return A
|
||||||
|
|
||||||
def slice_lora_b_weights(
|
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int) -> torch.Tensor:
|
||||||
self, B: List[torch.Tensor], tp_rank: int
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
B_q, B_kv = B
|
|
||||||
base_layer = self.base_layer
|
base_layer = self.base_layer
|
||||||
q_proj_shard_size = base_layer.q_proj_shard_size
|
q_proj_shard_size = base_layer.q_proj_shard_size
|
||||||
kv_proj_shard_size = base_layer.kv_proj_shard_size
|
kv_proj_shard_size = base_layer.kv_proj_shard_size
|
||||||
@@ -277,7 +216,19 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
kv_start_idx = kv_proj_shard_size * kv_shard_id
|
kv_start_idx = kv_proj_shard_size * kv_shard_id
|
||||||
kv_end_idx = kv_start_idx + kv_proj_shard_size
|
kv_end_idx = kv_start_idx + kv_proj_shard_size
|
||||||
|
|
||||||
return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
|
q_size, k_size, _ = base_layer.output_sizes
|
||||||
|
B_q_shard = B[q_start_idx:q_end_idx, :]
|
||||||
|
B_k_shard = B[q_size + kv_start_idx : q_size + kv_end_idx, :]
|
||||||
|
B_v_shard = B[q_size + k_size + kv_start_idx : q_size + k_size + kv_end_idx, :]
|
||||||
|
|
||||||
|
return torch.concat(
|
||||||
|
(
|
||||||
|
B_q_shard,
|
||||||
|
B_k_shard,
|
||||||
|
B_v_shard,
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||||
@@ -294,18 +245,13 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.B_buffer = B_buffer
|
self.B_buffer = B_buffer
|
||||||
|
|
||||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||||
backend_kwargs = {"base_output": base_output}
|
|
||||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||||
lora_a_output,
|
x=lora_a_output,
|
||||||
self.B_buffer[0],
|
weights=self.B_buffer,
|
||||||
**backend_kwargs,
|
base_output=base_output,
|
||||||
)
|
|
||||||
return (
|
|
||||||
lora_output
|
|
||||||
if self.lora_backend.fuse_output_add
|
|
||||||
else base_output + lora_output
|
|
||||||
)
|
)
|
||||||
|
return lora_output
|
||||||
|
|
||||||
def forward(self, input_: torch.Tensor):
|
def forward(self, input_: torch.Tensor):
|
||||||
# duplicate the logic in RowParallelLinear
|
# duplicate the logic in RowParallelLinear
|
||||||
|
|||||||
@@ -117,7 +117,6 @@ class LoRAAdapter(nn.Module):
|
|||||||
q_name = weight_name
|
q_name = weight_name
|
||||||
k_name = weight_name.replace("q_proj", "k_proj")
|
k_name = weight_name.replace("q_proj", "k_proj")
|
||||||
v_name = weight_name.replace("q_proj", "v_proj")
|
v_name = weight_name.replace("q_proj", "v_proj")
|
||||||
kv_name = weight_name.replace("q_proj", "kv_proj")
|
|
||||||
qkv_name = weight_name.replace("q_proj", "qkv_proj")
|
qkv_name = weight_name.replace("q_proj", "qkv_proj")
|
||||||
|
|
||||||
# If k_proj doesn't have lora, initialize it to zero
|
# If k_proj doesn't have lora, initialize it to zero
|
||||||
@@ -126,57 +125,27 @@ class LoRAAdapter(nn.Module):
|
|||||||
if "k_proj" in target_module
|
if "k_proj" in target_module
|
||||||
else torch.zeros_like(weights[v_name])
|
else torch.zeros_like(weights[v_name])
|
||||||
)
|
)
|
||||||
if "lora_A" in weight_name:
|
weights[qkv_name] = torch.cat(
|
||||||
weights[qkv_name] = torch.cat(
|
(
|
||||||
(
|
weights[q_name],
|
||||||
weights[q_name],
|
k_proj_weight,
|
||||||
k_proj_weight,
|
weights[v_name],
|
||||||
weights[v_name],
|
),
|
||||||
),
|
0,
|
||||||
0,
|
)
|
||||||
)
|
weights.pop(q_name)
|
||||||
weights.pop(q_name)
|
if "k_proj" in target_module:
|
||||||
if "k_proj" in target_module:
|
weights.pop(k_name)
|
||||||
weights.pop(k_name)
|
weights.pop(v_name)
|
||||||
weights.pop(v_name)
|
|
||||||
else:
|
|
||||||
weights[kv_name] = torch.stack(
|
|
||||||
[
|
|
||||||
k_proj_weight,
|
|
||||||
weights[v_name],
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
if "k_proj" in target_module:
|
|
||||||
weights.pop(k_name)
|
|
||||||
weights.pop(v_name)
|
|
||||||
elif "qkv_proj" in weight_name:
|
elif "qkv_proj" in weight_name:
|
||||||
# If qkv_proj is already stacked, we normalize it following the SGL convention.
|
# If qkv_proj is already stacked, we normalize it following the SGL convention.
|
||||||
qkv_name = weight_name
|
qkv_name = weight_name
|
||||||
q_name = weight_name.replace("qkv_proj", "q_proj")
|
q_name = weight_name.replace("qkv_proj", "q_proj")
|
||||||
k_name = weight_name.replace("qkv_proj", "k_proj")
|
k_name = weight_name.replace("qkv_proj", "k_proj")
|
||||||
v_name = weight_name.replace("qkv_proj", "v_proj")
|
v_name = weight_name.replace("qkv_proj", "v_proj")
|
||||||
kv_name = weight_name.replace("qkv_proj", "kv_proj")
|
|
||||||
if "lora_A" in weight_name:
|
if "lora_A" in weight_name:
|
||||||
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
|
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
|
||||||
else:
|
# else: no-op as LoRA B weight is already stacked.
|
||||||
head_size = (
|
|
||||||
self.base_hf_config.hidden_size
|
|
||||||
// self.base_hf_config.num_attention_heads
|
|
||||||
)
|
|
||||||
weights[q_name], k_proj_weight, v_proj_weight = torch.split(
|
|
||||||
weights[qkv_name],
|
|
||||||
[
|
|
||||||
head_size * self.base_hf_config.num_attention_heads,
|
|
||||||
head_size * self.base_hf_config.num_key_value_heads,
|
|
||||||
head_size * self.base_hf_config.num_key_value_heads,
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
weights[kv_name] = torch.stack(
|
|
||||||
[k_proj_weight, v_proj_weight],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def normalize_gate_up_proj(
|
def normalize_gate_up_proj(
|
||||||
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||||
@@ -187,20 +156,14 @@ class LoRAAdapter(nn.Module):
|
|||||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||||
if up_name not in weights:
|
if up_name not in weights:
|
||||||
weights[up_name] = torch.zeros_like(weights[weight_name])
|
weights[up_name] = torch.zeros_like(weights[weight_name])
|
||||||
# FIXME: Add gate-only support for flashinfer in future implementations
|
|
||||||
assert self.lora_backend.name == "triton", (
|
assert self.lora_backend.name == "triton", (
|
||||||
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
||||||
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
||||||
f"or consider implementing custom initialization logic for other backends."
|
f"or consider implementing custom initialization logic for other backends."
|
||||||
)
|
)
|
||||||
if "lora_A" in weight_name:
|
weights[gate_up_name] = torch.cat(
|
||||||
weights[gate_up_name] = torch.cat(
|
(weights[weight_name], weights[up_name]), 0
|
||||||
(weights[weight_name], weights[up_name]), 0
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
weights[gate_up_name] = torch.stack(
|
|
||||||
[weights[weight_name], weights[up_name]], dim=0
|
|
||||||
)
|
|
||||||
weights.pop(weight_name)
|
weights.pop(weight_name)
|
||||||
if up_name in weights:
|
if up_name in weights:
|
||||||
weights.pop(up_name)
|
weights.pop(up_name)
|
||||||
@@ -209,12 +172,4 @@ class LoRAAdapter(nn.Module):
|
|||||||
gate_up_name = weight_name
|
gate_up_name = weight_name
|
||||||
if "lora_A" in weight_name:
|
if "lora_A" in weight_name:
|
||||||
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
|
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
|
||||||
else:
|
# else: no-op as LoRA B weight is already stacked.
|
||||||
output_dim = weights[gate_up_name].shape[0] // 2
|
|
||||||
weights[gate_up_name] = torch.stack(
|
|
||||||
[
|
|
||||||
weights[gate_up_name][:output_dim, :],
|
|
||||||
weights[gate_up_name][output_dim:, :],
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
|||||||
from sglang.srt.lora.utils import (
|
from sglang.srt.lora.utils import (
|
||||||
LoRABatchInfo,
|
LoRABatchInfo,
|
||||||
LoRAType,
|
LoRAType,
|
||||||
get_customized_names_from_hf_names,
|
|
||||||
get_layer_id,
|
get_layer_id,
|
||||||
get_normalized_lora_weight_names,
|
get_normalized_lora_weight_names,
|
||||||
get_weight_name,
|
get_weight_name,
|
||||||
@@ -345,40 +344,19 @@ class LoRAManager:
|
|||||||
)
|
)
|
||||||
self.lora_backend.set_batch_info(batch_info)
|
self.lora_backend.set_batch_info(batch_info)
|
||||||
|
|
||||||
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
|
|
||||||
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
|
|
||||||
self.update_lora_info()
|
|
||||||
|
|
||||||
def update_lora_info(self):
|
def update_lora_info(self):
|
||||||
"""
|
"""
|
||||||
Update all LoRA modules to associate them with the latest memory buffer.
|
Update all LoRA modules to associate them with the latest memory buffer.
|
||||||
"""
|
"""
|
||||||
for layer_id, layer_modules in enumerate(self.lora_modules):
|
for layer_id, layer_modules in enumerate(self.lora_modules):
|
||||||
for module_name, module in layer_modules.items():
|
for module_name, module in layer_modules.items():
|
||||||
if "qkv_proj" in module_name:
|
weight_name = get_weight_name(
|
||||||
module.set_lora_info(
|
module_name, self.memory_pool.lora_weight_names
|
||||||
self.memory_pool.get_tensor(
|
)
|
||||||
"qkv_proj", layer_id, LoRAType.LORA_A
|
module.set_lora_info(
|
||||||
),
|
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
|
||||||
self.memory_pool.get_tensor(
|
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
|
||||||
"q_proj", layer_id, LoRAType.LORA_B
|
)
|
||||||
),
|
|
||||||
self.memory_pool.get_tensor(
|
|
||||||
"kv_proj", layer_id, LoRAType.LORA_B
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
weight_name = get_weight_name(
|
|
||||||
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
|
|
||||||
)
|
|
||||||
module.set_lora_info(
|
|
||||||
self.memory_pool.get_tensor(
|
|
||||||
weight_name, layer_id, LoRAType.LORA_A
|
|
||||||
),
|
|
||||||
self.memory_pool.get_tensor(
|
|
||||||
weight_name, layer_id, LoRAType.LORA_B
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_state(
|
def init_state(
|
||||||
self,
|
self,
|
||||||
@@ -405,6 +383,7 @@ class LoRAManager:
|
|||||||
self.init_lora_weight_names()
|
self.init_lora_weight_names()
|
||||||
self.init_lora_modules()
|
self.init_lora_modules()
|
||||||
self.init_memory_pool()
|
self.init_memory_pool()
|
||||||
|
self.update_lora_info()
|
||||||
|
|
||||||
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||||
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
||||||
@@ -461,9 +440,9 @@ class LoRAManager:
|
|||||||
Add new LoRA weight names if needed based on the current `self.configs`.
|
Add new LoRA weight names if needed based on the current `self.configs`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
|
||||||
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
|
self.target_modules
|
||||||
self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
|
)
|
||||||
|
|
||||||
def load_lora_weights(self, lora_ref: LoRARef):
|
def load_lora_weights(self, lora_ref: LoRARef):
|
||||||
"""
|
"""
|
||||||
@@ -479,15 +458,6 @@ class LoRAManager:
|
|||||||
lora_adapter.initialize_weights()
|
lora_adapter.initialize_weights()
|
||||||
self.loras[lora_ref.lora_id] = lora_adapter
|
self.loras[lora_ref.lora_id] = lora_adapter
|
||||||
|
|
||||||
# Additional checks for flashinfer backend
|
|
||||||
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
|
||||||
if self.lora_backend == "flashinfer":
|
|
||||||
lora_dims = set(x.r for x in self.configs.values())
|
|
||||||
scalings = set(x.scaling for x in self.loras.values())
|
|
||||||
assert (
|
|
||||||
len(lora_dims) == 1 and len(scalings) == 1
|
|
||||||
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
|
||||||
|
|
||||||
def init_memory_pool(self):
|
def init_memory_pool(self):
|
||||||
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
||||||
self.memory_pool = LoRAMemoryPool(
|
self.memory_pool = LoRAMemoryPool(
|
||||||
@@ -512,12 +482,6 @@ class LoRAManager:
|
|||||||
{} for _ in range(self.base_hf_config.num_hidden_layers)
|
{} for _ in range(self.base_hf_config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Target module names of customized layers defined in python/sglang/srt/layers
|
|
||||||
# e.g., {"qkv_proj", "o_proj"}
|
|
||||||
customized_target_names = get_customized_names_from_hf_names(
|
|
||||||
self.target_modules, self.base_model
|
|
||||||
)
|
|
||||||
|
|
||||||
for module_name, module in self.base_model.named_modules():
|
for module_name, module in self.base_model.named_modules():
|
||||||
# TODO (lifuhuang): in the future, we should consider generalizing the
|
# TODO (lifuhuang): in the future, we should consider generalizing the
|
||||||
# should_apply_lora function to support mapping by full module name instead
|
# should_apply_lora function to support mapping by full module name instead
|
||||||
@@ -530,7 +494,7 @@ class LoRAManager:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# The module should be converted if it is included in target_names
|
# The module should be converted if it is included in target_names
|
||||||
if module_name.split(".")[-1] in customized_target_names:
|
if module_name.split(".")[-1] in self.lora_weight_names:
|
||||||
layer_id = get_layer_id(module_name)
|
layer_id = get_layer_id(module_name)
|
||||||
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
||||||
module_name, module
|
module_name, module
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ class LoRAMemoryPool:
|
|||||||
tp_size: int,
|
tp_size: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
max_lora_rank: int,
|
max_lora_rank: int,
|
||||||
lora_weight_names: Tuple[Set[str], Set[str]],
|
lora_weight_names: Set[str],
|
||||||
base_model: torch.nn.Module,
|
base_model: torch.nn.Module,
|
||||||
):
|
):
|
||||||
self.base_hf_config: AutoConfig = base_hf_config
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
@@ -62,9 +62,7 @@ class LoRAMemoryPool:
|
|||||||
self.tp_size: int = tp_size
|
self.tp_size: int = tp_size
|
||||||
self.tp_rank: int = tp_rank
|
self.tp_rank: int = tp_rank
|
||||||
self.max_lora_rank: int = max_lora_rank
|
self.max_lora_rank: int = max_lora_rank
|
||||||
|
self.lora_weight_names: Set[str] = lora_weight_names
|
||||||
# lora weight names for LoRA A and B respectively.
|
|
||||||
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
|
|
||||||
|
|
||||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||||
# A_buffer contains num_layer number of row-major tensors with shape
|
# A_buffer contains num_layer number of row-major tensors with shape
|
||||||
@@ -97,12 +95,8 @@ class LoRAMemoryPool:
|
|||||||
"""
|
"""
|
||||||
if config.r > self.max_lora_rank:
|
if config.r > self.max_lora_rank:
|
||||||
return False
|
return False
|
||||||
weights_a, weights_b = get_normalized_lora_weight_names(
|
weights = get_normalized_lora_weight_names(config.target_modules)
|
||||||
config.target_modules
|
return weights.issubset(self.lora_weight_names)
|
||||||
)
|
|
||||||
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
|
|
||||||
self.lora_weight_names[1]
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(config, LoRAConfig):
|
if isinstance(config, LoRAConfig):
|
||||||
return _can_support(config)
|
return _can_support(config)
|
||||||
@@ -132,11 +126,9 @@ class LoRAMemoryPool:
|
|||||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||||
"""
|
"""
|
||||||
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
||||||
c = get_stacked_multiply(module_name)
|
|
||||||
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||||
output_dim = divide(output_dim, self.tp_size)
|
output_dim = divide(output_dim, self.tp_size)
|
||||||
return (
|
return (
|
||||||
c,
|
|
||||||
self.max_loras_per_batch,
|
self.max_loras_per_batch,
|
||||||
output_dim,
|
output_dim,
|
||||||
max_lora_dim,
|
max_lora_dim,
|
||||||
@@ -165,13 +157,13 @@ class LoRAMemoryPool:
|
|||||||
|
|
||||||
init_buffer(
|
init_buffer(
|
||||||
self.A_buffer,
|
self.A_buffer,
|
||||||
self.lora_weight_names[0],
|
self.lora_weight_names,
|
||||||
self.get_lora_A_shape,
|
self.get_lora_A_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
init_buffer(
|
init_buffer(
|
||||||
self.B_buffer,
|
self.B_buffer,
|
||||||
self.lora_weight_names[1],
|
self.lora_weight_names,
|
||||||
self.get_lora_B_shape,
|
self.get_lora_B_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -246,7 +238,7 @@ class LoRAMemoryPool:
|
|||||||
return
|
return
|
||||||
|
|
||||||
assert lora_adapter is not None
|
assert lora_adapter is not None
|
||||||
lora_rank = lora_adapter.config.hf_config["r"]
|
lora_rank = lora_adapter.config.r
|
||||||
for layer_id in range(self.num_layer):
|
for layer_id in range(self.num_layer):
|
||||||
layer_weights = lora_adapter.layers[layer_id].weights
|
layer_weights = lora_adapter.layers[layer_id].weights
|
||||||
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
||||||
@@ -256,73 +248,38 @@ class LoRAMemoryPool:
|
|||||||
weight_name: None for weight_name in self.B_buffer
|
weight_name: None for weight_name in self.B_buffer
|
||||||
}
|
}
|
||||||
for name, weights in layer_weights.items():
|
for name, weights in layer_weights.items():
|
||||||
|
lora_weight_name = get_weight_name(name, self.lora_weight_names)
|
||||||
if "lora_A" in name:
|
if "lora_A" in name:
|
||||||
lora_weight_name = get_weight_name(
|
|
||||||
name, self.lora_weight_names, LoRAType.LORA_A
|
|
||||||
)
|
|
||||||
temp_A_buffer[lora_weight_name] = weights
|
temp_A_buffer[lora_weight_name] = weights
|
||||||
else:
|
else:
|
||||||
lora_weight_name = get_weight_name(
|
|
||||||
name, self.lora_weight_names, LoRAType.LORA_B
|
|
||||||
)
|
|
||||||
temp_B_buffer[lora_weight_name] = weights
|
temp_B_buffer[lora_weight_name] = weights
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
cur_layer_modules = lora_modules[layer_id]
|
cur_layer_modules = lora_modules[layer_id]
|
||||||
for module_name, module in cur_layer_modules.items():
|
for module_name, module in cur_layer_modules.items():
|
||||||
weight_name = get_weight_name(
|
weight_name = get_weight_name(module_name, self.lora_weight_names)
|
||||||
module_name, self.lora_weight_names, LoRAType.LORA_A
|
|
||||||
)
|
|
||||||
|
|
||||||
if temp_A_buffer[weight_name] is None:
|
if temp_A_buffer[weight_name] is None:
|
||||||
# Skip weight slicing if the weight is not present in the adapter
|
# Skip weight slicing if the weight is not present in the adapter
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "qkv_proj" in module_name:
|
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
||||||
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
temp_A_buffer[weight_name], self.tp_rank
|
||||||
temp_A_buffer["qkv_proj"], self.tp_rank
|
)
|
||||||
)
|
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
||||||
temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
|
temp_B_buffer[weight_name], self.tp_rank
|
||||||
module.slice_lora_b_weights(
|
)
|
||||||
[temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
|
|
||||||
self.tp_rank,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
|
|
||||||
# Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
|
|
||||||
# B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
|
|
||||||
# FlashInfer LoRA backend.
|
|
||||||
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
|
||||||
temp_A_buffer[weight_name], self.tp_rank
|
|
||||||
)
|
|
||||||
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
|
||||||
temp_B_buffer[weight_name], self.tp_rank
|
|
||||||
)
|
|
||||||
|
|
||||||
for name, weights in temp_A_buffer.items():
|
for name, weights in temp_A_buffer.items():
|
||||||
c = get_stacked_multiply(name)
|
c = get_stacked_multiply(name)
|
||||||
buffer_view = self.A_buffer[name][layer_id][buffer_id][
|
target_buffer = self.A_buffer[name][layer_id]
|
||||||
: lora_rank * c, :
|
buffer_view = target_buffer[buffer_id, : lora_rank * c, :]
|
||||||
]
|
|
||||||
load_lora_weight_tensor(buffer_view, weights)
|
load_lora_weight_tensor(buffer_view, weights)
|
||||||
|
|
||||||
for name, weights in temp_B_buffer.items():
|
for name, weights in temp_B_buffer.items():
|
||||||
c = get_stacked_multiply(name)
|
target_buffer = self.B_buffer[name][layer_id]
|
||||||
if c > 1:
|
buffer_view = target_buffer[buffer_id, :, :lora_rank]
|
||||||
for stacked_id in range(c):
|
load_lora_weight_tensor(buffer_view, weights)
|
||||||
buffer_view = self.B_buffer[name][layer_id][stacked_id][
|
|
||||||
buffer_id
|
|
||||||
][:, :lora_rank]
|
|
||||||
weight_slice = (
|
|
||||||
weights[stacked_id] if weights is not None else None
|
|
||||||
)
|
|
||||||
load_lora_weight_tensor(buffer_view, weight_slice)
|
|
||||||
else:
|
|
||||||
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
|
|
||||||
:, :lora_rank
|
|
||||||
]
|
|
||||||
load_lora_weight_tensor(buffer_view, weights)
|
|
||||||
|
|
||||||
def get_tensor(
|
def get_tensor(
|
||||||
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ def _qkv_lora_b_kernel(
|
|||||||
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
||||||
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
||||||
)
|
)
|
||||||
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
|
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < n_size)
|
||||||
partial_sum += tl.load(output_ptr, mask=output_mask)
|
partial_sum += tl.load(output_ptr, mask=output_mask)
|
||||||
tl.store(output_ptr, partial_sum, mask=output_mask)
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||||
|
|
||||||
|
|||||||
@@ -47,34 +47,6 @@ def get_layer_id(name: str) -> int:
|
|||||||
return int(match.group(1))
|
return int(match.group(1))
|
||||||
|
|
||||||
|
|
||||||
def get_customized_names_from_hf_names(
|
|
||||||
hf_module_names: Set[str], base_model: torch.nn.Module
|
|
||||||
) -> Set[str]:
|
|
||||||
"""
|
|
||||||
This function takes in a set of huggingface style module names:
|
|
||||||
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
|
||||||
and outputs a set of module names of customized sglang layers:
|
|
||||||
e.g., {"qkv_proj", "o_proj"}
|
|
||||||
"""
|
|
||||||
if hasattr(base_model, "get_module_name"):
|
|
||||||
return {base_model.get_module_name(name) for name in hf_module_names}
|
|
||||||
else:
|
|
||||||
"""
|
|
||||||
Fallback solution of mapping from config module name to module name in model class.
|
|
||||||
Please check if it aligns with your base model.
|
|
||||||
Please implement the function in the model class if it is not.
|
|
||||||
You can reference this function in llama.py.
|
|
||||||
"""
|
|
||||||
params_mapping = {
|
|
||||||
"q_proj": "qkv_proj",
|
|
||||||
"k_proj": "qkv_proj",
|
|
||||||
"v_proj": "qkv_proj",
|
|
||||||
"gate_proj": "gate_up_proj",
|
|
||||||
"up_proj": "gate_up_proj",
|
|
||||||
}
|
|
||||||
return {params_mapping.get(name, name) for name in hf_module_names}
|
|
||||||
|
|
||||||
|
|
||||||
def get_hidden_dim(
|
def get_hidden_dim(
|
||||||
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
||||||
) -> Tuple[int]:
|
) -> Tuple[int]:
|
||||||
@@ -95,22 +67,9 @@ def get_hidden_dim(
|
|||||||
head_dim = getattr(
|
head_dim = getattr(
|
||||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: the special handling of qkv will be addressed in #8940.
|
|
||||||
if module_name == "qkv_proj":
|
if module_name == "qkv_proj":
|
||||||
return (
|
return config.hidden_size, head_dim * (
|
||||||
config.hidden_size,
|
config.num_attention_heads + config.num_key_value_heads * 2
|
||||||
None, # qkv_proj is only used in LoRA A
|
|
||||||
)
|
|
||||||
elif module_name == "kv_proj":
|
|
||||||
return (
|
|
||||||
None, # kv_proj is only used in LoRA B
|
|
||||||
head_dim * config.num_key_value_heads,
|
|
||||||
)
|
|
||||||
elif module_name == "q_proj":
|
|
||||||
return (
|
|
||||||
None, # q_proj is only used in LoRA B
|
|
||||||
head_dim * config.num_attention_heads,
|
|
||||||
)
|
)
|
||||||
elif module_name == "o_proj":
|
elif module_name == "o_proj":
|
||||||
return (
|
return (
|
||||||
@@ -118,7 +77,7 @@ def get_hidden_dim(
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
elif module_name == "gate_up_proj":
|
elif module_name == "gate_up_proj":
|
||||||
return config.hidden_size, config.intermediate_size
|
return config.hidden_size, config.intermediate_size * 2
|
||||||
elif module_name == "down_proj":
|
elif module_name == "down_proj":
|
||||||
return config.intermediate_size, config.hidden_size
|
return config.intermediate_size, config.hidden_size
|
||||||
else:
|
else:
|
||||||
@@ -127,26 +86,22 @@ def get_hidden_dim(
|
|||||||
|
|
||||||
def get_normalized_lora_weight_names(
|
def get_normalized_lora_weight_names(
|
||||||
target_modules: Iterable[str],
|
target_modules: Iterable[str],
|
||||||
) -> Tuple[set[str], set[str]]:
|
) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Mapping a list of target module name to names of the normalized LoRA weights.
|
Mapping a list of target module name to names of the normalized LoRA weights.
|
||||||
Returned tuple contains (name for Lora A, name for Lora B)
|
|
||||||
"""
|
"""
|
||||||
params_mapping = {
|
params_mapping = {
|
||||||
"q_proj": (["qkv_proj"], ["q_proj"]),
|
"q_proj": "qkv_proj",
|
||||||
"k_proj": (["qkv_proj"], ["kv_proj"]),
|
"k_proj": "qkv_proj",
|
||||||
"v_proj": (["qkv_proj"], ["kv_proj"]),
|
"v_proj": "qkv_proj",
|
||||||
"gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
"gate_proj": "gate_up_proj",
|
||||||
"up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
"up_proj": "gate_up_proj",
|
||||||
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
|
||||||
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = (set(), set())
|
result = set()
|
||||||
for name in target_modules:
|
for name in target_modules:
|
||||||
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
|
weight_name = params_mapping.get(name, name)
|
||||||
result[0].update(lora_a)
|
result.add(weight_name)
|
||||||
result[1].update(lora_b)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -156,23 +111,21 @@ def get_stacked_multiply(module_name: str) -> int:
|
|||||||
"""
|
"""
|
||||||
stacked_rank = {
|
stacked_rank = {
|
||||||
"qkv_proj": 3,
|
"qkv_proj": 3,
|
||||||
"kv_proj": 2,
|
|
||||||
"gate_up_proj": 2,
|
"gate_up_proj": 2,
|
||||||
}
|
}
|
||||||
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
||||||
|
|
||||||
|
|
||||||
def get_weight_name(
|
def get_weight_name(
|
||||||
target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
|
target_name: str, lora_weight_names: Tuple[Set[str]]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
target_name is name of a given module,
|
Get the weight name in lora_weight_names that can match target_name.
|
||||||
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
|
|
||||||
If there is a weight name in lora_weight_names that can match target_name, return this name
|
If there is a weight name in lora_weight_names that can match target_name, return this name
|
||||||
Else raise ValueError.
|
Else raise ValueError.
|
||||||
"""
|
"""
|
||||||
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
for weight_name in lora_weight_names:
|
||||||
for weight_name in lora_weight_names[idx]:
|
|
||||||
if weight_name in target_name:
|
if weight_name in target_name:
|
||||||
return weight_name
|
return weight_name
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -180,9 +133,4 @@ def get_weight_name(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
|
|
||||||
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
|
|
||||||
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
|
|
||||||
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
|
|
||||||
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
|
|
||||||
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
|
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
|
||||||
|
|||||||
@@ -501,23 +501,16 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
def get_hidden_dim(self, module_name):
|
||||||
# return input_dim, output_dim
|
# return input_dim, output_dim
|
||||||
# TODO: the special handling of qkv will be addressed in #8940.
|
|
||||||
if module_name == "qkv_proj":
|
if module_name == "qkv_proj":
|
||||||
return (
|
return (
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
None, # qkv_proj is only used in LoRA A
|
self.config.head_dim
|
||||||
|
* (
|
||||||
|
self.config.num_attention_heads
|
||||||
|
+ self.config.num_key_value_heads * 2
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif module_name == "kv_proj":
|
elif module_name == "o_proj":
|
||||||
return (
|
|
||||||
None, # kv_proj is only used in LoRA B
|
|
||||||
self.config.head_dim * self.config.num_key_value_heads,
|
|
||||||
)
|
|
||||||
elif module_name == "q_proj":
|
|
||||||
return (
|
|
||||||
None, # q_proj is only used in LoRA B
|
|
||||||
self.config.head_dim * self.config.num_attention_heads,
|
|
||||||
)
|
|
||||||
elif module_name in ["o_proj"]:
|
|
||||||
return (
|
return (
|
||||||
self.config.head_dim * self.config.num_attention_heads,
|
self.config.head_dim * self.config.num_attention_heads,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
@@ -527,7 +520,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|||||||
"Currently SGLang requires uniform intermediate size for all layers. "
|
"Currently SGLang requires uniform intermediate size for all layers. "
|
||||||
"Please file an issue if you need support for non-uniform intermediate sizes."
|
"Please file an issue if you need support for non-uniform intermediate sizes."
|
||||||
)
|
)
|
||||||
return self.config.hidden_size, self.config.intermediate_size[0]
|
return self.config.hidden_size, self.config.intermediate_size[0] * 2
|
||||||
elif module_name == "down_proj":
|
elif module_name == "down_proj":
|
||||||
assert len(set(self.config.intermediate_size)) == 1, (
|
assert len(set(self.config.intermediate_size)) == 1, (
|
||||||
"Currently SGLang requires uniform intermediate size for all layers. "
|
"Currently SGLang requires uniform intermediate size for all layers. "
|
||||||
|
|||||||
Reference in New Issue
Block a user