feat: remove the dependency on FusedMoE (#2153)

This commit is contained in:
Yineng Zhang
2024-11-24 20:09:27 +08:00
committed by GitHub
parent dbe1729395
commit b509db5832
7 changed files with 1602 additions and 7 deletions

View File

@@ -31,7 +31,7 @@ import time
import warnings
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
import numpy as np
import psutil
@@ -45,6 +45,7 @@ from packaging import version as pkg_version
from starlette.routing import Mount
from torch import nn
from torch.func import functional_call
from torch.library import Library
from torch.profiler import ProfilerActivity, profile, record_function
from triton.runtime.cache import (
FileCacheManager,
@@ -930,3 +931,44 @@ def get_nvgpu_memory_capacity():
def crash_on_warnings():
# Crash on warning if we are running CI tests
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
def get_device_name(device_id: int = 0) -> str:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return torch.cuda.get_device_name(device_id)
if hasattr(torch, "hip") and torch.hip.is_available():
return torch.hip.get_device_name(device_id)
if hasattr(torch, "xpu") and torch.xpu.is_available():
return torch.xpu.get_device_name(device_id)
if hasattr(torch, "hpu") and torch.hpu.is_available():
return torch.hpu.get_device_name(device_id)
sglang_lib = Library("sglang", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: List[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
):
import torch.library
if hasattr(torch.library, "infer_schema"):
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
else:
# for pytorch 2.4
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or sglang_lib
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)