[Misc] Refactor additional_config (#1029)

More and more config options are added to additional_config. This PR
provide a new AscendConfig to manage these config options by an easier
way to make code cleaner and readable.

 This PR also added the `additional_config` doc for users.

Added the test_ascend_config.py to make sure the new AscendConfig works
as expect.

TODO: Add e2e test with torchair and deepseek once the CI resource is
available.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-06-05 16:28:01 +08:00
committed by GitHub
parent 7737aaa40f
commit e1ab6d318e
23 changed files with 456 additions and 208 deletions

View File

@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
@@ -587,11 +588,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.local_batch_size = self.global_batch_size // self.ep_size
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
try:
device_group = ep_group.device_group
@@ -678,7 +676,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
elif self.enable_graph_mode or get_ep_group().world_size == 1:
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -772,11 +770,8 @@ class AscendFusedMoE(FusedMoE):
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
@@ -818,12 +813,6 @@ class AscendFusedMoE(FusedMoE):
self.ep_group = get_ep_group()
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -844,13 +833,13 @@ class AscendFusedMoE(FusedMoE):
if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif self.enable_graph_mode:
elif self.torchair_graph_enabled:
if USING_LCCL_COM: # type: ignore
hidden_states = get_dp_group().all_gather(
hidden_states, 0, False)
router_logits = get_dp_group().all_gather(
router_logits, 0, False)
elif self.enable_graph_mode and not is_prefill:
elif self.torchair_graph_enabled and not is_prefill:
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
else:
@@ -878,14 +867,14 @@ class AscendFusedMoE(FusedMoE):
if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif self.enable_graph_mode:
elif self.torchair_graph_enabled:
if USING_LCCL_COM: # type: ignore
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
elif self.enable_graph_mode and not is_prefill:
elif self.torchair_graph_enabled and not is_prefill:
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
"sum",