[Bugfix] Fix the method of importing environment variables in DeepSee… (#817)
### What this PR does / why we need it? Fix the method of importing environment variables in DeepSeek model to support successful compilation via aclgraph. Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -34,6 +34,10 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
lambda: os.getenv("CMAKE_BUILD_TYPE"),
|
lambda: os.getenv("CMAKE_BUILD_TYPE"),
|
||||||
"COMPILE_CUSTOM_KERNELS":
|
"COMPILE_CUSTOM_KERNELS":
|
||||||
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
|
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
|
||||||
|
"VLLM_ENABLE_MC2":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
|
||||||
|
"USING_LCCL_COM":
|
||||||
|
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
|
||||||
"SOC_VERSION":
|
"SOC_VERSION":
|
||||||
lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
|
lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
|
||||||
# If set, vllm-ascend will print verbose logs during compilation
|
# If set, vllm-ascend will print verbose logs during compilation
|
||||||
|
|||||||
@@ -25,7 +25,6 @@
|
|||||||
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
||||||
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -66,9 +65,12 @@ from vllm.model_executor.models.utils import (
|
|||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||||
|
|
||||||
|
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepseekV2MLP(nn.Module):
|
class CustomDeepseekV2MLP(nn.Module):
|
||||||
|
|
||||||
@@ -206,7 +208,6 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.dp_size = get_dp_group().world_size
|
self.dp_size = get_dp_group().world_size
|
||||||
batch_size = vllm_config.scheduler_config.max_num_seqs
|
batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||||
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1
|
|
||||||
|
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
self.final_hidden_states = torch.zeros(
|
self.final_hidden_states = torch.zeros(
|
||||||
@@ -223,7 +224,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
if (self.tp_size > 1 and self.enable_mc2 and not is_prefill):
|
if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill):
|
||||||
chunks = torch.chunk(hidden_states,
|
chunks = torch.chunk(hidden_states,
|
||||||
get_tp_group().world_size,
|
get_tp_group().world_size,
|
||||||
dim=0)
|
dim=0)
|
||||||
@@ -239,7 +240,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
|
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
if self.enable_mc2 and not is_prefill:
|
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||||
dist.all_gather_into_tensor(self.final_hidden_states,
|
dist.all_gather_into_tensor(self.final_hidden_states,
|
||||||
final_hidden_states, self.tp_group)
|
final_hidden_states, self.tp_group)
|
||||||
final_hidden_states = self.final_hidden_states
|
final_hidden_states = self.final_hidden_states
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
# Adapted from vllm/tests/kernels/test_moe.py
|
# Adapted from vllm/tests/kernels/test_moe.py
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -29,8 +28,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
|||||||
from vllm.model_executor.layers.quantization.base_config import \
|
from vllm.model_executor.layers.quantization.base_config import \
|
||||||
QuantizeMethodBase
|
QuantizeMethodBase
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
||||||
|
|
||||||
|
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||||
|
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_with_mc2(
|
def fused_experts_with_mc2(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -493,7 +496,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
|
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||||
return fused_experts_with_mc2(
|
return fused_experts_with_mc2(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@@ -624,11 +627,9 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
real_top_k = self.top_k
|
real_top_k = self.top_k
|
||||||
|
|
||||||
if self.dp_size > 1:
|
if self.dp_size > 1:
|
||||||
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
|
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||||
) == 1 and not is_prefill:
|
|
||||||
...
|
...
|
||||||
elif int(os.environ.get("USING_LCCL_COM",
|
elif USING_LCCL_COM: # type: ignore
|
||||||
'0')) == 1: # type: ignore
|
|
||||||
hidden_states = get_dp_group().all_gather(
|
hidden_states = get_dp_group().all_gather(
|
||||||
hidden_states, 0, False)
|
hidden_states, 0, False)
|
||||||
router_logits = get_dp_group().all_gather(
|
router_logits = get_dp_group().all_gather(
|
||||||
@@ -655,8 +656,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
is_prefill=is_prefill)
|
is_prefill=is_prefill)
|
||||||
|
|
||||||
if self.dp_size > 1:
|
if self.dp_size > 1:
|
||||||
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
|
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||||
) == 1 and not is_prefill:
|
|
||||||
...
|
...
|
||||||
else:
|
else:
|
||||||
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||||
|
|||||||
Reference in New Issue
Block a user