[Misc] Refactor additional_config (#1029)

More and more config options are added to additional_config. This PR
provide a new AscendConfig to manage these config options by an easier
way to make code cleaner and readable.

 This PR also added the `additional_config` doc for users.

Added the test_ascend_config.py to make sure the new AscendConfig works
as expect.

TODO: Add e2e test with torchair and deepseek once the CI resource is
available.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-06-05 16:28:01 +08:00
committed by GitHub
parent 7737aaa40f
commit e1ab6d318e
23 changed files with 456 additions and 208 deletions

View File

@@ -34,8 +34,7 @@ import vllm.envs as envs
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce)
@@ -67,6 +66,7 @@ from vllm.model_executor.models.utils import (
from vllm.sequence import IntermediateTensors
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor
@@ -214,11 +214,8 @@ class CustomDeepseekV2MoE(nn.Module):
self.params_dtype = torch.get_default_dtype()
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
def forward(
self,
@@ -248,7 +245,7 @@ class CustomDeepseekV2MoE(nn.Module):
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
hidden_states = chunks[self.tp_rank]
elif not self.enable_graph_mode:
elif not self.torchair_graph_enabled:
num_padding_tokens = (self.tp_size -
num_tokens % self.tp_size) % self.tp_size
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
@@ -272,7 +269,7 @@ class CustomDeepseekV2MoE(nn.Module):
) * self.routed_scaling_factor
if self.tp_size > 1:
if self.enable_graph_mode:
if self.torchair_graph_enabled:
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
final_hidden_states = torch.zeros(
[num_tokens, hidden_size],
@@ -423,11 +420,9 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
def forward(
self,
@@ -440,7 +435,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
if self.enable_graph_mode:
if self.torchair_graph_enabled:
forward_kwargs = {}
if envs.VLLM_USE_V1:
output_shape = hidden_states.shape