[main] mlp weight prefetch in Qwen Dense Models (#2816)
### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: main
- vLLM main:
a1213fae5f
Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Shuming19 <313093131@qq.com>
This commit is contained in:
@@ -35,8 +35,10 @@ class AscendSiluAndMul(SiluAndMul):
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
|
||||
if is_310p():
|
||||
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
|
||||
else:
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(out)
|
||||
return out
|
||||
|
||||
@@ -44,12 +44,7 @@ class AddRMSNormW8A8Quant(RMSNorm):
|
||||
import torch_npu
|
||||
|
||||
if residual is not None:
|
||||
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
|
||||
# is enabled, without interfering with the normal operation of components like torchair.
|
||||
# The final solution should be to move this check into the operator and support
|
||||
# integration with torchair.
|
||||
if x.size(0) != residual.size(0):
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
assert x.size(0) == residual.size(0)
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
@@ -58,6 +53,7 @@ class AddRMSNormW8A8Quant(RMSNorm):
|
||||
self.layer.aclnn_input_scale,
|
||||
self.layer.aclnn_input_offset,
|
||||
epsilon=self.variance_epsilon)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
@@ -76,12 +72,7 @@ class AscendRMSNorm(RMSNorm):
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
if residual is not None:
|
||||
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature
|
||||
# is enabled, without interfering with the normal operation of components like torchair.
|
||||
# The final solution should be to move this check into the operator and support
|
||||
# integration with torchair.
|
||||
if x.size(0) != residual.size(0):
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
assert x.size(0) == residual.size(0)
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
@@ -92,6 +83,7 @@ class AscendRMSNorm(RMSNorm):
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
|
||||
@@ -390,6 +390,7 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
|
||||
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
@@ -8,10 +9,16 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||
residual: torch.Tensor) -> torch.Tensor:
|
||||
if get_forward_context().flashcomm_v1_enabled:
|
||||
if x.size(0) != residual.size(0):
|
||||
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
|
||||
assert flashcomm_v1_enabled is True, (
|
||||
"Currently, this situation only occurs "
|
||||
"when flashcomm_v1 is enabled")
|
||||
pad_size = get_forward_context().pad_size
|
||||
if pad_size > 0:
|
||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||
@@ -44,6 +51,76 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
||||
prefix: str) -> None:
|
||||
forward_context = get_forward_context()
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
return
|
||||
model_instance = forward_context.model_instance
|
||||
prefetch_stream = forward_context.prefetch_stream
|
||||
layer_idx = int(prefix.split('.')[2])
|
||||
|
||||
# start point of gate_up_proj weight prefetch
|
||||
if prefix.split('.')[-2] == "self_attn":
|
||||
forward_context.prefetch_mlp_gate_up_proj = True
|
||||
if forward_context.prefetch_mlp_gate_up_proj:
|
||||
prefetch_stream.wait_stream(torch.npu.current_stream())
|
||||
|
||||
with torch.npu.stream(prefetch_stream):
|
||||
MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE
|
||||
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
|
||||
x_dependency, MLP_GATE_UP_PREFETCH_SIZE)
|
||||
return
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
|
||||
prefix: str) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
|
||||
forward_context = get_forward_context()
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
return
|
||||
forward_context.prefetch_mlp_down_proj = True
|
||||
model_instance = forward_context.model_instance
|
||||
prefetch_stream = forward_context.prefetch_stream
|
||||
layer_idx = forward_context.layer_idx
|
||||
|
||||
# start point of down_proj weight prefetch
|
||||
prefetch_stream.wait_stream(torch.npu.current_stream())
|
||||
|
||||
with torch.npu.stream(prefetch_stream):
|
||||
MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE
|
||||
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
|
||||
x_dependency, MLP_DOWN_PREFETCH_SIZE)
|
||||
forward_context.layer_idx += 1
|
||||
return
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_down_proj_impl_fake(
|
||||
x_dependency: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
|
||||
forward_context = get_forward_context()
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
return
|
||||
if forward_context.prefetch_mlp_gate_up_proj or \
|
||||
forward_context.prefetch_mlp_down_proj:
|
||||
prefetch_stream = get_forward_context().prefetch_stream
|
||||
# wait until prefetch done
|
||||
torch.npu.current_stream().wait_stream(prefetch_stream)
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
return
|
||||
|
||||
|
||||
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(op_name="maybe_chunk_residual",
|
||||
op_func=_maybe_chunk_residual_impl,
|
||||
fake_impl=lambda x, residual: residual,
|
||||
@@ -60,4 +137,22 @@ direct_register_custom_op(op_name="maybe_pad_and_reduce",
|
||||
op_func=_maybe_pad_and_reduce_impl,
|
||||
fake_impl=lambda x: x,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj",
|
||||
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
|
||||
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj",
|
||||
op_func=_maybe_prefetch_mlp_down_proj_impl,
|
||||
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_wait_prefetch_done",
|
||||
op_func=_maybe_wait_prefetch_done_impl,
|
||||
fake_impl=_maybe_wait_prefetch_done_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
Reference in New Issue
Block a user