[main] flashcomm_v1 optim in Qwen Dense Models (#2802)

### What this PR does / why we need it?
Flashcomm_v1 optim in Qwen Dense Models.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.10.1.1
- vLLM main:
5e537f45b4

Co-authored-by: 1024daniel <xxltju324@gmail.com>
This commit is contained in:
rjg-lyh
2025-09-08 22:52:24 +08:00
committed by GitHub
parent 4df8df5b94
commit 1bbb20ea13
11 changed files with 362 additions and 20 deletions

View File

@@ -83,6 +83,7 @@ def set_ascend_forward_context(
forward_context = get_forward_context()
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
forward_context.with_prefill = with_prefill
tp_world_size = get_tensor_model_parallel_world_size()
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
@@ -103,6 +104,21 @@ def set_ascend_forward_context(
# due to multiple warmups before actual capturing
forward_context.capturing = False
# set for flashcomm_v1, 1000 is the batchsize concurrency threshold for enabling the flashcomm_v1 feature.
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000
if flashcomm_v1_enabled:
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled
if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
@@ -118,7 +134,6 @@ def set_ascend_forward_context(
if num_tokens is not None:
if num_actual_tokens is None:
num_actual_tokens = num_tokens
tp_world_size = get_tensor_model_parallel_world_size()
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size