[Refactor] MLP weight prefetch to consistency with MoE Model's prefetching in terms of code and usage (#6442)

### What this PR does / why we need it?
Refactor MLP weight prefetch to consistency with MoE Model's prefetching
in terms of code and usage.
Environments VLLM_ASCEND_ENABLE_PREFETCH_MLP,
VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE and
VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE is removed, usage as following:

--additional-config '{"weight_prefetch_config": { "enabled": true,
"prefetch_ratio": {"mlp": { "gate_up": 1.0, "down": 1.0} }}}'

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
This commit is contained in:
Nengjun Ma
2026-02-04 09:08:18 +08:00
committed by GitHub
parent fa56abea9f
commit 78fad4e348
18 changed files with 250 additions and 171 deletions

View File

@@ -171,9 +171,6 @@ export TASK_QUEUE_ENABLE=1
# Enable the AIVector core to directly schedule ROCE communication
export HCCL_OP_EXPANSION_MODE="AIV"
# Enable MLP prefetch for better performance.
export VLLM_ASCEND_ENABLE_PREFETCH_MLP=1
# Enable FlashComm_v1 optimization when tensor parallel is enabled.
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1
@@ -187,7 +184,7 @@ vllm serve /model/Qwen3-32B-W8A8 \
--max-model-len 5500 \
--max-num-batched-tokens 40960 \
--compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \
--additional-config '{"pa_shape_list":[48,64,72,80]}' \
--additional-config '{"pa_shape_list":[48,64,72,80], "weight_prefetch_config":{"enabled":true}}' \
--port 8113 \
--block-size 128 \
--gpu-memory-utilization 0.9
@@ -348,9 +345,7 @@ Weight prefetching optimizes memory usage by preloading weights into the cache b
In dense model scenarios, the MLP's gate_up_proj and down_proj linear layers often exhibit relatively high MTE utilization. To address this, we create a separate pipeline specifically for weight prefetching, which runs in parallel with the original vector computation pipeline, such as RMSNorm and SiLU, before the MLP. This approach allows the weights to be preloaded to L2 cache ahead of time, reducing MTE utilization during the MLP computations and indirectly improving Cube computation efficiency by minimizing resource contention and optimizing data flow.
It is important to emphasize that, since we use vector computations to hide the weight prefetching pipeline, the setting of the prefetch buffer size is crucial. If the buffer size is too small, the optimization benefits will not be fully realized, while a larger buffer size may lead to resource contention, resulting in performance degradation. To accommodate different scenarios, we have exposed two environment variables `VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE` and `VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE` to allow for flexible buffer size configuration based on the specific workload.
This optimization requires setting the environment variable `VLLM_ASCEND_ENABLE_PREFETCH_MLP = 1` to be enabled.
Previously, the environment variables VLLM_ASCEND_ENABLE_PREFETCH_MLP used to enable MLP weight prefetch and VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE and VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE used to set the weight prefetch size for MLP gate_up_proj and down_proj were deprecated. Please use the following configuration instead: "weight_prefetch_config": { "enabled": true, "prefetch_ratio": { "mlp": { "gate_up": 1.0, "down": 1.0}}}. See User Guide->Feature Guide->Weight Prefetch Guide for details.
### 6. Zerolike Elimination

View File

@@ -60,7 +60,7 @@ The details of each configuration option are as follows:
| Name | Type | Default | Description |
|------------------|------|-------------------------------------------------------------|------------------------------------|
| `enabled` | bool | `False` | Whether to enable weight prefetch. |
| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}}` | Prefetch ratio of each weight. |
| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}, "mlp": { "gate_up": 1.0, "down": 1.0}}` | Prefetch ratio of each weight. |
**finegrained_tp_config**
@@ -115,6 +115,10 @@ An example of additional configuration is as follows:
},
"moe": {
"gate_up": 0.8
},
"mlp": {
"gate_up": 1.0,
"down": 1.0
}
},
},

View File

@@ -23,4 +23,5 @@ layer_sharding
speculative_decoding
context_parallel
npugraph_ex
weight_prefetch
:::

View File

