[main] mlp weight prefetch in Qwen Dense Models (#2816)

### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.

### Does this PR introduce _any_ user-facing change?
 No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: main
- vLLM main:
a1213fae5f

Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Shuming19 <313093131@qq.com>
This commit is contained in:
rjg-lyh
2025-09-11 21:20:09 +08:00
committed by GitHub
parent c3c2221503
commit 0005479b9c
17 changed files with 313 additions and 24 deletions

View File

@@ -258,4 +258,4 @@ jobs:
VLLM_WORKER_MULTIPROC_METHOD: spawn VLLM_WORKER_MULTIPROC_METHOD: spawn
VLLM_USE_MODELSCOPE: True VLLM_USE_MODELSCOPE: True
run: | run: |
pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP

View File

@@ -226,6 +226,8 @@ jobs:
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight
#pytest -sv tests/e2e/multicard/test_pipeline_parallel.py #pytest -sv tests/e2e/multicard/test_pipeline_parallel.py
pytest -sv tests/e2e/multicard/test_prefix_caching.py pytest -sv tests/e2e/multicard/test_prefix_caching.py

View File

@@ -31,7 +31,9 @@ from tests.e2e.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
QWEN_DENSE_MODELS = ["Qwen/QwQ-32B", "Qwen/Qwen-32B"] QWEN_DENSE_MODELS = [
"vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8"
]
def test_models_distributed_QwQ(): def test_models_distributed_QwQ():
@@ -170,6 +172,29 @@ def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
max_model_len=8192, max_model_len=8192,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
dtype="auto", dtype="auto",
tensor_parallel_size=4, tensor_parallel_size=2,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"})
def test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight(
model, enforce_eager):
example_prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download(model),
max_model_len=8192,
enforce_eager=enforce_eager,
dtype="auto",
tensor_parallel_size=2,
quantization="ascend",
) as vllm_model: ) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens) vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -38,7 +38,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor):
@pytest.mark.parametrize("is_310p_return", [True, False]) @pytest.mark.parametrize("is_310p_return", [True, False])
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1) @patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor): @patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj",
side_effect=lambda x: None)
def test_SiluAndMul_forward(mock_maybe_prefetch_mlp_down_proj,
mock_maybe_wait_prefetch_done, mock_swiglu,
is_310p_return, dummy_tensor):
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return): with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
layer = SiluAndMul() layer = SiluAndMul()
@@ -49,9 +54,15 @@ def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
else: else:
expected_arg = dummy_tensor expected_arg = dummy_tensor
# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
mock_maybe_prefetch_mlp_down_proj.assert_called_once()
# assert mock_swiglu.call_count == 1 # assert mock_swiglu.call_count == 1
mock_swiglu.assert_called_once() mock_swiglu.assert_called_once()
# assert mock_maybe_wait_prefetch_done.call_count == 1
mock_maybe_wait_prefetch_done.assert_called_once()
actual_arg = mock_swiglu.call_args[0][0] actual_arg = mock_swiglu.call_args[0][0]
assert torch.allclose( assert torch.allclose(
actual_arg, actual_arg,

View File

@@ -30,9 +30,11 @@ def mock_add_rms_norm(x, residual, weight, eps):
[None, torch.randn(4, 8, dtype=torch.float32)]) [None, torch.randn(4, 8, dtype=torch.float32)])
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) @patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) @patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_chunk_residual", @patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual) side_effect=mock_maybe_chunk_residual)
def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm, def test_RMSNorm_forward(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done, mock_add_rmsnorm,
mock_rmsnorm, is_310p_return, residual, dummy_tensor): mock_rmsnorm, is_310p_return, residual, dummy_tensor):
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return): with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
@@ -45,13 +47,17 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
expected_out_x = expected_arg_x + 1 expected_out_x = expected_arg_x + 1
expected_out_residual = expected_arg_x.to(residual.dtype) expected_out_residual = expected_arg_x.to(residual.dtype)
mock_maybe_chunk_residual.assert_called_once()
mock_rmsnorm.assert_called_once() mock_rmsnorm.assert_called_once()
mock_maybe_wait_prefetch_done.assert_called_once()
assert torch.allclose(out_x, expected_out_x) assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual) assert torch.allclose(out_residual, expected_out_residual)
else: else:
expected_out_x = 2 * dummy_tensor expected_out_x = 2 * dummy_tensor
expected_out_residual = 2 * residual expected_out_residual = 2 * residual
mock_maybe_chunk_residual.assert_called_once()
mock_add_rmsnorm.assert_called_once() mock_add_rmsnorm.assert_called_once()
mock_maybe_wait_prefetch_done.assert_called_once()
assert torch.allclose(out_x, expected_out_x) assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual) assert torch.allclose(out_residual, expected_out_residual)
else: else:
@@ -64,9 +70,11 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
@patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) @patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_chunk_residual", @patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual) side_effect=mock_maybe_chunk_residual)
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual, def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done,
mock_add_rms_norm, mock_is310p): mock_add_rms_norm, mock_is310p):
x = torch.randn(4, 512, dtype=torch.bfloat16) x = torch.randn(4, 512, dtype=torch.bfloat16)
residual = torch.randn(16, 512, dtype=torch.bfloat16) residual = torch.randn(16, 512, dtype=torch.bfloat16)
@@ -79,6 +87,7 @@ def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
mock_maybe_chunk_residual.assert_called_once() mock_maybe_chunk_residual.assert_called_once()
mock_add_rms_norm.assert_called_once() mock_add_rms_norm.assert_called_once()
mock_maybe_wait_prefetch_done.assert_called_once()
assert out_residual.size(0) == 4 assert out_residual.size(0) == 4
assert torch.allclose(out_x, expected_out_x) assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual) assert torch.allclose(out_residual, expected_out_residual)

