Files
2026-04-24 09:58:03 +08:00

694 lines
23 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional, Any
import torch
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, split_tensor_along_last_dim,
get_parallel_rank_with_group, get_parallel_world_size_with_group,
get_tp_world_group, get_tp_world_world_size, get_tp_world_rank)
from vllm.distributed.communication_op import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.linear import (
WEIGHT_LOADER_V2_SUPPORTED, UnquantizedLinearMethod, LinearBase,
ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantLinearMethod
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED.extend([
"GPTQMluLinearMethod",
"AWQMluLinearMethod"
])
vllm__module_executor__layers__linear__LinearBase____init__org = LinearBase.__init__
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org = MergedColumnParallelLinear.weight_loader
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org = RowParallelLinear.weight_loader
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual parameter.
@brief: dispatch unquantized_gemm to mlu ops.
'''
def vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
residual: torch.Tensor | None = None
) -> torch.Tensor:
beta = 0.0
if residual is not None:
beta = 1.0
residual = residual.view(-1, residual.shape[-1])
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
return mlu_ops.matmul(x.reshape(x.numel() // x.shape[-1], x.shape[-1]),
layer.weight,
bias, residual, 'none', 1.0, beta).view(res_shape)
'''
==================
End of MLU Hijack
==================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group and keep_full_weights parameters.
'''
def vllm__module_executor__layers__linear__LinearBase____init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
vllm__module_executor__layers__linear__LinearBase____init__org(
self=self,
input_size=input_size,
output_size=output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add self.tp_group, world_size and tp_rank to support data parallel and moe expert parallel
'''
self.tp_group = tp_group
self.tp_world_size = get_parallel_world_size_with_group(self.tp_group)
self.tp_size = self.tp_world_size
self.tp_rank = get_parallel_rank_with_group(self.tp_group)
self.keep_full_weights = keep_full_weights
if self.keep_full_weights or disable_tp:
self.tp_group = None
self.tp_world_size = 1
self.tp_size = self.tp_world_size
self.tp_rank = 0
self.tp_world_size_org = get_tp_world_world_size()
self.tp_rank_org = get_tp_world_rank()
'''
=================
End of MLU Hijack
=================
'''
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group and keep_full_weights parameters.
'''
def vllm__module_executor__layers__linear__ColumnParallelLinear____init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
output_sizes: list[int] | None = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
super(ColumnParallelLinear, self).__init__(
input_size,
output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
tp_group=tp_group,
keep_full_weights=keep_full_weights,
return_bias=return_bias,
disable_tp=disable_tp,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: self.tp_size and self.tp_rank has been initialized in LinearBase.__init__
'''
# Divide the weight matrix along the last dimension.
# self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
# self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
'''
=================
End of MLU Hijack
=================
'''
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size) for output_size in self.output_sizes
]
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group in create_weights
'''
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
tp_group=self.tp_group,
)
'''
=================
End of MLU Hijack
=================
'''
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=params_dtype)
)
set_weight_attrs(
self.bias,
{
"output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add smooth_quant_scale and use_tp_weight parameters.
'''
def vllm__module_executor__layers__linear__ColumnParallelLinear__forward(
self,
input_,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add input_scale and use_tp_weight parameter.
'''
kwargs = {'bias': bias}
if use_tp_weight:
kwargs['use_tp_weight'] = use_tp_weight
if smooth_quant_scale is not None:
kwargs['input_scale'] = smooth_quant_scale
output_parallel = self.quant_method.apply(self, input_, **kwargs)
'''
==================
End of MLU Hijack
==================
'''
if self.gather_output and self.tp_size > 1:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group param to tensor_model_parallel_all_gather
'''
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel, dim=-1, tp_group=self.tp_group)
'''
=================
End of MLU Hijack
=================
'''
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group and keep_full_weights parameters.
'''
def vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
self.output_sizes = output_sizes
'''
=============================
Modify by vllm_mlu
=============================
@brief: checkout output_sizes after init to get self.tp_world_size
@brief: add keep_full_weights for dp parallelize shared expert
'''
super(MergedColumnParallelLinear, self).__init__(
input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
output_sizes=self.output_sizes,
prefix=prefix,
tp_group=tp_group,
keep_full_weights=keep_full_weights,
return_bias=return_bias,
disable_tp=disable_tp,
)
assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
if self.keep_full_weights:
tp_size = self.tp_world_size_org
if isinstance(self.quant_method, UnquantizedLinearMethod):
out_dim, in_dim = self.weight.shape
out_dim_tp = divide(out_dim, tp_size)
self.tp_weight = Parameter(
self.weight.data.new_empty((out_dim_tp, in_dim)),
requires_grad=False,
)
elif (isinstance(self.quant_method, SmoothQuantLinearMethod)
and quant_config.input_quant_method == "per_token"):
out_dim, in_dim = self.qweight.shape
out_dim_tp = divide(out_dim, tp_size)
self.tp_qweight = Parameter(
self.qweight.data.new_empty((out_dim_tp, in_dim)),
requires_grad=False,
)
self.tp_per_channel_scale = Parameter(
self.per_channel_scale.data.new_empty((out_dim_tp)),
requires_grad=False,
)
else:
raise TypeError(f"quant method is expected to be unquantized or smoothquant per-token")
'''
=================
End of MLU Hijack
=================
'''
'''
=================
End of MLU Hijack
=================
'''
def vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: int | None = None,
):
loaded_weight_orig = loaded_weight
output_dim = getattr(param, "output_dim", None)
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader_org(
self=self,
param=param,
loaded_weight=loaded_weight,
loaded_shard_id=loaded_shard_id,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add keep_full_weights for dp parallelize shared expert
'''
# load into tp weight
if self.keep_full_weights:
tp_size = self.tp_world_size_org
tp_rank = self.tp_rank_org
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
start_idx = tp_rank * shard_size
if isinstance(self.quant_method, UnquantizedLinearMethod):
tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size)
tp_weight_shard = self.tp_weight.narrow(output_dim, shard_offset, shard_size)
tp_weight_shard.copy_(tp_weight)
elif isinstance(self.quant_method, SmoothQuantLinearMethod):
if output_dim is None:
return
tp_weight = loaded_weight_orig.narrow(output_dim, start_idx, shard_size)
if loaded_weight_orig.ndim == 1:
tp_weight_shard = self.tp_per_channel_scale.narrow(output_dim, shard_offset, shard_size)
elif loaded_weight_orig.ndim == 2:
tp_weight_shard = self.tp_qweight.narrow(output_dim, shard_offset, shard_size)
else:
raise ValueError("only support rank 1 and 2 when using tp_weight")
tp_weight_shard.copy_(tp_weight)
else:
raise TypeError(f"quant method is expected to be either unquantized or smoothquant")
'''
=================
End of MLU Hijack
=================
'''
def vllm__module_executor__layers__linear__RowParallelLinear____init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
tp_group: Any = None,
keep_full_weights: bool = False,
return_bias: bool = True,
disable_tp: bool = False,
):
super(RowParallelLinear, self).__init__(
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
tp_group=tp_group,
keep_full_weights=keep_full_weights,
return_bias=return_bias,
disable_tp=disable_tp,
)
# Divide the weight matrix along the last dimension
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group in create_weights
'''
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
else self.weight_loader
),
tp_group=self.tp_group,
)
'''
=================
End of MLU Hijack
=================
'''
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(
self.bias,
{
"output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else:
self.register_parameter("bias", None)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add keep_full_weights for dp parallelize shared expert
'''
if self.keep_full_weights:
tp_size = self.tp_world_size_org
if isinstance(self.quant_method, UnquantizedLinearMethod):
out_dim, in_dim = self.weight.data.shape
in_dim_tp = divide(in_dim, tp_size)
self.tp_weight = Parameter(self.weight.data.new_empty((out_dim, in_dim_tp)),
requires_grad=False)
elif (isinstance(self.quant_method, SmoothQuantLinearMethod)
and quant_config.input_quant_method == "per_token"):
out_dim, in_dim = self.qweight.data.shape
in_dim_tp = divide(in_dim, tp_size)
self.tp_qweight = Parameter(self.qweight.data.new_empty((out_dim, in_dim_tp)),
requires_grad=False)
if hasattr(self, "smooth"):
assert len(self.smooth.shape) == 1, "smooth should be a 1D tensor"
dim = self.smooth.shape[0]
dim_tp = divide(dim, tp_size)
self.tp_smooth = Parameter(self.smooth.data.new_empty((dim_tp)),
requires_grad=False)
else:
raise TypeError("quant method expected to be unquantized or smoothquant per-token")
'''
=================
End of MLU Hijack
=================
'''
self.update_param_tp_status()
def vllm__module_executor__layers__linear__RowParallelLinear__weight_loader(
self, param: Parameter, loaded_weight: torch.Tensor
):
input_dim = getattr(param, "input_dim", None)
loaded_weight_orig = loaded_weight
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader_org(
self=self,
param=param,
loaded_weight=loaded_weight,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add keep_full_weights for dp parallelize shared expert
'''
if self.keep_full_weights:
if input_dim is None:
return
tp_size = self.tp_world_size_org
tp_rank = self.tp_rank_org
shard_size = divide(loaded_weight_orig.shape[input_dim], tp_size)
start_idx = tp_rank * shard_size
if isinstance(self.quant_method, UnquantizedLinearMethod):
shard_view = self.weight.narrow(input_dim, start_idx, shard_size)
self.tp_weight.copy_(shard_view)
elif isinstance(self.quant_method, SmoothQuantLinearMethod):
if loaded_weight_orig.ndim == 1:
shard_view = self.smooth.narrow(input_dim, start_idx, shard_size)
self.tp_smooth.copy_(shard_view)
elif loaded_weight_orig.ndim == 2:
shard_view = self.qweight.narrow(input_dim, start_idx, shard_size)
self.tp_qweight.copy_(shard_view)
else:
raise ValueError("only rank 1 and 2 is supported for tp_weight")
else:
raise TypeError("quant method is expected to be UnquantizedLinearMethod and SmoothQuant")
'''
=================
End of MLU Hijack
=================
'''
'''
=============================
Modify by vllm_mlu
=============================
@brief: add residual, smooth_quant_scale, use_tp_weight and output parameters.
'''
def vllm__module_executor__layers__linear__RowParallelLinear__forward(
self,
input_,
residual: torch.Tensor | None = None,
smooth_quant_scale: torch.Tensor | None = None,
use_tp_weight: bool = False,
output: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add additional matmul parameters.
'''
residual_ = None if self.tp_rank > 0 else residual
kwargs = {'bias': bias_, 'residual': residual_}
if use_tp_weight:
kwargs['use_tp_weight'] = use_tp_weight
if smooth_quant_scale is not None:
kwargs['input_scale'] = smooth_quant_scale
if output is not None:
kwargs['output'] = output
output_parallel = self.quant_method.apply(self, input_parallel, **kwargs)
'''
=================
End of MLU Hijack
=================
'''
if self.reduce_results and self.tp_size > 1:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tensor_model_parallel_all_reduce() with self.tp_group
'''
output = tensor_model_parallel_all_reduce(output_parallel, tp_group=self.tp_group)
'''
=================
End of MLU Hijack
=================
'''
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
'''
=================
End of MLU Hijack
=================
'''
MluHijackObject.apply_hijack(UnquantizedLinearMethod,
UnquantizedLinearMethod.apply,
vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply)
MluHijackObject.apply_hijack(LinearBase,
LinearBase.__init__,
vllm__module_executor__layers__linear__LinearBase____init__)
MluHijackObject.apply_hijack(ColumnParallelLinear,
ColumnParallelLinear.__init__,
vllm__module_executor__layers__linear__ColumnParallelLinear____init__)
MluHijackObject.apply_hijack(ColumnParallelLinear,
ColumnParallelLinear.forward,
vllm__module_executor__layers__linear__ColumnParallelLinear__forward)
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
MergedColumnParallelLinear.__init__,
vllm__module_executor__layers__linear__MergedColumnParallelLinear____init__)
MluHijackObject.apply_hijack(MergedColumnParallelLinear,
MergedColumnParallelLinear.weight_loader,
vllm__module_executor__layers__linear__MergedColumnParallelLinear__weight_loader)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.__init__,
vllm__module_executor__layers__linear__RowParallelLinear____init__)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.weight_loader,
vllm__module_executor__layers__linear__RowParallelLinear__weight_loader)
MluHijackObject.apply_hijack(RowParallelLinear,
RowParallelLinear.forward,
vllm__module_executor__layers__linear__RowParallelLinear__forward)