43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from vllm.model_executor.layers.linear import LinearBase
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class FlashCommon3Context:
|
||
|
|
gate: Optional[LinearBase] = None
|
||
|
|
topk_weights: Optional[torch.Tensor] = None
|
||
|
|
topk_ids: Optional[torch.Tensor] = None
|
||
|
|
row_idx: Optional[torch.Tensor] = None
|
||
|
|
shared_experts: Optional[torch.nn.Module] = None
|
||
|
|
shared_out: Optional[torch.Tensor] = None
|
||
|
|
|
||
|
|
|
||
|
|
_flash_common3_context: Optional[FlashCommon3Context] = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_flash_common3_context() -> Optional[FlashCommon3Context]:
|
||
|
|
return _flash_common3_context
|
||
|
|
|
||
|
|
|
||
|
|
def set_flash_common3_context(
|
||
|
|
topk_weights: Optional[torch.Tensor] = None,
|
||
|
|
topk_ids: Optional[torch.Tensor] = None,
|
||
|
|
shared_experts: Optional[torch.nn.Module] = None,
|
||
|
|
shared_out: Optional[torch.Tensor] = None,
|
||
|
|
):
|
||
|
|
global _flash_common3_context
|
||
|
|
if _flash_common3_context is None:
|
||
|
|
_flash_common3_context = FlashCommon3Context()
|
||
|
|
|
||
|
|
if topk_weights is not None:
|
||
|
|
_flash_common3_context.topk_weights = topk_weights
|
||
|
|
if topk_ids is not None:
|
||
|
|
_flash_common3_context.topk_ids = topk_ids
|
||
|
|
if shared_experts is not None:
|
||
|
|
_flash_common3_context.shared_experts = shared_experts
|
||
|
|
if shared_out is not None:
|
||
|
|
_flash_common3_context.shared_out = shared_out
|