[Perf]enable prefill flashcommon3 (#4065)
### What this PR does / why we need it?
moe multistream overlap to improve the performance.
### How was this patch tested?
--additional-config '{"multistream_overlap_gate": true}'
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.flash_common3_context import get_flash_common3_context
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
@@ -114,6 +115,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
self.use_aclgraph = (vllm_config.compilation_config.mode
|
||||
== CompilationMode.VLLM_COMPILE
|
||||
and not vllm_config.model_config.enforce_eager)
|
||||
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
|
||||
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
self.in_dtype = vllm_config.model_config.dtype
|
||||
@@ -198,18 +200,25 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
topk_weights, topk_ids = None, None
|
||||
if self.multistream_overlap_gate:
|
||||
fc3_context = get_flash_common3_context()
|
||||
assert fc3_context is not None
|
||||
topk_weights = fc3_context.topk_weights
|
||||
topk_ids = fc3_context.topk_ids
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
@@ -222,6 +231,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
topk_ids = torch.argsort(
|
||||
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
||||
|
||||
assert topk_weights is not None
|
||||
topk_weights = topk_weights.to(self.in_dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
|
||||
Reference in New Issue
Block a user