465 lines
24 KiB
Python
465 lines
24 KiB
Python
import itertools
|
||
from abc import abstractmethod
|
||
from typing import Dict, List, Optional, Tuple
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||
|
||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||
get_tensor_model_parallel_world_size,
|
||
split_tensor_along_last_dim,
|
||
tensor_model_parallel_all_gather,
|
||
tensor_model_parallel_all_reduce)
|
||
from vllm.logger import init_logger
|
||
from vllm.model_executor.layers.quantization.base_config import (
|
||
QuantizationConfig, QuantizeMethodBase)
|
||
# yapf: disable
|
||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||
BlockQuantScaleParameter,
|
||
PackedColumnParameter,
|
||
PackedvLLMParameter,
|
||
PerTensorScaleParameter,
|
||
RowvLLMParameter)
|
||
# yapf: enable
|
||
from vllm.model_executor.utils import set_weight_attrs
|
||
|
||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||
ReplicatedLinear,
|
||
WEIGHT_LOADER_V2_SUPPORTED,
|
||
LinearBase,
|
||
RowParallelLinear)
|
||
|
||
def ReplicatedLinear__init__(self,
|
||
input_size: int,
|
||
output_size: int,
|
||
bias: bool = True,
|
||
skip_bias_add: bool = False,
|
||
params_dtype: Optional[torch.dtype] = None,
|
||
quant_config: Optional[QuantizationConfig] = None,
|
||
prefix: str = ""):
|
||
super(ReplicatedLinear,self).__init__(input_size,
|
||
output_size,
|
||
skip_bias_add,
|
||
params_dtype,
|
||
quant_config,
|
||
prefix=prefix)
|
||
|
||
# All the linear layer supports quant method.
|
||
assert self.quant_method is not None
|
||
|
||
self.scale_k = 1 # quant_block_k 128 需要除以 scale_k, 如设置为2 即 quant_block_k 是 64
|
||
self.scale_k_slice = 1
|
||
self.scale_n = 1
|
||
self.scale_n_slice = 1
|
||
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
|
||
gcd_value = quant_config.weight_block_size[1]
|
||
import math
|
||
if input_size % quant_config.weight_block_size[1]:
|
||
gcd_value = math.gcd(input_size % quant_config.weight_block_size[1], quant_config.weight_block_size[1])
|
||
self.scale_k =self.scale_k * quant_config.weight_block_size[1] // gcd_value
|
||
self.scale_k_slice = input_size // gcd_value
|
||
if output_size % quant_config.weight_block_size[0]:
|
||
gcd_value = math.gcd(output_size % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
|
||
self.scale_n = self.scale_n * quant_config.weight_block_size[0] // gcd_value
|
||
self.scale_n_slice = output_size // gcd_value
|
||
|
||
self.quant_method.create_weights(self,
|
||
self.input_size, [self.output_size],
|
||
self.input_size,
|
||
self.output_size,
|
||
self.params_dtype,
|
||
scale_k = self.scale_k,
|
||
scale_n = self.scale_n,
|
||
weight_loader=self.weight_loader)
|
||
|
||
if bias:
|
||
self.bias = Parameter(
|
||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||
set_weight_attrs(self.bias, {
|
||
"output_dim": 0,
|
||
"weight_loader": self.weight_loader,
|
||
})
|
||
else:
|
||
self.register_parameter("bias", None)
|
||
|
||
def ReplicatedLinear_weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||
# If the weight on disk does not have a shape, give it one
|
||
# (such scales for AutoFp8).
|
||
if len(loaded_weight.shape) == 0:
|
||
loaded_weight = loaded_weight.reshape(1)
|
||
|
||
if len(loaded_weight.shape) == 0:
|
||
assert loaded_weight.numel() == 1
|
||
loaded_weight = loaded_weight.reshape(1)
|
||
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
|
||
if self.scale_k > 1 and len(loaded_weight.shape) == 2:
|
||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||
loaded_weight = loaded_weight.expand(self.scale_k, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,2,0).reshape([loaded_weight.shape[1], -1])[:, :self.scale_k_slice]
|
||
#[1,n,k] -> [scale_k,n,k] -> [n,k,scale_k] -> [n, k*scale_k]
|
||
|
||
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
|
||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[2]])[:self.scale_n_slice]
|
||
|
||
assert param.size() == loaded_weight.size(), f'{param.size()}, {loaded_weight.size()}'
|
||
param.data.copy_(loaded_weight)
|
||
|
||
def refine_block(block_size:list[int],
|
||
weight_size:list[int],
|
||
dim:int=0,
|
||
pingpong_size:int = 2.5*1024*1024, #bytes
|
||
core_number:int = 4,
|
||
data_type:int = 2, #bfloat16
|
||
max_iter_number:int = 2):
|
||
'''
|
||
对于不均匀分core, 需要每个core <= 2.5M 才能保证可以pingpong,
|
||
core间相差数量为 block_size[dim] * weight_size[1-dim]
|
||
缩小block_size可以减小core间差距,使得更平均一些,直到大core数据量不超
|
||
如果均匀分core已经超了或者没有超,就没必要调整
|
||
'''
|
||
if dim < 0:
|
||
dim = 2 + dim
|
||
|
||
pingpong_size = pingpong_size / data_type # number of data
|
||
|
||
block_size_refine = block_size[dim]
|
||
all_block_number = weight_size[dim] // block_size_refine
|
||
|
||
if all_block_number % core_number == 0:
|
||
#均分,这种情况不管有没有超,都无需调整
|
||
return block_size_refine
|
||
|
||
block_number_tiny = all_block_number // core_number
|
||
block_number_big = all_block_number // core_number + 1
|
||
if block_number_tiny * block_size_refine * weight_size[1-dim] >= pingpong_size or \
|
||
block_number_big * block_size_refine * weight_size[1-dim] <= pingpong_size :
|
||
# 小的已经超了,无法再调整了
|
||
# 大的没有超,无需调整
|
||
return block_size_refine
|
||
|
||
all_block_number_tmp = all_block_number
|
||
block_size_refine_tmp = block_size_refine
|
||
for iter_index in range(max_iter_number):
|
||
all_block_number_tmp = all_block_number_tmp * 2
|
||
block_size_refine_tmp = block_size_refine_tmp // 2
|
||
if all_block_number_tmp % core_number == 0:
|
||
block_number_tiny = all_block_number // core_number
|
||
if block_number_tiny * block_size_refine_tmp * weight_size[1-dim] <= pingpong_size:
|
||
return block_size_refine_tmp
|
||
else:
|
||
#均分还是超了,无需调整
|
||
return block_size_refine
|
||
else:
|
||
block_number_big = all_block_number_tmp // core_number + 1
|
||
if block_number_big * block_size_refine_tmp * weight_size[1-dim] <= pingpong_size:
|
||
return block_size_refine_tmp
|
||
|
||
return block_size_refine
|
||
|
||
def ColumnParallelLinear__init__(self,
|
||
input_size: int,
|
||
output_size: int,
|
||
bias: bool = True,
|
||
gather_output: bool = False,
|
||
skip_bias_add: bool = False,
|
||
params_dtype: Optional[torch.dtype] = None,
|
||
quant_config: Optional[QuantizationConfig] = None,
|
||
output_sizes: Optional[List[int]] = None,
|
||
prefix: str = "",
|
||
*,
|
||
return_bias: bool = True,
|
||
disable_tp: bool = False,):
|
||
# Divide the weight matrix along the last dimension.
|
||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||
if not disable_tp else 0)
|
||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||
if not disable_tp else 1)
|
||
self.input_size_per_partition = input_size
|
||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||
self.output_partition_sizes = [self.output_size_per_partition]
|
||
# If QKV or MergedColumn, use output size of each partition.
|
||
if hasattr(self, "output_sizes"):
|
||
self.output_partition_sizes = [
|
||
divide(output_size, self.tp_size)
|
||
for output_size in self.output_sizes
|
||
]
|
||
super(ColumnParallelLinear,self).__init__(input_size,
|
||
output_size,
|
||
skip_bias_add,
|
||
params_dtype,
|
||
quant_config,
|
||
prefix,
|
||
return_bias=return_bias,
|
||
disable_tp=disable_tp)
|
||
|
||
self.gather_output = gather_output
|
||
|
||
if output_sizes is None:
|
||
output_sizes = [output_size]
|
||
|
||
self.scale_n = 1
|
||
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
|
||
gcd_value = quant_config.weight_block_size[0]
|
||
|
||
import math
|
||
if hasattr(self, "output_sizes"):
|
||
# 对于Merge类型的ColumnParallelLinear来说,需要根据每个Part Linear的shape,去计算最小公约数
|
||
output_size_no_merge = self.output_partition_sizes
|
||
block_values = [o % quant_config.weight_block_size[0] for o in output_size_no_merge]
|
||
is_gcd_recompute = sum(block_values)
|
||
|
||
if is_gcd_recompute:
|
||
import math
|
||
block_values.append(quant_config.weight_block_size[0])
|
||
gcd_value = math.gcd(*block_values)
|
||
# Notice:
|
||
# 这儿对于非对齐的Part-Weight, 可能需要验证一下流程
|
||
# 对于DeepSeek来说,仅存在于MLP&MOE中的MergeColumnLinear,都是Shape一致的PartWeight
|
||
# 对于QWen3来说,会存在QKVColumnLinear,是Shape不一致的PartWeight,但是由于QWen3当下的切分方案,对于gcd_value无感,无需重计算所以暂时不会进来
|
||
if hasattr(self, "output_sizes") and len(output_size_no_merge) == 2 and output_size_no_merge[0] == output_size_no_merge[1]:
|
||
#only refine mlp w13
|
||
gcd_value = refine_block([gcd_value, quant_config.weight_block_size[1]], [output_size_no_merge[0], input_size])
|
||
self.scale_n =self.scale_n * quant_config.weight_block_size[0] // gcd_value
|
||
else:
|
||
# 对于非Merge的ColumnParallelLinear来说, 仅仅根据当下shape去计算最小公约数
|
||
output_size_no_merge = self.output_size_per_partition
|
||
is_gcd_recompute = output_size_no_merge % quant_config.weight_block_size[0]
|
||
if is_gcd_recompute:
|
||
gcd_value = math.gcd(output_size_no_merge % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
|
||
self.scale_n =self.scale_n * quant_config.weight_block_size[0] // gcd_value
|
||
|
||
|
||
|
||
self.quant_method.create_weights(
|
||
layer=self,
|
||
input_size_per_partition=self.input_size,
|
||
output_partition_sizes=self.output_partition_sizes,
|
||
input_size=self.input_size,
|
||
output_size=self.output_size,
|
||
params_dtype=self.params_dtype,
|
||
scale_n = self.scale_n,
|
||
weight_loader=(
|
||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||
if bias:
|
||
self.bias = Parameter(
|
||
torch.empty(self.output_size_per_partition,
|
||
dtype=params_dtype))
|
||
set_weight_attrs(self.bias, {
|
||
"output_dim": 0,
|
||
"weight_loader": self.weight_loader,
|
||
})
|
||
else:
|
||
self.register_parameter("bias", None)
|
||
|
||
self.update_param_tp_status()
|
||
|
||
def ColumnParallelLinear_weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
||
# Special case for loading scales off disk, which often do not
|
||
# have a shape (such as in the case of AutoFP8).
|
||
if len(loaded_weight.shape) == 0:
|
||
assert loaded_weight.numel() == 1
|
||
loaded_weight = loaded_weight.reshape(1)
|
||
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
|
||
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
|
||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[-1]])
|
||
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,n,k] -> [n*scale_n, k]
|
||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||
|
||
|
||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||
def weight_loader_v2(self,
|
||
param: BasevLLMParameter,
|
||
loaded_weight: torch.Tensor,
|
||
loaded_shard_id: Optional[int] = None):
|
||
|
||
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
|
||
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
|
||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[-1]])
|
||
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,n,k] -> [n*scale_n, k]
|
||
|
||
if self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
|
||
if self.quant_method.scale_k > 1 and len(loaded_weight.shape) == 2 and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||
loaded_weight = loaded_weight.unsqueeze(1) #[k,1,n]
|
||
loaded_weight = loaded_weight.expand(loaded_weight.shape[0], self.quant_method.scale_k, loaded_weight.shape[2]).reshape([-1, loaded_weight.shape[2]])
|
||
#[k,1,n] -> [k,scale_k,n]] -> [k*scale_k, n]
|
||
|
||
if loaded_shard_id is None:
|
||
if isinstance(param, PerTensorScaleParameter):
|
||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||
shard_id=0)
|
||
return
|
||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
||
return
|
||
# TODO: @dsikka - move to parameter.py
|
||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||
return
|
||
|
||
assert loaded_shard_id < len(self.output_sizes)
|
||
|
||
tp_size = get_tensor_model_parallel_world_size()
|
||
|
||
if isinstance(param, BlockQuantScaleParameter):
|
||
from vllm.model_executor.layers.quantization.fp8 import (
|
||
Fp8LinearMethod, Fp8MoEMethod)
|
||
assert self.quant_method is not None
|
||
assert isinstance(self.quant_method,
|
||
(Fp8LinearMethod, Fp8MoEMethod))
|
||
weight_block_size = self.quant_method.quant_config.weight_block_size
|
||
assert weight_block_size is not None
|
||
block_n, _ = weight_block_size[0] // self.scale_n, weight_block_size[1]
|
||
shard_offset = (
|
||
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
|
||
block_n) // tp_size
|
||
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
|
||
block_n // tp_size)
|
||
else:
|
||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||
|
||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||
shard_id=loaded_shard_id,
|
||
shard_offset=shard_offset,
|
||
shard_size=shard_size)
|
||
|
||
def RowParallelLinear__init__(
|
||
self,
|
||
input_size: int,
|
||
output_size: int,
|
||
bias: bool = True,
|
||
input_is_parallel: bool = True,
|
||
skip_bias_add: bool = False,
|
||
params_dtype: Optional[torch.dtype] = None,
|
||
reduce_results: bool = True,
|
||
quant_config: Optional[QuantizationConfig] = None,
|
||
prefix: str = "",
|
||
*,
|
||
return_bias: bool = True,
|
||
disable_tp: bool = False,
|
||
):
|
||
|
||
# Divide the weight matrix along the first dimension.
|
||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||
if not disable_tp else 0)
|
||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||
if not disable_tp else 1)
|
||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||
self.output_size_per_partition = output_size
|
||
self.output_partition_sizes = [output_size]
|
||
super(RowParallelLinear, self).__init__(input_size,
|
||
output_size,
|
||
skip_bias_add,
|
||
params_dtype,
|
||
quant_config,
|
||
prefix,
|
||
return_bias=return_bias,
|
||
disable_tp=disable_tp)
|
||
|
||
self.input_is_parallel = input_is_parallel
|
||
self.reduce_results = reduce_results
|
||
|
||
# Divide the weight matrix along the last dimension.
|
||
self.tp_rank = get_tensor_model_parallel_rank()
|
||
self.tp_size = get_tensor_model_parallel_world_size()
|
||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||
assert self.quant_method is not None
|
||
|
||
self.scale_k = 1 # quant_block_k 128 需要除以 scale_k, 如设置为2 即 quant_block_k 是 64
|
||
self.scale_n = 1
|
||
self.scale_n_slice = 1
|
||
|
||
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
|
||
gcd_value = quant_config.weight_block_size[1]
|
||
import math
|
||
if self.input_size_per_partition % quant_config.weight_block_size[1]:
|
||
gcd_value = math.gcd(self.input_size_per_partition % quant_config.weight_block_size[1], quant_config.weight_block_size[1])
|
||
self.scale_k =self.scale_k * quant_config.weight_block_size[1] // gcd_value
|
||
if output_size % quant_config.weight_block_size[0]:
|
||
gcd_value = math.gcd(output_size % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
|
||
self.scale_n = self.scale_n * quant_config.weight_block_size[0] // gcd_value
|
||
self.scale_n_slice = output_size // gcd_value
|
||
# N = 576, block = 128, n方向scale 扩充需要知道两个信息: 1.拷贝多少份 scale_n; 2. slice 有效的 scale_n_slice
|
||
# scale = [s0,s1,s2,s3,s4] 拷贝scale_n=2份
|
||
# scale = [s0,s0,s1,s1,s2,s2,s3,s3,s4,s4],slice scale_n_slice=9份 =>[s0,s0,s1,s1,s2,s2,s3,s3,s4]
|
||
|
||
if self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
|
||
gcd_value = quant_config.group_size
|
||
import math
|
||
if self.input_size_per_partition % quant_config.group_size:
|
||
gcd_value = math.gcd(self.input_size_per_partition % quant_config.group_size, quant_config.group_size)
|
||
self.quant_method.scale_k = self.quant_method.scale_k * quant_config.group_size // gcd_value
|
||
|
||
self.quant_method.create_weights(
|
||
layer=self,
|
||
input_size_per_partition=self.input_size_per_partition,
|
||
output_partition_sizes=[self.output_size],
|
||
input_size=self.input_size,
|
||
output_size=self.output_size,
|
||
params_dtype=self.params_dtype,
|
||
scale_k = self.scale_k,
|
||
scale_n = self.scale_n,
|
||
weight_loader=(
|
||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||
if not reduce_results and (bias and not skip_bias_add):
|
||
raise ValueError("When not reduce the results, adding bias to the "
|
||
"results can lead to incorrect results")
|
||
|
||
if bias:
|
||
self.bias = Parameter(
|
||
torch.empty(self.output_size, dtype=params_dtype))
|
||
set_weight_attrs(self.bias, {
|
||
"output_dim": 0,
|
||
"weight_loader": self.weight_loader,
|
||
})
|
||
else:
|
||
self.register_parameter("bias", None)
|
||
|
||
def RowParallelLinear_weight_loader_v2_vacc(self, param: BasevLLMParameter,
|
||
loaded_weight: torch.Tensor):
|
||
# Special case for loading scales off disk, which often do not
|
||
# have a shape (such as in the case of AutoFP8).
|
||
if len(loaded_weight.shape) == 0:
|
||
assert loaded_weight.numel() == 1
|
||
loaded_weight = loaded_weight.reshape(1)
|
||
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
|
||
if self.scale_k > 1 and len(loaded_weight.shape) == 2:
|
||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||
loaded_weight = loaded_weight.expand(self.scale_k, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,2,0).reshape([loaded_weight.shape[1], -1])
|
||
#[1,n,k] -> [scale_k,n,k] -> [n,k,scale_k] -> [n, k*scale_k]
|
||
|
||
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
|
||
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
|
||
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[2]])[:self.scale_n_slice]
|
||
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,k] -> [n*scale_n,k]
|
||
|
||
elif self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
|
||
# broadcast scale TODO: broadcast zero
|
||
if self.quant_method.scale_k > 1 and len(loaded_weight.shape) == 2 and loaded_weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
|
||
loaded_weight = loaded_weight.unsqueeze(1) #[k,1,n]
|
||
loaded_weight = loaded_weight.expand(loaded_weight.shape[0], self.quant_method.scale_k, loaded_weight.shape[2]).reshape([-1, loaded_weight.shape[2]])
|
||
#[k,1,n] -> [k,scale_k,n]] -> [k*scale_k, n]
|
||
|
||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||
|
||
class UnquantizedLinearMethod():
|
||
"""Linear method without quantization."""
|
||
|
||
def apply(self,
|
||
layer: torch.nn.Module,
|
||
x: torch.Tensor,
|
||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||
if bias is not None:
|
||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||
|
||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||
parallel_embedding_output = None
|
||
if memory_recycler is not None:
|
||
if memory_recycler.EMBEDDING_OUT_BUFFER.size(0) == x.size(0):
|
||
parallel_embedding_output = memory_recycler.EMBEDDING_OUT_BUFFER
|
||
return torch.mm(x.view(-1, x.shape[-1]), layer.weight.transpose(1,0), out=parallel_embedding_output).view(*(x.shape[:-1]), layer.weight.shape[0]) |