Files
xc-llm-kunlun/vllm_kunlun/ops/linear.py
2025-12-10 12:05:39 +08:00

24 lines
917 B
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import ReplicatedLinear as VllmReplicatedLinear
class ReplicatedLinear(VllmReplicatedLinear):
"""Replicated linear layer"""
def get_weights(self):
"""get_weights"""
if hasattr(self, 'kunlun_linear_weights'):
return self.kunlun_linear_weights
weights = torch.nn.Parameter(self.weight.to(torch.float32))
self.register_parameter("kunlun_linear_weights", weights)
return self.kunlun_linear_weights
def get_weights_half(self):
"""get_weights_half"""
if hasattr(self, 'kunlun_linear_weights_half'):
return self.kunlun_linear_weights_half
weights = torch.nn.Parameter(self.weight.to(torch.float16))