View File

@@ -275,7 +275,12 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
@patch("torch_npu.npu_add_rms_norm") @patch("torch_npu.npu_add_rms_norm")
@patch("torch_npu.npu_rms_norm") @patch("torch_npu.npu_rms_norm")
def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm, @patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=lambda x, residual: residual)
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done,
mock_rms_norm, mock_add_norm,
mock_distributed, base_config, mock_distributed, base_config,
vllm_config): vllm_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))

View File

@@ -66,7 +66,9 @@ def set_ascend_forward_context(
moe_comm_method: str = "", moe_comm_method: str = "",
num_actual_tokens: Optional[int] = None, num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None): batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
We add some additional param into forward_context. We add some additional param into forward_context.
@@ -108,7 +110,8 @@ def set_ascend_forward_context(
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods. # the performance may degrade due to the switching of communication methods.
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
tp_world_size > 1 and \ tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000 num_tokens is not None and num_tokens > 1000
@@ -122,6 +125,26 @@ def set_ascend_forward_context(
# set this for rope forward_oot using # set this for rope forward_oot using
forward_context.is_first_layer = True forward_context.is_first_layer = True
# set layer_idx to enable optimization features that depend on this information.
# This is only applicable to models that contain these necessary attributes.
forward_context.layer_idx = None
if model_instance is not None and \
hasattr(model_instance, "model") and \
hasattr(model_instance.model, "start_layer"):
forward_context.layer_idx = model_instance.model.start_layer
# set for mlp weight prefetch
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500
if prefetch_mlp_enabled:
forward_context.prefetch_stream = prefetch_stream
forward_context.model_instance = model_instance
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
if num_tokens is None and attn_metadata is not None: if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens num_tokens = attn_metadata.num_actual_tokens

View File

@@ -135,6 +135,15 @@ env_variables: Dict[str, Callable[[], Any]] = {
# This feature will get better performance when concurrency is large. # This feature will get better performance when concurrency is large.
"VLLM_ASCEND_ENABLE_FLASHCOMM": "VLLM_ASCEND_ENABLE_FLASHCOMM":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))), lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
# Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
# buffer size for gate up prefetch
"MLP_GATE_UP_PREFETCH_SIZE":
lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
# buffer size for down proj prefetch
"MLP_DOWN_PREFETCH_SIZE":
lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
# Whether to enable dense model and general optimizations for better performance. # Whether to enable dense model and general optimizations for better performance.
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types. # Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models. # However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.

View File

@@ -35,8 +35,10 @@ class AscendSiluAndMul(SiluAndMul):
from vllm_ascend.utils import is_310p from vllm_ascend.utils import is_310p
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
if is_310p(): if is_310p():
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
else: else:
out = torch_npu.npu_swiglu(x) out = torch_npu.npu_swiglu(x)
torch.ops.vllm.maybe_wait_prefetch_done(out)
return out return out

View File

