Files
xc-llm-kunlun/vllm_kunlun/ops/linear.py
2025-12-24 12:02:09 +08:00

52 lines
1.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import ReplicatedLinear as VllmReplicatedLinear
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.utils import set_weight_attrs
class ReplicatedLinear(VllmReplicatedLinear):
"""Replicated linear layer"""
def get_weights(self):
"""get_weights"""
if hasattr(self, "kunlun_linear_weights"):
return self.kunlun_linear_weights
weights = torch.nn.Parameter(self.weight.to(torch.float32))
self.register_parameter("kunlun_linear_weights", weights)
return self.kunlun_linear_weights
def get_weights_half(self):
"""get_weights_half"""
if hasattr(self, "kunlun_linear_weights_half"):
return self.kunlun_linear_weights_half
weights = torch.nn.Parameter(self.weight.to(torch.float16))
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight = Parameter(
torch.empty(
sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype
),
requires_grad=False,
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
UnquantizedLinearMethod.create_weights = create_weights