[main] mlp weight prefetch in Qwen Dense Models (#2816)

### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.

### Does this PR introduce _any_ user-facing change?
 No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: main
- vLLM main:
a1213fae5f

Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Shuming19 <313093131@qq.com>
This commit is contained in:
rjg-lyh
2025-09-11 21:20:09 +08:00
committed by GitHub
parent c3c2221503
commit 0005479b9c
17 changed files with 313 additions and 24 deletions

View File

@@ -66,7 +66,9 @@ def set_ascend_forward_context(
moe_comm_method: str = "",
num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None):
batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
@@ -108,7 +110,8 @@ def set_ascend_forward_context(
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000
@@ -122,6 +125,26 @@ def set_ascend_forward_context(
# set this for rope forward_oot using
forward_context.is_first_layer = True
# set layer_idx to enable optimization features that depend on this information.
# This is only applicable to models that contain these necessary attributes.
forward_context.layer_idx = None
if model_instance is not None and \
hasattr(model_instance, "model") and \
hasattr(model_instance.model, "start_layer"):
forward_context.layer_idx = model_instance.model.start_layer
# set for mlp weight prefetch
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500
if prefetch_mlp_enabled:
forward_context.prefetch_stream = prefetch_stream
forward_context.model_instance = model_instance
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens