[Refactor] Unify full-graph parameter update logic (#6041)
### What this PR does / why we need it? **Refactor: Unify full-graph parameter update logic** This PR consolidates the scattered full-graph parameter update logic into a unified approach, improving code architecture and eliminating duplication. **Key improvements:** 1. **Unified interface** - Create `update_full_graph_params` as the single entry point for all full-graph updates - Replace multiple scattered update calls with one unified function - Remove ~50 lines of duplicated if-else logic across `model_runner_v1.py` and `eagle_proposer.py` 2. **Better architecture** - Move update logic to respective Backend classes (`AscendAttentionBackend`, `AscendMLABackend`) - Each Backend manages its own parameter update logic internally - Simplify caller code to just dispatch to the appropriate Backend 3. **Cleaner parameter handling** - Remove unnecessary `pcp_size` and `dcp_size` parameter passing - Get parallel configuration directly from distributed groups - Consistent with how other parts of the codebase obtain these values **Why we need it:** - **Maintainability**: Future changes only need to be made in one place per Backend - **Code quality**: Follows DRY principle and Single Responsibility Principle - **Readability**: Cleaner, more intuitive code structure ### Does this PR introduce _any_ user-facing change? **No.** This is a pure refactoring with no functional changes - same behavior, cleaner code. ### How was this patch tested? - All existing unit tests pass with updated mocks - No new tests needed (pure refactoring, no behavior changes) - CI validates correctness --- - vLLM version: v0.13.0 Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: drslark <slarksblood@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -24,12 +24,15 @@ from vllm.forward_context import BatchDescriptor, ForwardContext
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
|
||||
AscendMetadataForDecode)
|
||||
from vllm_ascend.attention.context_parallel.attention_cp import \
|
||||
AscendAttentionCPImpl
|
||||
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
|
||||
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||
AscendMLAMetadata)
|
||||
from vllm_ascend.compilation.acl_graph import (
|
||||
ACLGraphEntry, ACLGraphWrapper, get_draft_graph_params, get_graph_params,
|
||||
set_draft_graph_params, set_graph_params, update_attn_dcp_pcp_params,
|
||||
update_draft_graph_params_workspaces, update_mla_attn_dcp_pcp_params)
|
||||
set_draft_graph_params, set_graph_params,
|
||||
update_draft_graph_params_workspaces)
|
||||
|
||||
|
||||
class TestACLGraphEntry(TestBase):
|
||||
@@ -811,8 +814,9 @@ class TestPCPDCPGraphParams(TestBase):
|
||||
out, lse))
|
||||
|
||||
with patch("torch_npu._C._npu_setStream", return_value=None):
|
||||
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context,
|
||||
4)
|
||||
AscendMlaCPImpl.update_graph_params(
|
||||
self.update_stream, forward_context, 4
|
||||
)
|
||||
|
||||
_mock_graph_task_end.assert_called_once()
|
||||
|
||||
@@ -852,6 +856,8 @@ class TestPCPDCPGraphParams(TestBase):
|
||||
out, lse, 2, 0, 0))
|
||||
|
||||
with patch("torch_npu._C._npu_setStream", return_value=None):
|
||||
update_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
|
||||
AscendAttentionCPImpl.update_graph_params(
|
||||
self.update_stream, forward_context, 4, None
|
||||
)
|
||||
|
||||
_mock_graph_task_end.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user