[deepseek][bugfix] support deepseek quant (#469)
- support deepseek quant - add w8a8_dynamic quant see #391 Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -16,12 +16,16 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.layer import \
|
||||
UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
@@ -33,6 +37,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from .quantizer import AscendQuantizer
|
||||
|
||||
@@ -90,9 +95,16 @@ class AscendQuantConfig(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
return AscendLinearMethod(self, prefix,
|
||||
self.packed_modules_mapping)
|
||||
if isinstance(layer, Attention) and \
|
||||
'fa_quant_type' in self.quant_description.keys():
|
||||
elif isinstance(layer, Attention) and \
|
||||
'fa_quant_type' in self.quant_description.keys() and \
|
||||
self.quant_description['fa_quant_type'] is not None:
|
||||
return AscendKVCacheMethod(self, prefix)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return UnquantizedFusedMoEMethod()
|
||||
return AscendFusedMoEMethod(self, prefix,
|
||||
self.packed_modules_mapping)
|
||||
return None
|
||||
|
||||
def is_layer_skipped_ascend(
|
||||
@@ -253,3 +265,112 @@ class AscendKVCacheMethod(BaseKVCacheMethod):
|
||||
attn_metadata.slot_mapping,
|
||||
output,
|
||||
seq_lens_tensor_cpu=seq_lens_tensor_cpu)
|
||||
|
||||
|
||||
def fused_moe_perchannel_weight_loader(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str, shard_id: str,
|
||||
expert_id: int) -> None:
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
# dimension intermediate_size_per_partition is used.
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size_per_partition is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
shard_dim = int(not shard_dim)
|
||||
|
||||
if shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
"""FusedMoE method for Ascend quantization.
|
||||
|
||||
This class calls AscendQuantizer to search a specific quantization
|
||||
implementations supported on ascend hardware for kvcache methods.
|
||||
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]):
|
||||
self.quantizer = AscendQuantizer.get_quantizer(
|
||||
quant_config.quant_description, prefix, packed_modules_mapping)
|
||||
self.quant_method = self.quantizer.build_moe_method()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
weight_param = self.quant_method.get_weight(
|
||||
num_experts, intermediate_size_per_partition, hidden_size,
|
||||
params_dtype)
|
||||
for param_key, param_value in weight_param.items():
|
||||
param = torch.nn.Parameter(param_value, requires_grad=False)
|
||||
layer.register_parameter(param_key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||
# load `offset` weight in `fused_moe_perchannel_weight_loader`, the original weight load in vllm 0.7.3 could only load `scale` and `zero`
|
||||
extra_weight_attrs.update(
|
||||
{"weight_loader": fused_moe_perchannel_weight_loader})
|
||||
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
|
||||
num_experts, intermediate_size_per_partition, hidden_size,
|
||||
params_dtype)
|
||||
for param_key, param_value in dynamic_quant_param.items():
|
||||
param = torch.nn.Parameter(param_value, requires_grad=False)
|
||||
layer.register_parameter(param_key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer, x, use_grouped_topk, top_k,
|
||||
router_logits, renormalize, topk_group,
|
||||
num_expert_group,
|
||||
custom_routing_function, scoring_func,
|
||||
e_score_correction_bias)
|
||||
|
||||
Reference in New Issue
Block a user