[main] mlp weight prefetch in Qwen Dense Models (#2816)
### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: main
- vLLM main:
a1213fae5f
Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Shuming19 <313093131@qq.com>
This commit is contained in:
@@ -227,6 +227,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
self.device = device
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
|
||||
self.prefetch_stream = torch.npu.Stream(device=device)
|
||||
else:
|
||||
self.prefetch_stream = None
|
||||
self.dtype = self.model_config.dtype
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
|
||||
# TODO: drop the env config to use ascend sampler by default
|
||||
@@ -1592,7 +1596,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
num_actual_tokens=scheduler_output.
|
||||
total_num_scheduled_tokens):
|
||||
total_num_scheduled_tokens,
|
||||
prefetch_stream=self.prefetch_stream,
|
||||
model_instance=self.model):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
hidden_states = self._generate_process_reqs_hidden_states(
|
||||
@@ -2057,7 +2063,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
moe_comm_method=moe_comm_method,
|
||||
num_actual_tokens=0,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor):
|
||||
batch_descriptor=batch_descriptor,
|
||||
prefetch_stream=self.prefetch_stream,
|
||||
model_instance=self.model):
|
||||
hidden_states = self._generate_dummy_run_hidden_states(
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens, intermediate_tensors,
|
||||
|
||||
@@ -51,6 +51,18 @@ from vllm_ascend.utils import (init_ascend_soc_version,
|
||||
try_register_lib)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402
|
||||
from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402
|
||||
|
||||
torch_non_c_binding_in_graph_functions_npu = dict.fromkeys(
|
||||
["torch.npu.current_stream"],
|
||||
TorchInGraphFunctionVariable,
|
||||
) # noqa: E402
|
||||
torch_non_c_binding_in_graph_functions_npu[
|
||||
"torch.npu.stream"] = TorchInGraphFunctionVariable # noqa: E402
|
||||
torch._dynamo.trace_rules.torch_name_rule_map.append(
|
||||
torch_non_c_binding_in_graph_functions_npu) # noqa: E402
|
||||
|
||||
|
||||
class NPUWorker(WorkerBase):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user