[Refactor] Unify full-graph parameter update logic (#6041)
### What this PR does / why we need it? **Refactor: Unify full-graph parameter update logic** This PR consolidates the scattered full-graph parameter update logic into a unified approach, improving code architecture and eliminating duplication. **Key improvements:** 1. **Unified interface** - Create `update_full_graph_params` as the single entry point for all full-graph updates - Replace multiple scattered update calls with one unified function - Remove ~50 lines of duplicated if-else logic across `model_runner_v1.py` and `eagle_proposer.py` 2. **Better architecture** - Move update logic to respective Backend classes (`AscendAttentionBackend`, `AscendMLABackend`) - Each Backend manages its own parameter update logic internally - Simplify caller code to just dispatch to the appropriate Backend 3. **Cleaner parameter handling** - Remove unnecessary `pcp_size` and `dcp_size` parameter passing - Get parallel configuration directly from distributed groups - Consistent with how other parts of the codebase obtain these values **Why we need it:** - **Maintainability**: Future changes only need to be made in one place per Backend - **Code quality**: Follows DRY principle and Single Responsibility Principle - **Readability**: Cleaner, more intuitive code structure ### Does this PR introduce _any_ user-facing change? **No.** This is a pure refactoring with no functional changes - same behavior, cleaner code. ### How was this patch tested? - All existing unit tests pass with updated mocks - No new tests needed (pure refactoring, no behavior changes) - CI validates correctness --- - vLLM version: v0.13.0 Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: drslark <slarksblood@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -371,6 +371,144 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_graph_params(
|
||||
update_stream,
|
||||
forward_context,
|
||||
num_tokens,
|
||||
vllm_config,
|
||||
speculative_config=None,
|
||||
num_dcp_pcp_tokens=None,
|
||||
):
|
||||
if using_paged_attention(num_tokens, vllm_config):
|
||||
# Paged Attention update logic
|
||||
if forward_context.is_draft_model:
|
||||
graph_params = get_draft_graph_params()
|
||||
else:
|
||||
graph_params = get_graph_params()
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[num_tokens],
|
||||
graph_params.handles[num_tokens],
|
||||
graph_params.events[num_tokens],
|
||||
):
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
block_table,
|
||||
seq_lens,
|
||||
output,
|
||||
) = param
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
)
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace,
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
event.record(update_stream)
|
||||
else:
|
||||
# FIA update logic
|
||||
if forward_context.is_draft_model:
|
||||
graph_params = get_draft_graph_params()
|
||||
attn_metadata = forward_context.draft_attn_metadatas
|
||||
attn_keys = list(attn_metadata[0].keys())
|
||||
else:
|
||||
graph_params = get_graph_params()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
attn_keys = list(attn_metadata.keys())
|
||||
# For Qwen3-next, since the kv_cache_config has already categorized
|
||||
# linear_attn and self_attn, the attn_metadata is first arranged with
|
||||
# self_attn followed by linear_attn. Therefore, using zip directly
|
||||
# filters out the update operations for linear_attn.
|
||||
# TODO: We use a new variable `attn_keys` to ensure the loop count is
|
||||
# correct after get by `zip` because of the new structure of the attn_metadata
|
||||
# when running with the merged full eagle-graph. Should check it with Qwen3-next.
|
||||
num_layers = len(attn_keys)
|
||||
if num_layers == 0:
|
||||
return
|
||||
if forward_context.is_draft_model:
|
||||
attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers)
|
||||
attn_count = 0
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
attn_keys,
|
||||
graph_params.attn_params[num_tokens],
|
||||
graph_params.handles[num_tokens],
|
||||
graph_params.events[num_tokens],
|
||||
):
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value,
|
||||
block_tables,
|
||||
attn_mask,
|
||||
block_size,
|
||||
seq_lens,
|
||||
query_start_loc,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
attn_output,
|
||||
softmax_lse,
|
||||
) = param
|
||||
|
||||
if forward_context.is_draft_model:
|
||||
draft_step = attn_count // num_layers
|
||||
seq_lens = attn_metadata[draft_step][key].seq_lens_list
|
||||
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
|
||||
attn_count = attn_count + 1
|
||||
else:
|
||||
seq_lens = attn_metadata[key].seq_lens_list
|
||||
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q
|
||||
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query=query,
|
||||
key=key_cache,
|
||||
value=value,
|
||||
block_table=block_tables,
|
||||
atten_mask=attn_mask,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=seq_lens,
|
||||
num_key_value_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale=scale,
|
||||
sparse_mode=3,
|
||||
workspace=graph_params.workspaces.get(num_tokens),
|
||||
out=[attn_output, softmax_lse],
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
super().process_weights_after_loading(act_dtype)
|
||||
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||
|
||||
Reference in New Issue
Block a user