[Refactor] MLP weight prefetch to consistency with MoE Model's prefetching in terms of code and usage (#6442)

### What this PR does / why we need it?
Refactor MLP weight prefetch to consistency with MoE Model's prefetching
in terms of code and usage.
Environments VLLM_ASCEND_ENABLE_PREFETCH_MLP,
VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE and
VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE is removed, usage as following:

--additional-config '{"weight_prefetch_config": { "enabled": true,
"prefetch_ratio": {"mlp": { "gate_up": 1.0, "down": 1.0} }}}'

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
This commit is contained in:
Nengjun Ma
2026-02-04 09:08:18 +08:00
committed by GitHub
parent fa56abea9f
commit 78fad4e348
18 changed files with 250 additions and 171 deletions

View File

@@ -19,12 +19,16 @@ import torch
import torch.nn.functional as F
from vllm_ascend.ops.activation import AscendSiluAndMul
from vllm_ascend.utils import get_weight_prefetch_method
class AscendSiluAndMul310(AscendSiluAndMul):
def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_DOWN, x)
h = x.shape[-1] // 2
out = F.silu(x[..., :h]) * x[..., h:]
torch.ops.vllm.maybe_wait_prefetch_done(out)
out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to(torch.float16)
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out)
return out

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from typing import TYPE_CHECKING
from vllm.logger import logger
@@ -48,9 +49,7 @@ class AscendConfig:
# Dump / PrecisionDebugger configuration
self.dump_config_path = additional_config.get("dump_config_path", None)
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
self._construct_weight_prefetch_config(additional_config)
self.layer_sharding = additional_config.get("layer_sharding", None)
logger.info_once(
f"Linear layer sharding enabled with config: {self.layer_sharding}. "
@@ -138,6 +137,29 @@ class AscendConfig:
"enable_kv_nz is only supported in pd scenario and can only be used in D node."
)
def _construct_weight_prefetch_config(self, additional_config):
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
# Deprecated env var handling for backward compatibility
if os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0") == "1":
MAX_PREFETCH_WEIGHT_SIZE: int = 18 * 1024 * 1024
gate_up_prefetch_size = int(os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE))
down_prefetch_szie = int(os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", MAX_PREFETCH_WEIGHT_SIZE))
self.weight_prefetch_config.set_mlp_pre_version_compatibale_config(
gate_up_prefetch_size, down_prefetch_szie
)
logger.info_once(
f"MLP weight prefetch enabled from env variable VLLM_ASCEND_ENABLE_PREFETCH_MLP."
f"gate_up_prefetch_size={gate_up_prefetch_size}, "
f"down_prefetch_szie={down_prefetch_szie}."
)
warnings.warn(
"VLLM_ASCEND_ENABLE_PREFETCH_MLP is deprecated and will be removed in a v0.16.0 version. "
"Please use weight_prefetch_config in additional-config for now instead.",
DeprecationWarning,
stacklevel=2,
)
class FinegrainedTPConfig:
"""
@@ -305,18 +327,28 @@ class WeightPrefetchConfig:
Configuration Object for weight_prefetch_config from additional_config
"""
mlp_pre_version_compatibale_config: dict = {}
prefetch_ratio: dict = {
"attn": {
"qkv": 1.0,
"o": 1.0,
},
"moe": {"gate_up": 0.8},
"mlp": {"gate_up": 1, "down": 1.0},
}
def __init__(self, weight_prefetch_config: dict):
self.enabled = weight_prefetch_config.get("enabled", False)
self.prefetch_ratio = weight_prefetch_config.get("prefetch_ratio", self.prefetch_ratio)
def set_mlp_pre_version_compatibale_config(self, gate_up_prefetch_size: int, down_prefetch_size: int):
config = {
"gate_up": gate_up_prefetch_size,
"down": down_prefetch_size,
}
self.mlp_pre_version_compatibale_config = config
class EplbConfig:
"""

View File

@@ -119,18 +119,8 @@ def set_ascend_forward_context(
if has_layer_idx(model_instance):
forward_context.layer_idx = model_instance.model.start_layer
# TODO(rjg-lyh): refactor mlp weight prefetch method
# set for mlp weight prefetch
prefetch_mlp_enabled = (
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP
and forward_context.layer_idx is not None
and num_tokens is not None
and num_tokens < 500
)
if prefetch_mlp_enabled:
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.model_instance = model_instance
forward_context.is_draft_model = is_draft_model

View File

@@ -17,7 +17,7 @@
import torch
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
from vllm_ascend.utils import get_weight_prefetch_method
class AscendQuickGELU(QuickGELU):
@@ -33,7 +33,10 @@ class AscendSiluAndMul(SiluAndMul):
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
import torch_npu
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_DOWN, x)
out = torch_npu.npu_swiglu(x)
torch.ops.vllm.maybe_wait_prefetch_done(out)
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out)
return out

View File

@@ -24,7 +24,7 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm, RMSNormG
from vllm_ascend.ops.triton.layernorm_gated import layer_norm_fwd_npu
from vllm_ascend.utils import enable_custom_op
from vllm_ascend.utils import get_weight_prefetch_method
class AscendRMSNorm(RMSNorm):
@@ -67,6 +67,10 @@ class AscendRMSNorm(RMSNorm):
self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(x)
return x

View File

@@ -65,8 +65,8 @@ from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import (enable_dsa_cp, enable_dsa_cp_with_layer_shard, enable_sp, flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable, shared_expert_dp_enabled)
oproj_tp_enable, shared_expert_dp_enabled,
get_weight_prefetch_method)
class CustomLinearOp:
@@ -138,8 +138,10 @@ class CustomRowParallelOp(CustomLinearOp):
def apply(self, input_):
output, output_bias = self.apply_impl(input_)
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_GATE_UP, output, self.prefix)
if not self.return_bias:
return output
return output, output_bias

View File

@@ -110,33 +110,6 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor,
0)
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
prefix: str) -> None:
try:
forward_context = get_forward_context()
except AssertionError:
return
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
return
model_instance = forward_context.model_instance
weight_prefetch_stream = prefetch_stream()
layer_idx = int(prefix.split('.')[2])
# start point of gate_up_proj weight prefetch
if prefix.split('.')[-2] == "self_attn":
forward_context.prefetch_mlp_gate_up_proj = True
if forward_context.prefetch_mlp_gate_up_proj:
weight_prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(weight_prefetch_stream):
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
x_dependency, mlp_gate_up_prefetch_size)
return
def _maybe_all_gather_and_maybe_unpad_fake(
x: torch.Tensor,
label: bool,
@@ -164,63 +137,6 @@ def _maybe_pad_and_reduce_fake(x: torch.Tensor,
return x
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
prefix: str) -> None:
return
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
try:
forward_context = get_forward_context()
except AssertionError:
return
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
return
forward_context.prefetch_mlp_down_proj = True
model_instance = forward_context.model_instance
weight_prefetch_stream = prefetch_stream()
layer_idx = forward_context.layer_idx
# start point of down_proj weight prefetch
weight_prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(weight_prefetch_stream):
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.down_proj.weight,
x_dependency, mlp_down_prefetch_size)
forward_context.layer_idx += 1
return
def _maybe_prefetch_mlp_down_proj_impl_fake(
x_dependency: torch.Tensor) -> None:
return
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
try:
forward_context = get_forward_context()
except AssertionError:
return
if not getattr(forward_context, 'prefetch_mlp_enabled', False):
return
if forward_context.prefetch_mlp_gate_up_proj or \
forward_context.prefetch_mlp_down_proj:
weight_prefetch_stream = prefetch_stream()
# wait until prefetch done
torch.npu.current_stream().wait_stream(weight_prefetch_stream)
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
return
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
return
def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor,
max_weight_size: int) -> None:
calculation_stream = torch_npu.npu.current_stream()
@@ -331,24 +247,6 @@ direct_register_custom_op(op_name="maybe_pad_and_reduce",
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj",
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj",
op_func=_maybe_prefetch_mlp_down_proj_impl,
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_wait_prefetch_done",
op_func=_maybe_wait_prefetch_done_impl,
fake_impl=_maybe_wait_prefetch_done_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="prefetch_preprocess",
op_func=_prefetch_preprocess_impl,
fake_impl=_prefetch_preprocess_impl_fake,

View File

@@ -2,15 +2,18 @@ from dataclasses import dataclass, field
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.config import get_current_vllm_config
from vllm.logger import logger
from vllm_ascend.ascend_config import WeightPrefetchConfig
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
AscendRowParallelLinear)
from vllm_ascend.utils import is_moe_model
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
MOE_PREFETCH_TOKEN_THRESHOLD = 96
MAX_PREFETCH_WEIGHT_SIZE = 18 * 1024 * 1024
@dataclass
class ModuleWeightPrefetchConfig:
@@ -38,22 +41,37 @@ class WeightPrefetchMethod:
"""
Unified weight prefetch method.
"""
is_moe: bool = True
MLP_GATE_UP: str = "gate_up"
MLP_DOWN: str = "down"
# backward compatibility: delete in future versions
mlp_pre_version_compatibale_config: dict = {}
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
self.is_moe = is_moe_model(get_current_vllm_config())
self.attn = ModuleWeightPrefetchConfig(
module_name="attn",
enable=weight_prefetch_config.enabled,
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
"attn", {}),
"attn", {}) or {'qkv': 1.0, 'o': 1.0},
linear_prefix_map={
AscendQKVParallelLinear.__name__: "qkv",
AscendRowParallelLinear.__name__: "o",
})
self.moe = ModuleWeightPrefetchConfig(
module_name="moe",
enable=weight_prefetch_config.enabled,
enable=weight_prefetch_config.enabled and self.is_moe,
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
"moe", {}))
"moe", {}) or {'gate_up': 0.8})
self.mlp = ModuleWeightPrefetchConfig(
module_name="mlp",
enable=weight_prefetch_config.enabled and not self.is_moe,
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
"mlp", {}) or {'gate_up': 1.0, 'down': 1.0})
self.mlp_pre_version_compatibale_config = weight_prefetch_config.mlp_pre_version_compatibale_config
def maybe_prefetch_attn_weight_preprocess(
self, layer_cls_name: str, weight: torch.Tensor,
@@ -97,6 +115,82 @@ class WeightPrefetchMethod:
torch.ops.vllm.prefetch_postprocess(stop_flag)
# x_dependency only eager mode can pass None
def maybe_prefetch_mlp_weight_preprocess(self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None):
if not self.mlp.enable and not self.mlp_pre_version_compatibale_config:
self.mlp.is_active_this_forward = False
return
try:
forward_context = get_forward_context()
except AssertionError:
return
self.mlp.is_active_this_forward = (
forward_context.layer_idx is not None
and forward_context.num_tokens is not None
and forward_context.num_tokens < 500
)
if not self.mlp.is_active_this_forward:
return
if prefetch_layer_name == self.MLP_GATE_UP:
self._maybe_prefetch_mlp_gate_up_weight_preprocess(x_dependency, forward_context, curr_layer_prefix)
elif prefetch_layer_name == self.MLP_DOWN:
self._maybe_prefetch_mlp_down_weight_preprocess(x_dependency, forward_context)
else:
raise ValueError(f"Unsupported prefetch weight name: {prefetch_layer_name}")
def _maybe_prefetch_mlp_gate_up_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext, curr_layer_prefix: str | None):
if not curr_layer_prefix:
raise ValueError("curr_layer_prefix must been specified when prefetching mlp gate_up_proj weight")
# start point of gate_up_proj weight prefetch
if curr_layer_prefix.split('.')[-2] == "self_attn":
model_instance = forward_context.model_instance
layer_idx = int(curr_layer_prefix.split('.')[2])
weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight
if self.mlp_pre_version_compatibale_config:
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0)
else:
weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0)
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
weight_size = MAX_PREFETCH_WEIGHT_SIZE
torch.ops.vllm.prefetch_preprocess(weight=weight,
start_flag=x_dependency,
max_weight_size=int(weight_size))
forward_context.prefetch_mlp_gate_up_proj = True
def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext):
layer_idx = forward_context.layer_idx
model_instance = forward_context.model_instance
weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight
if self.mlp_pre_version_compatibale_config:
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0)
else:
weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0)
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
weight_size = MAX_PREFETCH_WEIGHT_SIZE
torch.ops.vllm.prefetch_preprocess(weight=weight,
start_flag=x_dependency,
max_weight_size=int(weight_size))
forward_context.prefetch_mlp_down_proj = True
forward_context.layer_idx += 1
def maybe_prefetch_mlp_weight_postprocess(self, stop_flag: torch.Tensor):
if not self.mlp.is_active_this_forward:
return
try:
forward_context = get_forward_context()
except AssertionError:
return
if forward_context.prefetch_mlp_gate_up_proj or \
forward_context.prefetch_mlp_down_proj:
torch.ops.vllm.prefetch_postprocess(stop_flag)
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
def maybe_npu_prefetch(inputs: torch.Tensor,
dependency: torch.Tensor,