@@ -44,12 +44,7 @@ class AddRMSNormW8A8Quant(RMSNorm):
import torch_npu import torch_npu
if residual is not None: if residual is not None:
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
# is enabled, without interfering with the normal operation of components like torchair.
# The final solution should be to move this check into the operator and support
# integration with torchair.
if x.size(0) != residual.size(0):
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0) assert x.size(0) == residual.size(0)
x, _, residual = torch_npu.npu_add_rms_norm_quant( x, _, residual = torch_npu.npu_add_rms_norm_quant(
x, x,
@@ -58,6 +53,7 @@ class AddRMSNormW8A8Quant(RMSNorm):
self.layer.aclnn_input_scale, self.layer.aclnn_input_scale,
self.layer.aclnn_input_offset, self.layer.aclnn_input_offset,
epsilon=self.variance_epsilon) epsilon=self.variance_epsilon)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight, x, residual = torch_npu.npu_rms_norm(x, self.weight,
@@ -76,12 +72,7 @@ class AscendRMSNorm(RMSNorm):
from vllm_ascend.utils import is_310p from vllm_ascend.utils import is_310p
if residual is not None: if residual is not None:
# FIXME(rjg-lyh): This is a hacky way to chunk residuals when the flashcomm_v1 feature residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
# is enabled, without interfering with the normal operation of components like torchair.
# The final solution should be to move this check into the operator and support
# integration with torchair.
if x.size(0) != residual.size(0):
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0) assert x.size(0) == residual.size(0)
if is_310p(): if is_310p():
orig_dtype = residual.dtype orig_dtype = residual.dtype
@@ -92,6 +83,7 @@ class AscendRMSNorm(RMSNorm):
else: else:
x, _, residual = torch_npu.npu_add_rms_norm( x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon) x, residual, self.weight, self.variance_epsilon)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight, x, residual = torch_npu.npu_rms_norm(x, self.weight,

View File

@@ -390,6 +390,7 @@ class AscendRowParallelLinear(RowParallelLinear):
input_parallel, input_parallel,
bias=bias_) bias=bias_)
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None

View File

