Files
xc-llm-ascend/vllm_ascend/flash_common3_context.py
AlvisGong ba28d54f35 [Perf]enable prefill flashcommon3 (#4065)
### What this PR does / why we need it?
moe multistream overlap to improve the performance.

### How was this patch tested?
--additional-config '{"multistream_overlap_gate": true}'

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
2025-12-14 09:34:13 +08:00

43 lines
1.3 KiB
Python

from dataclasses import dataclass
from typing import Optional
import torch
from vllm.model_executor.layers.linear import LinearBase
@dataclass
class FlashCommon3Context:
gate: Optional[LinearBase] = None
topk_weights: Optional[torch.Tensor] = None
topk_ids: Optional[torch.Tensor] = None
row_idx: Optional[torch.Tensor] = None
shared_experts: Optional[torch.nn.Module] = None
shared_out: Optional[torch.Tensor] = None
_flash_common3_context: Optional[FlashCommon3Context] = None
def get_flash_common3_context() -> Optional[FlashCommon3Context]:
return _flash_common3_context
def set_flash_common3_context(
topk_weights: Optional[torch.Tensor] = None,
topk_ids: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.nn.Module] = None,
shared_out: Optional[torch.Tensor] = None,
):
global _flash_common3_context
if _flash_common3_context is None:
_flash_common3_context = FlashCommon3Context()
if topk_weights is not None:
_flash_common3_context.topk_weights = topk_weights
if topk_ids is not None:
_flash_common3_context.topk_ids = topk_ids
if shared_experts is not None:
_flash_common3_context.shared_experts = shared_experts
if shared_out is not None:
_flash_common3_context.shared_out = shared_out