From 344228a5da6130bf1d7e5f3c3d04cc8b661748ff Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Mon, 7 Apr 2025 10:56:12 +0800 Subject: [PATCH] [deepseek][bugfix] support deepseek quant (#469) - support deepseek quant - add w8a8_dynamic quant see #391 Signed-off-by: MengqingCao Co-authored-by: zzzzwwjj <1183291235@qq.com> --- vllm_ascend/models/__init__.py | 8 + vllm_ascend/models/deepseek_v2.py | 390 +++++++++++++++++++++++ vllm_ascend/quantization/quant_config.py | 127 +++++++- 3 files changed, 522 insertions(+), 3 deletions(-) create mode 100644 vllm_ascend/models/deepseek_v2.py diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 64b994d..e2c3fd9 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -7,3 +7,11 @@ def register_model(): ModelRegistry.register_model( "Qwen2VLForConditionalGeneration", "vllm_ascend.models.qwen2_vl:CustomQwen2VLForConditionalGeneration") + + ModelRegistry.register_model( + "DeepseekV2ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") + + ModelRegistry.register_model( + "DeepseekV3ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py new file mode 100644 index 0000000..8d315ca --- /dev/null +++ b/vllm_ascend/models/deepseek_v2.py @@ -0,0 +1,390 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeepseekV2/DeepseekV3 model.""" +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_v2 import ( # noqa + DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, + DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2MoE) +from vllm.model_executor.models.utils import ( + PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.sequence import IntermediateTensors + + +class CustomDeepseekV2MoE(DeepseekV2MoE): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + nn.Module.__init__(self) + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + else: + self.gate.e_score_correction_bias = None + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + +class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + if model_config.use_mla: + attn_cls = DeepseekV2MLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = CustomDeepseekV2MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + +class CustomDeepseekV2Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CustomDeepseekV2DecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): + # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = CustomDeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + # w8a8 weight from modelslim need flatten before load_weight + if "scale" in name or "offset" in name: + loaded_weight = loaded_weight.flatten() + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): + pass + + +def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, "num_nextn_predict_layers") and ( + config.num_nextn_predict_layers > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 51f201e..bac531e 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -16,12 +16,16 @@ # limitations under the License. # from types import MappingProxyType -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Callable, Dict, List, Mapping, Optional import torch import torch_npu # noqa: F401 from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.layer import \ + UnquantizedFusedMoEMethod from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, RowParallelLinear, UnquantizedLinearMethod) @@ -33,6 +37,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) +from vllm.model_executor.utils import set_weight_attrs from .quantizer import AscendQuantizer @@ -90,9 +95,16 @@ class AscendQuantConfig(QuantizationConfig): return UnquantizedLinearMethod() return AscendLinearMethod(self, prefix, self.packed_modules_mapping) - if isinstance(layer, Attention) and \ - 'fa_quant_type' in self.quant_description.keys(): + elif isinstance(layer, Attention) and \ + 'fa_quant_type' in self.quant_description.keys() and \ + self.quant_description['fa_quant_type'] is not None: return AscendKVCacheMethod(self, prefix) + elif isinstance(layer, FusedMoE): + if self.is_layer_skipped_ascend(prefix, + self.packed_modules_mapping): + return UnquantizedFusedMoEMethod() + return AscendFusedMoEMethod(self, prefix, + self.packed_modules_mapping) return None def is_layer_skipped_ascend( @@ -253,3 +265,112 @@ class AscendKVCacheMethod(BaseKVCacheMethod): attn_metadata.slot_mapping, output, seq_lens_tensor_cpu=seq_lens_tensor_cpu) + + +def fused_moe_perchannel_weight_loader(param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, shard_id: str, + expert_id: int) -> None: + + if shard_id not in ("w1", "w2", "w3"): + raise ValueError(f"shard_id must be ['w1','w2','w3'] but " + f"got {shard_id}.") + + # Fetch the dim to shard the parameter/loaded weight + # based on the shard id. This will be whatever + # dimension intermediate_size_per_partition is used. + SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} + + expert_data = param.data[expert_id] + tp_rank = get_tensor_model_parallel_rank() + + # is_transposed: if the dim to shard the weight + # should be flipped. Required by GPTQ, compressed-tensors + # should be whatever dimension intermediate_size_per_partition is + is_transposed = getattr(param, "is_transposed", False) + shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + if is_transposed: + shard_dim = int(not shard_dim) + + if shard_id == "w2": + expert_data.copy_(loaded_weight) + elif shard_id in ("w1", "w3"): + shard_size = expert_data.shape[shard_dim] // 2 + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + + +class AscendFusedMoEMethod(FusedMoEMethodBase): + """FusedMoE method for Ascend quantization. + + This class calls AscendQuantizer to search a specific quantization + implementations supported on ascend hardware for kvcache methods. + + Args: + quant_config: The Ascend quantization config. + """ + + def __init__(self, quant_config: AscendQuantConfig, prefix: str, + packed_modules_mapping: Dict[str, Any]): + self.quantizer = AscendQuantizer.get_quantizer( + quant_config.quant_description, prefix, packed_modules_mapping) + self.quant_method = self.quantizer.build_moe_method() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + weight_param = self.quant_method.get_weight( + num_experts, intermediate_size_per_partition, hidden_size, + params_dtype) + for param_key, param_value in weight_param.items(): + param = torch.nn.Parameter(param_value, requires_grad=False) + layer.register_parameter(param_key, param) + set_weight_attrs(param, extra_weight_attrs) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + # load `offset` weight in `fused_moe_perchannel_weight_loader`, the original weight load in vllm 0.7.3 could only load `scale` and `zero` + extra_weight_attrs.update( + {"weight_loader": fused_moe_perchannel_weight_loader}) + dynamic_quant_param = self.quant_method.get_dynamic_quant_param( + num_experts, intermediate_size_per_partition, hidden_size, + params_dtype) + for param_key, param_value in dynamic_quant_param.items(): + param = torch.nn.Parameter(param_value, requires_grad=False) + layer.register_parameter(param_key, param) + set_weight_attrs(param, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return self.quant_method.apply(layer, x, use_grouped_topk, top_k, + router_logits, renormalize, topk_group, + num_expert_group, + custom_routing_function, scoring_func, + e_score_correction_bias)