[Bugfix] Fix the method of importing environment variables in DeepSee… (#817)
### What this PR does / why we need it? Fix the method of importing environment variables in DeepSeek model to support successful compilation via aclgraph. Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -25,7 +25,6 @@
|
||||
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
||||
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -66,9 +65,12 @@ from vllm.model_executor.models.utils import (
|
||||
maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
|
||||
class CustomDeepseekV2MLP(nn.Module):
|
||||
|
||||
@@ -206,7 +208,6 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.dp_size = get_dp_group().world_size
|
||||
batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1
|
||||
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.final_hidden_states = torch.zeros(
|
||||
@@ -223,7 +224,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if (self.tp_size > 1 and self.enable_mc2 and not is_prefill):
|
||||
if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill):
|
||||
chunks = torch.chunk(hidden_states,
|
||||
get_tp_group().world_size,
|
||||
dim=0)
|
||||
@@ -239,7 +240,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
if self.enable_mc2 and not is_prefill:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
dist.all_gather_into_tensor(self.final_hidden_states,
|
||||
final_hidden_states, self.tp_group)
|
||||
final_hidden_states = self.final_hidden_states
|
||||
|
||||
Reference in New Issue
Block a user