### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| vllm_ascend/ops/\_\_init\_\_.py |
| vllm_ascend/ops/activation.py |
| vllm_ascend/ops/flashcomm2_oshard_manager.py |
| vllm_ascend/ops/layernorm.py |
| vllm_ascend/ops/mla.py |
| vllm_ascend/ops/mm_encoder_attention.py |
| vllm_ascend/ops/register_custom_ops.py |
| vllm_ascend/ops/vocab_parallel_embedding.py |
| vllm_ascend/ops/weight_prefetch.py |
| vllm_ascend/spec_decode/\_\_init\_\_.py |
| vllm_ascend/spec_decode/eagle_proposer.py |
| vllm_ascend/spec_decode/interface.py |
| vllm_ascend/spec_decode/mtp_proposer.py |
| vllm_ascend/spec_decode/ngram_proposer.py |
| vllm_ascend/spec_decode/suffix_proposer.py |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
@@ -2,19 +2,18 @@ from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import logger
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.ops.linear import AscendQKVParallelLinear, AscendRowParallelLinear
|
||||
from vllm_ascend.utils import is_moe_model
|
||||
|
||||
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
||||
MOE_PREFETCH_TOKEN_THRESHOLD = 96
|
||||
MAX_PREFETCH_WEIGHT_SIZE = 18 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleWeightPrefetchConfig:
|
||||
module_name: str
|
||||
@@ -24,10 +23,7 @@ class ModuleWeightPrefetchConfig:
|
||||
linear_prefix_map: dict = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.prefetch_ratio = {
|
||||
prefix: ratio
|
||||
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
|
||||
}
|
||||
self.prefetch_ratio = {prefix: ratio for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1}
|
||||
|
||||
assert self.module_name in SUPPORTED_MODULES, (
|
||||
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
|
||||
@@ -41,6 +37,7 @@ class WeightPrefetchMethod:
|
||||
"""
|
||||
Unified weight prefetch method.
|
||||
"""
|
||||
|
||||
is_moe: bool = True
|
||||
MLP_GATE_UP: str = "gate_up"
|
||||
MLP_DOWN: str = "down"
|
||||
@@ -54,60 +51,53 @@ class WeightPrefetchMethod:
|
||||
self.attn = ModuleWeightPrefetchConfig(
|
||||
module_name="attn",
|
||||
enable=weight_prefetch_config.enabled,
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
||||
"attn", {}) or {'qkv': 1.0, 'o': 1.0},
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get("attn", {}) or {"qkv": 1.0, "o": 1.0},
|
||||
linear_prefix_map={
|
||||
AscendQKVParallelLinear.__name__: "qkv",
|
||||
AscendRowParallelLinear.__name__: "o",
|
||||
})
|
||||
},
|
||||
)
|
||||
self.moe = ModuleWeightPrefetchConfig(
|
||||
module_name="moe",
|
||||
enable=weight_prefetch_config.enabled and self.is_moe,
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
||||
"moe", {}) or {'gate_up': 0.8})
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get("moe", {}) or {"gate_up": 0.8},
|
||||
)
|
||||
|
||||
self.mlp = ModuleWeightPrefetchConfig(
|
||||
module_name="mlp",
|
||||
enable=weight_prefetch_config.enabled and not self.is_moe,
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
||||
"mlp", {}) or {'gate_up': 1.0, 'down': 1.0})
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get("mlp", {}) or {"gate_up": 1.0, "down": 1.0},
|
||||
)
|
||||
self.mlp_pre_version_compatibale_config = weight_prefetch_config.mlp_pre_version_compatibale_config
|
||||
|
||||
def maybe_prefetch_attn_weight_preprocess(
|
||||
self, layer_cls_name: str, weight: torch.Tensor,
|
||||
start_flag: torch.Tensor) -> None:
|
||||
self, layer_cls_name: str, weight: torch.Tensor, start_flag: torch.Tensor
|
||||
) -> None:
|
||||
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
|
||||
return
|
||||
|
||||
prefix = self.attn.linear_prefix_map.get(layer_cls_name, "")
|
||||
weight_size = weight.data.element_size() * weight.data.numel(
|
||||
) * self.attn.prefetch_ratio.get(prefix, 0)
|
||||
weight_size = weight.data.element_size() * weight.data.numel() * self.attn.prefetch_ratio.get(prefix, 0)
|
||||
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||
start_flag=start_flag,
|
||||
max_weight_size=int(weight_size))
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=start_flag, max_weight_size=int(weight_size))
|
||||
|
||||
def maybe_prefetch_attn_weight_postprocess(
|
||||
self, layer_cls_name: str, stop_flag: torch.Tensor) -> None:
|
||||
def maybe_prefetch_attn_weight_postprocess(self, layer_cls_name: str, stop_flag: torch.Tensor) -> None:
|
||||
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
|
||||
return
|
||||
|
||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||
|
||||
def maybe_prefetch_moe_weight_preprocess(self, hidden_states, prefix):
|
||||
self.moe.is_active_this_forward = hidden_states.shape[
|
||||
0] >= MOE_PREFETCH_TOKEN_THRESHOLD if self.moe.enable else False
|
||||
self.moe.is_active_this_forward = (
|
||||
hidden_states.shape[0] >= MOE_PREFETCH_TOKEN_THRESHOLD if self.moe.enable else False
|
||||
)
|
||||
if not self.moe.is_active_this_forward:
|
||||
return
|
||||
forward_context = get_forward_context()
|
||||
# layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm.
|
||||
weight = forward_context.model_instance.model.layers[
|
||||
forward_context.layer_idx - 1].mlp.experts.w13_weight
|
||||
weight_size = weight.data.element_size() * weight.data.numel(
|
||||
) * self.moe.prefetch_ratio.get(prefix, 0)
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||
start_flag=None,
|
||||
max_weight_size=int(weight_size))
|
||||
weight = forward_context.model_instance.model.layers[forward_context.layer_idx - 1].mlp.experts.w13_weight
|
||||
weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0)
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=None, max_weight_size=int(weight_size))
|
||||
|
||||
def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor):
|
||||
if not self.moe.is_active_this_forward:
|
||||
@@ -116,7 +106,9 @@ class WeightPrefetchMethod:
|
||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||
|
||||
# x_dependency only eager mode can pass None
|
||||
def maybe_prefetch_mlp_weight_preprocess(self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None):
|
||||
def maybe_prefetch_mlp_weight_preprocess(
|
||||
self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None
|
||||
):
|
||||
if not self.mlp.enable and not self.mlp_pre_version_compatibale_config:
|
||||
self.mlp.is_active_this_forward = False
|
||||
return
|
||||
@@ -140,24 +132,26 @@ class WeightPrefetchMethod:
|
||||
else:
|
||||
raise ValueError(f"Unsupported prefetch weight name: {prefetch_layer_name}")
|
||||
|
||||
def _maybe_prefetch_mlp_gate_up_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext, curr_layer_prefix: str | None):
|
||||
def _maybe_prefetch_mlp_gate_up_weight_preprocess(
|
||||
self, x_dependency: torch.Tensor, forward_context: ForwardContext, curr_layer_prefix: str | None
|
||||
):
|
||||
if not curr_layer_prefix:
|
||||
raise ValueError("curr_layer_prefix must been specified when prefetching mlp gate_up_proj weight")
|
||||
|
||||
# start point of gate_up_proj weight prefetch
|
||||
if curr_layer_prefix.split('.')[-2] == "self_attn":
|
||||
if curr_layer_prefix.split(".")[-2] == "self_attn":
|
||||
model_instance = forward_context.model_instance
|
||||
layer_idx = int(curr_layer_prefix.split('.')[2])
|
||||
layer_idx = int(curr_layer_prefix.split(".")[2])
|
||||
weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight
|
||||
if self.mlp_pre_version_compatibale_config:
|
||||
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0)
|
||||
else:
|
||||
weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0)
|
||||
weight_size = (
|
||||
weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0)
|
||||
)
|
||||
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
||||
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||
start_flag=x_dependency,
|
||||
max_weight_size=int(weight_size))
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
||||
forward_context.prefetch_mlp_gate_up_proj = True
|
||||
|
||||
def _maybe_prefetch_mlp_down_weight_preprocess(self, x_dependency: torch.Tensor, forward_context: ForwardContext):
|
||||
@@ -167,12 +161,12 @@ class WeightPrefetchMethod:
|
||||
if self.mlp_pre_version_compatibale_config:
|
||||
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0)
|
||||
else:
|
||||
weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0)
|
||||
weight_size = (
|
||||
weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0)
|
||||
)
|
||||
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
||||
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||
start_flag=x_dependency,
|
||||
max_weight_size=int(weight_size))
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
||||
forward_context.prefetch_mlp_down_proj = True
|
||||
forward_context.layer_idx += 1
|
||||
|
||||
@@ -185,19 +179,15 @@ class WeightPrefetchMethod:
|
||||
except AssertionError:
|
||||
return
|
||||
|
||||
if forward_context.prefetch_mlp_gate_up_proj or \
|
||||
forward_context.prefetch_mlp_down_proj:
|
||||
if forward_context.prefetch_mlp_gate_up_proj or forward_context.prefetch_mlp_down_proj:
|
||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
|
||||
|
||||
def maybe_npu_prefetch(inputs: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
max_size: int = 0,
|
||||
offset: int = 0,
|
||||
*,
|
||||
enabled: bool = True) -> None:
|
||||
def maybe_npu_prefetch(
|
||||
inputs: torch.Tensor, dependency: torch.Tensor, max_size: int = 0, offset: int = 0, *, enabled: bool = True
|
||||
) -> None:
|
||||
if not enabled:
|
||||
return
|
||||
input_size = inputs.element_size() * inputs.numel()
|
||||
|
||||
Reference in New Issue
Block a user