Merge pull request #59 from liwei109/aicapx-quant

[fix]remove weight_loader_v2 to suport cuda graph
This commit is contained in:
Li Wei
2025-12-29 19:56:24 +08:00
committed by GitHub
parent 7fb627c34e
commit 9cee025f41
2 changed files with 28 additions and 16 deletions

View File

@@ -15,6 +15,7 @@
# This file is a part of the vllm-ascend project.
#
# import vllm_kunlun.ops.linear
import vllm_kunlun.ops.rotary_embedding
import vllm_kunlun.ops.layernorm
import vllm_kunlun.ops.quantization.awq

View File

@@ -4,27 +4,36 @@ import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import ReplicatedLinear as VllmReplicatedLinear
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.linear import (
WEIGHT_LOADER_V2_SUPPORTED,
ReplicatedLinear,
UnquantizedLinearMethod,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.logger import init_logger
logger = init_logger(__name__)
class ReplicatedLinear(VllmReplicatedLinear):
"""Replicated linear layer"""
def get_weights(self):
"""get_weights"""
if hasattr(self, "kunlun_linear_weights"):
return self.kunlun_linear_weights
weights = torch.nn.Parameter(self.weight.to(torch.float32))
self.register_parameter("kunlun_linear_weights", weights)
def get_weights(self):
"""get_weights"""
if hasattr(self, "kunlun_linear_weights"):
return self.kunlun_linear_weights
weights = torch.nn.Parameter(self.weight.to(torch.float32))
self.register_parameter("kunlun_linear_weights", weights)
return self.kunlun_linear_weights
def get_weights_half(self):
"""get_weights_half"""
if hasattr(self, "kunlun_linear_weights_half"):
return self.kunlun_linear_weights_half
weights = torch.nn.Parameter(self.weight.to(torch.float16))
def get_weights_half(self):
"""get_weights_half"""
if hasattr(self, "kunlun_linear_weights_half"):
return self.kunlun_linear_weights_half
weights = torch.nn.Parameter(self.weight.to(torch.float16))
ReplicatedLinear.get_weights = get_weights
ReplicatedLinear.get_weights_half = get_weights_half
def create_weights(
@@ -48,4 +57,6 @@ def create_weights(
set_weight_attrs(weight, extra_weight_attrs)
# rewrite create_weights and remove weight_loader_v2 to suport cuda graph
UnquantizedLinearMethod.create_weights = create_weights
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")