diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 3db14c6..7d7bb93 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -24,12 +24,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import List, Optional, Union +from typing import Optional, Union import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE @@ -61,11 +60,6 @@ class CustomDeepseekV2MoE(DeepseekV2MoE): self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts - self.routed_scaling_factor = config.routed_scaling_factor - if self.tp_size > config.n_routed_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.n_routed_experts}.") if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " @@ -129,6 +123,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx if model_config.use_mla: attn_cls = DeepseekV2MLAAttention else: @@ -171,6 +166,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor class CustomDeepseekV2Model(nn.Module): @@ -184,8 +180,8 @@ class CustomDeepseekV2Model(nn.Module): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + self.config = config - self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size if get_pp_group().is_first_rank: @@ -223,8 +219,6 @@ class CustomDeepseekV2Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -239,11 +233,8 @@ class CustomDeepseekV2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -272,9 +263,12 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): self.model = CustomDeepseekV2Model(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() self.make_empty_intermediate_tensors = (