89 lines
2.2 KiB
Python
89 lines
2.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
"""
|
|
Based on:
|
|
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
|
Punica: Multi-Tenant LoRA Serving.
|
|
https://arxiv.org/abs/2310.18547
|
|
"""
|
|
|
|
from typing import Optional, Tuple, Union, final
|
|
|
|
import torch
|
|
|
|
from vllm.triton_utils import HAS_TRITON
|
|
|
|
if HAS_TRITON:
|
|
from vllm_mlu.lora.ops.triton_ops import sgmv_expand_mlu
|
|
from vllm_mlu.lora.ops.triton_ops import sgmv_expand_slice_mlu
|
|
from vllm_mlu.lora.ops.triton_ops import sgmv_shrink_mlu
|
|
|
|
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
|
|
|
|
|
|
@final
|
|
class PunicaWrapperMLU(PunicaWrapperCPU):
|
|
"""
|
|
PunicaWrapperMLU is designed to manage and provide metadata for the punica
|
|
kernel. The main function is to maintain the state information for
|
|
Multi-LoRA, and to provide the interface for the punica triton kernel.
|
|
"""
|
|
|
|
def _shrink_prefill(
|
|
self,
|
|
y: torch.Tensor,
|
|
x: torch.Tensor,
|
|
w_t_all: torch.Tensor,
|
|
scale: float,
|
|
):
|
|
#No LoRA request, so return directly
|
|
if self.no_lora:
|
|
return
|
|
sgmv_shrink_mlu(
|
|
x,
|
|
w_t_all,
|
|
y,
|
|
*self.prefill_metadata,
|
|
scale,
|
|
)
|
|
|
|
def _expand_prefill(
|
|
self,
|
|
y: torch.Tensor,
|
|
x: torch.Tensor,
|
|
w_t_all: torch.Tensor,
|
|
add_inputs: bool,
|
|
):
|
|
#No LoRA request, so return directly
|
|
if self.no_lora:
|
|
return
|
|
sgmv_expand_mlu(
|
|
x,
|
|
w_t_all,
|
|
y,
|
|
*self.prefill_metadata,
|
|
add_inputs,
|
|
)
|
|
|
|
def _expand_slice_prefill(
|
|
self,
|
|
y: torch.Tensor,
|
|
x: torch.Tensor,
|
|
w_t_all: torch.Tensor,
|
|
y_offset: int,
|
|
y_slice_size: int,
|
|
add_inputs: bool,
|
|
):
|
|
#No LoRA request, so return directly
|
|
if self.no_lora:
|
|
return
|
|
sgmv_expand_slice_mlu(
|
|
x,
|
|
w_t_all,
|
|
y,
|
|
*self.prefill_metadata,
|
|
y_offset,
|
|
y_slice_size,
|
|
add_inputs,
|
|
) |