[Bugfix] Fix the method of importing environment variables in DeepSee… (#817)
### What this PR does / why we need it? Fix the method of importing environment variables in DeepSeek model to support successful compilation via aclgraph. Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
@@ -29,8 +28,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizeMethodBase
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
|
||||
|
||||
|
||||
def fused_experts_with_mc2(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -493,7 +496,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -624,11 +627,9 @@ class AscendFusedMoE(FusedMoE):
|
||||
real_top_k = self.top_k
|
||||
|
||||
if self.dp_size > 1:
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
|
||||
) == 1 and not is_prefill:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
elif int(os.environ.get("USING_LCCL_COM",
|
||||
'0')) == 1: # type: ignore
|
||||
elif USING_LCCL_COM: # type: ignore
|
||||
hidden_states = get_dp_group().all_gather(
|
||||
hidden_states, 0, False)
|
||||
router_logits = get_dp_group().all_gather(
|
||||
@@ -655,8 +656,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if self.dp_size > 1:
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
|
||||
) == 1 and not is_prefill:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
else:
|
||||
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
|
||||
Reference in New Issue
Block a user