diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index e709b3a..6685bfe 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -24,24 +24,24 @@ LLM(model="Qwen/Qwen3-8B", additional_config={"config_key":"config_value"}) The following table lists the additional configuration options available in vLLM Ascend: -| Name | Type | Default | Description | -|-------------------------------| ---- |------|-----------------------------------------------------------------------------------------------| -| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode | -| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler | -| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. | -| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | -| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. | -| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. | -| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. | -| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. | -| `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. | -| `multistream_overlap_shared_expert`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on moe models with shared experts. | -| `dynamic_eplb` | bool | `False` | Whether to enable dynamic eplb | -|`num_iterations_eplb_update`| int | `400` | Forward iterations when eplb would begin | -|`gate_eplb`| bool | `False` | Whether to enale eplb only once. | -|`num_wait_worker_iterations`| int | `30` | The forward iterations when eplb worker will finish cpu task. In our test default value 30 would cover most cases. | -|`expert_map_record_path`| str | `None` | When dynamic eplb is completed, save the current expert load heatmap to the specified path. | -|`init_redundancy_expert`| int | `0` |Specify redundant experts during initialization.| +| Name | Type | Default | Description | +|-------------------------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------| +| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode | +| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler | +| `weight_prefetch_config` | dict | `{}` | The config options for weight prefetch | +| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. | +| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | +| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. | +| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. | +| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. | +| `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. | +| `multistream_overlap_shared_expert` | bool | `False` | Whether to enable multistream shared expert. This option only takes effects on moe models with shared experts. | +| `dynamic_eplb` | bool | `False` | Whether to enable dynamic eplb | +| `num_iterations_eplb_update` | int | `400` | Forward iterations when eplb would begin | +| `gate_eplb` | bool | `False` | Whether to enale eplb only once. | +| `num_wait_worker_iterations` | int | `30` | The forward iterations when eplb worker will finish cpu task. In our test default value 30 would cover most cases. | +| `expert_map_record_path` | str | `None` | When dynamic eplb is completed, save the current expert load heatmap to the specified path. | +| `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. | The details of each config option are as follows: @@ -71,6 +71,13 @@ The details of each config option are as follows: ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well. +**weight_prefetch_config** + +| Name | Type | Default | Description | +|------------------|------|------------------------------------|------------------------------------| +| `enabled` | bool | `False` | Whether to enable weight prefetch. | +| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}}` | Prefetch ratio of each weights. | + ### Example An example of additional configuration is as follows: @@ -90,6 +97,15 @@ An example of additional configuration is as follows: "max_long_partial_prefills": 1, "long_prefill_token_threshold": 4096, }, + "weight_prefetch_config": { + "enabled": True, + "prefetch_ratio": { + "attn": { + "qkv": 1.0, + "o": 1.0, + }, + }, + }, "multistream_overlap_shared_expert": True, "refresh": False, } diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 0164057..6aac6df 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -495,7 +495,7 @@ class TestAscendMLAImpl(TestBase): mock_up_proj.assert_called_once() mock_npu_fused_infer_attention_score.assert_called_once() - @patch("vllm_ascend.attention.mla_v1.npu_prefetch") + @patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch") def test_mla_preprocess(self, magic_npu_fetch): magic_npu_fetch.return_value = MagicMock() batch_size = 4 diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 3f2557b..69b33a9 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -68,16 +68,23 @@ class TestAscendW8A8LinearMethod(TestBase): self.assertEqual(params['weight_scale'].shape, (10, 1)) self.assertEqual(params['weight_offset'].shape, (10, 1)) + @patch("vllm_ascend.quantization.w8a8.get_forward_context") @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") @patch("torch_npu.npu_quant_matmul") def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, - mock_quant_per_tensor): + mock_quant_per_tensor, + mock_get_forward_context): layer = MagicMock() layer.aclnn_input_scale = 0.1 layer.aclnn_input_offset = 0.2 layer.weight = torch.randn(128, 256) layer.deq_scale = 0.3 + mock_forward_context = MagicMock() + mock_get_forward_context.return_value = mock_forward_context + mock_weight_prefetch_method = MagicMock() + mock_forward_context.weight_prefetch_method = mock_weight_prefetch_method + x = torch.randn(32, 128) bias = torch.randn(256) mock_quant_per_tensor.return_value = torch.randint(-128, diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 27017c1..93579ce 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -45,6 +45,12 @@ class AscendConfig: "ascend_scheduler_config", {}) self.ascend_scheduler_config = AscendSchedulerConfig( ascend_scheduler_config) + + weight_prefetch_config = additional_config.get( + "weight_prefetch_config", {}) + self.weight_prefetch_config = WeightPrefetchConfig( + weight_prefetch_config) + # Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config self.expert_map_path = additional_config.get("expert_map_path", None) self.expert_map_record_path = additional_config.get( @@ -65,7 +71,6 @@ class AscendConfig: ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel self.multistream_overlap_shared_expert = additional_config.get( "multistream_overlap_shared_expert", False) - self.enable_prefetch = additional_config.get("enable_prefetch", False) self.lmhead_tensor_parallel_size = additional_config.get( "lmhead_tensor_parallel_size", None) if self.lmhead_tensor_parallel_size is not None: @@ -185,6 +190,24 @@ class AscendSchedulerConfig: setattr(self, k, v) +class WeightPrefetchConfig: + """ + Configuration Object for weight_prefetch_config from additional_config + """ + + prefetch_ratio: dict = { + "attn": { + "qkv": 1.0, + "o": 1.0, + }, + } + + def __init__(self, weight_prefetch_config: dict): + self.enabled = weight_prefetch_config.get("enabled", False) + self.prefetch_ratio = weight_prefetch_config.get( + "prefetch_ratio", self.prefetch_ratio) + + _ASCEND_CONFIG: Optional[AscendConfig] = None diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index ad61245..209e507 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -1,7 +1,7 @@ import math from contextlib import contextmanager from enum import Enum -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import torch from vllm.config import CUDAGraphMode, VllmConfig @@ -13,6 +13,11 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context, import vllm_ascend.envs as envs_ascend from vllm_ascend.utils import enable_sp +if TYPE_CHECKING: + from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod +else: + WeightPrefetchMethod = None + class FusedMoEState(Enum): AllGather = 0 @@ -65,7 +70,8 @@ def set_ascend_forward_context( aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: Optional[BatchDescriptor] = None, prefetch_stream: torch.npu.Stream = None, - model_instance: torch.nn.Module = None): + model_instance: torch.nn.Module = None, + weight_prefetch_method: Optional[WeightPrefetchMethod] = None): """A context manager that stores the current forward context, can be attention metadata, etc. We add some additional param into forward_context. @@ -127,6 +133,7 @@ def set_ascend_forward_context( hasattr(model_instance.model, "start_layer"): forward_context.layer_idx = model_instance.model.start_layer + # TODO(rjg-lyh): refactor mlp weight prefetch method # set for mlp weight prefetch prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \ @@ -138,6 +145,8 @@ def set_ascend_forward_context( forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_down_proj = False forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled + # TODO(yuzhup): integrate moe weight prefetch method + forward_context.weight_prefetch_method = weight_prefetch_method # TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant. # It will be improved later by implementing operator fusion through the FX graph. diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 73cbae6..39340f7 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -24,7 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn -from vllm_ascend.utils import npu_prefetch +from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.worker.npu_input_batch import InputBatch if TYPE_CHECKING: @@ -493,7 +493,7 @@ class AscendMLAImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.enable_prefetch = ascend_config.enable_prefetch + self.enable_prefetch = ascend_config.weight_prefetch_config.enabled self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz vllm_config = get_current_vllm_config() @@ -877,9 +877,9 @@ class AscendMLAImpl(MLAAttentionImpl): num_decode_tokens = attn_metadata.num_decode_tokens num_actual_tokens = attn_metadata.num_actual_tokens if self.q_a_proj is not None: - npu_prefetch(self.q_a_proj.weight, - hidden_states, - enabled=self.enable_prefetch) + maybe_npu_prefetch(inputs=self.q_a_proj.weight, + dependency=hidden_states, + enabled=self.enable_prefetch) ckq = self.q_a_proj(hidden_states)[0] q_c = self.q_a_layernorm(ckq) else: @@ -1005,10 +1005,10 @@ class AscendMLAImpl(MLAAttentionImpl): current_ms_metadata = get_multistream_comm_context() MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 if current_ms_metadata is None: - npu_prefetch(self.o_proj.weight, - o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=self.enable_prefetch) + maybe_npu_prefetch(inputs=self.o_proj.weight, + dependency=o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=self.enable_prefetch) output[...] = self.o_proj( o_proj_input, @@ -1016,10 +1016,10 @@ class AscendMLAImpl(MLAAttentionImpl): is_force_scatter=self.enable_shared_expert_dp)[0] else: with torch.npu.stream(current_ms_metadata.comm_stream): - npu_prefetch(self.o_proj.weight, - o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=self.enable_prefetch) + maybe_npu_prefetch(inputs=self.o_proj.weight, + dependency=o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=self.enable_prefetch) output[...] = self.o_proj( o_proj_input, is_prefill=prefill_preprocess_res is not None, diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 438bff1..7e9cdde 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -11,6 +11,8 @@ from vllm.utils import direct_register_custom_op import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch +from vllm_ascend.utils import npu_stream_switch, prefetch_stream def _maybe_chunk_residual_impl(x: torch.Tensor, @@ -148,6 +150,33 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None: return +def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor, + max_weight_size: int) -> None: + calculation_stream = torch_npu.npu.current_stream() + weight_prefetch_stream = prefetch_stream() + weight_prefetch_stream.wait_stream(calculation_stream) + with npu_stream_switch(weight_prefetch_stream): + maybe_npu_prefetch(inputs=weight, + dependency=start_flag, + max_size=max_weight_size) + + +def _prefetch_preprocess_impl_fake(weight: torch.Tensor, + start_flag: torch.Tensor, + max_weight_size: int) -> None: + return + + +def _prefetch_postprocess_impl(stop_flag: torch.Tensor) -> None: + calculation_stream = torch_npu.npu.current_stream() + weight_prefetch_stream = prefetch_stream() + calculation_stream.wait_stream(weight_prefetch_stream) + + +def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None: + return + + def _maybe_all_reduce_tensor_model_parallel_impl( final_hidden_states: torch.Tensor) -> torch.Tensor: forward_context = get_forward_context() @@ -194,6 +223,18 @@ direct_register_custom_op(op_name="maybe_wait_prefetch_done", mutates_args=[], dispatch_key="PrivateUse1") +direct_register_custom_op(op_name="prefetch_preprocess", + op_func=_prefetch_preprocess_impl, + fake_impl=_prefetch_preprocess_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="prefetch_postprocess", + op_func=_prefetch_postprocess_impl, + fake_impl=_prefetch_postprocess_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", op_func=_maybe_all_reduce_tensor_model_parallel_impl, fake_impl=lambda x: x, diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py new file mode 100644 index 0000000..a6004c5 --- /dev/null +++ b/vllm_ascend/ops/weight_prefetch.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass, field + +import torch +import torch_npu + +from vllm_ascend.ascend_config import WeightPrefetchConfig + +SUPPORTED_MODULES = ["attn", "mlp", "moe"] + + +@dataclass +class ModuleWeightPrefetchConfig: + module_name: str + enable: bool = False + prefetch_ratio: dict = field(default_factory=dict) + + def __post_init__(self) -> None: + self.prefetch_ratio = { + prefix: ratio + for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1 + } + + assert self.module_name in SUPPORTED_MODULES, ( + f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}" + ) + + if self.module_name in SUPPORTED_MODULES: + self.enable = self.enable and any(self.prefetch_ratio.values()) > 0 + + +class WeightPrefetchMethod: + """ + Unified weight prefetch method. + """ + + def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: + self.attn = ModuleWeightPrefetchConfig( + module_name="attn", + enable=weight_prefetch_config.enabled, + prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( + "attn", {})) + + def maybe_prefetch_attn_weight_preprocess( + self, prefix: str, weight: torch.Tensor, + start_flag: torch.Tensor) -> None: + if not self.attn.enable: + return + + weight_size = weight.data.element_size() * weight.data.numel( + ) * self.attn.prefetch_ratio.get(prefix, 0) + + torch.ops.vllm.prefetch_preprocess(weight=weight, + start_flag=start_flag, + max_weight_size=int(weight_size)) + + def maybe_prefetch_attn_weight_postprocess( + self, stop_flag: torch.Tensor) -> None: + if not self.attn.enable: + return + + torch.ops.vllm.prefetch_postprocess(stop_flag) + + +def maybe_npu_prefetch(inputs: torch.Tensor, + dependency: torch.Tensor, + max_size: int = 0, + offset: int = 0, + *, + enabled: bool = True) -> None: + if not enabled: + return + input_size = inputs.element_size() * inputs.numel() + if max_size <= 0 or max_size > input_size: + max_size = input_size + torch_npu.npu_prefetch(inputs, dependency, max_size, offset) diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 010d45d..433dbab 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -21,6 +21,7 @@ import torch import torch_npu from vllm.attention.backends.abstract import AttentionType from vllm.distributed.parallel_state import get_ep_group +from vllm.forward_context import get_forward_context from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.moe.experts_selector import select_experts @@ -97,11 +98,32 @@ class AscendW8A8LinearMethod: tp_rank: Optional[int] = 0, ) -> torch.Tensor: if x.dtype != torch.int8: + attn_weight_map = { + "AscendQKVParallelLinear": "qkv", + "AscendRowParallelLinear": "o", + } + layer_cls_name = layer.__class__.__name__ + weight_prefetch_method = get_forward_context( + ).weight_prefetch_method + assert weight_prefetch_method is not None + + # prefetch_qkvo_proj.weight preprocess + weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( + prefix=attn_weight_map.get(layer_cls_name, ""), + weight=layer.weight, + start_flag=x, + ) + # quant x = quant_per_tensor( x, layer.aclnn_input_scale_reciprocal, layer.aclnn_input_offset, ) + # prefetch_qkvo_proj.weight postprocess + if layer_cls_name in attn_weight_map.keys(): + weight_prefetch_method.maybe_prefetch_attn_weight_postprocess( + x) + quant_bias = layer.quant_bias if tp_rank == 0 else None if is_310p(): # On 300I Duo platform, we need transpose again if diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index 371b1c9..a7ab345 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -70,11 +70,12 @@ from vllm.sequence import IntermediateTensors from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.models.layers.sfa import Indexer +from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ TorchairAscendW8A8DynamicLinearMethod -from vllm_ascend.utils import dispose_tensor, npu_prefetch, oproj_tp_enable +from vllm_ascend.utils import dispose_tensor, oproj_tp_enable class TorchairDeepseekV2SiluAndMul(SiluAndMul): @@ -589,9 +590,9 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention): and attn_metadata.num_decodes > 0) forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} if self.q_lora_rank is not None: - npu_prefetch(self.q_a_proj.weight, - hidden_states, - enabled=enable_multistream_mla) + maybe_npu_prefetch(self.q_a_proj.weight, + hidden_states, + enabled=enable_multistream_mla) ckq = self.q_a_proj(hidden_states)[0] hidden_states_or_q_c = self.q_a_layernorm(ckq) forward_kwargs['ckq'] = ckq diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 995173a..ed14fed 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -23,9 +23,9 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn +from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, npu_stream_switch, npu_wait_tensor) -from vllm_ascend.utils import npu_prefetch from vllm_ascend.worker.npu_input_batch import InputBatch if TYPE_CHECKING: @@ -684,10 +684,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl): if hasattr(self, "running_in_graph") and not self.running_in_graph: return x MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB - npu_prefetch(self.o_proj.weight, - x, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=enable_multistream_mla) + maybe_npu_prefetch(self.o_proj.weight, + x, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) return self.o_proj(x, is_prefill=False)[0] # Return `ql_nope`, `q_pe` @@ -1281,10 +1281,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl): current_ms_metadata = get_multistream_comm_context() MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB if current_ms_metadata is None: - npu_prefetch(self.o_proj.weight, - o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=enable_multistream_mla) + maybe_npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) output[...] = self.o_proj( o_proj_input, @@ -1292,10 +1292,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl): is_force_scatter=self.enable_shared_expert_dp)[0] else: with torch.npu.stream(current_ms_metadata.comm_stream): - npu_prefetch(self.o_proj.weight, - o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=enable_multistream_mla) + maybe_npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) output[...] = self.o_proj( o_proj_input, is_prefill=True, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 6157914..17f2eda 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -51,6 +51,7 @@ _CUSTOM_OP_ENABLED = None _IS_310P = None _SLEEP_MODE_ENABLED = None _CURRENT_STREAM = None +_PREFETCH_STREAM = None _ASCEND_CUSTOMOP_IS_REIGISTERED = False @@ -241,6 +242,15 @@ def current_stream() -> torch.npu.Stream: return _CURRENT_STREAM +def prefetch_stream() -> torch.npu.Stream: + global _PREFETCH_STREAM + if _PREFETCH_STREAM is None: + # when this function is called before any stream is set, + # we return the default stream. + _PREFETCH_STREAM = torch_npu.npu.Stream() + return _PREFETCH_STREAM + + def adapt_patch(is_global_patch: bool = False): if is_global_patch: from vllm_ascend.patch import platform # noqa: F401 @@ -446,20 +456,6 @@ class ProfileExecuteDuration: return durations -# TODO(wxy): Move to ops module -def npu_prefetch(input: torch.Tensor, - dependency: torch.Tensor, - max_size: int = 0, - *, - enabled: bool = True): - if not enabled: - return - input_size = input.element_size() * input.numel() - if max_size <= 0 or max_size > input_size: - max_size = input_size - torch_npu.npu_prefetch(input, dependency, max_size) - - # TODO(ttanzhiqiang): rm_router_logits # dp>1 will trigger # In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors. diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3576fc5..b46d1be 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -113,6 +113,7 @@ from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.utils import model_register from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.multistream.ms_split import compute_split_seq_index +from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.logits_processor import build_logitsprocs from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler @@ -285,6 +286,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled else: self.chunked_prefill_enabled = True + self.weight_prefetch_method = WeightPrefetchMethod( + self.ascend_config.weight_prefetch_config) if self.cache_config.cache_dtype == "auto": self.kv_cache_dtype = self.dtype @@ -1856,7 +1859,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_actual_tokens=scheduler_output. total_num_scheduled_tokens, prefetch_stream=self.prefetch_stream, - model_instance=self.model): + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( @@ -2370,7 +2374,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, prefetch_stream=self.prefetch_stream, - model_instance=self.model): + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method): hidden_states = self._generate_dummy_run_hidden_states( with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors,