[perf][refactor] Refactor and optimize sfa_v1.py for dsv3.2/glm5 (#6874)
### What this PR does / why we need it?
This PR refactors sfa_v1.py to improve code readability and usability,
fixes a code bug, and enhances performance through the replacement of
certain operators.
### changes
- **improve code readability**: Optimizes parts of the code structure in
sfa_v1.py, supplementary comments for key code blocks, removes some
unused variables, and improves the naming of certain functions and
variables.
- **resolved a duplicated double write to k_cache**: Fixed redundant
double writes of k_cache in the indexer_select module (in both the
`forward` function and `indexer_select_post_process`), improving
performance to some extent.
- **replace `scatter` ops with `reshape_and_cache`**: This optimization
replaces two separate cache storage operations on `k_nope` and `k_pe`
with a single call to the `reshape_and_cache` operator, improving
performance. The original `scatter` operator involves reordering
slot_mapping for generality, introducing significant scalar
computations. In contrast, the `reshape_and_cache` operator eliminates
this redundant reordering step, thus reducing unnecessary computation
time and enhancing the operator's performance.
### performance comparison
4*A3, 1P1D, P dp2tp16, D dp8tp4, input/output: 64K/3K
origin:
TTFT: **28s**, TPOT: 26ms, TPS: **820 token/s**
fixed redundant double writes of k_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **840 token/s**
replace scatter ops with reshape_and_cache:
TTFT: **24s**, TPOT: 26ms, TPS: **850 token/s**
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -1138,10 +1138,24 @@ def enable_dsa_cp_with_layer_shard() -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
# because the broadcast in layer sharding needs to be overlapped with a heavy compute stream to be
|
||||
# effectively hidden, it is enabled only during the prefill stage.
|
||||
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
|
||||
return is_prefill_instance
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def enable_dsa_cp_with_o_proj_tp() -> bool:
|
||||
if not enable_dsa_cp():
|
||||
return False
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
# if is PD mix stage, using original TP o_proj weight, and also need to
|
||||
# full gather for o_proj weight for prefill stage.
|
||||
return vllm_config.kv_transfer_config is None
|
||||
|
||||
|
||||
def check_gdn_layer(vllm_config) -> bool:
|
||||
"""
|
||||
gdn layer is marked with `linear_attention`.
|
||||
|
||||
Reference in New Issue
Block a user