feat: remove the dependency on FusedMoE (#2153)
This commit is contained in:
@@ -31,7 +31,7 @@ import time
|
||||
import warnings
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
@@ -45,6 +45,7 @@ from packaging import version as pkg_version
|
||||
from starlette.routing import Mount
|
||||
from torch import nn
|
||||
from torch.func import functional_call
|
||||
from torch.library import Library
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from triton.runtime.cache import (
|
||||
FileCacheManager,
|
||||
@@ -930,3 +931,44 @@ def get_nvgpu_memory_capacity():
|
||||
def crash_on_warnings():
|
||||
# Crash on warning if we are running CI tests
|
||||
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||
|
||||
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return torch.cuda.get_device_name(device_id)
|
||||
|
||||
if hasattr(torch, "hip") and torch.hip.is_available():
|
||||
return torch.hip.get_device_name(device_id)
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return torch.xpu.get_device_name(device_id)
|
||||
|
||||
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
||||
return torch.hpu.get_device_name(device_id)
|
||||
|
||||
|
||||
sglang_lib = Library("sglang", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def direct_register_custom_op(
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
mutates_args: List[str],
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
):
|
||||
import torch.library
|
||||
|
||||
if hasattr(torch.library, "infer_schema"):
|
||||
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
||||
else:
|
||||
# for pytorch 2.4
|
||||
import torch._custom_op.impl
|
||||
|
||||
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||
|
||||
my_lib = target_lib or sglang_lib
|
||||
my_lib.define(op_name + schema_str)
|
||||
my_lib.impl(op_name, op_func, "CUDA")
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
|
||||
Reference in New Issue
Block a user