Files
enginex-mlu590-vllm/vllm_mlu/lora/punica_wrapper/punica_mlu.py
2026-04-24 09:58:03 +08:00

89 lines
2.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Optional, Tuple, Union, final
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm_mlu.lora.ops.triton_ops import sgmv_expand_mlu
from vllm_mlu.lora.ops.triton_ops import sgmv_expand_slice_mlu
from vllm_mlu.lora.ops.triton_ops import sgmv_shrink_mlu
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
@final
class PunicaWrapperMLU(PunicaWrapperCPU):
"""
PunicaWrapperMLU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica triton kernel.
"""
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
add_inputs,
)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_inputs,
)