[Perf]enable prefill flashcommon3 (#4065)
### What this PR does / why we need it?
moe multistream overlap to improve the performance.
### How was this patch tested?
--additional-config '{"multistream_overlap_gate": true}'
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -29,7 +29,10 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.utils import enable_sp, prefill_context_parallel_enable
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||
from vllm_ascend.utils import (enable_sp, npu_stream_switch,
|
||||
prefill_context_parallel_enable)
|
||||
|
||||
|
||||
class QuantType(Enum):
|
||||
@@ -49,9 +52,14 @@ class PrepareAndFinalize(ABC):
|
||||
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
|
||||
sizes, ranks, and communication settings.
|
||||
"""
|
||||
quant_stream: Optional[torch.npu.Stream] = None
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
ascend_config = get_ascend_config()
|
||||
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
|
||||
if self.multistream_overlap_gate and PrepareAndFinalize.quant_stream is None:
|
||||
PrepareAndFinalize.quant_stream = torch.npu.Stream()
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
@@ -335,12 +343,28 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
if quant_type == QuantType.W8A8:
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
if self.multistream_overlap_gate:
|
||||
assert PrepareAndFinalize.quant_stream is not None
|
||||
PrepareAndFinalize.quant_stream.wait_stream(
|
||||
torch.npu.current_stream())
|
||||
with npu_stream_switch(PrepareAndFinalize.quant_stream,
|
||||
enabled=self.multistream_overlap_gate):
|
||||
hidden_states = fc3_all_gather_and_maybe_unpad_impl(
|
||||
hidden_states)
|
||||
else:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True, True)
|
||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
router_logits, True, True)
|
||||
|
||||
if pertoken_scale is not None:
|
||||
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
pertoken_scale, True, True)
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True, True)
|
||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
router_logits, True, True)
|
||||
|
||||
if self.multistream_overlap_gate:
|
||||
torch.npu.current_stream().wait_stream(
|
||||
PrepareAndFinalize.quant_stream)
|
||||
|
||||
if pertoken_scale is not None:
|
||||
return (hidden_states, pertoken_scale), router_logits, None, None
|
||||
|
||||
Reference in New Issue
Block a user