Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/punica.py

116 lines
2.8 KiB
Python
Raw Normal View History

2026-02-04 17:22:39 +08:00
from typing import Optional
import torch
from vllm.lora.punica import PunicaWrapper
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.lora.ops.sgmv_expand import sgmv_expand_mlu
from vllm_mlu.lora.ops.sgmv_expand_slice import sgmv_expand_slice_mlu
from vllm_mlu.lora.ops.sgmv_shrink import sgmv_shrink_mlu
def vllm__lora__punica__PunicaWrapper__shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
'''
=============================
Modify by vllm_mlu
=============================
@brief: Change function from sgmv_shrink to sgmv_shrink_mlu.
'''
sgmv_shrink_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
'''
==================
End of MLU Hijack
==================
'''
def vllm__lora__punica__PunicaWrapper__expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
'''
=============================
Modify by vllm_mlu
=============================
@brief: Change function from sgmv_expand to sgmv_expand_mlu.
'''
sgmv_expand_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
add_input,
)
'''
==================
End of MLU Hijack
==================
'''
def vllm__lora__punica__PunicaWrapper__expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
'''
=============================
Modify by vllm_mlu
=============================
@brief: Change function from sgmv_expand_slice to sgmv_expand_slice_mlu.
'''
sgmv_expand_slice_mlu(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_input,
)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(PunicaWrapper,
PunicaWrapper.shrink_prefill,
vllm__lora__punica__PunicaWrapper__shrink_prefill)
MluHijackObject.apply_hijack(PunicaWrapper,
PunicaWrapper.expand_prefill,
vllm__lora__punica__PunicaWrapper__expand_prefill)
MluHijackObject.apply_hijack(PunicaWrapper,
PunicaWrapper.expand_slice_prefill,
vllm__lora__punica__PunicaWrapper__expand_slice_prefill)