[perf][refactor] Refactor and optimize sfa_v1.py for dsv3.2/glm5 (#6874)

### What this PR does / why we need it?
This PR refactors sfa_v1.py to improve code readability and usability,
fixes a code bug, and enhances performance through the replacement of
certain operators.

### changes
- **improve code readability**: Optimizes parts of the code structure in
sfa_v1.py, supplementary comments for key code blocks, removes some
unused variables, and improves the naming of certain functions and
variables.

- **resolved a duplicated double write to k_cache**: Fixed redundant
double writes of k_cache in the indexer_select module (in both the
`forward` function and `indexer_select_post_process`), improving
performance to some extent.

- **replace `scatter` ops with `reshape_and_cache`**: This optimization
replaces two separate cache storage operations on `k_nope` and `k_pe`
with a single call to the `reshape_and_cache` operator, improving
performance. The original `scatter` operator involves reordering
slot_mapping for generality, introducing significant scalar
computations. In contrast, the `reshape_and_cache` operator eliminates
this redundant reordering step, thus reducing unnecessary computation
time and enhancing the operator's performance.

### performance comparison
4*A3, 1P1D, P dp2tp16, D dp8tp4, input/output: 64K/3K
origin:
TTFT: **28s**, TPOT: 26ms, TPS: **820 token/s**

fixed redundant double writes of k_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **840 token/s**

replace scatter ops with reshape_and_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **850 token/s**

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2026-03-05 14:27:11 +08:00
committed by GitHub
parent 77e009d9fc
commit 2bd9c35788
4 changed files with 676 additions and 515 deletions

View File

@@ -1138,10 +1138,24 @@ def enable_dsa_cp_with_layer_shard() -> bool:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
# because the broadcast in layer sharding needs to be overlapped with a heavy compute stream to be
# effectively hidden, it is enabled only during the prefill stage.
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
return is_prefill_instance
@lru_cache(maxsize=1)
def enable_dsa_cp_with_o_proj_tp() -> bool:
if not enable_dsa_cp():
return False
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
# if is PD mix stage, using original TP o_proj weight, and also need to
# full gather for o_proj weight for prefill stage.
return vllm_config.kv_transfer_config is None
def check_gdn_layer(vllm_config) -> bool:
"""
gdn layer is marked with `linear_attention`.