update
This commit is contained in:
@@ -1,10 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||
from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper
|
||||
|
||||
__all__ = [
|
||||
"PunicaWrapperBase",
|
||||
"get_punica_wrapper",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,492 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import compute_meta, convert_mapping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
|
||||
|
||||
class PunicaWrapperABC(ABC):
|
||||
"""
|
||||
PunicaWrapper ABC.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(
|
||||
self,
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Update the lora-related metadata
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_shrink(
|
||||
self,
|
||||
y: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
output_slices: tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_b.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_lora_embedding(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA,
|
||||
and this layer only requires the expand operation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_lora_linear(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
output_slices: tuple[int, ...],
|
||||
*,
|
||||
buffer: tuple[torch.Tensor, ...] | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_lora_logits(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PunicaWrapperBase(PunicaWrapperABC):
|
||||
"""
|
||||
PunicaWrapperBase is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the punica.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
max_batches: int,
|
||||
device: torch.device | str,
|
||||
**kwargs,
|
||||
):
|
||||
self._token_lora_indices = torch.empty(
|
||||
max_num_batched_tokens, dtype=torch.long, device=device
|
||||
)
|
||||
self._sampler_indices = torch.empty(
|
||||
max_num_batched_tokens, dtype=torch.long, device=device
|
||||
)
|
||||
self._sampler_indices_padded = torch.empty(
|
||||
max_num_batched_tokens, dtype=torch.long, device=device
|
||||
)
|
||||
self._embeddings_indices = torch.empty(
|
||||
2, max_num_batched_tokens, dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
# 4 is the number of indices tensors.
|
||||
# base_indices, sampler_indices, sampler_indices_padded,
|
||||
# embeddings_indices
|
||||
self.indices_len: list[int | None] = [None] * 4
|
||||
# these attributes are the information required for sgmv kernel
|
||||
self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, device=device)
|
||||
self._seq_lengths = torch.empty(max_batches, dtype=torch.long, device=device)
|
||||
self._lora_indices_per_batch = torch.empty(
|
||||
max_batches, dtype=torch.long, device=device
|
||||
)
|
||||
self.device: torch.device = device
|
||||
self.max_length: int = 0
|
||||
self.token_nums: int = 0
|
||||
self.batch_size: int = -1
|
||||
self.is_prefill = False
|
||||
self.no_lora = False
|
||||
|
||||
def _update_base_metadata(
|
||||
self,
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
):
|
||||
(
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
indices_len,
|
||||
) = convert_mapping(
|
||||
mapping,
|
||||
lora_index_to_id,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
extra_vocab_size,
|
||||
self.device,
|
||||
)
|
||||
self._token_lora_indices[: base_indices.shape[0]].copy_(base_indices)
|
||||
self._sampler_indices[: sampler_indices.shape[0]].copy_(sampler_indices)
|
||||
self._sampler_indices_padded[: sampler_indices_padded.shape[0]].copy_(
|
||||
sampler_indices_padded
|
||||
)
|
||||
self._embeddings_indices[
|
||||
: embeddings_indices.shape[0], : embeddings_indices.shape[1]
|
||||
].copy_(embeddings_indices)
|
||||
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None:
|
||||
(
|
||||
b_seq_start_tensor,
|
||||
seq_length_tensor,
|
||||
lora_indices_tensor,
|
||||
batch_size,
|
||||
max_length,
|
||||
token_nums,
|
||||
no_lora,
|
||||
) = compute_meta(token_lora_tensor)
|
||||
|
||||
self._seq_start_locs[: b_seq_start_tensor.shape[0]].copy_(b_seq_start_tensor)
|
||||
self._seq_lengths[: seq_length_tensor.shape[0]].copy_(seq_length_tensor)
|
||||
self._lora_indices_per_batch[: lora_indices_tensor.shape[0]].copy_(
|
||||
lora_indices_tensor
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
self.max_length = max_length
|
||||
self.token_nums = token_nums
|
||||
self.no_lora = no_lora
|
||||
|
||||
@property
|
||||
def prefill_metadata(
|
||||
self,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
|
||||
"""
|
||||
This property provides a convenient way to access the necessary
|
||||
metadata for prefill-related kernel computations.
|
||||
1. seq_start_locs: Tensor of sequence start positions.
|
||||
2. seq_lengths: Tensor of sequence lengths.
|
||||
3. lora_indices_per_batch: Tensor of lora indices, and an index of
|
||||
-1 means no lora should be applied.
|
||||
4. batch_size: Batch size after clustering identical lora indices.
|
||||
5. max_length: The maximum sequence length in the batch.
|
||||
6. token_nums: The token numbers in the batch.
|
||||
"""
|
||||
return (
|
||||
self._seq_start_locs[: self.batch_size],
|
||||
self._seq_lengths[: self.batch_size],
|
||||
self._lora_indices_per_batch[: self.batch_size],
|
||||
self.batch_size,
|
||||
self.max_length,
|
||||
self.token_nums,
|
||||
)
|
||||
|
||||
@property
|
||||
def token_lora_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides the lora indices corresponding to each token
|
||||
in the batch. An index of -1 means no lora should be applied.
|
||||
"""
|
||||
token_lora_len = self.indices_len[0]
|
||||
return self._token_lora_indices[:token_lora_len]
|
||||
|
||||
@property
|
||||
def sampler_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property is used to access the lora indices specifically for
|
||||
LogitsProcessorWithLoRA.
|
||||
"""
|
||||
sampler_indices_len = self.indices_len[1]
|
||||
return self._sampler_indices[:sampler_indices_len]
|
||||
|
||||
@property
|
||||
def sampler_indices_padded(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to padded sampler indices.
|
||||
"""
|
||||
indices_padded_len = self.indices_len[2]
|
||||
return self._sampler_indices_padded[:indices_padded_len]
|
||||
|
||||
@property
|
||||
def embeddings_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to the indices used for lora embeddings,
|
||||
specifically for VocabParallelEmbeddingWithLoRA.
|
||||
"""
|
||||
embeddings_indices_len = self.indices_len[3]
|
||||
return self._embeddings_indices[:, :embeddings_indices_len]
|
||||
|
||||
def update_metadata(
|
||||
self,
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
self._update_base_metadata(
|
||||
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
|
||||
)
|
||||
|
||||
if mapping.is_prefill:
|
||||
# Update metadata required for prefill-related operators.
|
||||
self._update_prefill_metadata(self.token_lora_indices)
|
||||
self.is_prefill = True
|
||||
else:
|
||||
self.is_prefill = False
|
||||
|
||||
@abstractmethod
|
||||
def add_shrink(
|
||||
self,
|
||||
y: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
Args:
|
||||
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
|
||||
scale (float): Scaling factor for the operation
|
||||
|
||||
"""
|
||||
# TODO: implement it based on torch ops
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
output_slices: tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_b.
|
||||
|
||||
Semantics:
|
||||
offset = offset_start
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
||||
offset += slice
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
||||
output_slices (tuple[int, ...]): Every slice's size
|
||||
offset_start (int): The starting position of y, defaults to 0
|
||||
add_inputs (bool): Defaults to True.
|
||||
|
||||
"""
|
||||
# TODO: implement it based on torch ops
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_lora_embedding(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
and this layer only requires the expand operation.
|
||||
Semantics:
|
||||
y += x @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
# TODO: implement it based on torch ops
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_lora_linear(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
output_slices: tuple[int, ...],
|
||||
*,
|
||||
buffer: tuple[torch.Tensor, ...] | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
||||
scale (float): Scaling factor.
|
||||
output_slices (tuple[int, ...]): Every slice's size.
|
||||
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
|
||||
"""
|
||||
# TODO: implement it based on torch ops
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_lora_logits(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
Semantics:
|
||||
buffer = (x @ lora_a_stacked) * scale
|
||||
y += buffer @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||
lora_b_stacked (torch.Tensor):lora_b's weights.
|
||||
scale (float): Scaling factor.
|
||||
buffer (Optional[torch.Tensor]):Default to None.
|
||||
"""
|
||||
# TODO: implement it based on torch ops
|
||||
raise NotImplementedError
|
||||
|
||||
def moe_lora_align_block_size(
|
||||
self,
|
||||
topk_ids: torch.Tensor,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
max_loras: int,
|
||||
adapter_enabled: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns tokens and experts into block-sized chunks for LoRA-based
|
||||
mixture-of-experts (MoE) execution.
|
||||
"""
|
||||
# TODO: implement it based on torch ops
|
||||
raise NotImplementedError
|
||||
|
||||
def add_lora_fused_moe(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: list[torch.Tensor],
|
||||
lora_b_stacked: list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
shrink_config,
|
||||
expand_config,
|
||||
adapter_enabled: torch.Tensor,
|
||||
mul_routed_weight=False,
|
||||
):
|
||||
"""
|
||||
Performs a fused forward computation for LoRA of
|
||||
Mixture-of-Experts (MoE) layer.
|
||||
"""
|
||||
# TODO: implement it based on torch ops
|
||||
raise NotImplementedError
|
||||
@@ -1,351 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.ops.torch_ops import (
|
||||
bgmv_expand,
|
||||
bgmv_expand_slice,
|
||||
bgmv_shrink,
|
||||
sgmv_expand,
|
||||
sgmv_expand_slice,
|
||||
sgmv_shrink,
|
||||
)
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
|
||||
# The platforms that are compatible with the PyTorch-native implementation can
|
||||
# inherit this class
|
||||
class PunicaWrapperCPU(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperCPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
max_batches: int,
|
||||
device: torch.device | str,
|
||||
**kwargs,
|
||||
):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
|
||||
|
||||
def _shrink_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
# No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_shrink(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
scale,
|
||||
)
|
||||
|
||||
def _shrink_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
||||
|
||||
def _expand_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
# No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
|
||||
|
||||
def _expand_slice_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
# No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand_slice(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_slice_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
bgmv_expand_slice(
|
||||
x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs
|
||||
)
|
||||
|
||||
def _apply_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||
computation, which is suitable for the
|
||||
GEMM of lora'b.
|
||||
"""
|
||||
|
||||
expand_slice_fun: Callable = (
|
||||
self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
|
||||
)
|
||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
|
||||
|
||||
def _apply_shrink(
|
||||
self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float
|
||||
):
|
||||
"""
|
||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||
GEMM of lora'a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
prefill stage, and the `_shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||
should be called.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
shrink_fun: Callable = (
|
||||
self._shrink_prefill if self.is_prefill else self._shrink_decode
|
||||
)
|
||||
shrink_fun(y, x, w_t_all, scale)
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_shrink(
|
||||
self,
|
||||
y: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
prefill stage, and the `_shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||
should be called.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
Args:
|
||||
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
|
||||
scale (float): Scaling factor for the operation
|
||||
"""
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale)
|
||||
|
||||
def add_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
output_slices: tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_b.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
||||
offset += slice
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
||||
output_slices (tuple[int, ...]): Every slice's size
|
||||
add_inputs (bool): Defaults to True.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = offset_start
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self._apply_expand(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
offset_left,
|
||||
output_slices[slice_idx],
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
offset_left += output_slices[slice_idx]
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
|
||||
Semantics:
|
||||
y += x @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
|
||||
# Embedding layer only need expand op
|
||||
expand_fun: Callable = (
|
||||
self._expand_prefill if self.is_prefill else self._expand_decode
|
||||
)
|
||||
expand_fun(y, x, lora_b_stacked, add_inputs)
|
||||
|
||||
def add_lora_linear(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
output_slices: tuple[int, ...],
|
||||
*,
|
||||
buffer: tuple[torch.Tensor, ...] | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
||||
scale (float): Scaling factor.
|
||||
output_slices (tuple[int, ...]): Every slice's size.
|
||||
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
|
||||
"""
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
|
||||
if buffer is None:
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
buffer = tuple(
|
||||
torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
for _ in range(len(output_slices))
|
||||
)
|
||||
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||
self.add_expand(
|
||||
y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
|
||||
)
|
||||
|
||||
def add_lora_logits(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
Semantics:
|
||||
buffer = (x @ lora_a_stacked) * scale
|
||||
y += buffer @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||
lora_b_stacked (torch.Tensor):lora_b's weights.
|
||||
scale (float): Scaling factor.
|
||||
buffer (Optional[torch.Tensor]):Default to None.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = lora_b_stacked.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
# LogitsProcessorWithLoRA always using bgmv.
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
|
||||
bgmv_expand(buffer, lora_b_stacked, y, self.sampler_indices, add_inputs=True)
|
||||
y = y.view_as(y_org)
|
||||
@@ -1,422 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.triton_ops import (
|
||||
LoRAKernelMeta,
|
||||
fused_moe_lora,
|
||||
lora_expand,
|
||||
lora_shrink,
|
||||
)
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
|
||||
@final
|
||||
class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperGPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the punica triton kernel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
max_batches: int,
|
||||
device: torch.device | str,
|
||||
**kwargs,
|
||||
):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
|
||||
|
||||
self.max_loras = kwargs["max_loras"]
|
||||
|
||||
self.token_mapping_meta = LoRAKernelMeta.make(
|
||||
self.max_loras, max_num_batched_tokens, device=device
|
||||
)
|
||||
|
||||
# When speculative decoding is enabled, max_num_samples is
|
||||
# max_batches * (num_speculative_decoding_tokens + 1).
|
||||
# This line can be optimized by replacing max_num_batched_tokens
|
||||
# to max_batches * (num_speculative_decoding_tokens + 1).
|
||||
self.prompt_mapping_meta = LoRAKernelMeta.make(
|
||||
self.max_loras, max_num_batched_tokens, device=device
|
||||
)
|
||||
|
||||
def update_metadata(
|
||||
self,
|
||||
mapping: LoRAMapping,
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
self.is_prefill = mapping.is_prefill
|
||||
self._update_base_metadata(
|
||||
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
|
||||
)
|
||||
|
||||
# Prepare cuda kernel metadata tensors
|
||||
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
|
||||
self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
|
||||
|
||||
def add_shrink(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
|
||||
scale (float): Scaling factor for the operation
|
||||
"""
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
lora_shrink(
|
||||
x,
|
||||
lora_a_stacked,
|
||||
y,
|
||||
*self.token_mapping_meta.meta_args(x.size(0)),
|
||||
scale,
|
||||
)
|
||||
|
||||
def add_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
output_slices: tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_b.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
||||
offset += slice
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensors
|
||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
||||
output_slices (tuple[int, ...]): Every slice's size
|
||||
add_inputs (bool): Defaults to True.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
|
||||
assert x.ndim == 3
|
||||
assert x.size(0) == len(output_slices)
|
||||
num_tokens = x.size(1) # first dimension is the num slices
|
||||
|
||||
lora_expand(
|
||||
x,
|
||||
lora_b_stacked,
|
||||
y,
|
||||
*self.token_mapping_meta.meta_args(num_tokens),
|
||||
offset_start=offset_start,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
|
||||
Semantics:
|
||||
y += x @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
|
||||
lora_expand(
|
||||
x.unsqueeze(dim=0),
|
||||
(lora_b_stacked,),
|
||||
y,
|
||||
*self.token_mapping_meta.meta_args(x.size(0)),
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
def add_lora_linear(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
output_slices: tuple[int, ...],
|
||||
*,
|
||||
buffer: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||
* scale
|
||||
).squeeze(0)
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
||||
scale (float): Scaling factor.
|
||||
output_slices (tuple[int, ...]): Every slice's size.
|
||||
buffer (Optional[torch.Tensor]): Defaults to None.
|
||||
"""
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
|
||||
import vllm.envs as env
|
||||
if env.VLLM_USE_LORA_FUSION:
|
||||
import ixformer.inference.functions as ops
|
||||
|
||||
num_token, m = x.size(0), x.size(-1)
|
||||
k, n = lora_b_stacked[0].size(-1), y.size(-1)
|
||||
if len(lora_a_stacked) == 1 and ops.lora_gemv_optim_condition(num_token, m, k, n):
|
||||
ops.add_lora_linear(y, x, lora_a_stacked, lora_b_stacked,
|
||||
lora_bias_stacked = None, scale = 1.0, output_slices = (1,))
|
||||
return
|
||||
|
||||
assert buffer is None, (
|
||||
"To minimize overhead, the buffer should be created by "
|
||||
".add_lora_linear() instead of being passed in."
|
||||
)
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
# We set the buffer to be float32 by default, refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
# Note: buffer is zeroed inside the shrink op
|
||||
buffer = torch.empty(
|
||||
(len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device
|
||||
)
|
||||
|
||||
self.add_shrink(
|
||||
buffer, # type: ignore
|
||||
x,
|
||||
lora_a_stacked,
|
||||
scale,
|
||||
**kwargs,
|
||||
)
|
||||
self.add_expand(
|
||||
y,
|
||||
buffer, # type: ignore
|
||||
lora_b_stacked,
|
||||
output_slices,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def add_lora_logits(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
Semantics:
|
||||
buffer = (x @ lora_a_stacked) * scale
|
||||
y += buffer @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
scale (float): Scaling factor.
|
||||
buffer (Optional[torch.Tensor]): Default to None.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = lora_b_stacked.size(-1)
|
||||
|
||||
assert buffer is None, (
|
||||
"To minimize overhead, the buffer should be created by "
|
||||
".add_lora_linear() instead of being passed in."
|
||||
)
|
||||
# We set the buffer to be float32 by default, refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
# Note: buffer is zeroed inside the shrink op
|
||||
buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
|
||||
lora_shrink(
|
||||
x,
|
||||
[lora_a_stacked],
|
||||
buffer.unsqueeze(dim=0),
|
||||
*self.prompt_mapping_meta.meta_args(x.size(0)),
|
||||
scale,
|
||||
)
|
||||
|
||||
lora_expand(
|
||||
buffer.unsqueeze(dim=0),
|
||||
[lora_b_stacked],
|
||||
y,
|
||||
*self.prompt_mapping_meta.meta_args(buffer.size(0)),
|
||||
add_inputs=True,
|
||||
)
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def moe_lora_align_block_size(
|
||||
self,
|
||||
topk_ids: torch.Tensor,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
max_loras: int,
|
||||
adapter_enabled: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns tokens and experts into block-sized chunks for LoRA-based
|
||||
mixture-of-experts (MoE) execution.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
sorted_ids = torch.empty(
|
||||
(max_loras * max_num_tokens_padded,),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
# Expert ids must be set default to -1 to prevent a blank block
|
||||
expert_ids = torch.empty(
|
||||
(max_loras * max_num_m_blocks,),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
num_tokens_post_pad = torch.empty(
|
||||
(max_loras), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
|
||||
num_tokens
|
||||
)
|
||||
|
||||
ops.moe_lora_align_block_size(
|
||||
topk_ids,
|
||||
token_lora_mapping,
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
max_num_tokens_padded,
|
||||
max_num_m_blocks,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
adapter_enabled,
|
||||
lora_ids,
|
||||
)
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
def add_lora_fused_moe(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: list[torch.Tensor],
|
||||
lora_b_stacked: list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
shrink_config,
|
||||
expand_config,
|
||||
adapter_enabled: torch.Tensor,
|
||||
mul_routed_weight=False,
|
||||
):
|
||||
"""
|
||||
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
|
||||
"""
|
||||
(_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0))
|
||||
fused_moe_lora(
|
||||
y,
|
||||
x,
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
shrink_config.get("BLOCK_SIZE_M", 64),
|
||||
shrink_config.get("BLOCK_SIZE_N", 64),
|
||||
shrink_config.get("BLOCK_SIZE_K", 32),
|
||||
shrink_config.get("GROUP_SIZE_M", 8),
|
||||
shrink_config.get("NUM_WARPS", 4),
|
||||
shrink_config.get("NUM_STAGES", 3),
|
||||
shrink_config.get("SPLIT_K", 1),
|
||||
expand_config.get("BLOCK_SIZE_M", 64),
|
||||
expand_config.get("BLOCK_SIZE_N", 64),
|
||||
expand_config.get("BLOCK_SIZE_K", 32),
|
||||
expand_config.get("GROUP_SIZE_M", 8),
|
||||
expand_config.get("NUM_WARPS", 4),
|
||||
expand_config.get("NUM_STAGES", 3),
|
||||
expand_config.get("SPLIT_K", 1),
|
||||
mul_routed_weight,
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
|
||||
punica_wrapper_qualname = current_platform.get_punica_wrapper()
|
||||
punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname)
|
||||
punica_wrapper = punica_wrapper_cls(*args, **kwargs)
|
||||
assert punica_wrapper is not None, (
|
||||
"the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong."
|
||||
)
|
||||
logger.info_once("Using %s.", punica_wrapper_qualname.rsplit(".", 1)[1])
|
||||
return punica_wrapper
|
||||
@@ -1,359 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_xla
|
||||
|
||||
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
from vllm.lora.punica_wrapper.utils import convert_mapping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
|
||||
class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperTPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
max_batches: int,
|
||||
device: torch.device | str,
|
||||
**kwargs,
|
||||
):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
|
||||
|
||||
# PunicaWrapperBase defines some tensors with dtype=torch.int64, which
|
||||
# isn't supported by the TPU. So convert those tensors to int32.
|
||||
# Not all of them are used by the TPU so only convert the useful ones.
|
||||
self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32)
|
||||
self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
|
||||
self._sampler_indices_padded = self._sampler_indices_padded.to(
|
||||
dtype=torch.int32
|
||||
)
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True)
|
||||
|
||||
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
|
||||
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
|
||||
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
|
||||
|
||||
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
|
||||
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
|
||||
|
||||
@property
|
||||
def embeddings_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to the indices used for lora embeddings,
|
||||
specifically for VocabParallelEmbeddingWithLoRA.
|
||||
"""
|
||||
return self._embeddings_indices[:]
|
||||
|
||||
@property
|
||||
def sampler_indices_padded(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to padded sampler indices.
|
||||
"""
|
||||
return self._sampler_indices_padded[:]
|
||||
|
||||
def shrink(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale)
|
||||
|
||||
def expand(
|
||||
self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool
|
||||
):
|
||||
return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), add_inputs)
|
||||
|
||||
def expand_slice(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
) -> torch.Tensor:
|
||||
return bgmv_expand_slice(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
self._get_token_lora_indices(x),
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def add_shrink(
|
||||
self,
|
||||
y: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
Args:
|
||||
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
|
||||
scale (float): Scaling factor for the operation
|
||||
"""
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(y, True)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
lora_s = lora_a_stacked[slice_idx]
|
||||
y_s = self.shrink(x, lora_s, scale)
|
||||
y[slice_idx, :, :] = y_s # type: ignore[index]
|
||||
return y
|
||||
|
||||
def add_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
output_slices: tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_b.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
||||
offset += slice
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
||||
output_slices (tuple[int, ...]): Every slice's size
|
||||
add_inputs (bool): Defaults to True.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = 0
|
||||
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
y = self.expand_slice(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
offset_left,
|
||||
output_slices[slice_idx],
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
offset_left += output_slices[slice_idx]
|
||||
return y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
|
||||
Semantics:
|
||||
y += x @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
|
||||
# Embedding layer only needs the expand op
|
||||
return self.expand(y, x, lora_b_stacked, add_inputs)
|
||||
|
||||
def add_lora_linear(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
output_slices: tuple[int, ...],
|
||||
*,
|
||||
buffer: tuple[torch.Tensor, ...] | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will not be changed in-place.
|
||||
x (torch.Tensor): Input tensor (T, E)
|
||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
||||
scale (float): Scaling factor.
|
||||
output_slices (tuple[int, ...]): Every slice's size.
|
||||
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
|
||||
"""
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
|
||||
if buffer is None:
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
T = x.size(0)
|
||||
buffer = torch.zeros(
|
||||
(len(output_slices), T, r),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||
return self.add_expand(
|
||||
y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
|
||||
)
|
||||
|
||||
def add_lora_logits(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
Semantics:
|
||||
buffer = (x @ lora_a_stacked) * scale
|
||||
y += buffer @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||
lora_b_stacked (torch.Tensor):lora_b's weights.
|
||||
scale (float): Scaling factor.
|
||||
buffer (Optional[torch.Tensor]):Default to None.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
|
||||
buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale)
|
||||
y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
|
||||
return y.view_as(y_org)
|
||||
|
||||
# This performs the same tensor ops as the base method, except it does them
|
||||
# on the CPU then transfers the results to the TPU
|
||||
def _update_base_metadata(
|
||||
self,
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
):
|
||||
# Make sure we don't accidentally collect outside operations
|
||||
torch_xla.sync()
|
||||
|
||||
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
||||
# TODO: Should this happen inside mapping internally? If so how can we
|
||||
# avoid having backend specific LoRAMapping classes?
|
||||
mapping.prompt_mapping = self._pad_prompt_mapping(mapping.prompt_mapping)
|
||||
|
||||
(
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
indices_len,
|
||||
) = convert_mapping(
|
||||
mapping,
|
||||
lora_index_to_id,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
extra_vocab_size,
|
||||
"cpu",
|
||||
)
|
||||
self._token_lora_indices = self._pad_to_shape(
|
||||
base_indices, self._token_lora_indices.shape, dims=1
|
||||
).to(self.device)
|
||||
self._sampler_indices = self._pad_to_shape(
|
||||
sampler_indices, self._sampler_indices.shape, dims=1
|
||||
).to(self.device)
|
||||
self._sampler_indices_padded = self._pad_to_shape(
|
||||
sampler_indices_padded, self._sampler_indices_padded.shape, dims=1
|
||||
).to(self.device)
|
||||
self._embeddings_indices = self._pad_to_shape(
|
||||
embeddings_indices, self._embeddings_indices.shape, dims=2
|
||||
).to(self.device)
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None:
|
||||
self.batch_size = 1
|
||||
self._lora_indices_per_batch[: self.batch_size] = token_lora_tensor[
|
||||
: self.batch_size
|
||||
]
|
||||
|
||||
def _pad_prompt_mapping(self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
|
||||
num_reqs = len(prompt_mapping)
|
||||
|
||||
# From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
|
||||
# import
|
||||
MIN_NUM_SEQS = 8
|
||||
|
||||
padded_num_reqs = max(2 ** math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
|
||||
pad_len = padded_num_reqs - num_reqs
|
||||
|
||||
padding = [-1] * pad_len
|
||||
return tuple(list(prompt_mapping) + padding)
|
||||
|
||||
def _pad_to_shape(self, src, target_shape, dims=1):
|
||||
if dims == 1:
|
||||
pad_len = target_shape[0] - src.shape[0]
|
||||
return F.pad(src, (0, pad_len), value=0).to(torch.int32)
|
||||
else:
|
||||
pad_rows = target_shape[0] - src.shape[0]
|
||||
pad_cols = target_shape[1] - src.shape[1]
|
||||
return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32)
|
||||
@@ -1,279 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.ops.ipex_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
|
||||
@final
|
||||
class PunicaWrapperXPU(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperXPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the punica ipex kernel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
max_batches: int,
|
||||
device: torch.device | str,
|
||||
**kwargs,
|
||||
):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
|
||||
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
|
||||
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
|
||||
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
|
||||
|
||||
def update_metadata(
|
||||
self,
|
||||
mapping: LoRAMapping,
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
self.is_prefill = mapping.is_prefill
|
||||
self._update_base_metadata(
|
||||
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
|
||||
)
|
||||
|
||||
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
|
||||
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
|
||||
|
||||
def _apply_shrink(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x), scale)
|
||||
|
||||
def _apply_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
token_lora_indices = self._get_token_lora_indices(x)
|
||||
bgmv_expand_slice(
|
||||
x, w_t_all, y, token_lora_indices, y_offset, y_slice_size, add_inputs
|
||||
)
|
||||
|
||||
def add_shrink(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
|
||||
scale (float): Scaling factor for the operation
|
||||
"""
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale)
|
||||
|
||||
def add_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
output_slices: tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_b.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
||||
offset += slice
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensors
|
||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
||||
output_slices (tuple[int, ...]): Every slice's size
|
||||
add_inputs (bool): Defaults to True.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
|
||||
assert x.ndim == 3
|
||||
assert x.size(0) == len(output_slices)
|
||||
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self._apply_expand(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
offset_start,
|
||||
output_slices[slice_idx],
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
offset_start += output_slices[slice_idx]
|
||||
y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
|
||||
Semantics:
|
||||
y += x @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
token_lora_indices = self._get_token_lora_indices(x)
|
||||
bgmv_expand(x, lora_b_stacked, y, token_lora_indices, add_inputs)
|
||||
|
||||
def add_lora_linear(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
output_slices: tuple[int, ...],
|
||||
*,
|
||||
buffer: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
|
||||
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
|
||||
scale (float): Scaling factor.
|
||||
output_slices (tuple[int, ...]): Every slice's size.
|
||||
buffer (Optional[torch.Tensor]): Defaults to None.
|
||||
"""
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
|
||||
if buffer is None:
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
# We set the buffer to be float32 by default, refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
buffer = torch.zeros( # type: ignore
|
||||
(len(output_slices), x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
self.add_shrink(
|
||||
buffer, # type: ignore
|
||||
x,
|
||||
lora_a_stacked,
|
||||
scale,
|
||||
**kwargs,
|
||||
)
|
||||
self.add_expand(
|
||||
y,
|
||||
buffer, # type: ignore
|
||||
lora_b_stacked,
|
||||
output_slices,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def sampler_indices_padded(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to padded sampler indices.
|
||||
"""
|
||||
return self._sampler_indices_padded[:]
|
||||
|
||||
def add_lora_logits(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
Semantics:
|
||||
buffer = (x @ lora_a_stacked) * scale
|
||||
y += buffer @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
scale (float): Scaling factor.
|
||||
buffer (Optional[torch.Tensor]): Default to None.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = lora_b_stacked.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default, refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale)
|
||||
bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
|
||||
return y.view_as(y_org)
|
||||
@@ -1,150 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
|
||||
|
||||
def compute_meta(
|
||||
token_lora_tensor: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
|
||||
"""
|
||||
Get the information required for the sgmv kernel. With the features:
|
||||
1. If consecutive requests in the batch use the same LoRA, this function
|
||||
will combine them into a single request, improving sgmv kernel inference
|
||||
performance.
|
||||
2. At the beginning of each prefill stage inference, recalculations are
|
||||
needed based on the input, but only once.
|
||||
"""
|
||||
|
||||
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
|
||||
token_lora_tensor, return_counts=True
|
||||
)
|
||||
cum_result = torch.cumsum(seq_length_tensor, dim=0)
|
||||
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
|
||||
b_seq_start_tensor[1:].copy_(cum_result[:-1])
|
||||
max_length = seq_length_tensor.max().item()
|
||||
token_nums = seq_length_tensor.sum().item()
|
||||
batch_size = lora_indices_tensor.size(0)
|
||||
no_lora = False
|
||||
# -1 means no lora should be applied. Use `no_lora` to determine whether
|
||||
# the current step requires LoRA. If LoRA is not needed, the prefill stage
|
||||
# does not need to launch the triton kernel, which can improve performance
|
||||
if batch_size == 1 and lora_indices_tensor == -1:
|
||||
no_lora = True
|
||||
return (
|
||||
b_seq_start_tensor,
|
||||
seq_length_tensor,
|
||||
lora_indices_tensor,
|
||||
batch_size,
|
||||
max_length,
|
||||
token_nums,
|
||||
no_lora,
|
||||
)
|
||||
|
||||
|
||||
# TODO see if this can be vectorized
|
||||
def convert_mapping(
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: list[int | None],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
device: torch.device,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[int]]:
|
||||
"""Converts LoRAMapping to index tensors.
|
||||
|
||||
Args:
|
||||
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
|
||||
lora_index_to_id: List mapping LoRA ids to LoRA indices.
|
||||
max_loras: Maximum number of LoRAs.
|
||||
vocab_size: Model vocab size.
|
||||
extra_vocab_size: Extra vocab size each LoRA can have.
|
||||
|
||||
Returns:
|
||||
A tuple of tensors:
|
||||
base_indices: Tensor of shape [batch_size] mapping batch rows to
|
||||
LoRA indices.
|
||||
sampler_indices: Tensor of shape [batch_size] mapping requests to
|
||||
LoRA indices for sampler. For generation, this will be the
|
||||
same as base_indices. For prefill, this will map requests
|
||||
to LoRA indices.
|
||||
sampler_indices_padded: Tensor of shape [batch_size] mapping
|
||||
requests to LoRA indices for sampler with padding.
|
||||
Same as sampler_indices, but -1 is replaced with
|
||||
max_loras.
|
||||
embeddings_indices: Tensor of shape [2, batch_size] mapping
|
||||
requests to embedding indices. First row is for embeddings
|
||||
added by the LoRAs, second row is for the LoRA.lora_a
|
||||
embeddings.
|
||||
indices_len: List of lengths of the above tensors. It contains
|
||||
(base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices).
|
||||
"""
|
||||
index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
|
||||
embedding_indices = index_mapping_indices.copy()
|
||||
lora_indices = index_mapping_indices.copy()
|
||||
|
||||
prompt_mapping: list[int] = [
|
||||
lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping
|
||||
]
|
||||
lora_idx = None
|
||||
for i in range(len(index_mapping_indices)):
|
||||
# TODO index can be slow. optimize
|
||||
lora_idx = (
|
||||
lora_index_to_id.index(index_mapping_indices[i])
|
||||
if index_mapping_indices[i] > 0
|
||||
else -1
|
||||
)
|
||||
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
||||
lora_indices[i] = lora_idx
|
||||
|
||||
indices_list: list[list[int] | torch.Tensor] = [
|
||||
index_mapping_indices,
|
||||
lora_indices,
|
||||
embedding_indices,
|
||||
]
|
||||
|
||||
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
|
||||
prompt_mapping_tensor = torch.tensor(
|
||||
prompt_mapping, dtype=torch.long, device=device
|
||||
)
|
||||
embeddings_indices = torch.stack(
|
||||
[
|
||||
indices[2] * extra_vocab_size,
|
||||
indices[2] * (vocab_size + extra_vocab_size),
|
||||
]
|
||||
)
|
||||
embeddings_indices = torch.where(
|
||||
embeddings_indices == -1, max_loras - 1, embeddings_indices
|
||||
)
|
||||
base_indices = indices[1]
|
||||
sampler_indices = prompt_mapping_tensor
|
||||
sampler_indices_padded = sampler_indices.clone()
|
||||
sampler_indices_padded = torch.where(
|
||||
sampler_indices_padded == -1, max_loras - 1, sampler_indices_padded
|
||||
)
|
||||
sampler_indices_padded = torch.arange(
|
||||
0, len(sampler_indices_padded), device=device, dtype=torch.long
|
||||
) + (sampler_indices_padded * len(sampler_indices_padded))
|
||||
|
||||
# Contain length of indices tensors. Used to index into each tensor.
|
||||
indices_len = [
|
||||
base_indices.shape[-1],
|
||||
sampler_indices.shape[-1],
|
||||
sampler_indices_padded.shape[-1],
|
||||
embeddings_indices.shape[-1],
|
||||
]
|
||||
|
||||
return (
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
indices_len,
|
||||
)
|
||||
Reference in New Issue
Block a user