[Feature] support aclgraph for model runner v2 (#7110)
### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.ops.linear import AscendQKVParallelLinear, AscendRowParallelLinear
|
||||
from vllm_ascend.utils import is_moe_model
|
||||
|
||||
@@ -95,11 +96,11 @@ class WeightPrefetchMethod:
|
||||
if not self.moe.is_active_this_forward:
|
||||
return
|
||||
forward_context = get_forward_context()
|
||||
if not forward_context or forward_context.model_instance is None:
|
||||
if not forward_context or _EXTRA_CTX.model_instance is None:
|
||||
return
|
||||
|
||||
# layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm.
|
||||
weight = forward_context.model_instance.model.layers[forward_context.layer_idx - 1].mlp.experts.w13_weight
|
||||
weight = _EXTRA_CTX.model_instance.model.layers[_EXTRA_CTX.layer_idx - 1].mlp.experts.w13_weight # type: ignore # type: ignore
|
||||
weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0)
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=None, max_weight_size=int(weight_size))
|
||||
|
||||
@@ -122,9 +123,7 @@ class WeightPrefetchMethod:
|
||||
except AssertionError:
|
||||
return
|
||||
self.mlp.is_active_this_forward = (
|
||||
forward_context.layer_idx is not None
|
||||
and forward_context.num_tokens is not None
|
||||
and forward_context.num_tokens < 500
|
||||
_EXTRA_CTX.layer_idx is not None and _EXTRA_CTX.num_tokens is not None and _EXTRA_CTX.num_tokens < 500
|
||||
)
|
||||
if not self.mlp.is_active_this_forward:
|
||||
return
|
||||
@@ -144,9 +143,9 @@ class WeightPrefetchMethod:
|
||||
|
||||
# start point of gate_up_proj weight prefetch
|
||||
if curr_layer_prefix.split(".")[-2] == "self_attn":
|
||||
model_instance = forward_context.model_instance
|
||||
model_instance = _EXTRA_CTX.model_instance
|
||||
layer_idx = int(curr_layer_prefix.split(".")[2])
|
||||
weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight
|
||||
weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight # type: ignore
|
||||
if self.mlp_pre_version_compatibale_config:
|
||||
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0)
|
||||
else:
|
||||
@@ -156,12 +155,12 @@ class WeightPrefetchMethod:
|
||||
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
||||
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
||||
forward_context.prefetch_mlp_gate_up_proj = True
|
||||
_EXTRA_CTX.prefetch_mlp_gate_up_proj = True
|
||||
|
||||
def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext):
|
||||
layer_idx = forward_context.layer_idx
|
||||
model_instance = forward_context.model_instance
|
||||
weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight
|
||||
layer_idx = _EXTRA_CTX.layer_idx
|
||||
model_instance = _EXTRA_CTX.model_instance
|
||||
weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight # type: ignore
|
||||
if self.mlp_pre_version_compatibale_config:
|
||||
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0)
|
||||
else:
|
||||
@@ -171,22 +170,22 @@ class WeightPrefetchMethod:
|
||||
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
||||
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
||||
forward_context.prefetch_mlp_down_proj = True
|
||||
forward_context.layer_idx += 1
|
||||
_EXTRA_CTX.prefetch_mlp_down_proj = True
|
||||
_EXTRA_CTX.layer_idx = layer_idx + 1 # type: ignore
|
||||
|
||||
def maybe_prefetch_mlp_weight_postprocess(self, stop_flag: torch.Tensor):
|
||||
if not self.mlp.is_active_this_forward:
|
||||
return
|
||||
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
get_forward_context()
|
||||
except AssertionError:
|
||||
return
|
||||
|
||||
if forward_context.prefetch_mlp_gate_up_proj or forward_context.prefetch_mlp_down_proj:
|
||||
if _EXTRA_CTX.prefetch_mlp_gate_up_proj or _EXTRA_CTX.prefetch_mlp_down_proj:
|
||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
_EXTRA_CTX.prefetch_mlp_gate_up_proj = False
|
||||
_EXTRA_CTX.prefetch_mlp_down_proj = False
|
||||
|
||||
def maybe_prefetch_mla_or_sla_weight_in_current_stream(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user