@@ -1,5 +1,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch_npu
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
@@ -8,10 +9,16 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
import vllm_ascend.envs as envs_ascend
def _maybe_chunk_residual_impl(x: torch.Tensor, def _maybe_chunk_residual_impl(x: torch.Tensor,
residual: torch.Tensor) -> torch.Tensor: residual: torch.Tensor) -> torch.Tensor:
if get_forward_context().flashcomm_v1_enabled: if x.size(0) != residual.size(0):
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
assert flashcomm_v1_enabled is True, (
"Currently, this situation only occurs "
"when flashcomm_v1 is enabled")
pad_size = get_forward_context().pad_size pad_size = get_forward_context().pad_size
if pad_size > 0: if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size)) residual = F.pad(residual, (0, 0, 0, pad_size))
@@ -44,6 +51,76 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
return tensor_model_parallel_all_reduce(x) return tensor_model_parallel_all_reduce(x)
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
prefix: str) -> None:
forward_context = get_forward_context()
if not forward_context.prefetch_mlp_enabled:
return
model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream
layer_idx = int(prefix.split('.')[2])
# start point of gate_up_proj weight prefetch
if prefix.split('.')[-2] == "self_attn":
forward_context.prefetch_mlp_gate_up_proj = True
if forward_context.prefetch_mlp_gate_up_proj:
prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream):
MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
x_dependency, MLP_GATE_UP_PREFETCH_SIZE)
return
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
prefix: str) -> None:
return
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
forward_context = get_forward_context()
if not forward_context.prefetch_mlp_enabled:
return
forward_context.prefetch_mlp_down_proj = True
model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream
layer_idx = forward_context.layer_idx
# start point of down_proj weight prefetch
prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream):
MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
x_dependency, MLP_DOWN_PREFETCH_SIZE)
forward_context.layer_idx += 1
return
def _maybe_prefetch_mlp_down_proj_impl_fake(
x_dependency: torch.Tensor) -> None:
return
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
forward_context = get_forward_context()
if not forward_context.prefetch_mlp_enabled:
return
if forward_context.prefetch_mlp_gate_up_proj or \
forward_context.prefetch_mlp_down_proj:
prefetch_stream = get_forward_context().prefetch_stream
# wait until prefetch done
torch.npu.current_stream().wait_stream(prefetch_stream)
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
return
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
return
direct_register_custom_op(op_name="maybe_chunk_residual", direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl, op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: residual, fake_impl=lambda x, residual: residual,
@@ -60,4 +137,22 @@ direct_register_custom_op(op_name="maybe_pad_and_reduce",
op_func=_maybe_pad_and_reduce_impl, op_func=_maybe_pad_and_reduce_impl,
fake_impl=lambda x: x, fake_impl=lambda x: x,
mutates_args=[], mutates_args=[],
dispatch_key="PrivateUse1") dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj",
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj",
op_func=_maybe_prefetch_mlp_down_proj_impl,
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_wait_prefetch_done",
op_func=_maybe_wait_prefetch_done_impl,
fake_impl=_maybe_wait_prefetch_done_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")

View File

@@ -0,0 +1,37 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
"""AscendSiluAndMul forward in torchair mode.
The key difference from the original implementation is the removal of operators
from the torch.ops.vllm class, as these operators only function in non-torchair
modes. Adding them back would cause the graph compilation to fail.
"""
import torch_npu
from vllm_ascend.utils import is_310p
if is_310p():
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
else:
out = torch_npu.npu_swiglu(x)
return out

View File

@@ -0,0 +1,51 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Optional, Tuple, Union
import torch
def torchair_rmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""AscendRMSNorm forward in torchair mode.
The key difference from the original implementation is the removal of operators
from the torch.ops.vllm class, as these operators only function in non-torchair
modes. Adding them back would cause the graph compilation to fail.
"""
import torch_npu
from vllm_ascend.utils import is_310p
if residual is not None:
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
return x

View File

@@ -199,8 +199,12 @@ def torchair_quant_method_register():
def torchair_ops_patch(): def torchair_ops_patch():
from vllm_ascend.ops.activation import AscendSiluAndMul
from vllm_ascend.ops.layernorm import AscendRMSNorm
from vllm_ascend.ops.rotary_embedding import ( from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
from vllm_ascend.torchair.ops import (torchair_activation,
torchair_layernorm)
from vllm_ascend.torchair.ops.torchair_rotary_embedding import ( from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
deepseek_rope_init_func, native_rope_deepseek_forward, deepseek_rope_init_func, native_rope_deepseek_forward,
qwen_rope_init_func, rope_forward) qwen_rope_init_func, rope_forward)
@@ -210,3 +214,6 @@ def torchair_ops_patch():
AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign] AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign]
AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign] AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign]
AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign]

View File

@@ -227,6 +227,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.device = device self.device = device
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
self.prefetch_stream = torch.npu.Stream(device=device)
else:
self.prefetch_stream = None
self.dtype = self.model_config.dtype self.dtype = self.model_config.dtype
if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION: if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
# TODO: drop the env config to use ascend sampler by default # TODO: drop the env config to use ascend sampler by default
@@ -1592,7 +1596,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
num_actual_tokens=scheduler_output. num_actual_tokens=scheduler_output.
total_num_scheduled_tokens): total_num_scheduled_tokens,
prefetch_stream=self.prefetch_stream,
model_instance=self.model):
self.maybe_setup_kv_connector(scheduler_output) self.maybe_setup_kv_connector(scheduler_output)
hidden_states = self._generate_process_reqs_hidden_states( hidden_states = self._generate_process_reqs_hidden_states(
@@ -2057,7 +2063,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
moe_comm_method=moe_comm_method, moe_comm_method=moe_comm_method,
num_actual_tokens=0, num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor): batch_descriptor=batch_descriptor,
prefetch_stream=self.prefetch_stream,
model_instance=self.model):
hidden_states = self._generate_dummy_run_hidden_states( hidden_states = self._generate_dummy_run_hidden_states(
with_prefill, is_torchair_compile, input_ids, positions, with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors, attn_metadata, num_tokens, intermediate_tensors,

View File

@@ -51,6 +51,18 @@ from vllm_ascend.utils import (init_ascend_soc_version,
try_register_lib) try_register_lib)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402
from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402
torch_non_c_binding_in_graph_functions_npu = dict.fromkeys(
["torch.npu.current_stream"],
TorchInGraphFunctionVariable,
) # noqa: E402
torch_non_c_binding_in_graph_functions_npu[
"torch.npu.stream"] = TorchInGraphFunctionVariable # noqa: E402
torch._dynamo.trace_rules.torch_name_rule_map.append(
torch_non_c_binding_in_graph_functions_npu) # noqa: E402
class NPUWorker(WorkerBase): class NPUWorker(WorkerBase):