[7/N] MoE Refactor: the implementation of new framework (#9269)

This commit is contained in:
Cheng Wan
2025-09-05 21:09:09 -07:00
committed by GitHub
parent dbb1235d58
commit 3fa62da78c
34 changed files with 1727 additions and 432 deletions

View File

@@ -11,6 +11,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
import logging
import math
import os
@@ -19,17 +22,19 @@ from abc import ABC
from collections import deque
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
import einops
import torch
import torch.distributed
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var
if TYPE_CHECKING:
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
logger = logging.getLogger(__name__)
# --------------------------------------- Entrypoint -----------------------------------------
@@ -43,7 +48,7 @@ class ExpertDistributionRecorder(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
expert_location_metadata: ExpertLocationMetadata,
rank: int,
):
if server_args.expert_distribution_recorder_mode is not None:
@@ -118,7 +123,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
expert_location_metadata: ExpertLocationMetadata,
rank: int,
):
self._server_args = server_args
@@ -279,7 +284,7 @@ class _SinglePassGatherer(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
expert_location_metadata: ExpertLocationMetadata,
rank: int,
) -> "_SinglePassGatherer":
if server_args.expert_distribution_recorder_mode == "per_token":
@@ -307,7 +312,7 @@ class _SinglePassGatherer(ABC):
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
def __init__(self, expert_location_metadata: ExpertLocationMetadata, rank: int):
self._expert_location_metadata = expert_location_metadata
self._rank = rank
@@ -346,7 +351,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
expert_location_metadata: ExpertLocationMetadata,
rank: int,
):
super().__init__(expert_location_metadata, rank)
@@ -561,7 +566,7 @@ class _Accumulator(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
expert_location_metadata: ExpertLocationMetadata,
rank: int,
) -> "_Accumulator":
return _Accumulator.get_class(server_args)(
@@ -580,7 +585,7 @@ class _Accumulator(ABC):
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
expert_location_metadata: ExpertLocationMetadata,
rank: int,
):
self._server_args = server_args