[misc][torchair] fix bugs around deepseek mtp, enable_shared_expert_dp and use_cached_kv_cache_bytes (#3074)

### What this PR does / why we need it?
This miscellaneous​ contains several small fixes:
1) fix initialization and forward bugs of DeepseekMTPLayer with
`shared_expert_dp` enabled.
2) fix a tensor shape mismatches after o_proj caused by a work-aroud
change in NPUModelRunner.
3) avoid unnecessary decline of kv_cache memory (default: 64MB) with
`use_cached_kv_cache_bytes` disabled.
4) fall back `fused_moe_state` from `MC2` to `All2All` since the padding
logic of `mc2_mask` is incompatible with input hidden_states when
`shared_expert_dp` enabled.

Once this PR is merged, users can launch disaggregated_prefill
deployments (large_ep) with `deepseek_mtp` and `shared_expert_dp` as
`v0.9.1-dev` branch. The remaining problem of kv_cache tokens decline
compared to `v0.9.1-dev` will be resolved by
https://github.com/vllm-project/vllm-ascend/pull/3073.
 
### Does this PR introduce _any_ user-facing change?

No.
### How was this patch tested?
E2E vllm serving about deepseek_mtp with torchair graph mode and
`enable_shared_expert_dp` with eager mode. Large ep deployments are also
tested with this PR.


- vLLM version: v0.10.2
- vLLM main:
5aeb925452

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2025-09-23 14:52:42 +08:00
committed by GitHub
parent 0f3939e5a9
commit d01fd1d1c3
7 changed files with 36 additions and 24 deletions

View File

@@ -72,6 +72,7 @@ def mock_dist_env(mocker: MockerFixture):
return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False,
enable_shared_expert_dp=False,
expert_map_path=None
)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map',

View File

@@ -60,6 +60,8 @@ class MtpProposer(Proposer):
self.torchair_compiled_models = {} # type: ignore
self.torchair_graph_enabled = get_ascend_config(
).torchair_graph_config.enabled
self.enable_shared_expert_dp = get_ascend_config(
).enable_shared_expert_dp
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
@@ -79,7 +81,9 @@ class MtpProposer(Proposer):
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
if self.torchair_graph_enabled:
if self.torchair_graph_enabled or (
self.enable_shared_expert_dp
and self.vllm_config.model_config.use_mla):
self.model = TorchairDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
else:

View File

@@ -813,7 +813,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
attn_metadata = next(iter(attn_metadata.values()), None)
if attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
else:

View File

@@ -803,6 +803,7 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
try:
device_group = get_mc2_group().device_group
@@ -884,6 +885,8 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
fused_moe_state = get_forward_context().fused_moe_state
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
fused_moe_state = FusedMoEState.All2All
if fused_moe_state == FusedMoEState.MC2:
return torchair_fused_experts_with_mc2(
@@ -1155,6 +1158,8 @@ class TorchairAscendFusedMoE(FusedMoE):
forward_context = get_forward_context()
fused_moe_state = forward_context.fused_moe_state
mc2_mask = forward_context.mc2_mask
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
fused_moe_state = FusedMoEState.All2All
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \

View File

@@ -833,6 +833,7 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
try:
device_group = get_mc2_group().device_group
@@ -946,6 +947,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
)
fused_moe_state = get_forward_context().fused_moe_state
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
fused_moe_state = FusedMoEState.All2All
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):

View File

@@ -87,8 +87,8 @@ class NPUTorchairModelRunner(NPUModelRunner):
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
"""Override from NPUModelRunner to pad num_tokens"""
if self.enable_shared_expert_dp:
return super()._sync_metadata_across_dp(num_tokens, with_prefill,
enable_dbo)
# Padding is not required for shared_expert_dp cases in eager mode.
return num_tokens, None, with_prefill, enable_dbo
if self.dp_size == 1:
if not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(

View File

@@ -35,26 +35,25 @@ class NPUTorchairWorker(NPUWorker):
ascend_config = get_ascend_config()
if ascend_config.enable_shared_expert_dp:
return available_kv_cache_memory
if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes and check_kv_cache_bytes_cache_exist(
):
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
torch.distributed.get_rank())
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
logger.info(
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
)
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
return old_kv_cache_bytes
else:
logger.info(
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
)
delete_torchair_cache_file()
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
available_kv_cache_memory -= bytes_floating_tolerance
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
if check_kv_cache_bytes_cache_exist():
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
torch.distributed.get_rank())
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
logger.info(
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
)
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
return old_kv_cache_bytes
else:
logger.info(
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
)
delete_torchair_cache_file()
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
available_kv_cache_memory -= bytes_floating_tolerance
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
return available_kv_cache_memory
def init_device(self):