116 lines
2.8 KiB
Python
116 lines
2.8 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from vllm.lora.punica import PunicaWrapper
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
from vllm_mlu.lora.ops.sgmv_expand import sgmv_expand_mlu
|
|
from vllm_mlu.lora.ops.sgmv_expand_slice import sgmv_expand_slice_mlu
|
|
from vllm_mlu.lora.ops.sgmv_shrink import sgmv_shrink_mlu
|
|
|
|
|
|
def vllm__lora__punica__PunicaWrapper__shrink_prefill(
|
|
self,
|
|
y: torch.Tensor,
|
|
x: torch.Tensor,
|
|
w_t_all: torch.Tensor,
|
|
scale: float,
|
|
):
|
|
#No LoRA request, so return directly
|
|
if self.no_lora:
|
|
return
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: Change function from sgmv_shrink to sgmv_shrink_mlu.
|
|
'''
|
|
sgmv_shrink_mlu(
|
|
x,
|
|
w_t_all,
|
|
y,
|
|
*self.prefill_metadata,
|
|
scale,
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
def vllm__lora__punica__PunicaWrapper__expand_prefill(
|
|
self,
|
|
y: torch.Tensor,
|
|
x: torch.Tensor,
|
|
w_t_all: torch.Tensor,
|
|
add_input: bool,
|
|
):
|
|
#No LoRA request, so return directly
|
|
if self.no_lora:
|
|
return
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: Change function from sgmv_expand to sgmv_expand_mlu.
|
|
'''
|
|
sgmv_expand_mlu(
|
|
x,
|
|
w_t_all,
|
|
y,
|
|
*self.prefill_metadata,
|
|
add_input,
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
def vllm__lora__punica__PunicaWrapper__expand_slice_prefill(
|
|
self,
|
|
y: torch.Tensor,
|
|
x: torch.Tensor,
|
|
w_t_all: torch.Tensor,
|
|
y_offset: Optional[int],
|
|
y_slice_size: Optional[int],
|
|
add_input: bool,
|
|
):
|
|
#No LoRA request, so return directly
|
|
if self.no_lora:
|
|
return
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: Change function from sgmv_expand_slice to sgmv_expand_slice_mlu.
|
|
'''
|
|
sgmv_expand_slice_mlu(
|
|
x,
|
|
w_t_all,
|
|
y,
|
|
*self.prefill_metadata,
|
|
y_offset,
|
|
y_slice_size,
|
|
add_input,
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
MluHijackObject.apply_hijack(PunicaWrapper,
|
|
PunicaWrapper.shrink_prefill,
|
|
vllm__lora__punica__PunicaWrapper__shrink_prefill)
|
|
MluHijackObject.apply_hijack(PunicaWrapper,
|
|
PunicaWrapper.expand_prefill,
|
|
vllm__lora__punica__PunicaWrapper__expand_prefill)
|
|
MluHijackObject.apply_hijack(PunicaWrapper,
|
|
PunicaWrapper.expand_slice_prefill,
|
|
vllm__lora__punica__PunicaWrapper__expand_slice_prefill)
|