Merge pull request #59 from liwei109/aicapx-quant
[fix]remove weight_loader_v2 to suport cuda graph
This commit is contained in:
@@ -15,6 +15,7 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
# import vllm_kunlun.ops.linear
|
||||||
import vllm_kunlun.ops.rotary_embedding
|
import vllm_kunlun.ops.rotary_embedding
|
||||||
import vllm_kunlun.ops.layernorm
|
import vllm_kunlun.ops.layernorm
|
||||||
import vllm_kunlun.ops.quantization.awq
|
import vllm_kunlun.ops.quantization.awq
|
||||||
|
|||||||
@@ -4,27 +4,36 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear as VllmReplicatedLinear
|
from vllm.model_executor.layers.linear import (
|
||||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
WEIGHT_LOADER_V2_SUPPORTED,
|
||||||
|
ReplicatedLinear,
|
||||||
|
UnquantizedLinearMethod,
|
||||||
|
)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.model_executor.parameter import ModelWeightParameter
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ReplicatedLinear(VllmReplicatedLinear):
|
def get_weights(self):
|
||||||
"""Replicated linear layer"""
|
"""get_weights"""
|
||||||
|
if hasattr(self, "kunlun_linear_weights"):
|
||||||
def get_weights(self):
|
|
||||||
"""get_weights"""
|
|
||||||
if hasattr(self, "kunlun_linear_weights"):
|
|
||||||
return self.kunlun_linear_weights
|
|
||||||
weights = torch.nn.Parameter(self.weight.to(torch.float32))
|
|
||||||
self.register_parameter("kunlun_linear_weights", weights)
|
|
||||||
return self.kunlun_linear_weights
|
return self.kunlun_linear_weights
|
||||||
|
weights = torch.nn.Parameter(self.weight.to(torch.float32))
|
||||||
|
self.register_parameter("kunlun_linear_weights", weights)
|
||||||
|
return self.kunlun_linear_weights
|
||||||
|
|
||||||
def get_weights_half(self):
|
|
||||||
"""get_weights_half"""
|
def get_weights_half(self):
|
||||||
if hasattr(self, "kunlun_linear_weights_half"):
|
"""get_weights_half"""
|
||||||
return self.kunlun_linear_weights_half
|
if hasattr(self, "kunlun_linear_weights_half"):
|
||||||
weights = torch.nn.Parameter(self.weight.to(torch.float16))
|
return self.kunlun_linear_weights_half
|
||||||
|
weights = torch.nn.Parameter(self.weight.to(torch.float16))
|
||||||
|
|
||||||
|
|
||||||
|
ReplicatedLinear.get_weights = get_weights
|
||||||
|
ReplicatedLinear.get_weights_half = get_weights_half
|
||||||
|
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
@@ -48,4 +57,6 @@ def create_weights(
|
|||||||
set_weight_attrs(weight, extra_weight_attrs)
|
set_weight_attrs(weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
|
||||||
|
# rewrite create_weights and remove weight_loader_v2 to suport cuda graph
|
||||||
UnquantizedLinearMethod.create_weights = create_weights
|
UnquantizedLinearMethod.create_weights = create_weights
|
||||||
|
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
|
||||||
Reference in New Issue
Block a user