[2/2] Introduce Chunked-SGMV kernels and corresponding LoRA backend for improved performance (#10286)
This commit is contained in:
@@ -35,7 +35,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
|
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we only support Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
|
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we support Triton LoRA backend (`triton`) and Chunked SGMV backend (`csgmv`). In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
|
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -79,7 +79,7 @@
|
|||||||
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
||||||
" --enable-lora \\\n",
|
" --enable-lora \\\n",
|
||||||
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
|
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
|
||||||
" --max-loras-per-batch 1 --lora-backend triton \\\n",
|
" --max-loras-per-batch 1 \\\n",
|
||||||
" --log-level warning \\\n",
|
" --log-level warning \\\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
")\n",
|
")\n",
|
||||||
@@ -139,7 +139,7 @@
|
|||||||
" --enable-lora \\\n",
|
" --enable-lora \\\n",
|
||||||
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
|
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
|
||||||
" lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n",
|
" lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n",
|
||||||
" --max-loras-per-batch 2 --lora-backend triton \\\n",
|
" --max-loras-per-batch 2 \\\n",
|
||||||
" --log-level warning \\\n",
|
" --log-level warning \\\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
")\n",
|
")\n",
|
||||||
@@ -214,7 +214,7 @@
|
|||||||
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
||||||
" --enable-lora \\\n",
|
" --enable-lora \\\n",
|
||||||
" --cuda-graph-max-bs 2 \\\n",
|
" --cuda-graph-max-bs 2 \\\n",
|
||||||
" --max-loras-per-batch 2 --lora-backend triton \\\n",
|
" --max-loras-per-batch 2 \\\n",
|
||||||
" --max-lora-rank 256\n",
|
" --max-lora-rank 256\n",
|
||||||
" --lora-target-modules all\n",
|
" --lora-target-modules all\n",
|
||||||
" --log-level warning\n",
|
" --log-level warning\n",
|
||||||
@@ -413,7 +413,7 @@
|
|||||||
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
||||||
" --enable-lora \\\n",
|
" --enable-lora \\\n",
|
||||||
" --cuda-graph-max-bs 8 \\\n",
|
" --cuda-graph-max-bs 8 \\\n",
|
||||||
" --max-loras-per-batch 3 --lora-backend triton \\\n",
|
" --max-loras-per-batch 3 \\\n",
|
||||||
" --max-lora-rank 256 \\\n",
|
" --max-lora-rank 256 \\\n",
|
||||||
" --lora-target-modules all \\\n",
|
" --lora-target-modules all \\\n",
|
||||||
" --lora-paths \\\n",
|
" --lora-paths \\\n",
|
||||||
@@ -501,6 +501,48 @@
|
|||||||
"terminate_process(server_process)"
|
"terminate_process(server_process)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Choosing LoRA Backend\n",
|
||||||
|
"\n",
|
||||||
|
"SGLang supports two LoRA backends that you can choose from using the `--lora-backend` argument:\n",
|
||||||
|
"\n",
|
||||||
|
"- `triton`: Default basic Triton-based backend.\n",
|
||||||
|
"- `csgmv`: Chunked SGMV backend optimized for high concurrency scenarios.\n",
|
||||||
|
"\n",
|
||||||
|
"The `csgmv` backend was recently introduced to improve performance especially at high-concurrency scenarios. Our benchmark shows that it achieves 20% to 80% latency improvements over the basic triton backend.\n",
|
||||||
|
"Currently it is at preview phase, we expect to make it our the default LoRA backend in future release. Before that, you can adopt it by manually setting the `--lora-backend` server config."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"server_process, port = launch_server_cmd(\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" python3 -m sglang.launch_server \\\n",
|
||||||
|
" --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
||||||
|
" --enable-lora \\\n",
|
||||||
|
" --lora-backend csgmv \\\n",
|
||||||
|
" --max-loras-per-batch 16 \\\n",
|
||||||
|
" --lora-paths lora1=path/to/lora1 lora2=path/to/lora2\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"terminate_process(server_process)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
|||||||
@@ -143,10 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
|
|||||||
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
||||||
|
|
||||||
return TritonLoRABackend
|
return TritonLoRABackend
|
||||||
# elif name == "csgmv":
|
elif name == "csgmv":
|
||||||
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
||||||
|
|
||||||
# return ChunkedSgmvLoRABackend
|
return ChunkedSgmvLoRABackend
|
||||||
elif name == "flashinfer":
|
elif name == "flashinfer":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|
||||||
|
|||||||
306
python/sglang/srt/lora/backend/chunked_backend.py
Normal file
306
python/sglang/srt/lora/backend/chunked_backend.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||||
|
from sglang.srt.lora.triton_ops import (
|
||||||
|
chunked_sgmv_lora_expand_forward,
|
||||||
|
chunked_sgmv_lora_shrink_forward,
|
||||||
|
)
|
||||||
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkedSgmvLoRABackend(BaseLoRABackend):
|
||||||
|
"""
|
||||||
|
Chunked LoRA backend using segmented matrix-vector multiplication.
|
||||||
|
|
||||||
|
This backend is largely based on the SGMV (Segmented Gather Matrix-Vector multiplication) algorithm
|
||||||
|
introduced in the Punica paper (https://arxiv.org/pdf/2310.18547). One main variation made here is to
|
||||||
|
segment the input sequences into fixed-size chunks, which reduces excessive kernel launches especially
|
||||||
|
when the LoRA distribution is skewed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "csgmv"
|
||||||
|
|
||||||
|
def __init__(self, max_loras_per_batch: int, device: torch.device):
|
||||||
|
super().__init__(max_loras_per_batch, device)
|
||||||
|
self.segment_size = 16 # TODO (lifuhuang): make it configurable?
|
||||||
|
|
||||||
|
def run_lora_a_sgemm(
|
||||||
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return chunked_sgmv_lora_shrink_forward(
|
||||||
|
x,
|
||||||
|
weights,
|
||||||
|
self.batch_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_lora_b_sgemm(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
output_offset: torch.Tensor,
|
||||||
|
base_output: torch.Tensor = None,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# For simple lora B, we use slice offsets [0, output_dim]
|
||||||
|
output_dim = weights.shape[-2]
|
||||||
|
max_slice_size = output_dim
|
||||||
|
return chunked_sgmv_lora_expand_forward(
|
||||||
|
x=x,
|
||||||
|
lora_weight_b=weights,
|
||||||
|
batch_info=self.batch_info,
|
||||||
|
slice_offsets=output_offset,
|
||||||
|
max_slice_size=max_slice_size,
|
||||||
|
base_output=base_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_qkv_lora(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
qkv_lora_a: torch.Tensor,
|
||||||
|
qkv_lora_b: torch.Tensor,
|
||||||
|
output_offset: torch.Tensor,
|
||||||
|
max_qkv_out_dim: int,
|
||||||
|
base_output: torch.Tensor = None,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# x: (s, input_dim)
|
||||||
|
# qkv_lora_a: (num_lora, 3 * r, input_dim)
|
||||||
|
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
||||||
|
assert isinstance(qkv_lora_b, torch.Tensor)
|
||||||
|
|
||||||
|
lora_a_output = chunked_sgmv_lora_shrink_forward(
|
||||||
|
x,
|
||||||
|
qkv_lora_a,
|
||||||
|
self.batch_info,
|
||||||
|
num_slices=3,
|
||||||
|
)
|
||||||
|
lora_output = chunked_sgmv_lora_expand_forward(
|
||||||
|
x=lora_a_output,
|
||||||
|
lora_weight_b=qkv_lora_b,
|
||||||
|
batch_info=self.batch_info,
|
||||||
|
slice_offsets=output_offset,
|
||||||
|
max_slice_size=max_qkv_out_dim,
|
||||||
|
base_output=base_output,
|
||||||
|
)
|
||||||
|
return lora_output
|
||||||
|
|
||||||
|
def run_gate_up_lora(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
gate_up_lora_a: torch.Tensor,
|
||||||
|
gate_up_lora_b: torch.Tensor,
|
||||||
|
output_offset: torch.Tensor,
|
||||||
|
base_output: torch.Tensor = None,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# x: (s, input_dim)
|
||||||
|
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
|
||||||
|
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
|
||||||
|
assert isinstance(gate_up_lora_b, torch.Tensor)
|
||||||
|
output_dim = gate_up_lora_b.shape[-2] // 2
|
||||||
|
|
||||||
|
# lora_a_output: (s, 2 * r)
|
||||||
|
lora_a_output = chunked_sgmv_lora_shrink_forward(
|
||||||
|
x,
|
||||||
|
gate_up_lora_a,
|
||||||
|
self.batch_info,
|
||||||
|
num_slices=2,
|
||||||
|
)
|
||||||
|
lora_output = chunked_sgmv_lora_expand_forward(
|
||||||
|
x=lora_a_output,
|
||||||
|
lora_weight_b=gate_up_lora_b,
|
||||||
|
batch_info=self.batch_info,
|
||||||
|
slice_offsets=output_offset,
|
||||||
|
max_slice_size=output_dim,
|
||||||
|
base_output=base_output,
|
||||||
|
)
|
||||||
|
return lora_output
|
||||||
|
|
||||||
|
def prepare_lora_batch(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
weight_indices: list[int],
|
||||||
|
lora_ranks: list[int],
|
||||||
|
scalings: list[float],
|
||||||
|
batch_info: Optional[LoRABatchInfo] = None,
|
||||||
|
):
|
||||||
|
permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation(
|
||||||
|
weight_indices, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
seg_weight_indices, seg_indptr = self._get_segments_info(
|
||||||
|
weight_indices_reordered
|
||||||
|
)
|
||||||
|
num_segments = len(seg_weight_indices)
|
||||||
|
|
||||||
|
lora_ranks_tensor = torch.tensor(
|
||||||
|
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
||||||
|
)
|
||||||
|
scalings_tensor = torch.tensor(
|
||||||
|
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch_info is None:
|
||||||
|
batch_info = LoRABatchInfo(
|
||||||
|
bs=forward_batch.batch_size,
|
||||||
|
num_segments=num_segments,
|
||||||
|
use_cuda_graph=False,
|
||||||
|
seg_indptr=torch.empty(
|
||||||
|
(num_segments + 1,), dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
weight_indices=torch.empty(
|
||||||
|
(num_segments,), dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
lora_ranks=torch.empty(
|
||||||
|
(self.max_loras_per_batch,), dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
scalings=torch.empty(
|
||||||
|
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
||||||
|
),
|
||||||
|
permutation=torch.empty(
|
||||||
|
(len(permutation),), dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
# Not used in chunked kernels
|
||||||
|
max_len=None,
|
||||||
|
seg_lens=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_info.bs = forward_batch.batch_size
|
||||||
|
batch_info.num_segments = num_segments
|
||||||
|
|
||||||
|
# Copy to device asynchronously
|
||||||
|
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
|
||||||
|
lora_ranks_tensor, non_blocking=True
|
||||||
|
)
|
||||||
|
batch_info.scalings[: self.max_loras_per_batch].copy_(
|
||||||
|
scalings_tensor, non_blocking=True
|
||||||
|
)
|
||||||
|
batch_info.weight_indices[:num_segments].copy_(
|
||||||
|
seg_weight_indices, non_blocking=True
|
||||||
|
)
|
||||||
|
batch_info.seg_indptr[: num_segments + 1].copy_(seg_indptr, non_blocking=True)
|
||||||
|
batch_info.permutation[: len(permutation)].copy_(permutation, non_blocking=True)
|
||||||
|
|
||||||
|
self.batch_info = batch_info
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
|
||||||
|
"""
|
||||||
|
Computes permutation indices for reordering tokens by their LoRA adapter assignments.
|
||||||
|
|
||||||
|
This function implements the "gather" step in Chunked Segmented Gather Matrix Vector
|
||||||
|
multiplication by creating a permutation that groups tokens by their LoRA adapter.
|
||||||
|
Tokens using the same LoRA adapter are placed together to enable efficient batched
|
||||||
|
computation.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
seq_weight_indices = [0, 1, 0] # 3 sequences using adapters [0, 1, 0]
|
||||||
|
extend_seq_lens = [2, 1, 3] # sequence lengths [2, 1, 3 tokens]
|
||||||
|
|
||||||
|
# Creates row_weight_indices: [0, 0, 1, 0, 0, 0] (6 tokens total)
|
||||||
|
# Returns permutation: [0, 1, 3, 4, 5, 2] (groups adapter 0 tokens together)
|
||||||
|
# weights_reordered: [0, 0, 0, 0, 0, 1] (sorted by adapter)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_weight_indices: List of LoRA adapter indices for each sequence
|
||||||
|
forward_batch (ForwardBatch): Batch information containing sequence lengths
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (permutation, weights_reordered) where:
|
||||||
|
- permutation: Token reordering indices to group by adapter
|
||||||
|
- weights_reordered: Sorted adapter indices for each token
|
||||||
|
"""
|
||||||
|
with torch.device("cpu"):
|
||||||
|
seq_weight_indices = torch.tensor(seq_weight_indices, dtype=torch.int32)
|
||||||
|
|
||||||
|
seg_lens_cpu = (
|
||||||
|
torch.tensor(
|
||||||
|
forward_batch.extend_seq_lens_cpu,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
if forward_batch.forward_mode.is_extend()
|
||||||
|
else torch.ones(forward_batch.batch_size, dtype=torch.int32)
|
||||||
|
)
|
||||||
|
|
||||||
|
row_weight_indices = torch.repeat_interleave(
|
||||||
|
seq_weight_indices, seg_lens_cpu
|
||||||
|
)
|
||||||
|
permutation = torch.empty(
|
||||||
|
(len(row_weight_indices),), dtype=torch.long, pin_memory=True
|
||||||
|
)
|
||||||
|
torch.argsort(row_weight_indices, stable=True, out=permutation)
|
||||||
|
weights_reordered = row_weight_indices[permutation]
|
||||||
|
|
||||||
|
return permutation, weights_reordered
|
||||||
|
|
||||||
|
def _get_segments_info(self, weights_reordered: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Computes segment information for chunked SGMV operations.
|
||||||
|
|
||||||
|
This function takes the reordered weight indices and creates segments of fixed size
|
||||||
|
(self.segment_size) for efficient kernel execution. Each segment contains tokens
|
||||||
|
that use the same LoRA adapter, enabling vectorized computation.
|
||||||
|
|
||||||
|
The segmentation is necessary because:
|
||||||
|
1. GPU kernels work efficiently on fixed-size blocks
|
||||||
|
2. Large groups of tokens using the same adapter are split into manageable chunks
|
||||||
|
3. Each segment can be processed independently in parallel
|
||||||
|
|
||||||
|
Example:
|
||||||
|
weights_reordered = [0, 0, 0, 0, 0, 1] # 5 tokens with adapter 0, 1 with adapter 1
|
||||||
|
segment_size = 3
|
||||||
|
|
||||||
|
# Creates segments:
|
||||||
|
# Segment 0: tokens 0-2 (adapter 0), length=3
|
||||||
|
# Segment 1: tokens 3-4 (adapter 0), length=2
|
||||||
|
# Segment 2: token 5 (adapter 1), length=1
|
||||||
|
|
||||||
|
# Returns:
|
||||||
|
# weight_indices_list: [0, 0, 1] (adapter for each segment)
|
||||||
|
# seg_indptr: [0, 3, 5, 6] (cumulative segment boundaries)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights_reordered (torch.Tensor): Sorted adapter indices for each token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (weight_indices_list, seg_indptr) where:
|
||||||
|
- weight_indices_list: LoRA adapter index for each segment
|
||||||
|
- seg_indptr: Cumulative segment boundaries (CSR-style indptr)
|
||||||
|
"""
|
||||||
|
with torch.device("cpu"):
|
||||||
|
unique_weights, counts = torch.unique_consecutive(
|
||||||
|
weights_reordered, return_counts=True
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_indices_list = []
|
||||||
|
seg_lens_list = []
|
||||||
|
|
||||||
|
for weight_idx, group_len in zip(unique_weights, counts):
|
||||||
|
group_len = group_len.item()
|
||||||
|
num_segs = (group_len + self.segment_size - 1) // self.segment_size
|
||||||
|
|
||||||
|
weight_indices_list.extend([weight_idx.item()] * num_segs)
|
||||||
|
seg_lens_list.extend([self.segment_size] * (num_segs - 1))
|
||||||
|
seg_lens_list.append(group_len - (num_segs - 1) * self.segment_size)
|
||||||
|
|
||||||
|
seg_lens = torch.tensor(seg_lens_list, dtype=torch.int32)
|
||||||
|
|
||||||
|
weight_indices_list = torch.tensor(
|
||||||
|
weight_indices_list, dtype=torch.int32, pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
seg_indptr = torch.empty(
|
||||||
|
(len(seg_lens) + 1,), dtype=torch.int32, pin_memory=True
|
||||||
|
)
|
||||||
|
seg_indptr[0] = 0
|
||||||
|
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
||||||
|
|
||||||
|
return weight_indices_list, seg_indptr
|
||||||
@@ -28,14 +28,15 @@ from torch import nn
|
|||||||
from sglang.srt.configs.load_config import LoadConfig
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||||
|
from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
||||||
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
|
||||||
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
||||||
from sglang.srt.lora.lora_config import LoRAConfig
|
from sglang.srt.lora.lora_config import LoRAConfig
|
||||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SUPPORTED_BACKENDS = (TritonLoRABackend, ChunkedSgmvLoRABackend)
|
||||||
|
|
||||||
|
|
||||||
class LoRALayer(nn.Module):
|
class LoRALayer(nn.Module):
|
||||||
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
|
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
|
||||||
@@ -48,6 +49,7 @@ class LoRALayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LoRAAdapter(nn.Module):
|
class LoRAAdapter(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
uid: str,
|
uid: str,
|
||||||
@@ -159,8 +161,8 @@ class LoRAAdapter(nn.Module):
|
|||||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||||
if up_name not in weights:
|
if up_name not in weights:
|
||||||
weights[up_name] = torch.zeros_like(weights[weight_name])
|
weights[up_name] = torch.zeros_like(weights[weight_name])
|
||||||
assert isinstance(self.lora_backend, TritonLoRABackend), (
|
assert isinstance(self.lora_backend, SUPPORTED_BACKENDS), (
|
||||||
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
f"LoRA weight initialization currently only supported for LoRA backends: {', '.join(b.name for b in SUPPORTED_BACKENDS)}"
|
||||||
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
||||||
f"or consider implementing custom initialization logic for other backends."
|
f"or consider implementing custom initialization logic for other backends."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward
|
||||||
|
from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward
|
||||||
from .gate_up_lora_b import gate_up_lora_b_fwd
|
from .gate_up_lora_b import gate_up_lora_b_fwd
|
||||||
from .qkv_lora_b import qkv_lora_b_fwd
|
from .qkv_lora_b import qkv_lora_b_fwd
|
||||||
from .sgemm_lora_a import sgemm_lora_a_fwd
|
from .sgemm_lora_a import sgemm_lora_a_fwd
|
||||||
@@ -8,4 +10,6 @@ __all__ = [
|
|||||||
"qkv_lora_b_fwd",
|
"qkv_lora_b_fwd",
|
||||||
"sgemm_lora_a_fwd",
|
"sgemm_lora_a_fwd",
|
||||||
"sgemm_lora_b_fwd",
|
"sgemm_lora_b_fwd",
|
||||||
|
"chunked_sgmv_lora_shrink_forward",
|
||||||
|
"chunked_sgmv_lora_expand_forward",
|
||||||
]
|
]
|
||||||
|
|||||||
211
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
Normal file
211
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _chunked_lora_expand_kernel(
|
||||||
|
# Pointers to matrices
|
||||||
|
x,
|
||||||
|
weights,
|
||||||
|
output,
|
||||||
|
# Parameters of size
|
||||||
|
# Strides
|
||||||
|
x_stride_0,
|
||||||
|
x_stride_1,
|
||||||
|
w_stride_0,
|
||||||
|
w_stride_1,
|
||||||
|
w_stride_2,
|
||||||
|
output_stride_0,
|
||||||
|
output_stride_1,
|
||||||
|
# Information on sequence lengths and weight id
|
||||||
|
seg_indptr,
|
||||||
|
weight_indices,
|
||||||
|
lora_ranks,
|
||||||
|
permutation,
|
||||||
|
num_segs,
|
||||||
|
# For fused output scaling
|
||||||
|
scalings,
|
||||||
|
# Offsets of q/k/v slice on output dimension
|
||||||
|
slice_offsets,
|
||||||
|
# Meta parameters
|
||||||
|
NUM_SLICES: tl.constexpr,
|
||||||
|
MAX_RANK: tl.constexpr, # K = R
|
||||||
|
BLOCK_S: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_K: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Computes a chunked SGMV for LoRA expand operations.
|
||||||
|
|
||||||
|
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
||||||
|
the convention in pytorch where the product of two matrices of shape (m, 0)
|
||||||
|
and (0, n) is an all-zero matrix of shape (m, n).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input tensor, which is the result of the LoRA A projection.
|
||||||
|
Shape: (s, num_slices * K), where s is the sum of all sequence lengths in the
|
||||||
|
batch and K is the maximum LoRA rank.
|
||||||
|
weights (Tensor): The LoRA B weights for all adapters.
|
||||||
|
Shape: (num_lora, output_dim, K).
|
||||||
|
output (Tensor): The output tensor where the result is stored.
|
||||||
|
Shape: (s, output_dim).
|
||||||
|
"""
|
||||||
|
tl.static_assert(NUM_SLICES <= 3)
|
||||||
|
|
||||||
|
pid_s = tl.program_id(axis=2)
|
||||||
|
if pid_s >= num_segs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Current block computes sequence with batch_id,
|
||||||
|
# which starts from row seg_start of x with length seg_len.
|
||||||
|
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
|
||||||
|
w_index = tl.load(weight_indices + pid_s)
|
||||||
|
cur_rank = tl.load(lora_ranks + w_index)
|
||||||
|
|
||||||
|
# If rank is 0, this kernel is a no-op.
|
||||||
|
if cur_rank == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
seg_start = tl.load(seg_indptr + pid_s)
|
||||||
|
seg_end = tl.load(seg_indptr + pid_s + 1)
|
||||||
|
|
||||||
|
slice_id = tl.program_id(axis=1)
|
||||||
|
slice_start = tl.load(slice_offsets + slice_id)
|
||||||
|
slice_end = tl.load(slice_offsets + slice_id + 1)
|
||||||
|
|
||||||
|
scaling = tl.load(scalings + w_index)
|
||||||
|
# Adjust K (rank) according to the specific LoRA adapter
|
||||||
|
cur_rank = tl.minimum(MAX_RANK, cur_rank)
|
||||||
|
|
||||||
|
# Map logical sequence index to physical index
|
||||||
|
s_offset_logical = tl.arange(0, BLOCK_S) + seg_start
|
||||||
|
s_offset_physical = tl.load(
|
||||||
|
permutation + s_offset_logical, mask=s_offset_logical < seg_end
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
|
||||||
|
# The pointers will be advanced as we move in the K direction
|
||||||
|
# and accumulate
|
||||||
|
pid_n = tl.program_id(axis=0)
|
||||||
|
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_start
|
||||||
|
k_offset = tl.arange(0, BLOCK_K)
|
||||||
|
|
||||||
|
x_ptrs = (
|
||||||
|
x
|
||||||
|
+ slice_id * cur_rank * x_stride_1
|
||||||
|
+ (s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1)
|
||||||
|
)
|
||||||
|
w_ptrs = (weights + w_index * w_stride_0) + (
|
||||||
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Iterate to compute the block in output matrix
|
||||||
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
||||||
|
for k in range(0, tl.cdiv(cur_rank, BLOCK_K)):
|
||||||
|
x_tile = tl.load(
|
||||||
|
x_ptrs,
|
||||||
|
mask=(s_offset_logical[:, None] < seg_end)
|
||||||
|
& (k_offset[None, :] < cur_rank - k * BLOCK_K),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
w_tile = tl.load(
|
||||||
|
w_ptrs,
|
||||||
|
mask=(k_offset[:, None] < cur_rank - k * BLOCK_K)
|
||||||
|
& (n_offset[None, :] < slice_end),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
partial_sum += tl.dot(x_tile, w_tile)
|
||||||
|
|
||||||
|
x_ptrs += BLOCK_K * x_stride_1
|
||||||
|
w_ptrs += BLOCK_K * w_stride_2
|
||||||
|
|
||||||
|
# Store result to output matrix
|
||||||
|
partial_sum *= scaling
|
||||||
|
partial_sum = partial_sum.to(x.dtype.element_ty)
|
||||||
|
output_ptr = output + (
|
||||||
|
s_offset_physical[:, None] * output_stride_0
|
||||||
|
+ n_offset[None, :] * output_stride_1
|
||||||
|
)
|
||||||
|
output_mask = (s_offset_logical[:, None] < seg_end) & (
|
||||||
|
n_offset[None, :] < slice_end
|
||||||
|
)
|
||||||
|
partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0)
|
||||||
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def chunked_sgmv_lora_expand_forward(
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_weight_b: torch.Tensor,
|
||||||
|
batch_info: LoRABatchInfo,
|
||||||
|
slice_offsets: torch.Tensor,
|
||||||
|
max_slice_size: int,
|
||||||
|
base_output: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# x: (s, slice_num * r)
|
||||||
|
# lora_weight_b: (num_lora, output_dim, r)
|
||||||
|
# slice_offsets: boundaries for different slices in the output dimension
|
||||||
|
# output: (s, output_dim)
|
||||||
|
|
||||||
|
# Compute lora_output with shape (s, output_dim) as follows:
|
||||||
|
# For each slice i, accumulates:
|
||||||
|
# lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], lora_weight_b[:, slice_offsets[i]:slice_offsets[i+1], :])
|
||||||
|
|
||||||
|
# Get dims
|
||||||
|
s = x.shape[0]
|
||||||
|
input_dim = x.shape[1]
|
||||||
|
max_lora_rank = lora_weight_b.shape[-1]
|
||||||
|
output_dim = lora_weight_b.shape[-2]
|
||||||
|
num_slices = len(slice_offsets) - 1
|
||||||
|
assert input_dim == num_slices * max_lora_rank
|
||||||
|
|
||||||
|
# TODO (lifuhuang): fine-tune per operation
|
||||||
|
BLOCK_M = 16
|
||||||
|
BLOCK_K = 16
|
||||||
|
BLOCK_N = 64
|
||||||
|
|
||||||
|
num_segments = batch_info.num_segments
|
||||||
|
|
||||||
|
grid = (
|
||||||
|
triton.cdiv(max_slice_size, BLOCK_N),
|
||||||
|
num_slices, # number of slices in the input/output
|
||||||
|
batch_info.bs if batch_info.use_cuda_graph else num_segments,
|
||||||
|
)
|
||||||
|
|
||||||
|
if base_output is None:
|
||||||
|
output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype)
|
||||||
|
else:
|
||||||
|
output = base_output
|
||||||
|
|
||||||
|
_chunked_lora_expand_kernel[grid](
|
||||||
|
x=x,
|
||||||
|
weights=lora_weight_b,
|
||||||
|
output=output,
|
||||||
|
x_stride_0=x.stride(0),
|
||||||
|
x_stride_1=x.stride(1),
|
||||||
|
w_stride_0=lora_weight_b.stride(0),
|
||||||
|
w_stride_1=lora_weight_b.stride(1),
|
||||||
|
w_stride_2=lora_weight_b.stride(2),
|
||||||
|
output_stride_0=output.stride(0),
|
||||||
|
output_stride_1=output.stride(1),
|
||||||
|
seg_indptr=batch_info.seg_indptr,
|
||||||
|
weight_indices=batch_info.weight_indices,
|
||||||
|
lora_ranks=batch_info.lora_ranks,
|
||||||
|
permutation=batch_info.permutation,
|
||||||
|
num_segs=num_segments,
|
||||||
|
scalings=batch_info.scalings,
|
||||||
|
slice_offsets=slice_offsets,
|
||||||
|
# constants
|
||||||
|
NUM_SLICES=num_slices,
|
||||||
|
MAX_RANK=max_lora_rank,
|
||||||
|
BLOCK_S=BLOCK_M,
|
||||||
|
BLOCK_N=BLOCK_N,
|
||||||
|
BLOCK_K=BLOCK_K,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
177
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
Normal file
177
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _chunked_lora_shrink_kernel(
|
||||||
|
# Pointers to matrices
|
||||||
|
x,
|
||||||
|
weights,
|
||||||
|
output,
|
||||||
|
# Strides
|
||||||
|
x_stride_0,
|
||||||
|
x_stride_1,
|
||||||
|
w_stride_0,
|
||||||
|
w_stride_1,
|
||||||
|
w_stride_2,
|
||||||
|
output_stride_0,
|
||||||
|
output_stride_1,
|
||||||
|
# Information on sequence lengths,ranks and weight id
|
||||||
|
seg_indptr,
|
||||||
|
weight_indices,
|
||||||
|
lora_ranks,
|
||||||
|
permutation,
|
||||||
|
num_segs,
|
||||||
|
# Meta parameters
|
||||||
|
N: tl.constexpr, # num_slices * r
|
||||||
|
K: tl.constexpr, # input_dim
|
||||||
|
NUM_SLICES: tl.constexpr,
|
||||||
|
BLOCK_S: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_K: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Computes a chunked SGMV for LoRA shrink operations.
|
||||||
|
|
||||||
|
The kernel ensures that output[seg_start:seg_start + seg_len, :rank * num_slices]
|
||||||
|
stores the product of the input `x` and the LoRA weights for the corresponding
|
||||||
|
sequence. This implies that when rank is 0, the kernel is essentially a no-op,
|
||||||
|
as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
|
||||||
|
is the sum of all sequence lengths in the batch.
|
||||||
|
weights (torch.Tensor): The LoRA A weights for all available adapters,
|
||||||
|
with shape `(num_lora, N, K)` where N = num_slices * r.
|
||||||
|
output (torch.Tensor): The output tensor of shape `(s, N)`.
|
||||||
|
"""
|
||||||
|
pid_s = tl.program_id(1)
|
||||||
|
if pid_s >= num_segs:
|
||||||
|
return
|
||||||
|
|
||||||
|
pid_n = tl.program_id(0)
|
||||||
|
|
||||||
|
# Current block computes sequence with batch_id,
|
||||||
|
# which starts from row seg_start of x with length seg_len
|
||||||
|
w_index = tl.load(weight_indices + pid_s)
|
||||||
|
rank = tl.load(lora_ranks + w_index)
|
||||||
|
|
||||||
|
# If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
|
||||||
|
if rank == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
seg_start = tl.load(seg_indptr + pid_s)
|
||||||
|
seg_end = tl.load(seg_indptr + pid_s + 1)
|
||||||
|
|
||||||
|
# Adjust N dim according to the specific LoRA adapter
|
||||||
|
cur_n = tl.minimum(N, rank * NUM_SLICES)
|
||||||
|
|
||||||
|
# Map logical sequence index to physical index
|
||||||
|
s_offset_logical = tl.arange(0, BLOCK_S) + seg_start
|
||||||
|
s_offset_physical = tl.load(
|
||||||
|
permutation + s_offset_logical, mask=s_offset_logical < seg_end
|
||||||
|
)
|
||||||
|
|
||||||
|
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||||
|
k_offset = tl.arange(0, BLOCK_K)
|
||||||
|
x_ptrs = x + (
|
||||||
|
s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
||||||
|
)
|
||||||
|
w_ptrs = (weights + w_index * w_stride_0) + (
|
||||||
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Iterate to compute the block in output matrix
|
||||||
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
||||||
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||||
|
x_tile = tl.load(
|
||||||
|
x_ptrs,
|
||||||
|
mask=(s_offset_logical[:, None] < seg_end)
|
||||||
|
& (k_offset[None, :] < K - k * BLOCK_K),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
w_tile = tl.load(
|
||||||
|
w_ptrs,
|
||||||
|
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < cur_n),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
partial_sum += tl.dot(x_tile, w_tile)
|
||||||
|
|
||||||
|
x_ptrs += BLOCK_K * x_stride_1
|
||||||
|
w_ptrs += BLOCK_K * w_stride_2
|
||||||
|
|
||||||
|
# Store result to output matrix
|
||||||
|
partial_sum = partial_sum.to(x.dtype.element_ty)
|
||||||
|
output_ptr = output + (
|
||||||
|
s_offset_physical[:, None] * output_stride_0
|
||||||
|
+ n_offset[None, :] * output_stride_1
|
||||||
|
)
|
||||||
|
output_mask = (s_offset_logical[:, None] < seg_end) & (n_offset[None, :] < cur_n)
|
||||||
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def chunked_sgmv_lora_shrink_forward(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
batch_info: LoRABatchInfo,
|
||||||
|
num_slices: int = 1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# x: (s, input_dim)
|
||||||
|
# weights: (num_lora, num_slices * r, input_dim)
|
||||||
|
# output: (s, num_slices * r)
|
||||||
|
# num_slices: qkv=3, gate_up=2, others=1
|
||||||
|
# when called with multiple slices, the weights.shape[-2] will be num_slices * r
|
||||||
|
# input_dim is much larger than r
|
||||||
|
|
||||||
|
assert x.is_contiguous()
|
||||||
|
assert weights.is_contiguous()
|
||||||
|
assert len(x.shape) == 2
|
||||||
|
assert len(weights.shape) == 3
|
||||||
|
|
||||||
|
# Block shapes
|
||||||
|
# TODO (lifuhuang): experiment with split-k
|
||||||
|
BLOCK_S = 16
|
||||||
|
BLOCK_N = 16
|
||||||
|
BLOCK_K = 256
|
||||||
|
|
||||||
|
S = x.shape[0]
|
||||||
|
N = weights.shape[1]
|
||||||
|
K = weights.shape[2]
|
||||||
|
assert x.shape[-1] == K
|
||||||
|
|
||||||
|
num_segments = batch_info.num_segments
|
||||||
|
grid = (
|
||||||
|
triton.cdiv(N, BLOCK_N),
|
||||||
|
batch_info.bs if batch_info.use_cuda_graph else num_segments,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
|
||||||
|
_chunked_lora_shrink_kernel[grid](
|
||||||
|
x=x,
|
||||||
|
weights=weights,
|
||||||
|
output=output,
|
||||||
|
x_stride_0=x.stride(0),
|
||||||
|
x_stride_1=x.stride(1),
|
||||||
|
w_stride_0=weights.stride(0),
|
||||||
|
w_stride_1=weights.stride(1),
|
||||||
|
w_stride_2=weights.stride(2),
|
||||||
|
output_stride_0=output.stride(0),
|
||||||
|
output_stride_1=output.stride(1),
|
||||||
|
seg_indptr=batch_info.seg_indptr,
|
||||||
|
weight_indices=batch_info.weight_indices,
|
||||||
|
lora_ranks=batch_info.lora_ranks,
|
||||||
|
permutation=batch_info.permutation,
|
||||||
|
num_segs=num_segments,
|
||||||
|
# constants
|
||||||
|
N=N,
|
||||||
|
K=K,
|
||||||
|
NUM_SLICES=num_slices,
|
||||||
|
BLOCK_S=BLOCK_S,
|
||||||
|
BLOCK_N=BLOCK_N,
|
||||||
|
BLOCK_K=BLOCK_K,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
@@ -110,6 +110,8 @@ ATTENTION_BACKEND_CHOICES = [
|
|||||||
"ascend",
|
"ascend",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
LORA_BACKEND_CHOICES = ["triton", "csgmv"]
|
||||||
|
|
||||||
DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
|
DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
|
||||||
|
|
||||||
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
||||||
@@ -1601,7 +1603,8 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-backend",
|
"--lora-backend",
|
||||||
type=str,
|
type=str,
|
||||||
default="triton",
|
choices=LORA_BACKEND_CHOICES,
|
||||||
|
default=ServerArgs.lora_backend,
|
||||||
help="Choose the kernel backend for multi-LoRA serving.",
|
help="Choose the kernel backend for multi-LoRA serving.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
740
test/srt/lora/test_chunked_sgmv_backend.py
Normal file
740
test/srt/lora/test_chunked_sgmv_backend.py
Normal file
@@ -0,0 +1,740 @@
|
|||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
||||||
|
from sglang.srt.lora.triton_ops import (
|
||||||
|
chunked_sgmv_lora_expand_forward,
|
||||||
|
chunked_sgmv_lora_shrink_forward,
|
||||||
|
)
|
||||||
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
|
|
||||||
|
|
||||||
|
def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Matrix multiplication with mixed precision handling for float16"""
|
||||||
|
result = torch.matmul(a.float(), b.float())
|
||||||
|
return result.to(a.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchComposition(Enum):
|
||||||
|
UNIFORM = "uniform"
|
||||||
|
MIXED = "mixed"
|
||||||
|
SKEWED = "skewed"
|
||||||
|
NONE = "_NO_LORA_"
|
||||||
|
|
||||||
|
|
||||||
|
class BatchMode(Enum):
|
||||||
|
PREFILL = "prefill"
|
||||||
|
DECODE = "decode"
|
||||||
|
|
||||||
|
|
||||||
|
def reference_sgmv_shrink(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
batch_info: LoRABatchInfo,
|
||||||
|
seq_lengths: List[int],
|
||||||
|
lora_assignments: List[str],
|
||||||
|
num_slices: int = 1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Simple sequence-level reference implementation of SGMV shrink operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (total_seq_len, input_dim) - Input activations
|
||||||
|
weights: (num_loras, num_slices * max_rank, input_dim) - LoRA A weights
|
||||||
|
batch_info: Batch information (only used for lora_ranks)
|
||||||
|
seq_lengths: Length of each sequence
|
||||||
|
lora_assignments: LoRA name for each sequence
|
||||||
|
num_slices: Number of slices (3 for QKV, 2 for gate_up, 1 for others)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: (total_seq_len, num_slices * max_rank) - Intermediate activations
|
||||||
|
"""
|
||||||
|
if weights.numel() == 0:
|
||||||
|
total_seq_len = x.shape[0]
|
||||||
|
return torch.zeros(total_seq_len, 0, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
total_seq_len, input_dim = x.shape
|
||||||
|
num_loras, weight_out_dim, _ = weights.shape
|
||||||
|
max_rank = weight_out_dim // num_slices
|
||||||
|
|
||||||
|
output = torch.zeros(
|
||||||
|
total_seq_len, num_slices * max_rank, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
|
unique_loras = sorted(set(lora_assignments))
|
||||||
|
lora_name_to_idx = {name: idx for idx, name in enumerate(unique_loras)}
|
||||||
|
lora_ranks = batch_info.lora_ranks.cpu().numpy()
|
||||||
|
|
||||||
|
token_offset = 0
|
||||||
|
for seq_len, lora_name in zip(seq_lengths, lora_assignments):
|
||||||
|
if seq_len == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_idx = lora_name_to_idx[lora_name]
|
||||||
|
rank = lora_ranks[lora_idx]
|
||||||
|
|
||||||
|
if rank > 0:
|
||||||
|
x_seq = x[token_offset : token_offset + seq_len, :]
|
||||||
|
w_seq = weights[lora_idx, : num_slices * rank, :]
|
||||||
|
|
||||||
|
result = safe_matmul(x_seq, w_seq.t())
|
||||||
|
output[token_offset : token_offset + seq_len, : num_slices * rank] = result
|
||||||
|
|
||||||
|
token_offset += seq_len
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def reference_sgmv_expand(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
batch_info: LoRABatchInfo,
|
||||||
|
seq_lengths: List[int],
|
||||||
|
lora_assignments: List[str],
|
||||||
|
slice_offsets: torch.Tensor,
|
||||||
|
max_slice_size: int,
|
||||||
|
base_output: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Simple sequence-level reference implementation of SGMV expand operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (total_seq_len, num_slices * max_rank) - Intermediate activations
|
||||||
|
weights: (num_loras, output_dim, max_rank) - LoRA B weights
|
||||||
|
batch_info: Batch information (only used for lora_ranks)
|
||||||
|
seq_lengths: Length of each sequence
|
||||||
|
lora_assignments: LoRA name for each sequence
|
||||||
|
slice_offsets: Tensor defining slice boundaries
|
||||||
|
max_slice_size: Maximum slice size for chunking
|
||||||
|
base_output: Optional base output to accumulate into
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: (total_seq_len, total_output_dim) - Final output
|
||||||
|
"""
|
||||||
|
if weights.numel() == 0:
|
||||||
|
total_seq_len = x.shape[0]
|
||||||
|
total_output_dim = slice_offsets[-1].item() if len(slice_offsets) > 0 else 0
|
||||||
|
return torch.zeros(
|
||||||
|
total_seq_len, total_output_dim, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
|
total_seq_len, _ = x.shape
|
||||||
|
|
||||||
|
num_slices = len(slice_offsets) - 1
|
||||||
|
|
||||||
|
if base_output is not None:
|
||||||
|
output = base_output.clone()
|
||||||
|
else:
|
||||||
|
total_output_dim = slice_offsets[-1].item()
|
||||||
|
output = torch.zeros(
|
||||||
|
total_seq_len, total_output_dim, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
|
unique_loras = sorted(set(lora_assignments))
|
||||||
|
lora_name_to_idx = {name: idx for idx, name in enumerate(unique_loras)}
|
||||||
|
lora_ranks = batch_info.lora_ranks.cpu().numpy()
|
||||||
|
|
||||||
|
token_offset = 0
|
||||||
|
for seq_len, lora_name in zip(seq_lengths, lora_assignments):
|
||||||
|
if seq_len == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_idx = lora_name_to_idx[lora_name]
|
||||||
|
lora_rank = lora_ranks[lora_idx]
|
||||||
|
|
||||||
|
if lora_rank > 0:
|
||||||
|
# Extract sequence intermediate activations
|
||||||
|
x_seq = x[
|
||||||
|
token_offset : token_offset + seq_len, : num_slices * lora_rank
|
||||||
|
] # (seq_len, num_slices * rank)
|
||||||
|
|
||||||
|
for slice_idx in range(num_slices):
|
||||||
|
slice_start_input = slice_idx * lora_rank
|
||||||
|
slice_end_input = (slice_idx + 1) * lora_rank
|
||||||
|
|
||||||
|
slice_start_output = slice_offsets[slice_idx].item()
|
||||||
|
slice_end_output = slice_offsets[slice_idx + 1].item()
|
||||||
|
|
||||||
|
x_slice = x_seq[:, slice_start_input:slice_end_input] # (seq_len, rank)
|
||||||
|
w_slice = weights[
|
||||||
|
lora_idx, slice_start_output:slice_end_output, :lora_rank
|
||||||
|
] # (slice_dim, rank)
|
||||||
|
|
||||||
|
result = safe_matmul(x_slice, w_slice.t()) # (seq_len, slice_dim)
|
||||||
|
output[
|
||||||
|
token_offset : token_offset + seq_len,
|
||||||
|
slice_start_output:slice_end_output,
|
||||||
|
] += result
|
||||||
|
|
||||||
|
token_offset += seq_len
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkedSGMV(unittest.TestCase):
|
||||||
|
|
||||||
|
# Test configuration constants
|
||||||
|
RTOL = 1e-3
|
||||||
|
ATOL = 1e-3
|
||||||
|
DEFAULT_BATCH_SIZE = 8
|
||||||
|
|
||||||
|
def _compare_shrink_outputs(
|
||||||
|
self,
|
||||||
|
chunked_output: torch.Tensor,
|
||||||
|
reference_output: torch.Tensor,
|
||||||
|
seq_lengths: List[int],
|
||||||
|
lora_assignments: List[str],
|
||||||
|
batch_info: LoRABatchInfo,
|
||||||
|
num_slices: int,
|
||||||
|
test_name: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compare only the valid portions of shrink outputs.
|
||||||
|
|
||||||
|
The chunked SGMV shrink kernel only guarantees correctness for
|
||||||
|
output[seq_start:seq_end, :rank * num_slices] for each sequence.
|
||||||
|
"""
|
||||||
|
# Create mapping from LoRA names to indices and ranks
|
||||||
|
unique_loras = sorted(set(lora_assignments))
|
||||||
|
lora_name_to_idx = {name: idx for idx, name in enumerate(unique_loras)}
|
||||||
|
lora_ranks = batch_info.lora_ranks.cpu().numpy()
|
||||||
|
|
||||||
|
token_offset = 0
|
||||||
|
for seq_idx, (seq_len, lora_name) in enumerate(
|
||||||
|
zip(seq_lengths, lora_assignments)
|
||||||
|
):
|
||||||
|
if seq_len == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_idx = lora_name_to_idx[lora_name]
|
||||||
|
rank = lora_ranks[lora_idx]
|
||||||
|
|
||||||
|
if rank > 0:
|
||||||
|
# Only compare the valid columns for this sequence
|
||||||
|
valid_cols = num_slices * rank
|
||||||
|
|
||||||
|
chunked_seq = chunked_output[
|
||||||
|
token_offset : token_offset + seq_len, :valid_cols
|
||||||
|
]
|
||||||
|
reference_seq = reference_output[
|
||||||
|
token_offset : token_offset + seq_len, :valid_cols
|
||||||
|
]
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
chunked_seq,
|
||||||
|
reference_seq,
|
||||||
|
rtol=self.RTOL,
|
||||||
|
atol=self.ATOL,
|
||||||
|
msg=f"Shrink operation failed for {test_name}, sequence {seq_idx} ({lora_name})",
|
||||||
|
)
|
||||||
|
|
||||||
|
token_offset += seq_len
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up common test parameters"""
|
||||||
|
torch.manual_seed(42)
|
||||||
|
random.seed(42)
|
||||||
|
|
||||||
|
self.device = torch.device("cuda")
|
||||||
|
self.dtype = torch.float16
|
||||||
|
self.input_dim = 2560 # Hidden dimension
|
||||||
|
self.max_seq_len = 1024
|
||||||
|
|
||||||
|
# LoRA configurations: name -> (rank, output_q, output_k, output_v)
|
||||||
|
self.lora_configs = {
|
||||||
|
"lora_A": (8, 4096, 1024, 1024),
|
||||||
|
"lora_B": (16, 4096, 1024, 1024),
|
||||||
|
"lora_C": (32, 4096, 1024, 1024),
|
||||||
|
"_NO_LORA_": (0, 4096, 1024, 1024),
|
||||||
|
}
|
||||||
|
|
||||||
|
# QKV slice offsets: 4096 (Q) + 1024 (K) + 1024 (V) = 6144 total
|
||||||
|
self.slice_offsets = torch.tensor(
|
||||||
|
[0, 4096, 5120, 6144], dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
self.max_slice_size = 4096
|
||||||
|
|
||||||
|
def generate_sequence_lengths(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
batch_mode: BatchMode = BatchMode.PREFILL,
|
||||||
|
min_len: int = 1,
|
||||||
|
max_len: int = None,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Generate sequence lengths for a batch based on mode"""
|
||||||
|
if batch_mode == BatchMode.DECODE:
|
||||||
|
return [1] * batch_size
|
||||||
|
else:
|
||||||
|
if max_len is None:
|
||||||
|
max_len = self.max_seq_len
|
||||||
|
return [random.randint(min_len, max_len) for _ in range(batch_size)]
|
||||||
|
|
||||||
|
def create_lora_weights(
|
||||||
|
self, lora_name: str, include_missing_k: bool = False
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Create LoRA A and B weights for given configuration"""
|
||||||
|
rank, out_q, out_k, out_v = self.lora_configs[lora_name]
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
lora_a = torch.empty(
|
||||||
|
0, self.input_dim, dtype=self.dtype, device=self.device
|
||||||
|
)
|
||||||
|
lora_b = torch.empty(
|
||||||
|
out_q + out_k + out_v, 0, dtype=self.dtype, device=self.device
|
||||||
|
)
|
||||||
|
return lora_a, lora_b
|
||||||
|
|
||||||
|
# Create LoRA A weights (3 slices for QKV)
|
||||||
|
lora_a = torch.randn(
|
||||||
|
3 * rank, self.input_dim, dtype=self.dtype, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_missing_k:
|
||||||
|
lora_a[rank : 2 * rank, :] = 0.0
|
||||||
|
|
||||||
|
# Create LoRA B weights (stacked Q, K, V)
|
||||||
|
total_output_dim = out_q + out_k + out_v
|
||||||
|
lora_b = torch.randn(
|
||||||
|
total_output_dim, rank, dtype=self.dtype, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_missing_k:
|
||||||
|
lora_b[out_q : out_q + out_k, :] = 0.0
|
||||||
|
|
||||||
|
return lora_a, lora_b
|
||||||
|
|
||||||
|
def create_batch_info(
|
||||||
|
self,
|
||||||
|
seq_lengths: List[int],
|
||||||
|
lora_assignments: List[Optional[str]],
|
||||||
|
batch_mode: BatchMode = BatchMode.PREFILL,
|
||||||
|
) -> LoRABatchInfo:
|
||||||
|
"""Create LoRABatchInfo using the same logic as chunked backend"""
|
||||||
|
unique_loras = sorted(set(lora_assignments))
|
||||||
|
lora_name_to_idx = {name: idx for idx, name in enumerate(unique_loras)}
|
||||||
|
|
||||||
|
seq_weight_indices = [lora_name_to_idx[name] for name in lora_assignments]
|
||||||
|
|
||||||
|
lora_ranks = [self.lora_configs[name][0] for name in unique_loras]
|
||||||
|
|
||||||
|
def create_mock_batch():
|
||||||
|
# Create a minimal mock ForwardBatch for the test
|
||||||
|
class MockForwardBatch:
|
||||||
|
def __init__(self, batch_size, seq_lengths):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.extend_seq_lens_cpu = seq_lengths
|
||||||
|
self.forward_mode = MockForwardMode()
|
||||||
|
|
||||||
|
class MockForwardMode:
|
||||||
|
def is_extend(self):
|
||||||
|
return batch_mode == BatchMode.PREFILL
|
||||||
|
|
||||||
|
return MockForwardBatch(len(seq_lengths), seq_lengths)
|
||||||
|
|
||||||
|
mock_batch = create_mock_batch()
|
||||||
|
|
||||||
|
# Use the same functions as chunked backend
|
||||||
|
permutation, weights_reordered = ChunkedSgmvLoRABackend._get_permutation(
|
||||||
|
seq_weight_indices, mock_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a minimal backend instance to access _get_segments_info
|
||||||
|
mock_backend = ChunkedSgmvLoRABackend(max_loras_per_batch=8, device=self.device)
|
||||||
|
weight_indices_list, seg_indptr = mock_backend._get_segments_info(
|
||||||
|
weights_reordered
|
||||||
|
)
|
||||||
|
|
||||||
|
scalings = [1.0] * len(unique_loras)
|
||||||
|
seg_indptr_tensor = seg_indptr.to(self.device)
|
||||||
|
weight_indices_tensor = weight_indices_list.to(self.device)
|
||||||
|
lora_ranks_tensor = (
|
||||||
|
torch.tensor(lora_ranks, dtype=torch.int32, device=self.device)
|
||||||
|
if lora_ranks
|
||||||
|
else torch.empty(0, dtype=torch.int32, device=self.device)
|
||||||
|
)
|
||||||
|
scalings_tensor = (
|
||||||
|
torch.tensor(scalings, dtype=torch.float32, device=self.device)
|
||||||
|
if scalings
|
||||||
|
else torch.empty(0, dtype=torch.float32, device=self.device)
|
||||||
|
)
|
||||||
|
permutation_tensor = permutation.to(
|
||||||
|
self.device, dtype=torch.int32
|
||||||
|
) # Convert to int32 for LoRABatchInfo
|
||||||
|
seq_lens_tensor = torch.tensor(
|
||||||
|
seq_lengths, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
return LoRABatchInfo(
|
||||||
|
use_cuda_graph=False,
|
||||||
|
bs=len(seq_lengths),
|
||||||
|
num_segments=len(weight_indices_list), # Number of segments, not sequences!
|
||||||
|
seg_indptr=seg_indptr_tensor,
|
||||||
|
weight_indices=weight_indices_tensor,
|
||||||
|
lora_ranks=lora_ranks_tensor,
|
||||||
|
scalings=scalings_tensor,
|
||||||
|
seg_lens=seq_lens_tensor, # Original sequence lengths for reference
|
||||||
|
max_len=max(seq_lengths) if seq_lengths else 0,
|
||||||
|
permutation=permutation_tensor, # Token reordering permutation
|
||||||
|
)
|
||||||
|
|
||||||
|
def stack_lora_weights(
|
||||||
|
self, weight_list: List[torch.Tensor], is_lora_a: bool
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Stack LoRA weights from different adapters into a single tensor"""
|
||||||
|
if not weight_list:
|
||||||
|
return torch.empty(0, 0, 0, dtype=self.dtype, device=self.device)
|
||||||
|
|
||||||
|
first_non_empty = next((w for w in weight_list if w.numel() > 0), None)
|
||||||
|
if first_non_empty is None:
|
||||||
|
return torch.empty(
|
||||||
|
len(weight_list), 0, 0, dtype=self.dtype, device=self.device
|
||||||
|
)
|
||||||
|
if is_lora_a:
|
||||||
|
# LoRA A: (slice_num * rank, input_dim) -> (num_loras, slice_num * max_rank, input_dim)
|
||||||
|
max_rank = max(w.shape[0] // 3 if w.numel() > 0 else 0 for w in weight_list)
|
||||||
|
final_shape = (len(weight_list), 3 * max_rank, self.input_dim)
|
||||||
|
else:
|
||||||
|
# LoRA B: (output_dim, rank) -> (num_loras, output_dim, max_rank)
|
||||||
|
max_rank = max(w.shape[1] if w.numel() > 0 else 0 for w in weight_list)
|
||||||
|
output_dim = first_non_empty.shape[0]
|
||||||
|
final_shape = (len(weight_list), output_dim, max_rank)
|
||||||
|
|
||||||
|
stacked = torch.zeros(final_shape, dtype=self.dtype, device=self.device)
|
||||||
|
|
||||||
|
for i, weight in enumerate(weight_list):
|
||||||
|
if weight.numel() > 0:
|
||||||
|
if is_lora_a:
|
||||||
|
stacked[i, : weight.shape[0], :] = weight
|
||||||
|
else:
|
||||||
|
stacked[i, :, : weight.shape[1]] = weight
|
||||||
|
|
||||||
|
return stacked
|
||||||
|
|
||||||
|
def create_test_batch(
|
||||||
|
self,
|
||||||
|
batch_composition: BatchComposition,
|
||||||
|
batch_size: int,
|
||||||
|
batch_mode: BatchMode = BatchMode.PREFILL,
|
||||||
|
include_missing_k: bool = False,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.Tensor,
|
||||||
|
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
LoRABatchInfo,
|
||||||
|
List[int],
|
||||||
|
List[str],
|
||||||
|
]:
|
||||||
|
"""Create test batch with specified composition and mode"""
|
||||||
|
seq_lengths = self.generate_sequence_lengths(
|
||||||
|
batch_size, batch_mode, 1, self.max_seq_len
|
||||||
|
)
|
||||||
|
if batch_composition == BatchComposition.UNIFORM:
|
||||||
|
lora_assignments = ["lora_A"] * batch_size
|
||||||
|
elif batch_composition == BatchComposition.MIXED:
|
||||||
|
lora_names = ["lora_A", "lora_B", "lora_C", None]
|
||||||
|
lora_assignments = [
|
||||||
|
lora_names[i % len(lora_names)] for i in range(batch_size)
|
||||||
|
]
|
||||||
|
elif batch_composition == BatchComposition.SKEWED:
|
||||||
|
num_minority = max(1, batch_size // 8)
|
||||||
|
lora_assignments = ["lora_A"] * num_minority + ["lora_B"] * (
|
||||||
|
batch_size - num_minority
|
||||||
|
)
|
||||||
|
random.shuffle(lora_assignments)
|
||||||
|
elif batch_composition == BatchComposition.NONE:
|
||||||
|
lora_assignments = [None] * batch_size
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown batch composition: {batch_composition}")
|
||||||
|
|
||||||
|
total_seq_len = sum(seq_lengths)
|
||||||
|
x = torch.randn(
|
||||||
|
total_seq_len, self.input_dim, dtype=self.dtype, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
normalized_assignments = [
|
||||||
|
name if name is not None else "_NO_LORA_" for name in lora_assignments
|
||||||
|
]
|
||||||
|
unique_loras = set(normalized_assignments)
|
||||||
|
weights = {}
|
||||||
|
for lora_name in unique_loras:
|
||||||
|
weights[lora_name] = self.create_lora_weights(lora_name, include_missing_k)
|
||||||
|
|
||||||
|
batch_info = self.create_batch_info(
|
||||||
|
seq_lengths, normalized_assignments, batch_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
return x, weights, batch_info, seq_lengths, normalized_assignments
|
||||||
|
|
||||||
|
def run_test_comparison(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
batch_info: LoRABatchInfo,
|
||||||
|
seq_lengths: List[int],
|
||||||
|
lora_assignments: List[str],
|
||||||
|
test_name: str,
|
||||||
|
):
|
||||||
|
"""Run comparison between chunked and reference implementations"""
|
||||||
|
if not weights: # Handle case with no LoRA weights
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stack LoRA A weights
|
||||||
|
lora_a_weights = [weights[name][0] for name in sorted(weights.keys())]
|
||||||
|
stacked_lora_a = self.stack_lora_weights(lora_a_weights, is_lora_a=True)
|
||||||
|
|
||||||
|
# Stack LoRA B weights
|
||||||
|
lora_b_weights = [weights[name][1] for name in sorted(weights.keys())]
|
||||||
|
stacked_lora_b = self.stack_lora_weights(lora_b_weights, is_lora_a=False)
|
||||||
|
|
||||||
|
# Test shrink operation
|
||||||
|
chunked_shrink = chunked_sgmv_lora_shrink_forward(
|
||||||
|
x, stacked_lora_a, batch_info, num_slices=3
|
||||||
|
)
|
||||||
|
reference_shrink = reference_sgmv_shrink(
|
||||||
|
x, stacked_lora_a, batch_info, seq_lengths, lora_assignments, num_slices=3
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only compare valid portions of shrink output (first rank * num_slices columns per sequence)
|
||||||
|
self._compare_shrink_outputs(
|
||||||
|
chunked_shrink,
|
||||||
|
reference_shrink,
|
||||||
|
seq_lengths,
|
||||||
|
lora_assignments,
|
||||||
|
batch_info,
|
||||||
|
num_slices=3,
|
||||||
|
test_name=test_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test expand operation
|
||||||
|
chunked_expand = chunked_sgmv_lora_expand_forward(
|
||||||
|
reference_shrink,
|
||||||
|
stacked_lora_b,
|
||||||
|
batch_info,
|
||||||
|
self.slice_offsets,
|
||||||
|
self.max_slice_size,
|
||||||
|
)
|
||||||
|
reference_expand = reference_sgmv_expand(
|
||||||
|
reference_shrink,
|
||||||
|
stacked_lora_b,
|
||||||
|
batch_info,
|
||||||
|
seq_lengths,
|
||||||
|
lora_assignments,
|
||||||
|
self.slice_offsets,
|
||||||
|
self.max_slice_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
chunked_expand,
|
||||||
|
reference_expand,
|
||||||
|
rtol=self.RTOL,
|
||||||
|
atol=self.ATOL,
|
||||||
|
msg=f"Expand operation failed for {test_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# === Basic Operations Tests ===
|
||||||
|
|
||||||
|
def test_shrink_basic(self):
|
||||||
|
"""Test basic shrink operation against PyTorch reference"""
|
||||||
|
for batch_size in [1, 2, 16, 64]:
|
||||||
|
with self.subTest(batch_size=batch_size):
|
||||||
|
x, weights, batch_info, seq_lengths, lora_assignments = (
|
||||||
|
self.create_test_batch(BatchComposition.UNIFORM, batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_a_weights = [weights[name][0] for name in sorted(weights.keys())]
|
||||||
|
stacked_lora_a = self.stack_lora_weights(lora_a_weights, is_lora_a=True)
|
||||||
|
|
||||||
|
chunked_shrink = chunked_sgmv_lora_shrink_forward(
|
||||||
|
x, stacked_lora_a, batch_info, num_slices=3
|
||||||
|
)
|
||||||
|
reference_shrink = reference_sgmv_shrink(
|
||||||
|
x,
|
||||||
|
stacked_lora_a,
|
||||||
|
batch_info,
|
||||||
|
seq_lengths,
|
||||||
|
lora_assignments,
|
||||||
|
num_slices=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
chunked_shrink, reference_shrink, rtol=self.RTOL, atol=self.ATOL
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_expand_basic(self):
|
||||||
|
"""Test basic expand operation against PyTorch reference"""
|
||||||
|
for batch_size in [1, 2, 16, 64]:
|
||||||
|
with self.subTest(batch_size=batch_size):
|
||||||
|
x, weights, batch_info, seq_lengths, lora_assignments = (
|
||||||
|
self.create_test_batch(BatchComposition.UNIFORM, batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_a_weights = [weights[name][0] for name in sorted(weights.keys())]
|
||||||
|
stacked_lora_a = self.stack_lora_weights(lora_a_weights, is_lora_a=True)
|
||||||
|
|
||||||
|
intermediate = reference_sgmv_shrink(
|
||||||
|
x,
|
||||||
|
stacked_lora_a,
|
||||||
|
batch_info,
|
||||||
|
seq_lengths,
|
||||||
|
lora_assignments,
|
||||||
|
num_slices=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_b_weights = [weights[name][1] for name in sorted(weights.keys())]
|
||||||
|
stacked_lora_b = self.stack_lora_weights(
|
||||||
|
lora_b_weights, is_lora_a=False
|
||||||
|
)
|
||||||
|
|
||||||
|
chunked_expand = chunked_sgmv_lora_expand_forward(
|
||||||
|
intermediate,
|
||||||
|
stacked_lora_b,
|
||||||
|
batch_info,
|
||||||
|
self.slice_offsets,
|
||||||
|
self.max_slice_size,
|
||||||
|
)
|
||||||
|
reference_expand = reference_sgmv_expand(
|
||||||
|
intermediate,
|
||||||
|
stacked_lora_b,
|
||||||
|
batch_info,
|
||||||
|
seq_lengths,
|
||||||
|
lora_assignments,
|
||||||
|
self.slice_offsets,
|
||||||
|
self.max_slice_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
chunked_expand, reference_expand, rtol=self.RTOL, atol=self.ATOL
|
||||||
|
)
|
||||||
|
|
||||||
|
# === QKV Operations Test ===
|
||||||
|
|
||||||
|
def test_qkv_missing_projections(self):
|
||||||
|
"""Test QKV operations with missing k_proj (Qwen3 scenario)"""
|
||||||
|
for batch_size in [1, 2, 16, 64]:
|
||||||
|
with self.subTest(batch_size=batch_size):
|
||||||
|
x, weights, batch_info, seq_lengths, lora_assignments = (
|
||||||
|
self.create_test_batch(
|
||||||
|
BatchComposition.MIXED, batch_size, include_missing_k=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.run_test_comparison(
|
||||||
|
x,
|
||||||
|
weights,
|
||||||
|
batch_info,
|
||||||
|
seq_lengths,
|
||||||
|
lora_assignments,
|
||||||
|
f"QKV missing k_proj batch_size={batch_size}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# === Batch Composition Tests ===
|
||||||
|
|
||||||
|
def test_uniform_lora_batch(self):
|
||||||
|
"""All sequences use same LoRA, random sequence lengths"""
|
||||||
|
for batch_size in [1, 2, 16, 64]:
|
||||||
|
with self.subTest(batch_size=batch_size):
|
||||||
|
x, weights, batch_info, seq_lengths, lora_assignments = (
|
||||||
|
self.create_test_batch(BatchComposition.UNIFORM, batch_size)
|
||||||
|
)
|
||||||
|
self.run_test_comparison(
|
||||||
|
x,
|
||||||
|
weights,
|
||||||
|
batch_info,
|
||||||
|
seq_lengths,
|
||||||
|
lora_assignments,
|
||||||
|
f"uniform batch_size={batch_size}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_evenly_mixed_lora_batch(self):
|
||||||
|
"""Sequences evenly distributed across LoRAs, random lengths"""
|
||||||
|
for batch_size in [1, 2, 16, 64]:
|
||||||
|
with self.subTest(batch_size=batch_size):
|
||||||
|
x, weights, batch_info, seq_lengths, lora_assignments = (
|
||||||
|
self.create_test_batch(BatchComposition.MIXED, batch_size)
|
||||||
|
)
|
||||||
|
self.run_test_comparison(
|
||||||
|
x,
|
||||||
|
weights,
|
||||||
|
batch_info,
|
||||||
|
seq_lengths,
|
||||||
|
lora_assignments,
|
||||||
|
f"mixed batch_size={batch_size}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_highly_skewed_lora_batch(self):
|
||||||
|
"""Highly uneven LoRA distribution, random lengths"""
|
||||||
|
for batch_size in [1, 2, 16, 64]:
|
||||||
|
with self.subTest(batch_size=batch_size):
|
||||||
|
x, weights, batch_info, seq_lengths, lora_assignments = (
|
||||||
|
self.create_test_batch(BatchComposition.SKEWED, batch_size)
|
||||||
|
)
|
||||||
|
self.run_test_comparison(
|
||||||
|
x,
|
||||||
|
weights,
|
||||||
|
batch_info,
|
||||||
|
seq_lengths,
|
||||||
|
lora_assignments,
|
||||||
|
f"skewed batch_size={batch_size}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# === Decode Mode Tests ===
|
||||||
|
|
||||||
|
def test_decode_uniform_lora_batch(self):
|
||||||
|
"""Decode mode: All sequences use same LoRA, all length 1"""
|
||||||
|
for batch_size in [1, 2, 16, 64]:
|
||||||
|
with self.subTest(batch_size=batch_size):
|
||||||
|
x, weights, batch_info, seq_lengths, lora_assignments = (
|
||||||
|
self.create_test_batch(
|
||||||
|
BatchComposition.UNIFORM, batch_size, BatchMode.DECODE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.run_test_comparison(
|
||||||
|
x,
|
||||||
|
weights,
|
||||||
|
batch_info,
|
||||||
|
seq_lengths,
|
||||||
|
lora_assignments,
|
||||||
|
f"decode uniform batch_size={batch_size}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_decode_mixed_lora_batch(self):
|
||||||
|
"""Decode mode: Sequences distributed across LoRAs, all length 1"""
|
||||||
|
for batch_size in [1, 2, 16, 64]:
|
||||||
|
with self.subTest(batch_size=batch_size):
|
||||||
|
x, weights, batch_info, seq_lengths, lora_assignments = (
|
||||||
|
self.create_test_batch(
|
||||||
|
BatchComposition.MIXED, batch_size, BatchMode.DECODE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.run_test_comparison(
|
||||||
|
x,
|
||||||
|
weights,
|
||||||
|
batch_info,
|
||||||
|
seq_lengths,
|
||||||
|
lora_assignments,
|
||||||
|
f"decode mixed batch_size={batch_size}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_decode_skewed_lora_batch(self):
|
||||||
|
"""Decode mode: Highly uneven LoRA distribution, all length 1"""
|
||||||
|
for batch_size in [1, 2, 16, 64]:
|
||||||
|
with self.subTest(batch_size=batch_size):
|
||||||
|
x, weights, batch_info, seq_lengths, lora_assignments = (
|
||||||
|
self.create_test_batch(
|
||||||
|
BatchComposition.SKEWED, batch_size, BatchMode.DECODE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.run_test_comparison(
|
||||||
|
x,
|
||||||
|
weights,
|
||||||
|
batch_info,
|
||||||
|
seq_lengths,
|
||||||
|
lora_assignments,
|
||||||
|
f"decode skewed batch_size={batch_size}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -24,6 +24,7 @@ suites = {
|
|||||||
TestFile("lora/test_lora_update.py", 400),
|
TestFile("lora/test_lora_update.py", 400),
|
||||||
TestFile("lora/test_lora_qwen3.py", 97),
|
TestFile("lora/test_lora_qwen3.py", 97),
|
||||||
TestFile("lora/test_lora_radix_cache.py", 100),
|
TestFile("lora/test_lora_radix_cache.py", 100),
|
||||||
|
TestFile("lora/test_chunked_sgmv_backend.py", 30),
|
||||||
TestFile("models/test_embedding_models.py", 73),
|
TestFile("models/test_embedding_models.py", 73),
|
||||||
# TestFile("models/test_clip_models.py", 52),
|
# TestFile("models/test_clip_models.py", 52),
|
||||||
TestFile("models/test_encoder_embedding_models.py", 100),
|
TestFile("models/test_encoder_embedding_models.py", 100),
|
||||||
|
|||||||
Reference in New Issue
Block a user