[Feature] add multi-rank support for Lora (#4492)
Co-authored-by: rudy152 <czh1137892874@gmail.com>
This commit is contained in:
@@ -965,7 +965,7 @@ async def benchmark(
|
||||
request_rate: float,
|
||||
max_concurrency: Optional[int],
|
||||
disable_tqdm: bool,
|
||||
lora_name: str,
|
||||
lora_names: List[str],
|
||||
extra_request_body: Dict[str, Any],
|
||||
profile: bool,
|
||||
pd_seperated: bool = False,
|
||||
@@ -988,6 +988,11 @@ async def benchmark(
|
||||
# Warmup
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
||||
if lora_names != None and len(lora_names) != 0:
|
||||
lora_name = lora_names[0]
|
||||
else:
|
||||
lora_name = None
|
||||
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
@@ -1028,6 +1033,12 @@ async def benchmark(
|
||||
tasks: List[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate):
|
||||
prompt, prompt_len, output_len = request
|
||||
if lora_names != None and len(lora_names) != 0:
|
||||
idx = random.randint(0, len(lora_names) - 1)
|
||||
lora_name = lora_names[idx]
|
||||
else:
|
||||
lora_name = None
|
||||
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
@@ -1347,7 +1358,7 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
request_rate=args.request_rate,
|
||||
max_concurrency=args.max_concurrency,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
lora_name=args.lora_name,
|
||||
lora_names=args.lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
pd_seperated=args.pd_seperated,
|
||||
@@ -1366,6 +1377,13 @@ def set_ulimit(target_soft_limit=65535):
|
||||
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||
|
||||
|
||||
class LoRAPathAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, [])
|
||||
for lora_name in values:
|
||||
getattr(namespace, self.dest).append(lora_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
||||
parser.add_argument(
|
||||
@@ -1509,8 +1527,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--lora-name",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="The name of LoRA adapter",
|
||||
action=LoRAPathAction,
|
||||
help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-suffix",
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
|
||||
|
||||
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
|
||||
def get_fuse_output_add_from_name(name: str) -> bool:
|
||||
mapping = {
|
||||
"triton": True,
|
||||
"flashinfer": False,
|
||||
@@ -28,14 +28,14 @@ class BaseLoRABackend:
|
||||
Args:
|
||||
name: name of backend
|
||||
batch_info: information of current batch for use
|
||||
fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
||||
and the operation of scaling and adding will be fused into kernel
|
||||
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
||||
and the operation of adding will be fused into kernel
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
||||
self.name = name
|
||||
self.batch_info = batch_info
|
||||
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
|
||||
self.fuse_output_add = get_fuse_output_add_from_name(name)
|
||||
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
|
||||
|
||||
def run_lora_a_sgemm(
|
||||
|
||||
@@ -37,13 +37,16 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
||||
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=weights,
|
||||
batch_size=self.batch_info.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.batch_info.seg_indptr,
|
||||
weight_indices=self.batch_info.weight_indices,
|
||||
return (
|
||||
self.segment_gemm.run(
|
||||
x=x,
|
||||
weights=weights,
|
||||
batch_size=self.batch_info.bs,
|
||||
weight_column_major=True,
|
||||
seg_indptr=self.batch_info.seg_indptr,
|
||||
weight_indices=self.batch_info.weight_indices,
|
||||
)
|
||||
* self.batch_info.scalings[0]
|
||||
)
|
||||
|
||||
def run_qkv_lora(
|
||||
@@ -90,7 +93,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
||||
weights=kv_lora_b[1],
|
||||
)
|
||||
|
||||
return lora_output
|
||||
return lora_output * self.batch_info.scalings[0]
|
||||
|
||||
def run_gate_up_lora(
|
||||
self,
|
||||
@@ -125,4 +128,4 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
||||
weights=gate_up_lora_b[1],
|
||||
)
|
||||
|
||||
return lora_output
|
||||
return lora_output * self.batch_info.scalings[0]
|
||||
|
||||
@@ -25,11 +25,10 @@ class TritonLoRABackend(BaseLoRABackend):
|
||||
x: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)
|
||||
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
|
||||
|
||||
def run_qkv_lora(
|
||||
self,
|
||||
@@ -39,7 +38,6 @@ class TritonLoRABackend(BaseLoRABackend):
|
||||
output_offset: torch.Tensor,
|
||||
max_qkv_out_dim: int,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
@@ -49,7 +47,7 @@ class TritonLoRABackend(BaseLoRABackend):
|
||||
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||
assert isinstance(qkv_lora_b, torch.Tensor)
|
||||
|
||||
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
|
||||
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info, stack_num=3)
|
||||
lora_output = qkv_lora_b_fwd(
|
||||
lora_a_output,
|
||||
qkv_lora_b,
|
||||
@@ -57,7 +55,6 @@ class TritonLoRABackend(BaseLoRABackend):
|
||||
output_offset,
|
||||
max_qkv_out_dim,
|
||||
base_output,
|
||||
scaling,
|
||||
)
|
||||
return lora_output
|
||||
|
||||
@@ -67,7 +64,6 @@ class TritonLoRABackend(BaseLoRABackend):
|
||||
gate_up_lora_a: torch.Tensor,
|
||||
gate_up_lora_b: torch.Tensor,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
@@ -79,13 +75,14 @@ class TritonLoRABackend(BaseLoRABackend):
|
||||
output_dim = gate_up_lora_b.shape[-2] // 2
|
||||
|
||||
# lora_a_output: (s, 2 * r)
|
||||
lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info)
|
||||
lora_a_output = sgemm_lora_a_fwd(
|
||||
x, gate_up_lora_a, self.batch_info, stack_num=2
|
||||
)
|
||||
lora_output = gate_up_lora_b_fwd(
|
||||
lora_a_output,
|
||||
gate_up_lora_b,
|
||||
self.batch_info,
|
||||
output_dim,
|
||||
base_output,
|
||||
scaling,
|
||||
)
|
||||
return lora_output
|
||||
|
||||
@@ -23,14 +23,10 @@ class BaseLayerWithLoRA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
lora_rank: int,
|
||||
scaling: float,
|
||||
lora_backend: BaseLoRABackend,
|
||||
):
|
||||
super().__init__()
|
||||
self.base_layer: nn.Module = base_layer
|
||||
self.lora_rank: int = lora_rank
|
||||
self.scaling: float = scaling
|
||||
self.set_lora: bool = False
|
||||
self.lora_backend: BaseLoRABackend = lora_backend
|
||||
|
||||
@@ -59,11 +55,9 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: VocabParallelEmbedding,
|
||||
lora_rank: int,
|
||||
scaling: float,
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
super().__init__(base_layer, lora_backend)
|
||||
self.weight = base_layer.weight
|
||||
|
||||
|
||||
@@ -71,11 +65,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: ColumnParallelLinear,
|
||||
lora_rank: int,
|
||||
scaling: float,
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
super().__init__(base_layer, lora_backend)
|
||||
|
||||
def set_lora_info(
|
||||
self,
|
||||
@@ -87,7 +79,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.B_buffer = B_buffer
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
|
||||
backend_kwargs = {"base_output": base_output}
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||
lora_a_output,
|
||||
@@ -96,8 +88,8 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
if self.lora_backend.fuse_output_add
|
||||
else base_output + lora_output
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor):
|
||||
@@ -132,11 +124,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: MergedColumnParallelLinear,
|
||||
lora_rank: int,
|
||||
scaling: float,
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
super().__init__(base_layer, lora_backend)
|
||||
|
||||
def set_lora_info(
|
||||
self,
|
||||
@@ -155,7 +145,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
|
||||
backend_kwargs = {"base_output": base_output}
|
||||
|
||||
lora_output = self.lora_backend.run_gate_up_lora(
|
||||
x,
|
||||
@@ -165,8 +155,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
if self.lora_backend.fuse_output_add
|
||||
else base_output + lora_output
|
||||
)
|
||||
|
||||
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
||||
@@ -184,11 +174,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def init__(
|
||||
self,
|
||||
base_layer: QKVParallelLinear,
|
||||
lora_rank: int,
|
||||
scaling: float,
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
super().__init__(base_layer, lora_backend)
|
||||
|
||||
def set_lora_info(
|
||||
self,
|
||||
@@ -230,7 +218,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
)
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
|
||||
backend_kwargs = {"base_output": base_output}
|
||||
if self.lora_backend.fuse_stacked_lora_b:
|
||||
backend_kwargs["output_offset"] = self.output_offset
|
||||
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
|
||||
@@ -243,8 +231,8 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
if self.lora_backend.fuse_output_add
|
||||
else base_output + lora_output
|
||||
)
|
||||
|
||||
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
||||
@@ -273,11 +261,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: RowParallelLinear,
|
||||
lora_rank: int,
|
||||
scaling: float,
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
||||
super().__init__(base_layer, lora_backend)
|
||||
|
||||
def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
|
||||
self.set_lora = True
|
||||
@@ -285,7 +271,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.B_buffer = B_buffer
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
|
||||
backend_kwargs = {"base_output": base_output}
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||
lora_a_output,
|
||||
@@ -294,8 +280,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
)
|
||||
return (
|
||||
lora_output
|
||||
if self.lora_backend.fuse_output_scaling_add
|
||||
else base_output + lora_output * self.scaling
|
||||
if self.lora_backend.fuse_output_add
|
||||
else base_output + lora_output
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor):
|
||||
@@ -344,7 +330,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
|
||||
def get_lora_layer(
|
||||
layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend
|
||||
layer: nn.Module, lora_backend: BaseLoRABackend
|
||||
) -> BaseLayerWithLoRA:
|
||||
supported_layer_types = {
|
||||
# the order matters
|
||||
@@ -356,6 +342,6 @@ def get_lora_layer(
|
||||
}
|
||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
||||
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
|
||||
ret = lora_layer_type(layer, lora_backend)
|
||||
return ret
|
||||
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
||||
|
||||
@@ -103,11 +103,14 @@ class LoRAManager:
|
||||
self.loras[name] = lora_adapter
|
||||
|
||||
# misc lora configs
|
||||
# FIXME remove the restrictions after implementing unified paging
|
||||
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
self.scaling: float = list(self.loras.values())[0].scaling
|
||||
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
|
||||
assert all(x.scaling == self.scaling for x in self.loras.values())
|
||||
|
||||
if self.lora_backend == "flashinfer":
|
||||
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
||||
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
scaling = list(self.loras.values())[0].scaling
|
||||
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
|
||||
assert all(x.scaling == scaling for x in self.loras.values())
|
||||
|
||||
# Convert original model layers to layers with LoRA
|
||||
self.convert_to_lora_layers()
|
||||
@@ -133,6 +136,10 @@ class LoRAManager:
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
||||
|
||||
# FIXME: Handle lora uid with None more safely
|
||||
if cur_uids == set([None]):
|
||||
return
|
||||
|
||||
# set up batch info shared by all lora moruldes
|
||||
bs = forward_batch.batch_size
|
||||
seg_lens = (
|
||||
@@ -144,8 +151,18 @@ class LoRAManager:
|
||||
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
||||
max_len = int(torch.max(seg_lens))
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
|
||||
lora_ranks = torch.empty(
|
||||
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
||||
)
|
||||
scalings = torch.empty(
|
||||
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
||||
)
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
||||
lora = self.loras[lora_path]
|
||||
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
||||
scalings[weight_indices[i]] = lora.scaling
|
||||
|
||||
batch_info = LoRABatchInfo(
|
||||
bs=bs,
|
||||
@@ -153,6 +170,8 @@ class LoRAManager:
|
||||
seg_indptr=seg_indptr,
|
||||
max_len=max_len,
|
||||
weight_indices=weight_indices,
|
||||
lora_ranks=lora_ranks,
|
||||
scalings=scalings,
|
||||
)
|
||||
self.lora_backend.set_batch_info(batch_info)
|
||||
|
||||
@@ -185,9 +204,7 @@ class LoRAManager:
|
||||
)
|
||||
|
||||
def set_lora_module(self, module_name, module):
|
||||
lora_module = get_lora_layer(
|
||||
module, self.max_lora_dim, self.scaling, self.lora_backend
|
||||
)
|
||||
lora_module = get_lora_layer(module, self.lora_backend)
|
||||
replace_submodule(self.base_model, module_name, lora_module)
|
||||
return lora_module
|
||||
|
||||
|
||||
@@ -167,6 +167,7 @@ class LoRAMemoryPool:
|
||||
return
|
||||
|
||||
assert lora_adapter is not None
|
||||
lora_rank = lora_adapter.config.hf_config["r"]
|
||||
for layer_id in range(self.num_layer):
|
||||
layer_weights = lora_adapter.layers[layer_id].weights
|
||||
temp_A_buffer: Dict[str, torch.Tensor] = {}
|
||||
@@ -208,17 +209,22 @@ class LoRAMemoryPool:
|
||||
)
|
||||
|
||||
for name, weights in temp_A_buffer.items():
|
||||
self.A_buffer[name][layer_id][buffer_id].copy_(weights)
|
||||
c = get_stacked_multiply(name)
|
||||
self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_(
|
||||
weights
|
||||
)
|
||||
|
||||
for name, weights in temp_B_buffer.items():
|
||||
c = get_stacked_multiply(name)
|
||||
if c > 1:
|
||||
for stacked_id in range(c):
|
||||
self.B_buffer[name][layer_id][stacked_id][buffer_id].copy_(
|
||||
weights[stacked_id]
|
||||
)
|
||||
self.B_buffer[name][layer_id][stacked_id][buffer_id][
|
||||
:, :lora_rank
|
||||
].copy_(weights[stacked_id])
|
||||
else:
|
||||
self.B_buffer[name][layer_id][0][buffer_id].copy_(weights)
|
||||
self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_(
|
||||
weights
|
||||
)
|
||||
|
||||
def get_tensor(
|
||||
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
||||
|
||||
@@ -22,17 +22,18 @@ def _gate_up_lora_b_kernel(
|
||||
w_stride_2,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
# Information on sequence lengths and weight id
|
||||
# Information on sequence lengths,ranks and weight id
|
||||
seg_lens,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
lora_ranks,
|
||||
# Meta parameters
|
||||
BLOCK_S: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
# For fused output scaling and adding
|
||||
fuse_scaling_add,
|
||||
scaling,
|
||||
scalings,
|
||||
):
|
||||
# This kernel packs 2 sgemms (gate/up) into a single kernel.
|
||||
|
||||
@@ -51,6 +52,11 @@ def _gate_up_lora_b_kernel(
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
n_start = gate_up_id * output_dim # offset on output dim
|
||||
rank = tl.load(lora_ranks + w_index)
|
||||
scaling = tl.load(scalings + w_index)
|
||||
|
||||
# Adjust K (rank) according to the specific LoRA adapter
|
||||
K = tl.minimum(K, rank)
|
||||
|
||||
# The tile in output matrix will have (pid_s, pid_n) as id
|
||||
num_pid_n = tl.cdiv(output_dim, BLOCK_N)
|
||||
@@ -109,7 +115,6 @@ def gate_up_lora_b_fwd(
|
||||
batch_info: LoRABatchInfo,
|
||||
output_dim: int,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# x: (s, 2 * r)
|
||||
@@ -160,11 +165,12 @@ def gate_up_lora_b_fwd(
|
||||
batch_info.seg_lens,
|
||||
batch_info.seg_indptr,
|
||||
batch_info.weight_indices,
|
||||
batch_info.lora_ranks,
|
||||
BLOCK_S,
|
||||
BLOCK_OUT,
|
||||
BLOCK_R,
|
||||
fuse_scaling_add,
|
||||
scaling,
|
||||
batch_info.scalings,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -26,6 +26,7 @@ def _qkv_lora_b_kernel(
|
||||
seg_lens,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
lora_ranks,
|
||||
# Offsets of q/k/v slice on output dimension
|
||||
n_offs,
|
||||
# Meta parameters
|
||||
@@ -34,7 +35,7 @@ def _qkv_lora_b_kernel(
|
||||
BLOCK_K: tl.constexpr,
|
||||
# For fused output scaling and adding
|
||||
fuse_scaling_add,
|
||||
scaling,
|
||||
scalings,
|
||||
):
|
||||
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
|
||||
|
||||
@@ -54,6 +55,10 @@ def _qkv_lora_b_kernel(
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
n_start = tl.load(n_offs + qkv_id)
|
||||
n_size = tl.load(n_offs + qkv_id + 1) - n_start
|
||||
rank = tl.load(lora_ranks + w_index)
|
||||
scaling = tl.load(scalings + w_index)
|
||||
# Adjust K (rank) according to the specific LoRA adapter
|
||||
K = tl.minimum(K, rank)
|
||||
|
||||
# The tile in output matrix will have (pid_s, pid_n) as id
|
||||
num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
|
||||
@@ -112,7 +117,6 @@ def qkv_lora_b_fwd(
|
||||
output_offset: torch.Tensor,
|
||||
max_qkv_out_dim: int,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# x: (s, 3 * r)
|
||||
@@ -171,12 +175,13 @@ def qkv_lora_b_fwd(
|
||||
batch_info.seg_lens,
|
||||
batch_info.seg_indptr,
|
||||
batch_info.weight_indices,
|
||||
batch_info.lora_ranks,
|
||||
output_offset,
|
||||
BLOCK_S,
|
||||
BLOCK_OUT,
|
||||
BLOCK_R,
|
||||
fuse_scaling_add,
|
||||
scaling,
|
||||
batch_info.scalings,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -12,8 +12,9 @@ def _sgemm_lora_a_kernel(
|
||||
weights,
|
||||
output,
|
||||
# Matrix dimensions
|
||||
N, # r
|
||||
N, # stack_num * r
|
||||
K, # input_dim
|
||||
stack_num,
|
||||
# Strides
|
||||
x_stride_0,
|
||||
x_stride_1,
|
||||
@@ -22,10 +23,11 @@ def _sgemm_lora_a_kernel(
|
||||
w_stride_2,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
# Information on sequence lengths and weight id
|
||||
# Information on sequence lengths,ranks and weight id
|
||||
seg_lens,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
lora_ranks,
|
||||
# Meta parameters
|
||||
BLOCK_S: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
@@ -43,6 +45,9 @@ def _sgemm_lora_a_kernel(
|
||||
seg_len = tl.load(seg_lens + batch_id)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
rank = tl.load(lora_ranks + w_index)
|
||||
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
|
||||
N = tl.minimum(N, rank * stack_num)
|
||||
|
||||
# The tile in output matrix will have (pid_s, pid_n) as id
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
@@ -91,11 +96,15 @@ def _sgemm_lora_a_kernel(
|
||||
|
||||
|
||||
def sgemm_lora_a_fwd(
|
||||
x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo
|
||||
x: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
batch_info: LoRABatchInfo,
|
||||
stack_num: int = 1,
|
||||
) -> torch.Tensor:
|
||||
# x: (s, input_dim)
|
||||
# weights: (num_lora, r, input_dim)
|
||||
# output: (s, r)
|
||||
# weights: (num_lora, stack_num * r, input_dim)
|
||||
# output: (s, stack_num * r)
|
||||
# stack_num: run_qkv_lora: 3, run_gate_up_lora: 2
|
||||
# when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
|
||||
# input_dim is much larger than r
|
||||
|
||||
@@ -126,6 +135,7 @@ def sgemm_lora_a_fwd(
|
||||
output,
|
||||
R,
|
||||
K,
|
||||
stack_num,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
weights.stride(0),
|
||||
@@ -136,6 +146,7 @@ def sgemm_lora_a_fwd(
|
||||
batch_info.seg_lens,
|
||||
batch_info.seg_indptr,
|
||||
batch_info.weight_indices,
|
||||
batch_info.lora_ranks,
|
||||
BLOCK_S,
|
||||
BLOCK_R,
|
||||
BLOCK_K,
|
||||
|
||||
@@ -26,13 +26,14 @@ def _sgemm_lora_b_kernel(
|
||||
seg_lens,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
lora_ranks,
|
||||
# Meta parameters
|
||||
BLOCK_S: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
# For fused output scaling and adding
|
||||
fuse_scaling_add,
|
||||
scaling,
|
||||
scalings,
|
||||
):
|
||||
# x: (s, K), s is the sum of sequence lengths
|
||||
# weights: (num_lora, N, K)
|
||||
@@ -45,6 +46,10 @@ def _sgemm_lora_b_kernel(
|
||||
seg_len = tl.load(seg_lens + batch_id)
|
||||
w_index = tl.load(weight_indices + batch_id)
|
||||
seg_start = tl.load(seg_indptr + batch_id)
|
||||
rank = tl.load(lora_ranks + w_index)
|
||||
scaling = tl.load(scalings + w_index)
|
||||
# Adjust K (rank) according to the specific LoRA adapter
|
||||
K = tl.minimum(K, rank)
|
||||
|
||||
# The tile in output matrix will have (pid_s, pid_n) as id
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
@@ -100,12 +105,11 @@ def sgemm_lora_b_fwd(
|
||||
weights: torch.Tensor,
|
||||
batch_info: LoRABatchInfo,
|
||||
base_output: torch.Tensor = None,
|
||||
scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
# x: (s, r)
|
||||
# weights: (num_lora, output_dim, r)
|
||||
# x: (s, max_r)
|
||||
# weights: (num_lora, output_dim, max_r)
|
||||
# output: (s, output_dim)
|
||||
# output_dim is much larger than r
|
||||
# output_dim is much larger than max_r
|
||||
|
||||
assert x.is_contiguous()
|
||||
assert weights.is_contiguous()
|
||||
@@ -150,10 +154,11 @@ def sgemm_lora_b_fwd(
|
||||
batch_info.seg_lens,
|
||||
batch_info.seg_indptr,
|
||||
batch_info.weight_indices,
|
||||
batch_info.lora_ranks,
|
||||
BLOCK_S,
|
||||
BLOCK_N,
|
||||
BLOCK_R,
|
||||
fuse_scaling_add,
|
||||
scaling,
|
||||
batch_info.scalings,
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -25,6 +25,12 @@ class LoRABatchInfo:
|
||||
# The index of lora adapter used by each sequence, in shape (bs,)
|
||||
weight_indices: torch.Tensor
|
||||
|
||||
# ranks of each lora adapter, in shape (lora_num,)
|
||||
lora_ranks: torch.Tensor
|
||||
|
||||
# scaling of each lora adapter, in shape (lora_num,)
|
||||
scalings: torch.Tensor
|
||||
|
||||
|
||||
class LoRAType(Enum):
|
||||
LORA_A = 0
|
||||
|
||||
Reference in New Issue
Block a user