@@ -0,0 +1,73 @@
# Weight Prefetch Guide
Weight prefetching optimizes memory usage by preloading weights into the cache before they are needed, minimizing delays caused by memory access during model execution. Linear layers sometimes exhibit relatively high MTE utilization. To address this, we create a separate pipeline specifically for weight prefetching, which runs in parallel with the original vector computation pipeline, such as quantize, MoE gating top_k, RMSNorm and SwiGlu. This approach allows the weights to be preloaded to L2 cache ahead of time, reducing MTE utilization during the linear layer computations and indirectly improving Cube computation efficiency by minimizing resource contention and optimizing data flow.
Since we use vector computations to hide the weight prefetching pipeline, it has effect on computation, if you prioritize low latency over high throughput, then it it best not to enable prefetching.
## Quick Start
With `--additional-config '{"weight_prefetch_config": {"enabled": true}}'` to open weight prefetch.
## Fine-tune Prefetch Ratio
Since weight prefetch use vector computations to hide the weight prefetching pipeline, the setting of the prefetch size is crucial. If the size is too small, the optimization benefits will not be fully realized, while a larger size may lead to resource contention, resulting in performance degradation. To accommodate different scenarios, we have add `prefetch_ratio` to allow for flexible size configuration based on the specific workload, detail as following:
With `prefetch_ratio` in `"weight_prefetch_config"` to custom the weight prefetch ratio for specify linear layers.
The “attn” and “moe” configuration options are used for MoE model, detail as following:
`"attn": { "qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}`
The “mlp” configuration option is used to optimize the performance of the Dense model, detail as following:
`"mlp": {"gate_up": 1.0, "down": 1.0}`
Above value are the default config, the default value has a good performance for Qwen3-235B-A22B-W8A8 when `--max-num-seqs`is 144, for Qwen3-32B-W8A8 when `--max-num-seqs`is 72.
However, this may not be the optimal configuration for your scenario. For higher concurrency, you can try increasing the prefetch size. For lower concurrency, prefetching may not offer any advantages, so you can decrease the size or disable prefetching. Determine if the prefetch size is appropriate by collecting profiling data. Specifically, check if the time required for the prefetch operation (e.g., MLP Down Proj weight prefetching) overlaps with the time required for parallel vector computation operators (e.g., SwiGlu computation), and whether the prefetch operation is no later than the completion time of the vector computation operator. In the profiling timeline, a prefetch operation appears as a CMO operation on a single stream; this CMO operation is the prefetch operation.
Notices:
1) Weight prefetch of MLP `down` project prefetch dependence sequence parallel, if you want open for mlp `down` please also enable sequence parallel.
2) Due to the current size of the L2 cache, the maximum prefetch cannot exceed 18MB. If `prefetch_ration * lineaer_layer_weight_size >= 18 * 1024 * 1024` bytes, the backend will only prefetch 18MB.
## Example
1) For MoE model:
```shell
--additional-config \
'{
"weight_prefetch_config": {
"enabled": true,
"prefetch_ratio": {
"attn": {
"qkv": 1.0,
"o": 1.0
},
"moe": {
"gate_up": 0.8
}
}
}
}'
```
2) For dense model:
Following is the default configuration that can get a good performance for `--max-num-seqs`is 72 for Qwen3-32B-W8A8
```shell
--additional-config \
'{
"weight_prefetch_config": {
"enabled": true,
"prefetch_ratio": {
"mlp": {
"gate_up": 1.0,
"down": 1.0
}
}
}
}'
```

View File

