[Feature] support aclgraph for model runner v2 (#7110)
### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -37,7 +37,7 @@ if not vllm_version_is("0.16.0"):
|
||||
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
|
||||
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
|
||||
@@ -148,7 +148,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
|
||||
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||
final_hidden_states = moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -401,12 +401,13 @@ class AscendFusedMoE(FusedMoE):
|
||||
# When static kernels are enabled, the forward pass runs twice (compilation + capture),
|
||||
# causing moe_layer_index to overflow. Wrap the index to prevent out-of-bounds errors.
|
||||
if self.enable_npugraph_ex_static_kernel:
|
||||
forward_context.moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers))
|
||||
moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers))
|
||||
forward_context.moe_layer_index = moe_layer_index
|
||||
|
||||
# Load balancing for token distribution among experts in dummy_run
|
||||
# TODO: The community only considers load balancing when DP > 1.
|
||||
# This approach may overlook some extreme scenarios.
|
||||
enable_force_load_balance = forward_context.in_profile_run
|
||||
enable_force_load_balance = _EXTRA_CTX.in_profile_run
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if self.multistream_overlap_gate:
|
||||
@@ -419,7 +420,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
assert fc3_context.shared_experts is not None
|
||||
shared_out = fc3_context.shared_experts(hidden_states)
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
moe_comm_type = _EXTRA_CTX.moe_comm_type
|
||||
if (
|
||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
and not shared_expert_dp_enabled()
|
||||
@@ -442,16 +443,16 @@ class AscendFusedMoE(FusedMoE):
|
||||
global_num_experts=self.global_num_experts,
|
||||
)
|
||||
|
||||
if isinstance(forward_context.moe_comm_method, AllGatherCommImpl):
|
||||
if isinstance(_EXTRA_CTX.moe_comm_method, AllGatherCommImpl):
|
||||
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True)
|
||||
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True)
|
||||
|
||||
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
|
||||
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
replace_allreduce=forward_context.flash_comm_v1_enabled,
|
||||
replace_allreduce=_EXTRA_CTX.flash_comm_v1_enabled,
|
||||
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
||||
quant_type=self.quant_type,
|
||||
)
|
||||
@@ -509,7 +510,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.load_counter.add_(1)
|
||||
else:
|
||||
self.moe_load.add_(local_load)
|
||||
routed_out = forward_context.moe_comm_method.finalize(
|
||||
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
|
||||
hidden_states=fused_experts_results.routed_out,
|
||||
reduce_results=self.reduce_results,
|
||||
context_metadata=context_metadata,
|
||||
@@ -670,8 +671,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
|
||||
# NOTE: This is exactly the opposite of
|
||||
# `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
moe_comm_type = _EXTRA_CTX.moe_comm_type
|
||||
if (
|
||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
and not shared_expert_dp_enabled()
|
||||
|
||||
@@ -19,11 +19,10 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
PrepareAndFinalize,
|
||||
@@ -135,7 +134,7 @@ class MoECommMethod(ABC):
|
||||
# Check constraints
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
before_dispatch_evt = torch.npu.current_stream().record_event()
|
||||
|
||||
@@ -18,10 +18,9 @@
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch.nn.functional import pad
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.device.device_op import DeviceOperator
|
||||
from vllm_ascend.device.mxfp_compat import (
|
||||
ensure_mxfp8_moe_available,
|
||||
@@ -147,7 +146,7 @@ def quant_apply_mlp(
|
||||
weight_prefetch_method = get_weight_prefetch_method()
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
|
||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||
is_mc2 = _EXTRA_CTX.moe_comm_type == MoECommType.MC2
|
||||
if w1_scale_bias is None and w1_offset is None and is_mc2:
|
||||
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
|
||||
@@ -26,10 +26,10 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||
from vllm_ascend.quantization.methods.base import QuantType
|
||||
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
||||
@@ -242,8 +242,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
mc2_mask = _EXTRA_CTX.mc2_mask
|
||||
if self.tp_size > 1:
|
||||
# Also slice mc2_mask
|
||||
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
|
||||
@@ -252,7 +251,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
if not self.replace_allreduce:
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
target_pad_length = forward_context.padded_num_tokens
|
||||
target_pad_length = _EXTRA_CTX.padded_num_tokens
|
||||
pad_size = target_pad_length - self.num_tokens
|
||||
|
||||
# Pad if necessary (unless shared expert DP is enabled)
|
||||
@@ -367,8 +366,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
if self.moe_config.dp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
max_tokens_across_dp = _EXTRA_CTX.max_tokens_across_dp
|
||||
|
||||
self.num_tokens = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_dp - self.num_tokens
|
||||
@@ -381,8 +379,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
router_logits = self.moe_config.dp_group.all_gather(router_logits, 0)
|
||||
|
||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_pcp = forward_context.max_tokens_across_pcp
|
||||
max_tokens_across_pcp = _EXTRA_CTX.max_tokens_across_pcp
|
||||
|
||||
self.num_tokens_pcp = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_pcp - self.num_tokens_pcp
|
||||
|
||||
@@ -57,9 +57,9 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
get_flashcomm2_odp_group,
|
||||
get_flashcomm2_otp_group,
|
||||
@@ -311,8 +311,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# padding for all-to-all
|
||||
forward_context = get_forward_context()
|
||||
num_padding_tokens = forward_context.pad_size
|
||||
num_padding_tokens = _EXTRA_CTX.pad_size
|
||||
if num_padding_tokens > 0:
|
||||
input_parallel = nn.functional.pad(input_parallel, (0, 0, 0, num_padding_tokens))
|
||||
|
||||
@@ -368,7 +367,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if not forward_context.flash_comm_v1_enabled:
|
||||
if not _EXTRA_CTX.flash_comm_v1_enabled:
|
||||
# flashcomm1 not enabled
|
||||
output = get_tp_group().all_gather(output, 0)
|
||||
if num_padding_tokens > 0:
|
||||
@@ -514,9 +513,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
def matmul_and_reduce(self, input_parallel: torch.Tensor, bias_: Parameter | None) -> torch.Tensor:
|
||||
assert self.quant_method is not None
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
|
||||
mmrs_fusion = forward_context.mmrs_fusion
|
||||
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
|
||||
mmrs_fusion = _EXTRA_CTX.mmrs_fusion
|
||||
except AssertionError:
|
||||
flash_comm_v1_enabled = False
|
||||
mmrs_fusion = False
|
||||
@@ -527,7 +525,7 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
|
||||
return tensor_model_parallel_all_reduce(output_parallel)
|
||||
|
||||
pad_size = forward_context.pad_size
|
||||
pad_size = _EXTRA_CTX.pad_size
|
||||
if pad_size > 0 and not (enable_dsa_cp() and "o_proj" in self.layer.prefix):
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
|
||||
|
||||
class IndexerWrapper(nn.Module):
|
||||
@@ -144,7 +145,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
kv_cache: torch.Tensor | None = None,
|
||||
attn_metadata: AttentionMetadata | None = None,
|
||||
) -> torch.Tensor:
|
||||
need_gather_q_kv = get_forward_context().flash_comm_v1_enabled
|
||||
need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled
|
||||
output_shape = hidden_states.shape
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
@@ -13,7 +13,7 @@ from vllm.distributed import (
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
|
||||
from vllm_ascend.ops.triton.muls_add import muls_add_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
@@ -22,12 +22,12 @@ from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
get_forward_context()
|
||||
except AssertionError:
|
||||
return residual
|
||||
|
||||
if x.size(0) != residual.size(0):
|
||||
pad_size = forward_context.pad_size
|
||||
pad_size = _EXTRA_CTX.pad_size
|
||||
if pad_size > 0:
|
||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -43,12 +43,12 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
||||
except AssertionError:
|
||||
return x
|
||||
|
||||
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
|
||||
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
|
||||
if flash_comm_v1_enabled and label:
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
x = tensor_model_parallel_all_gather(x, 0)
|
||||
pad_size = forward_context.pad_size
|
||||
pad_size = _EXTRA_CTX.pad_size
|
||||
if pad_size > 0:
|
||||
x = x[:-pad_size]
|
||||
else:
|
||||
@@ -57,7 +57,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool, is_ep_c
|
||||
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
||||
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]), device=x.device, dtype=x.dtype)
|
||||
dp_size = get_dp_group().world_size
|
||||
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
|
||||
x = x.view(dp_size, _EXTRA_CTX.padded_length, *x.shape[1:])
|
||||
offset = 0
|
||||
for idx in range(dp_size):
|
||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||
@@ -79,7 +79,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
pad_size = forward_context.pad_size
|
||||
pad_size = _EXTRA_CTX.pad_size
|
||||
if pad_size > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
return tensor_model_parallel_reduce_scatter(x, 0)
|
||||
@@ -87,7 +87,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
# padding
|
||||
dp_size = get_dp_group().world_size
|
||||
num_tokens_across_dp_cpu = get_forward_context().dp_metadata.num_tokens_across_dp_cpu
|
||||
padded_x = torch.empty((dp_size, forward_context.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype)
|
||||
padded_x = torch.empty((dp_size, _EXTRA_CTX.padded_length, *x.shape[1:]), device=x.device, dtype=x.dtype)
|
||||
offset = 0
|
||||
for idx in range(dp_size):
|
||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||
@@ -98,7 +98,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if get_forward_context().flash_comm_v1_enabled and label:
|
||||
if _EXTRA_CTX.flash_comm_v1_enabled and label:
|
||||
return torch.empty(
|
||||
(x.shape[0] * get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
@@ -107,7 +107,7 @@ def _maybe_all_gather_and_maybe_unpad_fake(x: torch.Tensor, label: bool, is_ep_c
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor, is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if get_forward_context().flash_comm_v1_enabled:
|
||||
if _EXTRA_CTX.flash_comm_v1_enabled:
|
||||
return torch.empty(
|
||||
(x.shape[0] // get_tensor_model_parallel_world_size(), *x.shape[1:]), device=x.device, dtype=x.dtype
|
||||
)
|
||||
@@ -138,11 +138,10 @@ def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
|
||||
|
||||
|
||||
def _maybe_all_reduce_tensor_model_parallel_impl(final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
moe_comm_type = _EXTRA_CTX.moe_comm_type
|
||||
if (
|
||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
or forward_context.flash_comm_v1_enabled
|
||||
or _EXTRA_CTX.flash_comm_v1_enabled
|
||||
):
|
||||
return final_hidden_states
|
||||
else:
|
||||
@@ -163,7 +162,7 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, layer_name: str)
|
||||
forward_context = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
num_tokens = input_parallel.size(0)
|
||||
if forward_context.flash_comm_v1_enabled:
|
||||
if _EXTRA_CTX.flash_comm_v1_enabled:
|
||||
num_tokens = num_tokens // self.tp_size
|
||||
output = torch.empty(
|
||||
size=(num_tokens, self.output_size_per_partition), device=input_parallel.device, dtype=input_parallel.dtype
|
||||
|
||||
@@ -21,7 +21,6 @@ import os
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding,
|
||||
MRotaryEmbedding,
|
||||
@@ -31,6 +30,7 @@ from vllm.model_executor.layers.rotary_embedding import (
|
||||
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
|
||||
|
||||
@@ -240,8 +240,8 @@ class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
is_neox_style = is_neox_style_override
|
||||
is_draft_model = get_forward_context().is_draft_model
|
||||
flash_comm_v1_enabled = get_forward_context().flash_comm_v1_enabled
|
||||
is_draft_model = _EXTRA_CTX.is_draft_model
|
||||
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
|
||||
if is_draft_model and self.use_mtp and flash_comm_v1_enabled:
|
||||
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions.contiguous(), True)
|
||||
return torch.ops.vllm.npu_rotary_embedding(
|
||||
|
||||
@@ -6,6 +6,7 @@ from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.ops.linear import AscendQKVParallelLinear, AscendRowParallelLinear
|
||||
from vllm_ascend.utils import is_moe_model
|
||||
|
||||
@@ -95,11 +96,11 @@ class WeightPrefetchMethod:
|
||||
if not self.moe.is_active_this_forward:
|
||||
return
|
||||
forward_context = get_forward_context()
|
||||
if not forward_context or forward_context.model_instance is None:
|
||||
if not forward_context or _EXTRA_CTX.model_instance is None:
|
||||
return
|
||||
|
||||
# layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm.
|
||||
weight = forward_context.model_instance.model.layers[forward_context.layer_idx - 1].mlp.experts.w13_weight
|
||||
weight = _EXTRA_CTX.model_instance.model.layers[_EXTRA_CTX.layer_idx - 1].mlp.experts.w13_weight # type: ignore # type: ignore
|
||||
weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0)
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=None, max_weight_size=int(weight_size))
|
||||
|
||||
@@ -122,9 +123,7 @@ class WeightPrefetchMethod:
|
||||
except AssertionError:
|
||||
return
|
||||
self.mlp.is_active_this_forward = (
|
||||
forward_context.layer_idx is not None
|
||||
and forward_context.num_tokens is not None
|
||||
and forward_context.num_tokens < 500
|
||||
_EXTRA_CTX.layer_idx is not None and _EXTRA_CTX.num_tokens is not None and _EXTRA_CTX.num_tokens < 500
|
||||
)
|
||||
if not self.mlp.is_active_this_forward:
|
||||
return
|
||||
@@ -144,9 +143,9 @@ class WeightPrefetchMethod:
|
||||
|
||||
# start point of gate_up_proj weight prefetch
|
||||
if curr_layer_prefix.split(".")[-2] == "self_attn":
|
||||
model_instance = forward_context.model_instance
|
||||
model_instance = _EXTRA_CTX.model_instance
|
||||
layer_idx = int(curr_layer_prefix.split(".")[2])
|
||||
weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight
|
||||
weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight # type: ignore
|
||||
if self.mlp_pre_version_compatibale_config:
|
||||
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0)
|
||||
else:
|
||||
@@ -156,12 +155,12 @@ class WeightPrefetchMethod:
|
||||
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
||||
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
||||
forward_context.prefetch_mlp_gate_up_proj = True
|
||||
_EXTRA_CTX.prefetch_mlp_gate_up_proj = True
|
||||
|
||||
def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext):
|
||||
layer_idx = forward_context.layer_idx
|
||||
model_instance = forward_context.model_instance
|
||||
weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight
|
||||
layer_idx = _EXTRA_CTX.layer_idx
|
||||
model_instance = _EXTRA_CTX.model_instance
|
||||
weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight # type: ignore
|
||||
if self.mlp_pre_version_compatibale_config:
|
||||
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0)
|
||||
else:
|
||||
@@ -171,22 +170,22 @@ class WeightPrefetchMethod:
|
||||
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
||||
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
||||
forward_context.prefetch_mlp_down_proj = True
|
||||
forward_context.layer_idx += 1
|
||||
_EXTRA_CTX.prefetch_mlp_down_proj = True
|
||||
_EXTRA_CTX.layer_idx = layer_idx + 1 # type: ignore
|
||||
|
||||
def maybe_prefetch_mlp_weight_postprocess(self, stop_flag: torch.Tensor):
|
||||
if not self.mlp.is_active_this_forward:
|
||||
return
|
||||
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
get_forward_context()
|
||||
except AssertionError:
|
||||
return
|
||||
|
||||
if forward_context.prefetch_mlp_gate_up_proj or forward_context.prefetch_mlp_down_proj:
|
||||
if _EXTRA_CTX.prefetch_mlp_gate_up_proj or _EXTRA_CTX.prefetch_mlp_down_proj:
|
||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
_EXTRA_CTX.prefetch_mlp_gate_up_proj = False
|
||||
_EXTRA_CTX.prefetch_mlp_down_proj = False
|
||||
|
||||
def maybe_prefetch_mla_or_sla_weight_in_current_stream(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user