[7/N] MoE Refactor: the implementation of new framework (#9269)
This commit is contained in:
@@ -11,6 +11,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -19,17 +22,19 @@ from abc import ABC
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import Withable, get_bool_env_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --------------------------------------- Entrypoint -----------------------------------------
|
||||
@@ -43,7 +48,7 @@ class ExpertDistributionRecorder(ABC):
|
||||
@staticmethod
|
||||
def init_new(
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
expert_location_metadata: ExpertLocationMetadata,
|
||||
rank: int,
|
||||
):
|
||||
if server_args.expert_distribution_recorder_mode is not None:
|
||||
@@ -118,7 +123,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
expert_location_metadata: ExpertLocationMetadata,
|
||||
rank: int,
|
||||
):
|
||||
self._server_args = server_args
|
||||
@@ -279,7 +284,7 @@ class _SinglePassGatherer(ABC):
|
||||
@staticmethod
|
||||
def init_new(
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
expert_location_metadata: ExpertLocationMetadata,
|
||||
rank: int,
|
||||
) -> "_SinglePassGatherer":
|
||||
if server_args.expert_distribution_recorder_mode == "per_token":
|
||||
@@ -307,7 +312,7 @@ class _SinglePassGatherer(ABC):
|
||||
|
||||
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
||||
|
||||
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
|
||||
def __init__(self, expert_location_metadata: ExpertLocationMetadata, rank: int):
|
||||
self._expert_location_metadata = expert_location_metadata
|
||||
self._rank = rank
|
||||
|
||||
@@ -346,7 +351,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
expert_location_metadata: ExpertLocationMetadata,
|
||||
rank: int,
|
||||
):
|
||||
super().__init__(expert_location_metadata, rank)
|
||||
@@ -561,7 +566,7 @@ class _Accumulator(ABC):
|
||||
@staticmethod
|
||||
def init_new(
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
expert_location_metadata: ExpertLocationMetadata,
|
||||
rank: int,
|
||||
) -> "_Accumulator":
|
||||
return _Accumulator.get_class(server_args)(
|
||||
@@ -580,7 +585,7 @@ class _Accumulator(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
expert_location_metadata: ExpertLocationMetadata,
|
||||
rank: int,
|
||||
):
|
||||
self._server_args = server_args
|
||||
|
||||
Reference in New Issue
Block a user