[main] mlp weight prefetch in Qwen Dense Models (#2816)

### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.

### Does this PR introduce _any_ user-facing change?
 No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: main
- vLLM main:
a1213fae5f

Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Shuming19 <313093131@qq.com>
This commit is contained in:
rjg-lyh
2025-09-11 21:20:09 +08:00
committed by GitHub
parent c3c2221503
commit 0005479b9c
17 changed files with 313 additions and 24 deletions

View File

@@ -35,8 +35,10 @@ class AscendSiluAndMul(SiluAndMul):
from vllm_ascend.utils import is_310p
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
if is_310p():
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
else:
out = torch_npu.npu_swiglu(x)
torch.ops.vllm.maybe_wait_prefetch_done(out)
return out

View File

@@ -44,12 +44,7 @@ class AddRMSNormW8A8Quant(RMSNorm):
import torch_npu
if residual is not None:
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
# is enabled, without interfering with the normal operation of components like torchair.
# The final solution should be to move this check into the operator and support
# integration with torchair.
if x.size(0) != residual.size(0):
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
@@ -58,6 +53,7 @@ class AddRMSNormW8A8Quant(RMSNorm):
self.layer.aclnn_input_scale,
self.layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
@@ -76,12 +72,7 @@ class AscendRMSNorm(RMSNorm):
from vllm_ascend.utils import is_310p
if residual is not None:
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
# is enabled, without interfering with the normal operation of components like torchair.
# The final solution should be to move this check into the operator and support
# integration with torchair.
if x.size(0) != residual.size(0):
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
if is_310p():
orig_dtype = residual.dtype
@@ -92,6 +83,7 @@ class AscendRMSNorm(RMSNorm):
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,

View File

@@ -390,6 +390,7 @@ class AscendRowParallelLinear(RowParallelLinear):
input_parallel,
bias=bias_)
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
output_bias = self.bias if self.skip_bias_add else None

View File

@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
import torch_npu
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
@@ -8,10 +9,16 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
import vllm_ascend.envs as envs_ascend
def _maybe_chunk_residual_impl(x: torch.Tensor,
residual: torch.Tensor) -> torch.Tensor:
if get_forward_context().flashcomm_v1_enabled:
if x.size(0) != residual.size(0):
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
assert flashcomm_v1_enabled is True, (
"Currently, this situation only occurs "
"when flashcomm_v1 is enabled")
pad_size = get_forward_context().pad_size
if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size))
@@ -44,6 +51,76 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
return tensor_model_parallel_all_reduce(x)
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
prefix: str) -> None:
forward_context = get_forward_context()
if not forward_context.prefetch_mlp_enabled:
return
model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream
layer_idx = int(prefix.split('.')[2])
# start point of gate_up_proj weight prefetch
if prefix.split('.')[-2] == "self_attn":
forward_context.prefetch_mlp_gate_up_proj = True
if forward_context.prefetch_mlp_gate_up_proj:
prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream):
MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
x_dependency, MLP_GATE_UP_PREFETCH_SIZE)
return
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
prefix: str) -> None:
return
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
forward_context = get_forward_context()
if not forward_context.prefetch_mlp_enabled:
return
forward_context.prefetch_mlp_down_proj = True
model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream
layer_idx = forward_context.layer_idx
# start point of down_proj weight prefetch
prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream):
MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
x_dependency, MLP_DOWN_PREFETCH_SIZE)
forward_context.layer_idx += 1
return
def _maybe_prefetch_mlp_down_proj_impl_fake(
x_dependency: torch.Tensor) -> None:
return
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
forward_context = get_forward_context()
if not forward_context.prefetch_mlp_enabled:
return
if forward_context.prefetch_mlp_gate_up_proj or \
forward_context.prefetch_mlp_down_proj:
prefetch_stream = get_forward_context().prefetch_stream
# wait until prefetch done
torch.npu.current_stream().wait_stream(prefetch_stream)
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
return
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
return
direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: residual,
@@ -60,4 +137,22 @@ direct_register_custom_op(op_name="maybe_pad_and_reduce",
op_func=_maybe_pad_and_reduce_impl,
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1")
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj",
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj",
op_func=_maybe_prefetch_mlp_down_proj_impl,
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_wait_prefetch_done",
op_func=_maybe_wait_prefetch_done_impl,
fake_impl=_maybe_wait_prefetch_done_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")