Merge pull request #59 from liwei109/aicapx-quant
[fix]remove weight_loader_v2 to suport cuda graph
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
# import vllm_kunlun.ops.linear
|
||||
import vllm_kunlun.ops.rotary_embedding
|
||||
import vllm_kunlun.ops.layernorm
|
||||
import vllm_kunlun.ops.quantization.awq
|
||||
|
||||
@@ -4,27 +4,36 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear as VllmReplicatedLinear
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.linear import (
|
||||
WEIGHT_LOADER_V2_SUPPORTED,
|
||||
ReplicatedLinear,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ReplicatedLinear(VllmReplicatedLinear):
|
||||
"""Replicated linear layer"""
|
||||
|
||||
def get_weights(self):
|
||||
"""get_weights"""
|
||||
if hasattr(self, "kunlun_linear_weights"):
|
||||
return self.kunlun_linear_weights
|
||||
weights = torch.nn.Parameter(self.weight.to(torch.float32))
|
||||
self.register_parameter("kunlun_linear_weights", weights)
|
||||
def get_weights(self):
|
||||
"""get_weights"""
|
||||
if hasattr(self, "kunlun_linear_weights"):
|
||||
return self.kunlun_linear_weights
|
||||
weights = torch.nn.Parameter(self.weight.to(torch.float32))
|
||||
self.register_parameter("kunlun_linear_weights", weights)
|
||||
return self.kunlun_linear_weights
|
||||
|
||||
def get_weights_half(self):
|
||||
"""get_weights_half"""
|
||||
if hasattr(self, "kunlun_linear_weights_half"):
|
||||
return self.kunlun_linear_weights_half
|
||||
weights = torch.nn.Parameter(self.weight.to(torch.float16))
|
||||
|
||||
def get_weights_half(self):
|
||||
"""get_weights_half"""
|
||||
if hasattr(self, "kunlun_linear_weights_half"):
|
||||
return self.kunlun_linear_weights_half
|
||||
weights = torch.nn.Parameter(self.weight.to(torch.float16))
|
||||
|
||||
|
||||
ReplicatedLinear.get_weights = get_weights
|
||||
ReplicatedLinear.get_weights_half = get_weights_half
|
||||
|
||||
|
||||
def create_weights(
|
||||
@@ -48,4 +57,6 @@ def create_weights(
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
|
||||
# rewrite create_weights and remove weight_loader_v2 to suport cuda graph
|
||||
UnquantizedLinearMethod.create_weights = create_weights
|
||||
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
|
||||
Reference in New Issue
Block a user