[Misc] Drop Prefetch MLP Env (#7357)
### What this PR does / why we need it?
remove deprecated environment variables related to MLP prefetching
### Does this PR introduce _any_ user-facing change?
yes, the deprecated env vars can not be used then.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -43,9 +43,6 @@ class WeightPrefetchMethod:
|
||||
MLP_GATE_UP: str = "gate_up"
|
||||
MLP_DOWN: str = "down"
|
||||
|
||||
# backward compatibility: delete in future versions
|
||||
mlp_pre_version_compatibale_config: dict = {}
|
||||
|
||||
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
|
||||
self.is_moe = is_moe_model(get_current_vllm_config())
|
||||
self.mla_sfa_prefetch_enable = weight_prefetch_config.enabled
|
||||
@@ -70,7 +67,6 @@ class WeightPrefetchMethod:
|
||||
enable=weight_prefetch_config.enabled and not self.is_moe,
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get("mlp", {}) or {"gate_up": 1.0, "down": 1.0},
|
||||
)
|
||||
self.mlp_pre_version_compatibale_config = weight_prefetch_config.mlp_pre_version_compatibale_config
|
||||
|
||||
def maybe_prefetch_attn_weight_preprocess(
|
||||
self, layer_cls_name: str, weight: torch.Tensor, start_flag: torch.Tensor
|
||||
@@ -114,7 +110,7 @@ class WeightPrefetchMethod:
|
||||
def maybe_prefetch_mlp_weight_preprocess(
|
||||
self, prefetch_layer_name: str, x_dependency: torch.Tensor | None, curr_layer_prefix: str | None = None
|
||||
):
|
||||
if not self.mlp.enable and not self.mlp_pre_version_compatibale_config:
|
||||
if not self.mlp.enable:
|
||||
self.mlp.is_active_this_forward = False
|
||||
return
|
||||
|
||||
@@ -146,12 +142,9 @@ class WeightPrefetchMethod:
|
||||
model_instance = _EXTRA_CTX.model_instance
|
||||
layer_idx = int(curr_layer_prefix.split(".")[2])
|
||||
weight = model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight # type: ignore
|
||||
if self.mlp_pre_version_compatibale_config:
|
||||
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_GATE_UP, 0)
|
||||
else:
|
||||
weight_size = (
|
||||
weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0)
|
||||
)
|
||||
weight_size = (
|
||||
weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_GATE_UP, 0)
|
||||
)
|
||||
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
||||
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
||||
@@ -161,12 +154,7 @@ class WeightPrefetchMethod:
|
||||
layer_idx = _EXTRA_CTX.layer_idx
|
||||
model_instance = _EXTRA_CTX.model_instance
|
||||
weight = model_instance.model.layers[layer_idx].mlp.down_proj.weight # type: ignore
|
||||
if self.mlp_pre_version_compatibale_config:
|
||||
weight_size = self.mlp_pre_version_compatibale_config.get(self.MLP_DOWN, 0)
|
||||
else:
|
||||
weight_size = (
|
||||
weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0)
|
||||
)
|
||||
weight_size = weight.data.element_size() * weight.data.numel() * self.mlp.prefetch_ratio.get(self.MLP_DOWN, 0)
|
||||
if weight_size > MAX_PREFETCH_WEIGHT_SIZE:
|
||||
weight_size = MAX_PREFETCH_WEIGHT_SIZE
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag=x_dependency, max_weight_size=int(weight_size))
|
||||
|
||||
Reference in New Issue
Block a user