[v0.11.0][Bugfix][MoE] enable force_load_balance in aclgraph (#4367)
### What this PR does / why we need it? Enable force_load_balance in aclgraph, solving OOM issues. pick from https://github.com/vllm-project/vllm-ascend/pull/4366 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? e2e & ut Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -19,7 +19,7 @@ from typing import Any, Callable, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@@ -51,20 +51,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
def __init__(self, moe: FusedMoEConfig = None):
|
def __init__(self, moe: FusedMoEConfig = None):
|
||||||
|
|
||||||
super().__init__(moe=moe)
|
super().__init__(moe=moe)
|
||||||
|
|
||||||
# NOTE: Currently, this self.use_aclgraph is only used in
|
|
||||||
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
|
|
||||||
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
|
|
||||||
# Once torch.randint_like is supported or removed, this flag can be removed.
|
|
||||||
vllm_config = get_current_vllm_config()
|
|
||||||
ascend_config = get_ascend_config()
|
|
||||||
self.dynamic_eplb = get_ascend_config().dynamic_eplb
|
self.dynamic_eplb = get_ascend_config().dynamic_eplb
|
||||||
if ascend_config.torchair_graph_config.enabled:
|
|
||||||
self.use_aclgraph = False
|
|
||||||
else:
|
|
||||||
self.use_aclgraph = (vllm_config.compilation_config.level
|
|
||||||
== CompilationLevel.PIECEWISE and
|
|
||||||
not vllm_config.model_config.enforce_eager)
|
|
||||||
self.transpose = True
|
self.transpose = True
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
@@ -133,7 +120,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
# this is a naive implementation for experts load balance so as
|
# this is a naive implementation for experts load balance so as
|
||||||
# to avoid accumulating too much tokens on a single rank.
|
# to avoid accumulating too much tokens on a single rank.
|
||||||
# currently it is only activated when doing profile runs.
|
# currently it is only activated when doing profile runs.
|
||||||
if enable_force_load_balance and not self.use_aclgraph:
|
if enable_force_load_balance:
|
||||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
|
|||||||
Reference in New Issue
Block a user