[Refactor] Unify full-graph parameter update logic (#6041)

### What this PR does / why we need it?

**Refactor: Unify full-graph parameter update logic**

This PR consolidates the scattered full-graph parameter update logic
into a unified approach, improving code architecture and eliminating
duplication.

**Key improvements:**

1. **Unified interface**
- Create `update_full_graph_params` as the single entry point for all
full-graph updates
   - Replace multiple scattered update calls with one unified function
- Remove ~50 lines of duplicated if-else logic across
`model_runner_v1.py` and `eagle_proposer.py`

2. **Better architecture**
- Move update logic to respective Backend classes
(`AscendAttentionBackend`, `AscendMLABackend`)
   - Each Backend manages its own parameter update logic internally
   - Simplify caller code to just dispatch to the appropriate Backend

3. **Cleaner parameter handling**
   - Remove unnecessary `pcp_size` and `dcp_size` parameter passing
   - Get parallel configuration directly from distributed groups
   - Consistent with how other parts of the codebase obtain these values

**Why we need it:**
- **Maintainability**: Future changes only need to be made in one place
per Backend
- **Code quality**: Follows DRY principle and Single Responsibility
Principle
- **Readability**: Cleaner, more intuitive code structure

### Does this PR introduce _any_ user-facing change?

**No.** This is a pure refactoring with no functional changes - same
behavior, cleaner code.

### How was this patch tested?

- All existing unit tests pass with updated mocks
- No new tests needed (pure refactoring, no behavior changes)
- CI validates correctness

---

- vLLM version: v0.13.0

Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: drslark <slarksblood@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
LICO67373
2026-01-24 20:12:57 +08:00
committed by GitHub
parent 8129c429ef
commit 8966a99710
10 changed files with 420 additions and 415 deletions

View File

@@ -371,6 +371,144 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)
@staticmethod
def update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config,
speculative_config=None,
num_dcp_pcp_tokens=None,
):
if using_paged_attention(num_tokens, vllm_config):
# Paged Attention update logic
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
)
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
else:
# FIA update logic
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
attn_metadata = forward_context.draft_attn_metadatas
attn_keys = list(attn_metadata[0].keys())
else:
graph_params = get_graph_params()
attn_metadata = forward_context.attn_metadata
attn_keys = list(attn_metadata.keys())
# For Qwen3-next, since the kv_cache_config has already categorized
# linear_attn and self_attn, the attn_metadata is first arranged with
# self_attn followed by linear_attn. Therefore, using zip directly
# filters out the update operations for linear_attn.
# TODO: We use a new variable `attn_keys` to ensure the loop count is
# correct after get by `zip` because of the new structure of the attn_metadata
# when running with the merged full eagle-graph. Should check it with Qwen3-next.
num_layers = len(attn_keys)
if num_layers == 0:
return
if forward_context.is_draft_model:
attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers)
attn_count = 0
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
attn_keys,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
query,
key_cache,
value,
block_tables,
attn_mask,
block_size,
seq_lens,
query_start_loc,
num_kv_heads,
num_heads,
scale,
attn_output,
softmax_lse,
) = param
if forward_context.is_draft_model:
draft_step = attn_count // num_layers
seq_lens = attn_metadata[draft_step][key].seq_lens_list
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
attn_count = attn_count + 1
else:
seq_lens = attn_metadata[key].seq_lens_list
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key_cache,
value=value,
block_table=block_tables,
atten_mask=attn_mask,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=seq_lens,
num_key_value_heads=num_kv_heads,
num_heads=num_heads,
scale=scale,
sparse_mode=3,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def process_weights_after_loading(self, act_dtype: torch.dtype):
super().process_weights_after_loading(act_dtype)
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():