[main] mlp weight prefetch in Qwen Dense Models (#2816)

### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.

### Does this PR introduce _any_ user-facing change?
 No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: main
- vLLM main:
a1213fae5f

Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Shuming19 <313093131@qq.com>
This commit is contained in:
rjg-lyh
2025-09-11 21:20:09 +08:00
committed by GitHub
parent c3c2221503
commit 0005479b9c
17 changed files with 313 additions and 24 deletions

View File

@@ -227,6 +227,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.device = device
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
self.prefetch_stream = torch.npu.Stream(device=device)
else:
self.prefetch_stream = None
self.dtype = self.model_config.dtype
if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
# TODO: drop the env config to use ascend sampler by default
@@ -1592,7 +1596,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
num_actual_tokens=scheduler_output.
total_num_scheduled_tokens):
total_num_scheduled_tokens,
prefetch_stream=self.prefetch_stream,
model_instance=self.model):
self.maybe_setup_kv_connector(scheduler_output)
hidden_states = self._generate_process_reqs_hidden_states(
@@ -2057,7 +2063,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
moe_comm_method=moe_comm_method,
num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor):
batch_descriptor=batch_descriptor,
prefetch_stream=self.prefetch_stream,
model_instance=self.model):
hidden_states = self._generate_dummy_run_hidden_states(
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors,

View File

@@ -51,6 +51,18 @@ from vllm_ascend.utils import (init_ascend_soc_version,
try_register_lib)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402
from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402
torch_non_c_binding_in_graph_functions_npu = dict.fromkeys(
["torch.npu.current_stream"],
TorchInGraphFunctionVariable,
) # noqa: E402
torch_non_c_binding_in_graph_functions_npu[
"torch.npu.stream"] = TorchInGraphFunctionVariable # noqa: E402
torch._dynamo.trace_rules.torch_name_rule_map.append(
torch_non_c_binding_in_graph_functions_npu) # noqa: E402
class NPUWorker(WorkerBase):