diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 89fcddb..8e1cc1c 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -34,6 +34,10 @@ env_variables: Dict[str, Callable[[], Any]] = { lambda: os.getenv("CMAKE_BUILD_TYPE"), "COMPILE_CUSTOM_KERNELS": lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))), + "VLLM_ENABLE_MC2": + lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))), + "USING_LCCL_COM": + lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))), "SOC_VERSION": lambda: os.getenv("SOC_VERSION", "ASCEND910B1"), # If set, vllm-ascend will print verbose logs during compilation diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 19cfe71..5bf1126 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -25,7 +25,6 @@ # # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py # """Inference-only DeepseekV2/DeepseekV3 model.""" -import os from typing import Any, Dict, List, Optional, Union import torch @@ -66,9 +65,12 @@ from vllm.model_executor.models.utils import ( maybe_prefix) from vllm.sequence import IntermediateTensors +import vllm_ascend.envs as envs_ascend from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod +VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 + class CustomDeepseekV2MLP(nn.Module): @@ -206,7 +208,6 @@ class CustomDeepseekV2MoE(nn.Module): vllm_config = get_current_vllm_config() self.dp_size = get_dp_group().world_size batch_size = vllm_config.scheduler_config.max_num_seqs - self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1 params_dtype = torch.get_default_dtype() self.final_hidden_states = torch.zeros( @@ -223,7 +224,7 @@ class CustomDeepseekV2MoE(nn.Module): num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if (self.tp_size > 1 and self.enable_mc2 and not is_prefill): + if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill): chunks = torch.chunk(hidden_states, get_tp_group().world_size, dim=0) @@ -239,7 +240,7 @@ class CustomDeepseekV2MoE(nn.Module): top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor if self.tp_size > 1: - if self.enable_mc2 and not is_prefill: + if VLLM_ENABLE_MC2 and not is_prefill: dist.all_gather_into_tensor(self.final_hidden_states, final_hidden_states, self.tp_group) final_hidden_states = self.final_hidden_states diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 2c25e0c..c912303 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -15,7 +15,6 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/kernels/test_moe.py -import os from typing import Callable, Optional import torch @@ -29,8 +28,12 @@ from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.quantization.base_config import \ QuantizeMethodBase +import vllm_ascend.envs as envs_ascend from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group +VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 +USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM + def fused_experts_with_mc2( hidden_states: torch.Tensor, @@ -493,7 +496,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): e_score_correction_bias=e_score_correction_bias, ) - if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill: + if VLLM_ENABLE_MC2 and not is_prefill: return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, @@ -624,11 +627,9 @@ class AscendFusedMoE(FusedMoE): real_top_k = self.top_k if self.dp_size > 1: - if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore - ) == 1 and not is_prefill: + if VLLM_ENABLE_MC2 and not is_prefill: ... - elif int(os.environ.get("USING_LCCL_COM", - '0')) == 1: # type: ignore + elif USING_LCCL_COM: # type: ignore hidden_states = get_dp_group().all_gather( hidden_states, 0, False) router_logits = get_dp_group().all_gather( @@ -655,8 +656,7 @@ class AscendFusedMoE(FusedMoE): is_prefill=is_prefill) if self.dp_size > 1: - if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore - ) == 1 and not is_prefill: + if VLLM_ENABLE_MC2 and not is_prefill: ... else: final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(