Files
enginex-mlu590-vllm/vllm_mlu/lora/layers/column_parallel_linear.py

39 lines
1.3 KiB
Python
Raw Permalink Normal View History

2026-04-24 09:50:34 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.lora.layers.column_parallel_linear import ColumnParallelLinearWithLoRA
from vllm_mlu.mlu_hijack_utils import MluHijackObject
vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward_org = ColumnParallelLinearWithLoRA.forward
'''
=============================
Modify by vllm_mlu
=============================
@brief: add smooth_quant_scale and use_tp_weight parameters.
'''
def vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward(
self,
input_,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
assert not use_tp_weight, "LoRa does not support use_tp_weight yet."
assert smooth_quant_scale is None, "LoRA does not support smooth quant yet."
return vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward_org(self, input_)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithLoRA.forward,
vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward,
)