forked from EngineX-Cambricon/enginex-mlu370-vllm
220 lines
7.3 KiB
Python
220 lines
7.3 KiB
Python
from typing import List, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.config import LoRAConfig
|
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
|
split_tensor_along_last_dim,
|
|
tensor_model_parallel_all_reduce)
|
|
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
|
RowParallelLinearWithLoRA,
|
|
LinearScalingRotaryEmbeddingWithLora,
|
|
apply_bias)
|
|
from vllm_mlu.model_executor.layers.rotary_embedding import (
|
|
MLURotaryEmbedding, MLULinearScalingRotaryEmbedding)
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
|
|
|
|
vllm__lora__layers__ColumnParallelLinearWithLoRA__forward_org = ColumnParallelLinearWithLoRA.forward
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: add smooth_quant_scale parameter.
|
|
'''
|
|
def vllm__lora__layers__ColumnParallelLinearWithLoRA__forward(
|
|
self,
|
|
input_,
|
|
smooth_quant_scale: Optional[torch.Tensor] = None
|
|
):
|
|
assert smooth_quant_scale is None, "LoRA does not support smooth quant yet."
|
|
return vllm__lora__layers__ColumnParallelLinearWithLoRA__forward_org(self, input_)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
def vllm__lora__layers__RowParallelLinearWithLoRA__apply(
|
|
self,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor],
|
|
residual: Optional[torch.Tensor]
|
|
) -> torch.Tensor:
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: add residual and bias in matmul
|
|
'''
|
|
output = self.base_layer.quant_method.apply(
|
|
self.base_layer, x, bias, residual)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
if self.bias_stacked is not None:
|
|
self.indices = self.punica_wrapper.token_lora_indices
|
|
output = apply_bias(
|
|
self.indices,
|
|
output,
|
|
self.bias_stacked,
|
|
)
|
|
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
|
self.lora_b_stacked, 1.0)
|
|
return output
|
|
|
|
|
|
def vllm__lora__layers__RowParallelLinearWithLoRA__forward(
|
|
self,
|
|
input_: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None
|
|
):
|
|
# Set up backprop all-reduce.
|
|
if self.base_layer.input_is_parallel:
|
|
input_parallel = input_
|
|
else:
|
|
# TODO: simplify code below
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
splitted_input = split_tensor_along_last_dim(
|
|
input_, num_partitions=self.base_layer.tp_size)
|
|
input_parallel = splitted_input[tp_rank].contiguous()
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: 1) apply residual fusion in matmul like RowParallelLinear
|
|
2) add bias in matmul, not after all reduce
|
|
'''
|
|
# Matrix multiply.
|
|
bias_ = (None if (self.base_layer.tp_rank > 0 or self.base_layer.skip_bias_add) else self.base_layer.bias)
|
|
residual_ = None if self.base_layer.tp_rank > 0 else residual
|
|
output_parallel = self.apply(input_parallel, bias_, residual_)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
|
else:
|
|
output = output_parallel
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: do not add bias after all_reduce
|
|
'''
|
|
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
return output, output_bias
|
|
|
|
|
|
def vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__create_lora_weights(
|
|
self,
|
|
max_loras: int,
|
|
lora_config: LoRAConfig,
|
|
model_config: Optional[PretrainedConfig] = None,
|
|
) -> None:
|
|
scaling_factors = (list(lora_config.long_lora_scaling_factors)
|
|
if lora_config.long_lora_scaling_factors else [])
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: change LinearScalingRotaryEmbedding to MLULinearScalingRotaryEmbedding
|
|
'''
|
|
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
|
|
self.base_layer, MLULinearScalingRotaryEmbedding) else 1.0)
|
|
scaling_factors = sorted(
|
|
list(set([base_scaling_factor] + scaling_factors)))
|
|
self.base_layer = MLULinearScalingRotaryEmbedding(
|
|
self.base_layer.head_size,
|
|
self.base_layer.rotary_dim,
|
|
self.base_layer.max_position_embeddings,
|
|
self.base_layer.base,
|
|
self.base_layer.is_neox_style,
|
|
scaling_factors,
|
|
self.base_layer.dtype,
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
def vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
qk: torch.Tensor
|
|
) -> torch.Tensor:
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: change function prototype to meet forward_mlu in rope
|
|
'''
|
|
return self.base_layer(
|
|
positions,
|
|
qk,
|
|
offsets=self.punica_wrapper.long_lora_indices,
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
@classmethod
|
|
def vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: List,
|
|
model_config: Optional[PretrainedConfig],
|
|
) -> bool:
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: change origin rope type to mlu rope
|
|
'''
|
|
return (type(source_layer) is MLULinearScalingRotaryEmbedding
|
|
or type(source_layer) is MLURotaryEmbedding)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
MluHijackObject.apply_hijack(RowParallelLinearWithLoRA,
|
|
RowParallelLinearWithLoRA.apply,
|
|
vllm__lora__layers__RowParallelLinearWithLoRA__apply)
|
|
MluHijackObject.apply_hijack(ColumnParallelLinearWithLoRA,
|
|
ColumnParallelLinearWithLoRA.forward,
|
|
vllm__lora__layers__ColumnParallelLinearWithLoRA__forward)
|
|
MluHijackObject.apply_hijack(RowParallelLinearWithLoRA,
|
|
RowParallelLinearWithLoRA.forward,
|
|
vllm__lora__layers__RowParallelLinearWithLoRA__forward)
|
|
MluHijackObject.apply_hijack(LinearScalingRotaryEmbeddingWithLora,
|
|
LinearScalingRotaryEmbeddingWithLora.create_lora_weights,
|
|
vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__create_lora_weights)
|
|
MluHijackObject.apply_hijack(LinearScalingRotaryEmbeddingWithLora,
|
|
LinearScalingRotaryEmbeddingWithLora.forward,
|
|
vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__forward)
|
|
MluHijackObject.apply_hijack(LinearScalingRotaryEmbeddingWithLora,
|
|
LinearScalingRotaryEmbeddingWithLora.can_replace_layer,
|
|
vllm__lora__layers__LinearScalingRotaryEmbeddingWithLora__can_replace_layer)
|