Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm_mlu/vllm_mlu/lora/layers.py
2026-02-04 17:22:39 +08:00

220 lines
7.3 KiB
Python

from typing import List, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce)
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
apply_bias)
from vllm_mlu.model_executor.layers.rotary_embedding import (
MLURotaryEmbedding, MLULinearScalingRotaryEmbedding)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
vllm__lora__layers__ColumnParallelLinearWithLoRA__forward_org = ColumnParallelLinearWithLoRA.forward
'''
=============================
Modify by vllm_mlu
=============================
@brief: add smooth_quant_scale parameter.
'''
def vllm__lora__layers__ColumnParallelLinearWithLoRA__forward(
self,
input_,
smooth_quant_scale: Optional[torch.Tensor] = None
):
assert smooth_quant_scale is None, "LoRA does not support smooth quant yet."
return vllm__lora__layers__ColumnParallelLinearWithLoRA__forward_org(self, input_)
'''
==================
End of MLU Hijack
==================
'''
def vllm__lora__layers__RowParallelLinearWithLoRA__apply(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
residual: Optional[torch.Tensor]
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual and bias in matmul
'''
output = self.base_layer.quant_method.apply(
self.base_layer, x, bias, residual)
'''
==================
End of MLU Hijack
==================
'''
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
return output
def vllm__lora__layers__RowParallelLinearWithLoRA__forward(
self,
input_: torch.Tensor,
residual: Optional[torch.Tensor] = None
):
# Set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1) apply residual fusion in matmul like RowParallelLinear
2) add bias in matmul, not after all reduce
'''
# Matrix multiply.
bias_ = (None if (self.base_layer.tp_rank > 0 or self.base_layer.skip_bias_add) else self.base_layer.bias)
residual_ = None if self.base_layer.tp_rank > 0 else residual
output_parallel = self.apply(input_parallel, bias_, residual_)
'''
==================
End of MLU Hijack
==================
'''
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
'''
=============================
Modify by vllm_mlu
=============================
@brief: do not add bias after all_reduce
'''
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
'''
==================
End of MLU Hijack
==================
'''
return output, output_bias
def vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
scaling_factors = (list(lora_config.long_lora_scaling_factors)
if lora_config.long_lora_scaling_factors else [])
'''
=============================
Modify by vllm_mlu
=============================
@brief: change LinearScalingRotaryEmbedding to MLULinearScalingRotaryEmbedding
'''
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
self.base_layer, MLULinearScalingRotaryEmbedding) else 1.0)
scaling_factors = sorted(
list(set([base_scaling_factor] + scaling_factors)))
self.base_layer = MLULinearScalingRotaryEmbedding(
self.base_layer.head_size,
self.base_layer.rotary_dim,
self.base_layer.max_position_embeddings,
self.base_layer.base,
self.base_layer.is_neox_style,
scaling_factors,
self.base_layer.dtype,
)
'''
==================
End of MLU Hijack
==================
'''
def vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__forward(
self,
positions: torch.Tensor,
qk: torch.Tensor
) -> torch.Tensor:
'''
=============================
Modify by vllm_mlu
=============================
@brief: change function prototype to meet forward_mlu in rope
'''
return self.base_layer(
positions,
qk,
offsets=self.punica_wrapper.long_lora_indices,
)
'''
==================
End of MLU Hijack
==================
'''
@classmethod
def vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
'''
=============================
Modify by vllm_mlu
=============================
@brief: change origin rope type to mlu rope
'''
return (type(source_layer) is MLULinearScalingRotaryEmbedding
or type(source_layer) is MLURotaryEmbedding)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(RowParallelLinearWithLoRA,
RowParallelLinearWithLoRA.apply,
vllm__lora__layers__RowParallelLinearWithLoRA__apply)
MluHijackObject.apply_hijack(ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithLoRA.forward,
vllm__lora__layers__ColumnParallelLinearWithLoRA__forward)
MluHijackObject.apply_hijack(RowParallelLinearWithLoRA,
RowParallelLinearWithLoRA.forward,
vllm__lora__layers__RowParallelLinearWithLoRA__forward)
MluHijackObject.apply_hijack(LinearScalingRotaryEmbeddingWithLora,
LinearScalingRotaryEmbeddingWithLora.create_lora_weights,
vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__create_lora_weights)
MluHijackObject.apply_hijack(LinearScalingRotaryEmbeddingWithLora,
LinearScalingRotaryEmbeddingWithLora.forward,
vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__forward)
MluHijackObject.apply_hijack(LinearScalingRotaryEmbeddingWithLora,
LinearScalingRotaryEmbeddingWithLora.can_replace_layer,
vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__can_replace_layer)