diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index 6685bfe..fa2e4fc 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -73,10 +73,10 @@ ascend_scheduler_config also support the options from [vllm scheduler config](ht **weight_prefetch_config** -| Name | Type | Default | Description | -|------------------|------|------------------------------------|------------------------------------| -| `enabled` | bool | `False` | Whether to enable weight prefetch. | -| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}}` | Prefetch ratio of each weights. | +| Name | Type | Default | Description | +|------------------|------|-------------------------------------------------------------|------------------------------------| +| `enabled` | bool | `False` | Whether to enable weight prefetch. | +| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}}` | Prefetch ratio of each weights. | ### Example @@ -104,6 +104,9 @@ An example of additional configuration is as follows: "qkv": 1.0, "o": 1.0, }, + "moe": { + "gate_up": 0.8 + } }, }, "multistream_overlap_shared_expert": True, diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index c6da287..5e9d9c3 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -291,7 +291,9 @@ def test_select_experts( custom_routing_function.return_value = (mock_weights, mock_ids) with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk" - ) as mock_native_grouped_topk: + ) as mock_native_grouped_topk, \ + patch('vllm_ascend.ops.moe.experts_selector.get_forward_context', + return_value=MagicMock(weight_prefetch_method=MagicMock())): mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( x) @@ -325,7 +327,9 @@ def test_select_experts( @pytest.mark.parametrize("device", DEVICE) def test_select_experts_invalid_scoring_func(device: str): - with pytest.raises(ValueError, + with patch('vllm_ascend.ops.moe.experts_selector.get_forward_context', + return_value=MagicMock(weight_prefetch_method=MagicMock())), \ + pytest.raises(ValueError, match="Unsupported scoring function: invalid"): select_experts(hidden_states=torch.randn(1, 128, device=device), router_logits=torch.randn(1, 8, device=device), diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 870b2be..eba948f 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -92,14 +92,16 @@ def mock_dist_env(mocker: MockerFixture): mock_moe_comm_method.finalize.side_effect = mock_finalize dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5]) - mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method, - moe_comm_type=MoECommType.MC2, - max_tokens_across_dp=10, - dp_metadata=dp_metadata, - mc2_mask=torch.zeros( - 16, dtype=torch.bool), - padded_num_tokens=16, - with_quant=False) + mock_weight_prefetch_method = MagicMock() + mock_forward_context_obj = MagicMock( + moe_comm_method=mock_moe_comm_method, + moe_comm_type=MoECommType.MC2, + max_tokens_across_dp=10, + dp_metadata=dp_metadata, + mc2_mask=torch.zeros(16, dtype=torch.bool), + padded_num_tokens=16, + with_quant=False, + weight_prefetch_method=mock_weight_prefetch_method) with patch('torch.distributed.get_rank', return_value=0), \ patch('torch.distributed.get_world_size', return_value=4), \ @@ -132,7 +134,9 @@ def mock_dist_env(mocker: MockerFixture): patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher', return_value=None), \ patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher', - return_value=None): + return_value=None), \ + patch('vllm_ascend.ops.moe.experts_selector.get_forward_context', + return_value=mock_forward_context_obj): yield { 'mock_forward_context_obj': mock_forward_context_obj, diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 98dd8f4..3bccb1e 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -755,6 +755,14 @@ class TestSelectExperts(TestBase): self.hidden_states = torch.randn(self.num_tokens, self.hidden_size) self.router_logits = torch.randn(self.num_tokens, self.num_experts) + self.mock_ctx = MagicMock() + self.mock_ctx.weight_prefetch_method = MagicMock() + patcher = patch( + 'vllm_ascend.ops.moe.experts_selector.get_forward_context', + return_value=self.mock_ctx) + self.addCleanup(patcher.stop) + patcher.start() + @patch('torch_npu.npu_moe_gating_top_k_softmax') def test_softmax_scoring(self, mock_topk): """Test softmax scoring function""" diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 9efb37a..1e68558 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -216,6 +216,9 @@ class WeightPrefetchConfig: "qkv": 1.0, "o": 1.0, }, + "moe": { + "gate_up": 0.8 + } } def __init__(self, weight_prefetch_config: dict): diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 93633ae..ade8268 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -145,7 +145,7 @@ def set_ascend_forward_context( forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_down_proj = False forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled - # TODO(yuzhup): integrate moe weight prefetch method + forward_context.model_instance = model_instance forward_context.weight_prefetch_method = weight_prefetch_method # TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant. diff --git a/vllm_ascend/ops/moe/experts_selector.py b/vllm_ascend/ops/moe/experts_selector.py index eace164..6236113 100644 --- a/vllm_ascend/ops/moe/experts_selector.py +++ b/vllm_ascend/ops/moe/experts_selector.py @@ -18,6 +18,7 @@ from typing import Callable, Optional import torch import torch_npu +from vllm.forward_context import get_forward_context def return_row_idx(hidden_states, top_k): @@ -65,7 +66,11 @@ def select_experts(hidden_states: torch.Tensor, topk_weights: router weights of shape (num_tokens, top_k). topk_ids: selected expert IDs of shape (num_tokens, top_k). """ - + # prefetch w1_w3_proj.weight preprocess + weight_prefetch_method = get_forward_context().weight_prefetch_method + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( + hidden_states, "gate_up") topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops( hidden_states=hidden_states, router_logits=router_logits, diff --git a/vllm_ascend/ops/moe/moe_mlp.py b/vllm_ascend/ops/moe/moe_mlp.py index 6526e56..05a1a2e 100644 --- a/vllm_ascend/ops/moe/moe_mlp.py +++ b/vllm_ascend/ops/moe/moe_mlp.py @@ -78,6 +78,10 @@ def quant_apply_mlp(hidden_states: torch.Tensor, bias1, bias2 = None, None _output_dtype = w2_scale.dtype + weight_prefetch_method = get_forward_context().weight_prefetch_method + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_moe_weight_postprocess( + hidden_states) is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 if w1_scale_bias is None and is_mc2: if fusion and not dynamic_eplb: diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index 080d56b..36b3a18 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -1,83 +1,112 @@ -from dataclasses import dataclass, field - -import torch -import torch_npu - -from vllm_ascend.ascend_config import WeightPrefetchConfig -from vllm_ascend.ops.linear import (AscendQKVParallelLinear, - AscendRowParallelLinear) - -SUPPORTED_MODULES = ["attn", "mlp", "moe"] - - -@dataclass -class ModuleWeightPrefetchConfig: - module_name: str - enable: bool = False - prefetch_ratio: dict = field(default_factory=dict) - linear_prefix_map: dict = field(default_factory=dict) - - def __post_init__(self) -> None: - self.prefetch_ratio = { - prefix: ratio - for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1 - } - - assert self.module_name in SUPPORTED_MODULES, ( - f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}" - ) - - if self.module_name in SUPPORTED_MODULES: - self.enable = self.enable and any(self.prefetch_ratio.values()) > 0 - - -class WeightPrefetchMethod: - """ - Unified weight prefetch method. - """ - - def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: - self.attn = ModuleWeightPrefetchConfig( - module_name="attn", - enable=weight_prefetch_config.enabled, - prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( - "attn", {}), - linear_prefix_map={ - AscendQKVParallelLinear.__name__: "qkv", - AscendRowParallelLinear.__name__: "o", - }) - - def maybe_prefetch_attn_weight_preprocess( - self, layer_cls_name: str, weight: torch.Tensor, - start_flag: torch.Tensor) -> None: - if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: - return - - prefix = self.attn.linear_prefix_map.get(layer_cls_name, "") - weight_size = weight.data.element_size() * weight.data.numel( - ) * self.attn.prefetch_ratio.get(prefix, 0) - - torch.ops.vllm.prefetch_preprocess(weight=weight, - start_flag=start_flag, - max_weight_size=int(weight_size)) - - def maybe_prefetch_attn_weight_postprocess( - self, layer_cls_name: str, stop_flag: torch.Tensor) -> None: - if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: - return - - torch.ops.vllm.prefetch_postprocess(stop_flag) - - -def maybe_npu_prefetch(inputs: torch.Tensor, - dependency: torch.Tensor, - max_size: int = 0, - offset: int = 0, - *, - enabled: bool = True) -> None: - if not enabled: - return - input_size = inputs.element_size() * inputs.numel() - if max_size <= 0 or max_size > input_size: - max_size = input_size - torch_npu.npu_prefetch(inputs, dependency, max_size, offset) +from dataclasses import dataclass, field + +import torch +import torch_npu +from vllm.forward_context import get_forward_context + +from vllm_ascend.ascend_config import WeightPrefetchConfig +from vllm_ascend.ops.linear import (AscendQKVParallelLinear, + AscendRowParallelLinear) + +SUPPORTED_MODULES = ["attn", "mlp", "moe"] +MOE_PREFETCH_TOKEN_THRESHOLD = 96 + + +@dataclass +class ModuleWeightPrefetchConfig: + module_name: str + enable: bool = False + is_active_this_forward: bool = False + prefetch_ratio: dict = field(default_factory=dict) + linear_prefix_map: dict = field(default_factory=dict) + + def __post_init__(self) -> None: + self.prefetch_ratio = { + prefix: ratio + for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1 + } + + assert self.module_name in SUPPORTED_MODULES, ( + f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}" + ) + + if self.module_name in SUPPORTED_MODULES: + self.enable = self.enable and any(self.prefetch_ratio.values()) > 0 + + +class WeightPrefetchMethod: + """ + Unified weight prefetch method. + """ + + def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: + self.attn = ModuleWeightPrefetchConfig( + module_name="attn", + enable=weight_prefetch_config.enabled, + prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( + "attn", {}), + linear_prefix_map={ + AscendQKVParallelLinear.__name__: "qkv", + AscendRowParallelLinear.__name__: "o", + }) + self.moe = ModuleWeightPrefetchConfig( + module_name="moe", + enable=weight_prefetch_config.enabled, + prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( + "moe", {})) + + def maybe_prefetch_attn_weight_preprocess( + self, layer_cls_name: str, weight: torch.Tensor, + start_flag: torch.Tensor) -> None: + if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: + return + + prefix = self.attn.linear_prefix_map.get(layer_cls_name, "") + weight_size = weight.data.element_size() * weight.data.numel( + ) * self.attn.prefetch_ratio.get(prefix, 0) + + torch.ops.vllm.prefetch_preprocess(weight=weight, + start_flag=start_flag, + max_weight_size=int(weight_size)) + + def maybe_prefetch_attn_weight_postprocess( + self, layer_cls_name: str, stop_flag: torch.Tensor) -> None: + if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: + return + + torch.ops.vllm.prefetch_postprocess(stop_flag) + + def maybe_prefetch_moe_weight_preprocess(self, hidden_states, prefix): + self.moe.is_active_this_forward = hidden_states.shape[ + 0] >= MOE_PREFETCH_TOKEN_THRESHOLD if self.moe.enable else False + if not self.moe.is_active_this_forward: + return + forward_context = get_forward_context() + weight = forward_context.model_instance.model.layers[ + forward_context.layer_idx].mlp.experts.w13_weight + weight_size = weight.data.element_size() * weight.data.numel( + ) * self.moe.prefetch_ratio.get(prefix, 0) + torch.ops.vllm.prefetch_preprocess(weight=weight, + start_flag=None, + max_weight_size=int(weight_size)) + forward_context.layer_idx += 1 + + def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor): + if not self.moe.is_active_this_forward: + return + + torch.ops.vllm.prefetch_postprocess(stop_flag) + + +def maybe_npu_prefetch(inputs: torch.Tensor, + dependency: torch.Tensor, + max_size: int = 0, + offset: int = 0, + *, + enabled: bool = True) -> None: + if not enabled: + return + input_size = inputs.element_size() * inputs.numel() + if max_size <= 0 or max_size > input_size: + max_size = input_size + torch_npu.npu_prefetch(inputs, dependency, max_size, offset)