[3/N] MoE Refactor: Simplify DeepEP Output (#8421)
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -50,6 +52,13 @@ from sglang.srt.utils import (
|
|||||||
next_power_of_2,
|
next_power_of_2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
|
||||||
|
DeepEPLLOutput,
|
||||||
|
DeepEPNormalOutput,
|
||||||
|
DispatchOutput,
|
||||||
|
)
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_npu = is_npu()
|
_is_npu = is_npu()
|
||||||
_is_fp8_fnuz = is_fp8_fnuz()
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
@@ -797,6 +806,24 @@ class DeepEPMoE(EPMoE):
|
|||||||
"alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
|
"alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: move to the beginning of the file
|
||||||
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||||
|
|
||||||
|
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||||
|
group=get_tp_group().device_group,
|
||||||
|
router_topk=self.top_k,
|
||||||
|
permute_fusion=True,
|
||||||
|
num_experts=self.num_experts,
|
||||||
|
num_local_experts=self.num_local_experts,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
deepep_mode=deepep_mode,
|
||||||
|
async_finish=True, # TODO
|
||||||
|
return_recv_hook=True,
|
||||||
|
)
|
||||||
|
|
||||||
if self.deepep_mode.enable_low_latency():
|
if self.deepep_mode.enable_low_latency():
|
||||||
assert (
|
assert (
|
||||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||||
@@ -837,37 +864,128 @@ class DeepEPMoE(EPMoE):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
reorder_topk_ids: torch.Tensor,
|
|
||||||
seg_indptr: torch.Tensor,
|
|
||||||
masked_m: torch.Tensor,
|
|
||||||
expected_m: int,
|
|
||||||
num_recv_tokens_per_expert: List[int],
|
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
):
|
):
|
||||||
|
dispatch_output = self.dispatch(
|
||||||
|
hidden_states, topk_idx, topk_weights, forward_batch
|
||||||
|
)
|
||||||
|
hidden_states = self.moe_impl(dispatch_output)
|
||||||
|
hidden_states = self.combine(
|
||||||
|
hidden_states,
|
||||||
|
dispatch_output.topk_idx,
|
||||||
|
dispatch_output.topk_weights,
|
||||||
|
forward_batch,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def dispatch(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
):
|
||||||
|
return self.deepep_dispatcher.dispatch(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
topk_idx=topk_idx,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
def moe_impl(self, dispatch_output: DispatchOutput):
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||||
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
|
return self.forward_aiter(dispatch_output)
|
||||||
resolved_deepep_mode = self.deepep_mode.resolve(
|
if dispatch_output.format.is_deepep_normal():
|
||||||
forward_batch.is_extend_in_batch
|
|
||||||
)
|
|
||||||
if resolved_deepep_mode == DeepEPMode.normal:
|
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||||
return self.forward_deepgemm_contiguous(
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||||
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
return self.forward_normal(dispatch_output)
|
||||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
elif dispatch_output.format.is_deepep_ll():
|
||||||
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
return self.forward_deepgemm_masked(dispatch_output)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||||
|
|
||||||
def forward_normal(
|
def combine(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
reorder_topk_ids: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
seg_indptr: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
):
|
):
|
||||||
|
return self.deepep_dispatcher.combine(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
topk_idx=topk_idx,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_for_normal(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
):
|
||||||
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
|
deepep_permute_triton_kernel,
|
||||||
|
deepep_run_moe_deep_preprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hidden_states.shape[0] == 0:
|
||||||
|
reorder_topk_ids = torch.empty(
|
||||||
|
(0,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
seg_indptr = torch.zeros(
|
||||||
|
(self.num_experts + 1,),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
return reorder_topk_ids, seg_indptr, hidden_states
|
||||||
|
else:
|
||||||
|
if _use_aiter:
|
||||||
|
# skip permutation here as aiter fused_moe has fused inside
|
||||||
|
reorder_topk_ids = torch.empty(
|
||||||
|
(0,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
seg_indptr = torch.zeros(
|
||||||
|
(self.num_experts + 1,),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
return reorder_topk_ids, seg_indptr, hidden_states
|
||||||
|
|
||||||
|
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
||||||
|
topk_idx, self.num_experts
|
||||||
|
)
|
||||||
|
num_total_tokens = reorder_topk_ids.numel()
|
||||||
|
gateup_input = torch.empty(
|
||||||
|
(int(num_total_tokens), hidden_states.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
# PreReorder
|
||||||
|
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
||||||
|
hidden_states,
|
||||||
|
gateup_input,
|
||||||
|
self.src2dst,
|
||||||
|
topk_idx,
|
||||||
|
None,
|
||||||
|
self.router_topk,
|
||||||
|
hidden_states.shape[1],
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
return reorder_topk_ids, seg_indptr, gateup_input
|
||||||
|
|
||||||
|
def forward_normal(
|
||||||
|
self,
|
||||||
|
dispatch_output: DeepEPNormalOutput,
|
||||||
|
):
|
||||||
|
hidden_states, topk_idx = (
|
||||||
|
dispatch_output.hidden_states,
|
||||||
|
dispatch_output.topk_idx,
|
||||||
|
)
|
||||||
|
reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
|
||||||
|
hidden_states, topk_idx
|
||||||
|
)
|
||||||
hidden_states_dtype = hidden_states.dtype
|
hidden_states_dtype = hidden_states.dtype
|
||||||
hidden_states_device = hidden_states.device
|
hidden_states_device = hidden_states.device
|
||||||
|
|
||||||
@@ -983,10 +1101,13 @@ class DeepEPMoE(EPMoE):
|
|||||||
|
|
||||||
def forward_aiter(
|
def forward_aiter(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
dispatch_output: DeepEPNormalOutput,
|
||||||
topk_idx: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
):
|
):
|
||||||
|
hidden_states, topk_idx, topk_weights = (
|
||||||
|
dispatch_output.hidden_states,
|
||||||
|
dispatch_output.topk_idx,
|
||||||
|
dispatch_output.topk_weights,
|
||||||
|
)
|
||||||
if hidden_states.shape[0] == 0:
|
if hidden_states.shape[0] == 0:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
||||||
@@ -1014,11 +1135,11 @@ class DeepEPMoE(EPMoE):
|
|||||||
|
|
||||||
def forward_deepgemm_contiguous(
|
def forward_deepgemm_contiguous(
|
||||||
self,
|
self,
|
||||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
dispatch_output: DeepEPNormalOutput,
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
num_recv_tokens_per_expert: List[int],
|
|
||||||
):
|
):
|
||||||
|
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
||||||
|
dispatch_output
|
||||||
|
)
|
||||||
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.activation == "silu"
|
assert self.activation == "silu"
|
||||||
@@ -1138,10 +1259,9 @@ class DeepEPMoE(EPMoE):
|
|||||||
|
|
||||||
def forward_deepgemm_masked(
|
def forward_deepgemm_masked(
|
||||||
self,
|
self,
|
||||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
dispatch_output: DeepEPLLOutput,
|
||||||
masked_m: torch.Tensor,
|
|
||||||
expected_m: int,
|
|
||||||
):
|
):
|
||||||
|
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.activation == "silu"
|
assert self.activation == "silu"
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,27 @@
|
|||||||
|
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
List,
|
||||||
|
NamedTuple,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
runtime_checkable,
|
||||||
|
)
|
||||||
|
|
||||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||||
|
BaseDispatcher,
|
||||||
|
BaseDispatcherConfig,
|
||||||
|
DispatchOutput,
|
||||||
|
DispatchOutputFormat,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -24,7 +44,6 @@ except ImportError:
|
|||||||
use_deepep = False
|
use_deepep = False
|
||||||
|
|
||||||
from enum import Enum, IntEnum, auto
|
from enum import Enum, IntEnum, auto
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -41,6 +60,37 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepEPNormalOutput(NamedTuple):
|
||||||
|
"""DeepEP normal dispatch output."""
|
||||||
|
|
||||||
|
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
topk_idx: torch.Tensor
|
||||||
|
topk_weights: torch.Tensor
|
||||||
|
num_recv_tokens_per_expert: List[int]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format(self) -> DispatchOutputFormat:
|
||||||
|
return DispatchOutputFormat.deepep_normal
|
||||||
|
|
||||||
|
|
||||||
|
class DeepEPLLOutput(NamedTuple):
|
||||||
|
"""DeepEP low latency dispatch output."""
|
||||||
|
|
||||||
|
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
topk_idx: torch.Tensor
|
||||||
|
topk_weights: torch.Tensor
|
||||||
|
masked_m: torch.Tensor
|
||||||
|
expected_m: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format(self) -> DispatchOutputFormat:
|
||||||
|
return DispatchOutputFormat.deepep_ll
|
||||||
|
|
||||||
|
|
||||||
|
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
||||||
|
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
||||||
|
|
||||||
|
|
||||||
class DeepEPDispatchMode(IntEnum):
|
class DeepEPDispatchMode(IntEnum):
|
||||||
NORMAL = auto()
|
NORMAL = auto()
|
||||||
LOW_LATENCY = auto()
|
LOW_LATENCY = auto()
|
||||||
@@ -139,7 +189,7 @@ class DeepEPBuffer:
|
|||||||
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
||||||
|
|
||||||
|
|
||||||
class DeepEPConfig:
|
class DeepEPConfig(BaseDispatcherConfig):
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -255,63 +305,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
return hidden_states, topk_idx, topk_weights, previous_event
|
return hidden_states, topk_idx, topk_weights, previous_event
|
||||||
|
|
||||||
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
(
|
||||||
(
|
hidden_states,
|
||||||
hidden_states,
|
topk_idx,
|
||||||
topk_idx,
|
topk_weights,
|
||||||
topk_weights,
|
num_recv_tokens_per_expert,
|
||||||
num_recv_tokens_per_expert_list,
|
event,
|
||||||
event,
|
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
|
||||||
) = self._dispatch_core(
|
event.current_stream_wait() if self.async_finish else ()
|
||||||
hidden_states, topk_idx, topk_weights, previous_event
|
return DeepEPNormalOutput(
|
||||||
)
|
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
||||||
event.current_stream_wait() if self.async_finish else ()
|
)
|
||||||
return (
|
|
||||||
hidden_states,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
None,
|
|
||||||
num_recv_tokens_per_expert_list,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
hidden_states,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
num_recv_tokens_per_expert_list,
|
|
||||||
event,
|
|
||||||
) = self._dispatch_core(
|
|
||||||
hidden_states, topk_idx, topk_weights, previous_event
|
|
||||||
)
|
|
||||||
event.current_stream_wait() if self.async_finish else ()
|
|
||||||
if hidden_states.shape[0] > 0:
|
|
||||||
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
|
||||||
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
reorder_topk_ids = torch.empty(
|
|
||||||
(0,), device=hidden_states.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
seg_indptr = torch.zeros(
|
|
||||||
(self.num_experts + 1,),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
|
|
||||||
masked_m = expected_m = None
|
|
||||||
return (
|
|
||||||
hidden_states,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
reorder_topk_ids,
|
|
||||||
None,
|
|
||||||
seg_indptr,
|
|
||||||
masked_m,
|
|
||||||
expected_m,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _dispatch_core(
|
def _dispatch_core(
|
||||||
self,
|
self,
|
||||||
@@ -343,7 +347,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
recv_x,
|
recv_x,
|
||||||
recv_topk_idx,
|
recv_topk_idx,
|
||||||
recv_topk_weights,
|
recv_topk_weights,
|
||||||
num_recv_tokens_per_expert_list,
|
num_recv_tokens_per_expert,
|
||||||
self.handle,
|
self.handle,
|
||||||
event,
|
event,
|
||||||
) = buffer.dispatch(
|
) = buffer.dispatch(
|
||||||
@@ -362,7 +366,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
||||||
num_recv_tokens_per_expert_list,
|
num_recv_tokens_per_expert,
|
||||||
num_tokens_per_rank=num_tokens_per_rank,
|
num_tokens_per_rank=num_tokens_per_rank,
|
||||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||||
num_tokens_per_expert=num_tokens_per_expert,
|
num_tokens_per_expert=num_tokens_per_expert,
|
||||||
@@ -372,58 +376,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
recv_x,
|
recv_x,
|
||||||
recv_topk_idx,
|
recv_topk_idx,
|
||||||
recv_topk_weights,
|
recv_topk_weights,
|
||||||
num_recv_tokens_per_expert_list,
|
num_recv_tokens_per_expert,
|
||||||
event,
|
event,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _deepep_permute(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
topk_idx: torch.Tensor,
|
|
||||||
fp8_dtype: Optional[torch.dtype] = None,
|
|
||||||
use_fp8_w8a8: bool = False,
|
|
||||||
use_block_quant: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
|
||||||
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
|
||||||
"""
|
|
||||||
if _use_aiter:
|
|
||||||
# skip permutation here as aiter fused_moe has fused inside
|
|
||||||
reorder_topk_ids = torch.empty(
|
|
||||||
(0,), device=hidden_states.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
seg_indptr = torch.zeros(
|
|
||||||
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
return reorder_topk_ids, seg_indptr, hidden_states
|
|
||||||
|
|
||||||
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
|
||||||
topk_idx, self.num_experts
|
|
||||||
)
|
|
||||||
num_total_tokens = reorder_topk_ids.numel()
|
|
||||||
gateup_input = torch.empty(
|
|
||||||
(int(num_total_tokens), hidden_states.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=(
|
|
||||||
fp8_dtype
|
|
||||||
if (use_fp8_w8a8 and not use_block_quant)
|
|
||||||
else hidden_states.dtype
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# PreReorder
|
|
||||||
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
|
||||||
hidden_states,
|
|
||||||
gateup_input,
|
|
||||||
self.src2dst,
|
|
||||||
topk_idx,
|
|
||||||
None,
|
|
||||||
self.router_topk,
|
|
||||||
hidden_states.shape[1],
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
return reorder_topk_ids, seg_indptr, gateup_input
|
|
||||||
|
|
||||||
def combine_a(
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -544,15 +500,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
masked_m
|
masked_m
|
||||||
)
|
)
|
||||||
|
|
||||||
reorder_topk_ids = seg_indptr = None
|
return DeepEPLLOutput(
|
||||||
|
|
||||||
return (
|
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
reorder_topk_ids,
|
|
||||||
None,
|
|
||||||
seg_indptr,
|
|
||||||
masked_m,
|
masked_m,
|
||||||
expected_m,
|
expected_m,
|
||||||
)
|
)
|
||||||
@@ -636,7 +587,7 @@ class _Stage(Enum):
|
|||||||
AFTER_COMBINE_A = auto()
|
AFTER_COMBINE_A = auto()
|
||||||
|
|
||||||
|
|
||||||
class DeepEPDispatcher:
|
class DeepEPDispatcher(BaseDispatcher):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
group: torch.distributed.ProcessGroup,
|
group: torch.distributed.ProcessGroup,
|
||||||
@@ -676,7 +627,7 @@ class DeepEPDispatcher:
|
|||||||
|
|
||||||
self._stage = _Stage.INITIAL
|
self._stage = _Stage.INITIAL
|
||||||
|
|
||||||
def dispatch(self, *args, **kwargs) -> Tuple:
|
def dispatch(self, *args, **kwargs) -> DispatchOutput:
|
||||||
self.dispatch_a(*args, **kwargs)
|
self.dispatch_a(*args, **kwargs)
|
||||||
ret = self.dispatch_b()
|
ret = self.dispatch_b()
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class DispatchOutputFormat(Enum):
|
||||||
|
standard = auto()
|
||||||
|
deepep_normal = auto()
|
||||||
|
deepep_ll = auto()
|
||||||
|
|
||||||
|
def is_standard(self) -> bool:
|
||||||
|
return self == DispatchOutputFormat.standard
|
||||||
|
|
||||||
|
def is_deepep_normal(self) -> bool:
|
||||||
|
return self == DispatchOutputFormat.deepep_normal
|
||||||
|
|
||||||
|
def is_deepep_ll(self) -> bool:
|
||||||
|
return self == DispatchOutputFormat.deepep_ll
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class DispatchOutput(Protocol):
|
||||||
|
"""Protocol for dispatch outputs in different formats."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format(self) -> DispatchOutputFormat: ...
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDispatcherConfig(ABC):
|
||||||
|
"""Base class for dispatcher configs."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDispatcher(ABC):
|
||||||
|
"""Base class for dispatchers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def dispatch(self, *args, **kwargs) -> DispatchOutput:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def combine(self, *args, **kwargs) -> torch.Tensor:
|
||||||
|
pass
|
||||||
19
python/sglang/srt/layers/moe/token_dispatcher/standard.py
Normal file
19
python/sglang/srt/layers/moe/token_dispatcher/standard.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||||
|
DispatchOutput,
|
||||||
|
DispatchOutputFormat,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StandardDispatchOutput(NamedTuple):
|
||||||
|
"""Standard dispatch output."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format(self) -> DispatchOutputFormat:
|
||||||
|
return DispatchOutputFormat.standard
|
||||||
|
|
||||||
|
|
||||||
|
assert isinstance(StandardDispatchOutput, DispatchOutput)
|
||||||
@@ -594,41 +594,13 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
topk_weights = torch.empty(
|
topk_weights = torch.empty(
|
||||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||||
)
|
)
|
||||||
if self.ep_size > 1:
|
|
||||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
|
||||||
(
|
|
||||||
hidden_states,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
reorder_topk_ids,
|
|
||||||
num_recv_tokens_per_expert,
|
|
||||||
seg_indptr,
|
|
||||||
masked_m,
|
|
||||||
expected_m,
|
|
||||||
) = self.deepep_dispatcher.dispatch(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
topk_idx=topk_idx,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
)
|
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
reorder_topk_ids=reorder_topk_ids,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
masked_m=masked_m,
|
|
||||||
expected_m=expected_m,
|
|
||||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
if self.ep_size > 1:
|
|
||||||
final_hidden_states = self.deepep_dispatcher.combine(
|
|
||||||
hidden_states=final_hidden_states,
|
|
||||||
topk_idx=topk_idx,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
x = shared_output
|
x = shared_output
|
||||||
@@ -689,8 +661,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
def op_dispatch_a(self, state):
|
def op_dispatch_a(self, state):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
self.experts.deepep_dispatcher.dispatch_a(
|
||||||
self.deepep_dispatcher.dispatch_a(
|
|
||||||
hidden_states=state.hidden_states_mlp_input,
|
hidden_states=state.hidden_states_mlp_input,
|
||||||
topk_idx=state.pop("topk_idx_local"),
|
topk_idx=state.pop("topk_idx_local"),
|
||||||
topk_weights=state.pop("topk_weights_local"),
|
topk_weights=state.pop("topk_weights_local"),
|
||||||
@@ -703,46 +674,32 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
with get_global_expert_distribution_recorder().with_current_layer(
|
with get_global_expert_distribution_recorder().with_current_layer(
|
||||||
self.layer_id
|
self.layer_id
|
||||||
):
|
):
|
||||||
(
|
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
|
||||||
state.hidden_states_experts_input,
|
|
||||||
state.topk_idx_dispatched,
|
|
||||||
state.topk_weights_dispatched,
|
|
||||||
state.reorder_topk_ids,
|
|
||||||
state.num_recv_tokens_per_expert,
|
|
||||||
state.seg_indptr,
|
|
||||||
state.masked_m,
|
|
||||||
state.expected_m,
|
|
||||||
) = self.deepep_dispatcher.dispatch_b(
|
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_experts(self, state):
|
def op_experts(self, state):
|
||||||
state.hidden_states_experts_output = self.experts(
|
state.hidden_states_experts_output = self.experts.moe_impl(
|
||||||
hidden_states=state.pop("hidden_states_experts_input"),
|
dispatch_output=state.dispatch_output,
|
||||||
topk_idx=state.topk_idx_dispatched,
|
|
||||||
topk_weights=state.topk_weights_dispatched,
|
|
||||||
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
|
||||||
seg_indptr=state.pop("seg_indptr"),
|
|
||||||
masked_m=state.pop("masked_m"),
|
|
||||||
expected_m=state.pop("expected_m"),
|
|
||||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
|
||||||
forward_batch=state.forward_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_combine_a(self, state):
|
def op_combine_a(self, state):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
self.deepep_dispatcher.combine_a(
|
self.experts.deepep_dispatcher.combine_a(
|
||||||
hidden_states=state.pop("hidden_states_experts_output"),
|
hidden_states=state.pop("hidden_states_experts_output"),
|
||||||
topk_idx=state.pop("topk_idx_dispatched"),
|
topk_idx=state.dispatch_output.topk_idx,
|
||||||
topk_weights=state.pop("topk_weights_dispatched"),
|
topk_weights=state.dispatch_output.topk_weights,
|
||||||
forward_batch=state.forward_batch,
|
forward_batch=state.forward_batch,
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
|
state.pop("dispatch_output")
|
||||||
|
|
||||||
def op_combine_b(self, state):
|
def op_combine_b(self, state):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
|
state.hidden_states_after_combine = (
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
self.experts.deepep_dispatcher.combine_b(
|
||||||
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_output(self, state):
|
def op_output(self, state):
|
||||||
|
|||||||
@@ -144,19 +144,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.top_k = config.num_experts_per_tok
|
self.top_k = config.num_experts_per_tok
|
||||||
|
|
||||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
|
||||||
group=parallel_state.get_tp_group().device_group,
|
|
||||||
router_topk=self.top_k,
|
|
||||||
permute_fusion=True,
|
|
||||||
num_experts=self.num_experts,
|
|
||||||
num_local_experts=config.num_experts // self.tp_size,
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
params_dtype=config.torch_dtype,
|
|
||||||
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
|
||||||
async_finish=True, # TODO
|
|
||||||
return_recv_hook=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -207,41 +194,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
topk_weights = torch.empty(
|
topk_weights = torch.empty(
|
||||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||||
)
|
)
|
||||||
if self.ep_size > 1:
|
|
||||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
|
||||||
(
|
|
||||||
hidden_states,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
reorder_topk_ids,
|
|
||||||
num_recv_tokens_per_expert,
|
|
||||||
seg_indptr,
|
|
||||||
masked_m,
|
|
||||||
expected_m,
|
|
||||||
) = self.deepep_dispatcher.dispatch(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
topk_idx=topk_idx,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
)
|
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
reorder_topk_ids=reorder_topk_ids,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
masked_m=masked_m,
|
|
||||||
expected_m=expected_m,
|
|
||||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
if self.ep_size > 1:
|
|
||||||
final_hidden_states = self.deepep_dispatcher.combine(
|
|
||||||
hidden_states=final_hidden_states,
|
|
||||||
topk_idx=topk_idx,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
)
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
def op_gate(self, state):
|
def op_gate(self, state):
|
||||||
@@ -278,8 +236,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
|
|
||||||
def op_dispatch_a(self, state):
|
def op_dispatch_a(self, state):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
self.experts.deepep_dispatcher.dispatch_a(
|
||||||
self.deepep_dispatcher.dispatch_a(
|
|
||||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||||
topk_idx=state.pop("topk_idx_local"),
|
topk_idx=state.pop("topk_idx_local"),
|
||||||
topk_weights=state.pop("topk_weights_local"),
|
topk_weights=state.pop("topk_weights_local"),
|
||||||
@@ -292,46 +249,32 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
with get_global_expert_distribution_recorder().with_current_layer(
|
with get_global_expert_distribution_recorder().with_current_layer(
|
||||||
self.layer_id
|
self.layer_id
|
||||||
):
|
):
|
||||||
(
|
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
|
||||||
state.hidden_states_experts_input,
|
|
||||||
state.topk_idx_dispatched,
|
|
||||||
state.topk_weights_dispatched,
|
|
||||||
state.reorder_topk_ids,
|
|
||||||
state.num_recv_tokens_per_expert,
|
|
||||||
state.seg_indptr,
|
|
||||||
state.masked_m,
|
|
||||||
state.expected_m,
|
|
||||||
) = self.deepep_dispatcher.dispatch_b(
|
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_experts(self, state):
|
def op_experts(self, state):
|
||||||
state.hidden_states_experts_output = self.experts(
|
state.hidden_states_experts_output = self.experts.moe_impl(
|
||||||
hidden_states=state.pop("hidden_states_experts_input"),
|
dispatch_output=state.dispatch_output,
|
||||||
topk_idx=state.topk_idx_dispatched,
|
|
||||||
topk_weights=state.topk_weights_dispatched,
|
|
||||||
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
|
||||||
seg_indptr=state.pop("seg_indptr"),
|
|
||||||
masked_m=state.pop("masked_m"),
|
|
||||||
expected_m=state.pop("expected_m"),
|
|
||||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
|
||||||
forward_batch=state.forward_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_combine_a(self, state):
|
def op_combine_a(self, state):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
self.deepep_dispatcher.combine_a(
|
self.experts.deepep_dispatcher.combine_a(
|
||||||
hidden_states=state.pop("hidden_states_experts_output"),
|
hidden_states=state.pop("hidden_states_experts_output"),
|
||||||
topk_idx=state.pop("topk_idx_dispatched"),
|
topk_idx=state.dispatch_output.topk_idx,
|
||||||
topk_weights=state.pop("topk_weights_dispatched"),
|
topk_weights=state.dispatch_output.topk_weights,
|
||||||
forward_batch=state.forward_batch,
|
forward_batch=state.forward_batch,
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
|
state.pop("dispatch_output")
|
||||||
|
|
||||||
def op_combine_b(self, state):
|
def op_combine_b(self, state):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
|
state.hidden_states_after_combine = (
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
self.experts.deepep_dispatcher.combine_b(
|
||||||
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_output(self, state):
|
def op_output(self, state):
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from typing import Dict, List, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -20,6 +22,9 @@ from sglang.srt.operations_strategy import OperationsStrategy
|
|||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
|
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DispatchOutput
|
||||||
|
|
||||||
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -802,7 +807,7 @@ class MaybeTboDeepEPDispatcher:
|
|||||||
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
|
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
|
||||||
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
|
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
|
||||||
|
|
||||||
def dispatch(self, **kwargs):
|
def dispatch(self, **kwargs) -> DispatchOutput:
|
||||||
return self._execute("dispatch", **kwargs)
|
return self._execute("dispatch", **kwargs)
|
||||||
|
|
||||||
def dispatch_a(self, **kwargs):
|
def dispatch_a(self, **kwargs):
|
||||||
@@ -811,7 +816,7 @@ class MaybeTboDeepEPDispatcher:
|
|||||||
def dispatch_b(self, **kwargs):
|
def dispatch_b(self, **kwargs):
|
||||||
return self._execute("dispatch_b", **kwargs)
|
return self._execute("dispatch_b", **kwargs)
|
||||||
|
|
||||||
def combine(self, **kwargs):
|
def combine(self, **kwargs) -> torch.Tensor:
|
||||||
return self._execute("combine", **kwargs)
|
return self._execute("combine", **kwargs)
|
||||||
|
|
||||||
def combine_a(self, **kwargs):
|
def combine_a(self, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user