[main] mlp weight prefetch in Qwen Dense Models (#2816)

### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.

### Does this PR introduce _any_ user-facing change?
 No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: main
- vLLM main:
a1213fae5f

Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Shuming19 <313093131@qq.com>
This commit is contained in:
rjg-lyh
2025-09-11 21:20:09 +08:00
committed by GitHub
parent c3c2221503
commit 0005479b9c
17 changed files with 313 additions and 24 deletions

View File

@@ -44,12 +44,7 @@ class AddRMSNormW8A8Quant(RMSNorm):
import torch_npu
if residual is not None:
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
# is enabled, without interfering with the normal operation of components like torchair.
# The final solution should be to move this check into the operator and support
# integration with torchair.
if x.size(0) != residual.size(0):
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
@@ -58,6 +53,7 @@ class AddRMSNormW8A8Quant(RMSNorm):
self.layer.aclnn_input_scale,
self.layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
@@ -76,12 +72,7 @@ class AscendRMSNorm(RMSNorm):
from vllm_ascend.utils import is_310p
if residual is not None:
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
# is enabled, without interfering with the normal operation of components like torchair.
# The final solution should be to move this check into the operator and support
# integration with torchair.
if x.size(0) != residual.size(0):
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
if is_310p():
orig_dtype = residual.dtype
@@ -92,6 +83,7 @@ class AscendRMSNorm(RMSNorm):
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,