@@ -222,7 +222,7 @@ def test_qwen3_dense_fc1_tp2(model):
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
def test_qwen3_dense_prefetch_mlp_weight_tp2(model):
example_prompts = [
"Hello, my name is",
@@ -236,6 +236,7 @@ def test_qwen3_dense_prefetch_mlp_weight_tp2(model):
tensor_parallel_size=2,
cudagraph_capture_sizes=[1, 2, 4, 8],
quantization="ascend",
additional_config={"weight_prefetch_config": {"enabled": True}},
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -57,7 +57,6 @@ async def test_models(model: str) -> None:
env_dict = {
"TASK_QUEUE_ENABLE": "1",
"HCCL_OP_EXPANSION_MODE": "AIV",
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1",
}
server_args = [
"--async-scheduling",
@@ -74,7 +73,7 @@ async def test_models(model: str) -> None:
"--compilation-config",
'{"cudagraph_mode": "FULL_DECODE_ONLY"}',
"--additional-config",
'{"pa_shape_list":[48,64,72,80]}',
'{"pa_shape_list":[48,64,72,80],"weight_prefetch_config":{"enabled":true}}',
"--block-size",
"128",
"--trust-remote-code",

View File

@@ -83,7 +83,6 @@ async def test_models(model: str, mode: str, tp_size: int) -> None:
"TASK_QUEUE_ENABLE": "1",
"HCCL_OP_EXPANSION_MODE": "AIV",
"VLLM_ASCEND_ENABLE_FLASHCOMM": "1",
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"
}
compilation_config = {
"cudagraph_mode":
@@ -98,7 +97,8 @@ async def test_models(model: str, mode: str, tp_size: int) -> None:
str(port), "--max-model-len", "40960", "--max-num-batched-tokens",
"40960", "--block-size", "128", "--trust-remote-code",
"--reasoning-parser", "qwen3", "--gpu-memory-utilization", "0.9",
"--async-scheduling"
"--async-scheduling", "--additional-config",
'{"weight_prefetch_config":{"enabled":true}}',
]
if mode == "single":
server_args.append("--enforce-eager")

View File

@@ -72,7 +72,6 @@ async def test_models(model: str, tp_size: int) -> None:
"OMP_PROC_BIND": "false",
"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1",
"VLLM_ASCEND_ENABLE_FLASHCOMM": "1",
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"
}
server_args = [
"--quantization", "ascend", "--tensor-parallel-size",
@@ -82,7 +81,8 @@ async def test_models(model: str, tp_size: int) -> None:
"0.9", "--block-size", "128", "--max-num-seqs", "256",
"--enforce-eager", "--max-model-len", "35840",
"--max-num-batched-tokens", "35840", "--additional-config",
'{"enable_weight_nz_layout":true}', "--compilation-config",
'{"enable_weight_nz_layout":true, "weight_prefetch_config":{"enabled": true}}',
"--compilation-config",
'{"cudagraph_mode":"FULL_DECODE_ONLY", "cudagraph_capture_sizes":[1,8,24,48,60]}'
]
with RemoteOpenAIServer(model,

View File

@@ -75,8 +75,7 @@ async def test_models(model: str, mode: str, tp_size: int) -> None:
"OMP_PROC_BIND": "false",
"HCCL_OP_EXPANSION_MODE": "AIV",
"VLLM_ASCEND_ENABLE_FLASHCOMM": "1",
"VLLM_ASCEND_ENABLE_DEBSE_OPTIMIZE": "1",
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"
"VLLM_ASCEND_ENABLE_DEBSE_OPTIMIZE": "1"
}
server_args = [
"--tensor-parallel-size",
@@ -86,7 +85,7 @@ async def test_models(model: str, mode: str, tp_size: int) -> None:
"--gpu-memory-utilization", "0.9", "--compilation_config",
'{"cudagraph_mode":"FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1, 8, 24, 48, 60]}',
"--reasoning-parser", "deepseek_r1", "--distributed_executor_backend",
"mp"
"mp", "--additional-config", '{"weight_prefetch_config":{"enabled":true}}'
]
if mode == "single":
server_args.remove("--compilation_config")

View File

