[main] mlp weight prefetch in Qwen Dense Models (#2816)
### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: main
- vLLM main:
a1213fae5f
Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Shuming19 <313093131@qq.com>
This commit is contained in:
@@ -44,12 +44,7 @@ class AddRMSNormW8A8Quant(RMSNorm):
|
||||
import torch_npu
|
||||
|
||||
if residual is not None:
|
||||
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
|
||||
# is enabled, without interfering with the normal operation of components like torchair.
|
||||
# The final solution should be to move this check into the operator and support
|
||||
# integration with torchair.
|
||||
if x.size(0) != residual.size(0):
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
assert x.size(0) == residual.size(0)
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
@@ -58,6 +53,7 @@ class AddRMSNormW8A8Quant(RMSNorm):
|
||||
self.layer.aclnn_input_scale,
|
||||
self.layer.aclnn_input_offset,
|
||||
epsilon=self.variance_epsilon)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
@@ -76,12 +72,7 @@ class AscendRMSNorm(RMSNorm):
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
if residual is not None:
|
||||
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
|
||||
# is enabled, without interfering with the normal operation of components like torchair.
|
||||
# The final solution should be to move this check into the operator and support
|
||||
# integration with torchair.
|
||||
if x.size(0) != residual.size(0):
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
assert x.size(0) == residual.size(0)
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
@@ -92,6 +83,7 @@ class AscendRMSNorm(RMSNorm):
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
|
||||
Reference in New Issue
Block a user