[Lint] Add python/sglang to ruff F401 checks and remove unused imports in files (#11685)
This commit is contained in:
@@ -27,9 +27,9 @@ repos:
|
|||||||
rev: v0.11.7
|
rev: v0.11.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--select=F401, --fixable=F401]
|
args: [--select=F401,F821, --fixable=F401]
|
||||||
files: ^(benchmark/|docs/|examples/)
|
files: ^(benchmark/|docs/|examples/|python/sglang/)
|
||||||
exclude: \.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$
|
exclude: __init__\.py$|\.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 24.10.0
|
rev: 24.10.0
|
||||||
hooks:
|
hooks:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ if not is_hpu():
|
|||||||
# ROCm does not use vllm custom allreduce
|
# ROCm does not use vllm custom allreduce
|
||||||
if use_vllm_custom_allreduce and not is_hip():
|
if use_vllm_custom_allreduce and not is_hip():
|
||||||
try:
|
try:
|
||||||
import vllm._C
|
import vllm._C # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Failed to import from vllm._C with %r", e)
|
logger.warning("Failed to import from vllm._C with %r", e)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from unittest.mock import patch
|
|||||||
import torch
|
import torch
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
|
|
||||||
import sglang.srt.compilation.weak_ref_tensor_jit
|
|
||||||
from sglang.srt.compilation.compilation_config import CompilationConfig
|
from sglang.srt.compilation.compilation_config import CompilationConfig
|
||||||
from sglang.srt.compilation.compilation_counter import compilation_counter
|
from sglang.srt.compilation.compilation_counter import compilation_counter
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,5 @@
|
|||||||
from typing import Any, List, Optional, Union
|
from transformers import AutoProcessor, PretrainedConfig
|
||||||
|
from transformers.processing_utils import ProcessingKwargs
|
||||||
from transformers import AutoProcessor, LlamaTokenizerFast, PretrainedConfig
|
|
||||||
from transformers.feature_extraction_utils import BatchFeature
|
|
||||||
from transformers.image_utils import ImageInput
|
|
||||||
from transformers.processing_utils import ProcessingKwargs, Unpack
|
|
||||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import Qwen2_5_VLProcessor
|
from transformers import Qwen2_5_VLProcessor
|
||||||
|
|||||||
@@ -14,17 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Falcon-H1 model configuration"""
|
"""Falcon-H1 model configuration"""
|
||||||
|
|
||||||
import enum
|
|
||||||
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_rope_utils import rope_config_validation
|
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import get_tensor_model_parallel_world_size
|
||||||
get_attention_tp_size,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from transformers.modeling_rope_utils import rope_config_validation
|
|||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
||||||
from sglang.srt.distributed.utils import divide
|
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Generator, List, Optional, Tuple
|
from typing import Generator, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import time
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
@@ -48,10 +48,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
|
||||||
from sglang.srt.mem_cache.allocator import (
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
BaseTokenToKVPoolAllocator,
|
|
||||||
SWATokenToKVPoolAllocator,
|
|
||||||
)
|
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
HybridLinearKVPool,
|
HybridLinearKVPool,
|
||||||
@@ -61,7 +58,6 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
ReqToTokenPool,
|
ReqToTokenPool,
|
||||||
SWAKVPool,
|
SWAKVPool,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
||||||
from sglang.srt.utils import get_int_env_var, require_mlp_sync
|
from sglang.srt.utils import get_int_env_var, require_mlp_sync
|
||||||
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ Life cycle of a request in the prefill server
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@@ -54,7 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
NSATokenToKVPool,
|
NSATokenToKVPool,
|
||||||
SWAKVPool,
|
SWAKVPool,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
DynamicGradMode,
|
DynamicGradMode,
|
||||||
broadcast_pyobj,
|
broadcast_pyobj,
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ try:
|
|||||||
ops.meta_size()
|
ops.meta_size()
|
||||||
else:
|
else:
|
||||||
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
|
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
|
||||||
import sgl_kernel
|
import sgl_kernel # noqa: F401
|
||||||
custom_ar = True
|
custom_ar = True
|
||||||
except Exception:
|
except Exception:
|
||||||
# For CPUs
|
# For CPUs
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Any, Callable, List, Optional, TypeVar, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -24,7 +24,7 @@ if _is_hip:
|
|||||||
mscclpp_is_available = False
|
mscclpp_is_available = False
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
try:
|
try:
|
||||||
import sgl_kernel
|
import sgl_kernel # noqa: F401
|
||||||
|
|
||||||
mscclpp_is_available = True
|
mscclpp_is_available = True
|
||||||
except:
|
except:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from torch.distributed import ProcessGroup
|
|||||||
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
|
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
|
||||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import get_device_capability, is_cuda, is_hip
|
from sglang.srt.utils import is_cuda, is_hip
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch.distributed._symmetric_memory as torch_symm_mem
|
import torch.distributed._symmetric_memory as torch_symm_mem
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import base64
|
import base64
|
||||||
import os
|
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# Copied from vLLM
|
# Copied from vLLM
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
# Adapted from vLLM: https://github.com/vllm-project/vllm/blob/1b9902806915040ac9b3029f2ab7522ec505afc3/vllm/entrypoints/harmony_utils.py
|
# Adapted from vLLM: https://github.com/vllm-project/vllm/blob/1b9902806915040ac9b3029f2ab7522ec505afc3/vllm/entrypoints/harmony_utils.py
|
||||||
# Slight differences in processing chat messages
|
# Slight differences in processing chat messages
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ This file implements HTTP APIs for the inference engine via fastapi.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as multiprocessing
|
import multiprocessing as multiprocessing
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,15 +1,9 @@
|
|||||||
import copy
|
|
||||||
import dataclasses
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import pickle
|
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import pybase64
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
from sglang.srt.entrypoints.EngineBase import EngineBase
|
from sglang.srt.entrypoints.EngineBase import EngineBase
|
||||||
from sglang.srt.entrypoints.http_server import launch_server
|
from sglang.srt.entrypoints.http_server import launch_server
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ from typing import Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.utils import get_bool_env_var
|
|
||||||
|
|
||||||
|
|
||||||
def balanced_packing(
|
def balanced_packing(
|
||||||
weight: torch.Tensor, num_packs: int
|
weight: torch.Tensor, num_packs: int
|
||||||
|
|||||||
@@ -6,11 +6,7 @@ from typing import List
|
|||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
from sglang.srt.function_call.core_types import (
|
from sglang.srt.function_call.core_types import StreamingParseResult, _GetInfoFunc
|
||||||
StreamingParseResult,
|
|
||||||
StructureInfo,
|
|
||||||
_GetInfoFunc,
|
|
||||||
)
|
|
||||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import json
|
|
||||||
import re
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
from json import JSONDecodeError, JSONDecoder
|
from json import JSONDecodeError, JSONDecoder
|
||||||
from json.decoder import WHITESPACE
|
from json.decoder import WHITESPACE
|
||||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def compile_proto(proto_file: Path, output_dir: Path, verbose: bool = True) -> b
|
|||||||
|
|
||||||
# Check if grpc_tools is available
|
# Check if grpc_tools is available
|
||||||
try:
|
try:
|
||||||
import grpc_tools.protoc
|
import grpc_tools.protoc # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Error: grpcio-tools not installed")
|
print("Error: grpcio-tools not installed")
|
||||||
print(
|
print(
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from sglang.srt.managers.io_struct import (
|
|||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|||||||
@@ -380,4 +380,7 @@ if not (
|
|||||||
logger.info(
|
logger.info(
|
||||||
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
from vllm.model_executor.layers.activation import ( # noqa: F401
|
||||||
|
GeluAndMul,
|
||||||
|
SiluAndMul,
|
||||||
|
)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
|||||||
@@ -3,9 +3,7 @@
|
|||||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from sglang.srt.layers.attention.fla.utils import tensor_cache
|
from sglang.srt.layers.attention.fla.utils import tensor_cache
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
||||||
from sglang.srt.layers.attention.fla.op import safe_exp
|
|
||||||
from sglang.srt.layers.attention.fla.utils import check_shared_mem
|
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ if is_flashinfer_available():
|
|||||||
fast_decode_plan,
|
fast_decode_plan,
|
||||||
)
|
)
|
||||||
from flashinfer.cascade import merge_state
|
from flashinfer.cascade import merge_state
|
||||||
from flashinfer.decode import _get_range_buf, get_seq_lens
|
|
||||||
|
|
||||||
|
|
||||||
class WrapperDispatch(Enum):
|
class WrapperDispatch(Enum):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
from dataclasses import astuple, dataclass
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
|
from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class IntelAMXAttnBackend(AttentionBackend):
|
class IntelAMXAttnBackend(AttentionBackend):
|
||||||
def __init__(self, model_runner: ModelRunner):
|
def __init__(self, model_runner: ModelRunner):
|
||||||
import sgl_kernel
|
import sgl_kernel # noqa: F401
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.forward_metadata = None
|
self.forward_metadata = None
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|||||||
@@ -10,7 +10,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ def is_mla_preprocess_enabled() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
if is_mla_preprocess_enabled():
|
if is_mla_preprocess_enabled():
|
||||||
import sgl_kernel_npu
|
import sgl_kernel_npu # noqa: F401
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
torch.npu.config.allow_internal_format = True
|
torch.npu.config.allow_internal_format = True
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -547,7 +547,7 @@ class Indexer(CustomOp):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
import custom_ops
|
import custom_ops # noqa: F401
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
|
||||||
|
|
||||||
@@ -34,18 +33,18 @@ _is_hip = is_hip()
|
|||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
try:
|
try:
|
||||||
from aiter import (
|
from aiter import ( # noqa: F401
|
||||||
flash_attn_varlen_func,
|
flash_attn_varlen_func,
|
||||||
mha_batch_prefill_func,
|
mha_batch_prefill_func,
|
||||||
paged_attention_ragged,
|
paged_attention_ragged,
|
||||||
)
|
)
|
||||||
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
|
from aiter.mla import mla_decode_fwd, mla_prefill_fwd # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print(
|
print(
|
||||||
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|||||||
@@ -372,4 +372,4 @@ if not (
|
|||||||
logger.info(
|
logger.info(
|
||||||
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm # noqa: F401
|
||||||
|
|||||||
@@ -116,8 +116,6 @@ def cutlass_fused_experts_fp8(
|
|||||||
|
|
||||||
if is_cuda:
|
if is_cuda:
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
per_group_transpose,
|
|
||||||
per_token_group_quant_fp8_hopper_moe_mn_major,
|
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Cutlass W4A8 MoE kernel."""
|
"""Cutlass W4A8 MoE kernel."""
|
||||||
import logging
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.utils import ceil_div, is_cuda
|
||||||
from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
|
|
||||||
from sglang.utils import is_in_ci
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
|
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
|
||||||
|
|||||||
@@ -43,13 +43,7 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
from flashinfer import (
|
from flashinfer import RoutingMethodType, fp4_quantize
|
||||||
RoutingMethodType,
|
|
||||||
fp4_quantize,
|
|
||||||
reorder_rows_for_gated_act_gemm,
|
|
||||||
shuffle_matrix_a,
|
|
||||||
shuffle_matrix_sf_a,
|
|
||||||
)
|
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
|
|||||||
@@ -51,7 +51,9 @@ elif _is_hip:
|
|||||||
|
|
||||||
|
|
||||||
if _is_cuda or _is_hip:
|
if _is_cuda or _is_hip:
|
||||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
from sgl_kernel import ( # noqa: F401
|
||||||
|
moe_align_block_size as sgl_moe_align_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from functools import cache
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
use_mooncake_ep = False
|
use_mooncake_ep = False
|
||||||
|
|
||||||
from enum import Enum, IntEnum, auto
|
from enum import Enum, auto
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compressed_tensors import CompressionFormat
|
from compressed_tensors import CompressionFormat
|
||||||
@@ -21,14 +21,7 @@ from sglang.srt.layers.quantization.utils import (
|
|||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
replace_parameter,
|
replace_parameter,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
|
||||||
get_bool_env_var,
|
|
||||||
is_cpu,
|
|
||||||
is_cuda,
|
|
||||||
is_hip,
|
|
||||||
is_npu,
|
|
||||||
set_weight_attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
@@ -49,7 +42,7 @@ if _use_aiter:
|
|||||||
from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
|
from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import vllm
|
import vllm # noqa: F401
|
||||||
|
|
||||||
VLLM_AVAILABLE = True
|
VLLM_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ def _compute_enable_deep_gemm():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import deep_gemm
|
import deep_gemm # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
|
from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
|
||||||
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( # noqa: F401
|
||||||
DEEPGEMM_BLACKWELL,
|
DEEPGEMM_BLACKWELL,
|
||||||
DEEPGEMM_SCALE_UE8M0,
|
DEEPGEMM_SCALE_UE8M0,
|
||||||
ENABLE_JIT_DEEPGEMM,
|
ENABLE_JIT_DEEPGEMM,
|
||||||
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
if ENABLE_JIT_DEEPGEMM:
|
if ENABLE_JIT_DEEPGEMM:
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor # noqa: F401
|
||||||
|
|
||||||
_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
|
_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
|
||||||
|
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ if _is_hip:
|
|||||||
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
import vllm._C
|
import vllm._C # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("vllm is required when SGLANG_USE_AITER is set to False")
|
raise ImportError("vllm is required when SGLANG_USE_AITER is set to False")
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from torch.nn.parameter import Parameter
|
|||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
FusedMoEMethodBase,
|
|
||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
@@ -28,7 +27,7 @@ from sglang.srt.layers.quantization.marlin_utils_fp8 import (
|
|||||||
prepare_fp8_layer_for_marlin,
|
prepare_fp8_layer_for_marlin,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||||
from sglang.srt.utils import get_bool_env_var, is_cuda
|
from sglang.srt.utils import get_bool_env_var, is_cuda
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|||||||
@@ -199,7 +199,6 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional[LinearMethodBase]:
|
) -> Optional[LinearMethodBase]:
|
||||||
# Delay the import to avoid circular dependency
|
# Delay the import to avoid circular dependency
|
||||||
from sglang.srt.layers.linear import LinearBase
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, FusedMoE):
|
if isinstance(layer, FusedMoE):
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.utils import get_bool_env_var, get_device_name, is_cuda
|
from sglang.srt.utils import get_device_name, is_cuda
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
|
|||||||
@@ -1059,16 +1059,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
intermediate_size,
|
intermediate_size,
|
||||||
num_experts,
|
num_experts,
|
||||||
):
|
):
|
||||||
from flashinfer import (
|
from flashinfer import nvfp4_block_scale_interleave
|
||||||
RoutingMethodType,
|
|
||||||
e2m1_and_ufp8sf_scale_to_float,
|
|
||||||
fp4_quantize,
|
|
||||||
next_positive_power_of_2,
|
|
||||||
nvfp4_block_scale_interleave,
|
|
||||||
reorder_rows_for_gated_act_gemm,
|
|
||||||
shuffle_matrix_a,
|
|
||||||
shuffle_matrix_sf_a,
|
|
||||||
)
|
|
||||||
from flashinfer.fused_moe.core import (
|
from flashinfer.fused_moe.core import (
|
||||||
_maybe_get_cached_w2_permute_indices,
|
_maybe_get_cached_w2_permute_indices,
|
||||||
_maybe_get_cached_w3_w1_permute_indices,
|
_maybe_get_cached_w3_w1_permute_indices,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -3,16 +3,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from aiter import ActivationType, QuantType, biased_grouped_topk
|
from aiter import ActivationType, QuantType
|
||||||
from aiter.fused_moe import fused_moe
|
from aiter.fused_moe import fused_moe
|
||||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||||
|
|
||||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||||
from sglang.srt.utils import get_bool_env_var, is_hip, mxfp_supported, set_weight_attrs
|
from sglang.srt.utils import is_hip, set_weight_attrs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
|
|||||||
@@ -2,20 +2,13 @@
|
|||||||
|
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import aiter
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
|
||||||
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||||
from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
|
from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
|
||||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||||
from aiter.utility import dtypes
|
|
||||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
|
||||||
|
|
||||||
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||||
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme
|
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme
|
||||||
from sglang.srt.utils import get_bool_env_var
|
|
||||||
|
|
||||||
__all__ = ["QuarkW4A4MXFP4"]
|
__all__ = ["QuarkW4A4MXFP4"]
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import numpy
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||||
from sglang.srt.utils import is_cuda
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
FusedMoEMethodBase,
|
FusedMoEMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
@@ -17,11 +16,11 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||||
from sglang.srt.utils import is_npu, set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
CombineInput,
|
CombineInput,
|
||||||
DeepEPNormalOutput,
|
DeepEPNormalOutput,
|
||||||
|
|||||||
@@ -1,28 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
|
||||||
import sys
|
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from sglang.srt.lora.triton_ops import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.lora.utils import LoRABatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.server_args import ServerArgs
|
|
||||||
|
|
||||||
|
|
||||||
class TritonLoRABackend(BaseLoRABackend):
|
class TritonLoRABackend(BaseLoRABackend):
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Dict, List, Optional, Set
|
from typing import Optional, Set
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
# and "Punica: Multi-Tenant LoRA Serving"
|
# and "Punica: Multi-Tenant LoRA Serving"
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
from typing import Dict, Iterable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,10 @@ limitations under the License.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from queue import Empty, Full, PriorityQueue, Queue
|
from queue import Empty, Full, Queue
|
||||||
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, List, NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -41,7 +40,7 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
is_dp_attention_enabled,
|
is_dp_attention_enabled,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -59,11 +59,10 @@ from sglang.srt.mem_cache.allocator import (
|
|||||||
SWATokenToKVPoolAllocator,
|
SWATokenToKVPoolAllocator,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
from sglang.srt.mem_cache.chunk_cache import SWAChunkCache
|
||||||
from sglang.srt.mem_cache.common import (
|
from sglang.srt.mem_cache.common import (
|
||||||
alloc_for_decode,
|
alloc_for_decode,
|
||||||
alloc_for_extend,
|
alloc_for_extend,
|
||||||
alloc_token_slots,
|
|
||||||
evict_from_tree_cache,
|
evict_from_tree_cache,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||||
@@ -76,7 +75,6 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs, get_global_server_args
|
from sglang.srt.server_args import ServerArgs, get_global_server_args
|
||||||
from sglang.srt.utils import flatten_nested_list
|
from sglang.srt.utils import flatten_nested_list
|
||||||
from sglang.srt.utils.common import next_power_of_2
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
|||||||
@@ -3,13 +3,10 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
|
||||||
from sglang.srt.managers.schedule_policy import PrefillAdder
|
from sglang.srt.managers.schedule_policy import PrefillAdder
|
||||||
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
||||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
from typing import TYPE_CHECKING, Optional
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if num_new_pages_item < 200:
|
if num_new_pages_item < 200:
|
||||||
import sgl_kernel_npu
|
import sgl_kernel_npu # noqa: F401
|
||||||
|
|
||||||
torch.ops.npu.alloc_extend(
|
torch.ops.npu.alloc_extend(
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
from typing import TYPE_CHECKING, Tuple, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.mem_cache.radix_cache import TreeNode
|
from sglang.srt.mem_cache.radix_cache import TreeNode
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ The radix tree data structure for managing the hybrid (full and Mamba) KV cache.
|
|||||||
import heapq
|
import heapq
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial
|
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -33,7 +32,6 @@ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
|
|||||||
from sglang.srt.mem_cache.radix_cache import (
|
from sglang.srt.mem_cache.radix_cache import (
|
||||||
RadixKey,
|
RadixKey,
|
||||||
_key_match_page_size1,
|
_key_match_page_size1,
|
||||||
_key_match_paged,
|
|
||||||
get_child_key,
|
get_child_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from enum import IntEnum
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import heapq
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import lru_cache, partial
|
from functools import lru_cache, partial
|
||||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -3,20 +3,8 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from aibrix_kvcache import (
|
from aibrix_kvcache.common.absl_logging import log_every_n_seconds
|
||||||
BaseKVCacheManager,
|
|
||||||
GroupAwareKVCacheManager,
|
|
||||||
KVCacheBlockLayout,
|
|
||||||
KVCacheBlockSpec,
|
|
||||||
KVCacheConfig,
|
|
||||||
KVCacheMetrics,
|
|
||||||
KVCacheTensorSpec,
|
|
||||||
ModelSpec,
|
|
||||||
TokenListView,
|
|
||||||
)
|
|
||||||
from aibrix_kvcache.common.absl_logging import getLogger, log_every_n_seconds, log_if
|
|
||||||
from aibrix_kvcache_storage import AibrixKVCacheStorage
|
from aibrix_kvcache_storage import AibrixKVCacheStorage
|
||||||
from torch.distributed import Backend, ProcessGroup
|
|
||||||
|
|
||||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
||||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
||||||
|
|||||||
@@ -2,21 +2,18 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
from typing import Any, List, Optional, Tuple
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import eic
|
import eic
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
|
||||||
from sglang.srt.mem_cache.hicache_storage import (
|
from sglang.srt.mem_cache.hicache_storage import (
|
||||||
HiCacheStorage,
|
HiCacheStorage,
|
||||||
HiCacheStorageConfig,
|
HiCacheStorageConfig,
|
||||||
HiCacheStorageExtraInfo,
|
HiCacheStorageExtraInfo,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache, MLATokenToKVPoolHost
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import hashlib
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ Records the latency of some functions
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from sglang.srt.metrics.utils import exponential_buckets
|
from sglang.srt.metrics.utils import exponential_buckets
|
||||||
|
|
||||||
|
|||||||
@@ -104,11 +104,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
ForwardBatch,
|
|
||||||
ForwardMode,
|
|
||||||
PPProxyTensors,
|
|
||||||
)
|
|
||||||
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
||||||
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
|
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
|
||||||
PiecewiseCudaGraphRunner,
|
PiecewiseCudaGraphRunner,
|
||||||
|
|||||||
@@ -19,10 +19,9 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.configs.model_config import AttentionArch, is_deepseek_nsa
|
from sglang.srt.configs.model_config import is_deepseek_nsa
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""SGLang BailingMoE model."""
|
"""SGLang BailingMoE model."""
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
from typing import Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -59,7 +59,6 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
|||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple
|
from typing import Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|||||||
@@ -183,9 +183,9 @@ elif _is_hip:
|
|||||||
awq_dequantize_triton as awq_dequantize,
|
awq_dequantize_triton as awq_dequantize,
|
||||||
)
|
)
|
||||||
elif _is_npu:
|
elif _is_npu:
|
||||||
import custom_ops
|
import custom_ops # noqa: F401
|
||||||
import sgl_kernel_npu
|
import sgl_kernel_npu # noqa: F401
|
||||||
import torch_npu
|
import torch_npu # noqa: F401
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Iterable, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.activations import ACT2FN
|
|
||||||
|
|
||||||
from sglang.srt.configs import DotsOCRConfig
|
from sglang.srt.configs import DotsOCRConfig
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
@@ -22,7 +21,6 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.models.dots_vlm_vit import DotsVisionTransformer
|
from sglang.srt.models.dots_vlm_vit import DotsVisionTransformer
|
||||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||||
from sglang.srt.utils import add_prefix
|
from sglang.srt.utils import add_prefix
|
||||||
from sglang.srt.utils.hf_transformers_utils import get_processor
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.configs.dots_vlm import DotsVLMConfig
|
from sglang.srt.configs.dots_vlm import DotsVLMConfig
|
||||||
from sglang.srt.distributed import parallel_state
|
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.managers.mm_utils import (
|
from sglang.srt.managers.mm_utils import (
|
||||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import enum
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Iterable, List, Optional, Set, Tuple
|
from typing import Any, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user