@@ -54,11 +54,7 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor, default_vllm_config):
@pytest.mark.skipif(is_310p_hw(), reason="non_310P device unittest case.")
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj", side_effect=lambda x: None)
def test_SiluAndMul_forward(
mock_maybe_prefetch_mlp_down_proj,
mock_maybe_wait_prefetch_done,
mock_swiglu,
dummy_tensor,
default_vllm_config,
@@ -67,15 +63,9 @@ def test_SiluAndMul_forward(
out = layer.forward(dummy_tensor)
expected_arg = dummy_tensor
# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
mock_maybe_prefetch_mlp_down_proj.assert_called_once()
# assert mock_swiglu.call_count == 1
mock_swiglu.assert_called_once()
# assert mock_maybe_wait_prefetch_done.call_count == 1
mock_maybe_wait_prefetch_done.assert_called_once()
actual_arg = mock_swiglu.call_args[0][0]
assert torch.allclose(actual_arg, expected_arg), "npu_swiglu called with unexpected input"
@@ -85,11 +75,7 @@ def test_SiluAndMul_forward(
@pytest.mark.skipif(not is_310p_hw(), reason="310P device unittest case.")
@patch("torch.nn.functional.silu", side_effect=lambda x: x + 1)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj", side_effect=lambda x: None)
def test_SiluAndMul_forward_310p(
mock_maybe_prefetch_mlp_down_proj,
mock_maybe_wait_prefetch_done,
mock_silu,
dummy_tensor,
default_vllm_config,
@@ -99,15 +85,9 @@ def test_SiluAndMul_forward_310p(
h = dummy_tensor.shape[-1] // 2
expected_arg = dummy_tensor[..., :h]
# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
mock_maybe_prefetch_mlp_down_proj.assert_called_once()
# assert mock_silu.call_count == 1
mock_silu.assert_called_once()
# assert mock_maybe_wait_prefetch_done.call_count == 1
mock_maybe_wait_prefetch_done.assert_called_once()
actual_arg = mock_silu.call_args[0][0]
assert torch.allclose(actual_arg, expected_arg), "swiglu called with unexpected input"

View File

@@ -19,12 +19,16 @@ import torch
import torch.nn.functional as F
from vllm_ascend.ops.activation import AscendSiluAndMul
from vllm_ascend.utils import get_weight_prefetch_method
class AscendSiluAndMul310(AscendSiluAndMul):
def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_DOWN, x)
h = x.shape[-1] // 2
out = F.silu(x[..., :h]) * x[..., h:]
torch.ops.vllm.maybe_wait_prefetch_done(out)
out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to(torch.float16)
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out)
return out

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from typing import TYPE_CHECKING
from vllm.logger import logger
@@ -48,9 +49,7 @@ class AscendConfig:
# Dump / PrecisionDebugger configuration
self.dump_config_path = additional_config.get("dump_config_path", None)
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
self._construct_weight_prefetch_config(additional_config)
self.layer_sharding = additional_config.get("layer_sharding", None)
logger.info_once(
f"Linear layer sharding enabled with config: {self.layer_sharding}. "
@@ -138,6 +137,29 @@ class AscendConfig:
"enable_kv_nz is only supported in pd scenario and can only be used in D node."
)
def _construct_weight_prefetch_config(self, additional_config):
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
# Deprecated env var handling for backward compatibility
if os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0") == "1":
MAX_PREFETCH_WEIGHT_SIZE: int = 18 * 1024 * 1024
gate_up_prefetch_size = int(os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE))
down_prefetch_szie = int(os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE))
self.weight_prefetch_config.set_mlp_pre_version_compatibale_config(
gate_up_prefetch_size, down_prefetch_szie
)
logger.info_once(
f"MLP weight prefetch enabled from env variable VLLM_ASCEND_ENABLE_PREFETCH_MLP."
f"gate_up_prefetch_size={gate_up_prefetch_size}, "
f"down_prefetch_szie={down_prefetch_szie}."
)
warnings.warn(
"VLLM_ASCEND_ENABLE_PREFETCH_MLP is deprecated and will be removed in a v0.16.0 version. "
"Please use weight_prefetch_config in additional-config for now instead.",
DeprecationWarning,
stacklevel=2,
)
class FinegrainedTPConfig:
"""
@@ -305,18 +327,28 @@ class WeightPrefetchConfig:
Configuration Object for weight_prefetch_config from additional_config
"""
mlp_pre_version_compatibale_config: dict = {}
prefetch_ratio: dict = {
"attn": {
"qkv": 1.0,
"o": 1.0,
},
"moe": {"gate_up": 0.8},
"mlp": {"gate_up": 1, "down": 1.0},
}
def __init__(self, weight_prefetch_config: dict):
self.enabled = weight_prefetch_config.get("enabled", False)
self.prefetch_ratio = weight_prefetch_config.get("prefetch_ratio", self.prefetch_ratio)
def set_mlp_pre_version_compatibale_config(self, gate_up_prefetch_size: int, down_prefetch_size: int):
config = {
"gate_up": gate_up_prefetch_size,
"down": down_prefetch_size,
}
self.mlp_pre_version_compatibale_config = config
class EplbConfig:
"""

View File

@@ -119,18 +119,8 @@ def set_ascend_forward_context(
if has_layer_idx(model_instance):
forward_context.layer_idx = model_instance.model.start_layer
# TODO(rjg-lyh): refactor mlp weight prefetch method
# set for mlp weight prefetch
prefetch_mlp_enabled = (
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP
and forward_context.layer_idx is not None
and num_tokens is not None
and num_tokens < 500
)
if prefetch_mlp_enabled:
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.model_instance = model_instance
forward_context.is_draft_model = is_draft_model

View File

@@ -17,7 +17,7 @@
import torch
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
from vllm_ascend.utils import get_weight_prefetch_method
class AscendQuickGELU(QuickGELU):
@@ -33,7 +33,10 @@ class AscendSiluAndMul(SiluAndMul):
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
import torch_npu
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_DOWN, x)
out = torch_npu.npu_swiglu(x)
torch.ops.vllm.maybe_wait_prefetch_done(out)
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out)
return out

View File

@@ -24,7 +24,7 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm, RMSNormG
from vllm_ascend.ops.triton.layernorm_gated import layer_norm_fwd_npu
from vllm_ascend.utils import enable_custom_op
from vllm_ascend.utils import get_weight_prefetch_method
class AscendRMSNorm(RMSNorm):
@@ -67,6 +67,10 @@ class AscendRMSNorm(RMSNorm):
self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(x)
return x

View File

@@ -65,8 +65,8 @@ from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import (enable_dsa_cp, enable_dsa_cp_with_layer_shard, enable_sp, flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable, shared_expert_dp_enabled)
oproj_tp_enable, shared_expert_dp_enabled,
get_weight_prefetch_method)
class CustomLinearOp:
@@ -138,8 +138,10 @@ class CustomRowParallelOp(CustomLinearOp):
def apply(self, input_):
output, output_bias = self.apply_impl(input_)
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_GATE_UP, output, self.prefix)
if not self.return_bias:
return output
return output, output_bias

View File

@@ -110,33 +110,6 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor,
0)
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
prefix: str) -> None:
try:
forward_context = get_forward_context()
except AssertionError:
return
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
return
model_instance = forward_context.model_instance
weight_prefetch_stream = prefetch_stream()
layer_idx = int(prefix.split('.')[2])
# start point of gate_up_proj weight prefetch
if prefix.split('.')[-2] == "self_attn":
forward_context.prefetch_mlp_gate_up_proj = True
if forward_context.prefetch_mlp_gate_up_proj:
weight_prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(weight_prefetch_stream):
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
x_dependency, mlp_gate_up_prefetch_size)
return
def _maybe_all_gather_and_maybe_unpad_fake(
x: torch.Tensor,
label: bool,
@@ -164,63 +137,6 @@ def _maybe_pad_and_reduce_fake(x: torch.Tensor,
return x
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
prefix: str) -> None:
return
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
try:
forward_context = get_forward_context()
except AssertionError:
return
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
return
forward_context.prefetch_mlp_down_proj = True
model_instance = forward_context.model_instance
weight_prefetch_stream = prefetch_stream()
layer_idx = forward_context.layer_idx
# start point of down_proj weight prefetch
weight_prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(weight_prefetch_stream):
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.down_proj.weight,
x_dependency, mlp_down_prefetch_size)
forward_context.layer_idx += 1
return
def _maybe_prefetch_mlp_down_proj_impl_fake(
x_dependency: torch.Tensor) -> None:
return
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
try:
forward_context = get_forward_context()
except AssertionError:
return
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
return
if forward_context.prefetch_mlp_gate_up_proj or \
forward_context.prefetch_mlp_down_proj:
weight_prefetch_stream = prefetch_stream()
# wait until prefetch done
torch.npu.current_stream().wait_stream(weight_prefetch_stream)
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
return
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
return
def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor,
max_weight_size: int) -> None:
calculation_stream = torch_npu.npu.current_stream()
@@ -331,24 +247,6 @@ direct_register_custom_op(op_name="maybe_pad_and_reduce",
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj",
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj",
op_func=_maybe_prefetch_mlp_down_proj_impl,
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_wait_prefetch_done",
op_func=_maybe_wait_prefetch_done_impl,
fake_impl=_maybe_wait_prefetch_done_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="prefetch_preprocess",
op_func=_prefetch_preprocess_impl,
fake_impl=_prefetch_preprocess_impl_fake,

View File

@@ -2,15 +2,18 @@ from dataclasses import dataclass, field
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.config import get_current_vllm_config
from vllm.logger import logger
from vllm_ascend.ascend_config import WeightPrefetchConfig
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
AscendRowParallelLinear)
from vllm_ascend.utils import is_moe_model
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
MOE_PREFETCH_TOKEN_THRESHOLD = 96
MAX_PREFETCH_WEIGHT_SIZE = 18 * 1024 * 1024
@dataclass
class ModuleWeightPrefetchConfig:
@@ -38,22 +41,37 @@ class WeightPrefetchMethod:
"""
Unified weight prefetch method.
"""
is_moe: bool = True
MLP_GATE_UP: str = "gate_up"
MLP_DOWN: str = "down"
# backward compatibility: delete in future versions
mlp_pre_version_compatibale_config: dict = {}
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
self.is_moe = is_moe_model(get_current_vllm_config())
self.attn = ModuleWeightPrefetchConfig(
module_name="attn",
enable=weight_prefetch_config.enabled,
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
"attn", {}),
"attn", {}) or {'qkv': 1.0, 'o': 1.0},
linear_prefix_map={
AscendQKVParallelLinear.__name__: "qkv",
AscendRowParallelLinear.__name__: "o",
})
self.moe = ModuleWeightPrefetchConfig(
module_name="moe",
enable=weight_prefetch_config.enabled,
enable=weight_prefetch_config.enabled and self.is_moe,
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
"moe", {}))
"moe", {}) or {'gate_up': 0.8})
self.mlp = ModuleWeightPrefetchConfig(
module_name="mlp",
enable=weight_prefetch_config.enabled and not self.is_moe,
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
"mlp", {}) or {'gate_up': 1.0, 'down': 1.0})
self.mlp_pre_version_compatibale_config = weight_prefetch_config.mlp_pre_version_compatibale_config
def maybe_prefetch_attn_weight_preprocess(
self, layer_cls_name: str, weight: torch.Tensor,
@@ -97,6 +115,82 @@ class WeightPrefetchMethod:
torch.ops.vllm.prefetch_postprocess(stop_flag)
# x_dependency only eager mode can pass None
def maybe_prefetch_mlp_weight_preprocess(self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None):
if not self.mlp.enable and not self.mlp_pre_version_compatibale_config:
self.mlp.is_active_this_forward = False
return
try:
forward_context = get_forward_context()
except AssertionError:
return
self.mlp.is_active_this_forward = (
forward_context.layer_idx is not None
and forward_context.num_tokens is not None
and forward_context.num_tokens < 500
)
if not self.mlp.is_active_this_forward:
return
if prefetch_layer_name == self.MLP_GATE_UP:
self._maybe_prefetch_mlp_gate_up_weight_preprocess(x_dependency, forward_context, curr_layer_prefix)
elif prefetch_layer_name == self.MLP_DOWN:
self._maybe_prefetch_mlp_down_weight_preprocess(x_dependency, forward_context)
else:
raise ValueError(f"Unsupported prefetch weight name: {prefetch_layer_name}")
def _maybe_prefetch_mlp_gate_up_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext, curr_layer_prefix: str | None):
if not curr_layer_prefix:
raise ValueError("curr_layer_prefix must been specified when prefetching mlp gate_up_proj weight")
# start point of gate_up_proj weight prefetch
if curr_layer_prefix.split('.')[-2] == "self_attn":
model_instance = forward_context.model_instance
layer_idx = int(curr_layer_prefix.split('.')[2])
weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight
if self.mlp_pre_version_compatibale_config:
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0)
else:
weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0)
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
weight_size = MAX_PREFETCH_WEIGHT_SIZE
torch.ops.vllm.prefetch_preprocess(weight=weight,
start_flag=x_dependency,
max_weight_size=int(weight_size))
forward_context.prefetch_mlp_gate_up_proj = True
def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext):
layer_idx = forward_context.layer_idx
model_instance = forward_context.model_instance
weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight
if self.mlp_pre_version_compatibale_config:
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0)
else:
weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0)
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
weight_size = MAX_PREFETCH_WEIGHT_SIZE
torch.ops.vllm.prefetch_preprocess(weight=weight,
start_flag=x_dependency,
max_weight_size=int(weight_size))
forward_context.prefetch_mlp_down_proj = True
forward_context.layer_idx += 1
def maybe_prefetch_mlp_weight_postprocess(self, stop_flag: torch.Tensor):
if not self.mlp.is_active_this_forward:
return
try:
forward_context = get_forward_context()
except AssertionError:
return
if forward_context.prefetch_mlp_gate_up_proj or \
forward_context.prefetch_mlp_down_proj:
torch.ops.vllm.prefetch_postprocess(stop_flag)
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
def maybe_npu_prefetch(inputs: torch.Tensor,
dependency: torch.Tensor,