[Misc] Remove redundant imported envs, using envs_ascend instead (#2193)

### What this PR does / why we need it?
Remove redundant imported `envs`, using `envs_ascend` instead.

```python
import vllm.envs as envs_vllm
import vllm_ascend.envs as envs_ascend
```

- vLLM version: v0.10.0
- vLLM main:
71683ca6f6

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2025-08-14 09:33:39 +08:00
committed by GitHub
parent 55d0790597
commit 103654ccd6
14 changed files with 46 additions and 46 deletions

View File

@@ -5,8 +5,8 @@ import torch
import vllm
from pytest_mock import MockerFixture
import vllm_ascend.envs as envs_ascend
from tests.ut.base import PytestBase
from vllm_ascend import envs
from vllm_ascend.patch.worker.patch_common import patch_linear
@@ -158,10 +158,10 @@ class TestAscendRowParallelLinear(PytestBase):
assert torch.allclose(ret, expected)
def test_enable_allreduce_matmul(self, mocker: MockerFixture):
mocker.patch.object(envs,
mocker.patch.object(envs_ascend,
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
new=True)
reload(patch_linear)
assert envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
assert envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id(
patch_linear.AscendRowParallelLinear)

View File

@@ -15,25 +15,26 @@
import inspect
import os
import vllm_ascend.envs as envs_ascend
from tests.ut.base import TestBase
from vllm_ascend import envs
class TestEnvVariables(TestBase):
def setUp(self):
self.env_vars = list(envs.env_variables.keys())
self.env_vars = list(envs_ascend.env_variables.keys())
def test_env_vars_behavior(self):
for var_name in self.env_vars:
with self.subTest(var=var_name):
original_val = os.environ.get(var_name)
var_handler = envs.env_variables[var_name]
var_handler = envs_ascend.env_variables[var_name]
try:
if var_name in os.environ:
del os.environ[var_name]
self.assertEqual(getattr(envs, var_name), var_handler())
self.assertEqual(getattr(envs_ascend, var_name),
var_handler())
handler_source = inspect.getsource(var_handler)
if 'int(' in handler_source:
@@ -45,7 +46,7 @@ class TestEnvVariables(TestBase):
for test_val in test_vals:
os.environ[var_name] = test_val
self.assertEqual(getattr(envs, var_name),
self.assertEqual(getattr(envs_ascend, var_name),
var_handler())
finally:
@@ -55,7 +56,7 @@ class TestEnvVariables(TestBase):
os.environ[var_name] = original_val
def test_dir_and_getattr(self):
self.assertEqual(sorted(envs.__dir__()), sorted(self.env_vars))
self.assertEqual(sorted(envs_ascend.__dir__()), sorted(self.env_vars))
for var_name in self.env_vars:
with self.subTest(var=var_name):
getattr(envs, var_name)
getattr(envs_ascend, var_name)

View File

@@ -9,7 +9,7 @@ from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context, set_forward_context
import vllm_ascend.envs as envs
import vllm_ascend.envs as envs_ascend
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
@@ -27,7 +27,7 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return FusedMoEState.AllGatherEP
elif ep_size == 1:
@@ -35,7 +35,7 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool,
return FusedMoEState.NaiveMulticast
else:
return FusedMoEState.AllGather
elif envs.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
return (FusedMoEState.All2AllSeq if
(ep_size < 16 or with_prefill) else FusedMoEState.MC2)

View File

@@ -14,7 +14,7 @@ from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm_ascend import envs
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
@@ -1054,7 +1054,7 @@ class AscendMLAImpl(MLAAttentionImpl):
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
# public available
assert len(kv_c_and_k_pe_cache) > 1
if envs.VLLM_ASCEND_MLA_PA:
if envs_ascend.VLLM_ASCEND_MLA_PA:
attn_output = torch_npu.atb.npu_multi_head_latent_attention(
q_nope, q_pe, kv_c_and_k_pe_cache[0],
kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table,

View File

@@ -23,7 +23,7 @@ from unittest.mock import patch
import torch
import torch.fx as fx
import vllm.envs as envs
import vllm.envs as envs_vllm
from vllm.compilation.backends import VllmBackend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import end_monitoring_torch_compile
@@ -93,7 +93,7 @@ class NPUPiecewiseBackend:
self.sym_shape_indices = sym_shape_indices
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
self.is_debugging_mode = envs_vllm.VLLM_LOGGING_LEVEL == "DEBUG"
# the entries for different shapes that we need to either
# compile or capture aclgraph

View File

@@ -27,7 +27,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request, RequestStatus
from vllm_ascend import envs
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
TORCH_DTYPE_TO_NPU_DTYPE = {
@@ -181,7 +181,7 @@ class LLMDataDistCMgrConnectorScheduler():
dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
self.port = dp_rank_local * tp_size + envs.VLLM_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs.VLLM_LLMDD_RPC_PORT
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_LLMDD_RPC_PORT
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
@@ -344,7 +344,7 @@ class LLMDataDistCMgrConnectorWorker():
def listen_for_agent_metadata_req(self, event: threading.Event):
assert self.local_agent_metadata is not None
port = envs.VLLM_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs.VLLM_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
port = envs_ascend.VLLM_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs_ascend.VLLM_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
url = f"tcp://0.0.0.0:{port}"
msg_encoder = msgspec.msgpack.Encoder()
msg_decoder = msgspec.msgpack.Decoder()
@@ -427,9 +427,9 @@ class LLMDataDistCMgrConnectorWorker():
def read_offline_rank_table(self):
assert (
envs.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH"
rank_table_path = envs.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
rank_table_path = envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
with open(rank_table_path, "r", encoding="utf-8") as f:
global_rank_table = json.load(f)
decode_device_list = global_rank_table["decode_device_list"]

View File

@@ -1,6 +1,6 @@
from vllm import ModelRegistry
import vllm_ascend.envs as envs
import vllm_ascend.envs as envs_ascend
def register_model():
@@ -21,7 +21,7 @@ def register_model():
"Qwen2VLForConditionalGeneration",
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
if envs.USE_OPTIMIZED_MODEL:
if envs_ascend.USE_OPTIMIZED_MODEL:
ModelRegistry.register_model(
"Qwen2_5_VLForConditionalGeneration",
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
@@ -32,7 +32,7 @@ def register_model():
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
)
if envs.VLLM_ASCEND_ENABLE_DBO:
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")

View File

@@ -18,7 +18,7 @@
# This file is a part of the vllm-ascend project.
import torch
import vllm.envs as envs
import vllm.envs as envs_vllm
from vllm.config import ParallelConfig
from vllm_ascend.utils import is_310p
@@ -37,7 +37,7 @@ def parallel_config_get_dp_port(self) -> int:
self.data_parallel_master_port += 1
# NOTE: Get port from envs directly when using torchrun
port = envs.VLLM_DP_MASTER_PORT if envs.VLLM_DP_MASTER_PORT else answer
port = envs_vllm.VLLM_DP_MASTER_PORT if envs_vllm.VLLM_DP_MASTER_PORT else answer
return port

View File

@@ -28,7 +28,7 @@ from vllm.distributed.parallel_state import get_tp_group
from vllm.logger import logger
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm_ascend import envs
import vllm_ascend.envs as envs_ascend
_HCOMM_INFO = None
@@ -142,6 +142,6 @@ class AscendRowParallelLinear(RowParallelLinear):
return output
if envs.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
if envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
logger.info("AscendRowParallelLinear: Matmul all-reduce is enabled. ")
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear

View File

@@ -20,7 +20,7 @@ from datetime import timedelta
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import vllm.envs as envs
import vllm.envs as envs_vllm
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import PrefixStore
from vllm.logger import logger
@@ -116,7 +116,7 @@ class NPUPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if not envs.VLLM_USE_V1:
if not envs_vllm.VLLM_USE_V1:
raise ValueError("vLLM Ascend does not support V0 engine.")
# initialize ascend config from vllm additional_config
ascend_config = init_ascend_config(vllm_config)

View File

@@ -23,7 +23,7 @@ import torch_npu
from vllm.distributed import GroupCoordinator, get_ep_group
from vllm.forward_context import get_forward_context
import vllm_ascend.envs as envs
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
@@ -1019,7 +1019,7 @@ class AscendW8A8DynamicFusedMoEMethod:
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
if envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)

View File

@@ -31,7 +31,7 @@ from packaging.version import InvalidVersion, Version
from torch_npu.npu.streams import Event
from vllm.logger import logger
import vllm_ascend.envs as envs
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
if TYPE_CHECKING:
@@ -236,7 +236,7 @@ def find_hccl_library() -> str:
After importing `torch`, `libhccl.so` can be
found by `ctypes` automatically.
"""
so_file = envs.HCCL_SO_PATH
so_file = envs_ascend.HCCL_SO_PATH
# manually load the hccl library
if so_file:
@@ -277,8 +277,8 @@ def adapt_patch(is_global_patch: bool = False):
@functools.cache
def vllm_version_is(target_vllm_version: str):
if envs.VLLM_VERSION is not None:
vllm_version = envs.VLLM_VERSION
if envs_ascend.VLLM_VERSION is not None:
vllm_version = envs_ascend.VLLM_VERSION
else:
import vllm
vllm_version = vllm.__version__
@@ -389,7 +389,7 @@ class ProfileExecuteDuration:
@contextmanager
def capture_async(self, duration_tag: str):
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
if not envs_ascend.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
yield
return
@@ -407,7 +407,7 @@ class ProfileExecuteDuration:
def pop_captured_sync(self) -> dict:
"""Pop and synchronize all events in the observation list"""
durations: dict[str, float] = {}
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
if not envs_ascend.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
return durations
while self._observations:
@@ -441,7 +441,7 @@ def get_rm_router_logits_state(ep_size: int, dp_size: int,
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if dp_size > 1:
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return True
elif ep_size == 1 and is_deepseek_v3_r1:
@@ -455,7 +455,7 @@ def get_rm_router_logits_state(ep_size: int, dp_size: int,
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return True
elif ep_size == 1 and is_deepseek_v3_r1:

View File

@@ -71,7 +71,6 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
@@ -172,7 +171,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.device = device
self.dtype = self.model_config.dtype
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
# TODO: drop the env config to use ascend sampler by default
from vllm_ascend.sample.sampler import AscendSampler

View File

@@ -23,8 +23,8 @@ from typing import Optional
import torch
import torch.nn as nn
import torch_npu
import vllm.envs as envs_vllm
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
@@ -317,8 +317,8 @@ class NPUWorker(WorkerBase):
def _init_profiler(self):
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
if envs_vllm.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs_vllm.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)