[Lint] Add python/sglang to ruff F401 checks and remove unused imports in files (#11685)
This commit is contained in:
@@ -15,7 +15,7 @@ if not is_hpu():
|
||||
# ROCm does not use vllm custom allreduce
|
||||
if use_vllm_custom_allreduce and not is_hip():
|
||||
try:
|
||||
import vllm._C
|
||||
import vllm._C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
else:
|
||||
|
||||
@@ -9,7 +9,6 @@ from unittest.mock import patch
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
||||
import sglang.srt.compilation.weak_ref_tensor_jit
|
||||
from sglang.srt.compilation.compilation_config import CompilationConfig
|
||||
from sglang.srt.compilation.compilation_counter import compilation_counter
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from transformers import AutoProcessor, LlamaTokenizerFast, PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessingKwargs, Unpack
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers import AutoProcessor, PretrainedConfig
|
||||
from transformers.processing_utils import ProcessingKwargs
|
||||
|
||||
try:
|
||||
from transformers import Qwen2_5_VLProcessor
|
||||
|
||||
@@ -14,17 +14,12 @@
|
||||
# limitations under the License.
|
||||
"""Falcon-H1 model configuration"""
|
||||
|
||||
import enum
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_size,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_tensor_model_parallel_world_size
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
||||
from sglang.srt.distributed.utils import divide
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from typing import Generator, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -48,10 +48,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
BaseTokenToKVPoolAllocator,
|
||||
SWATokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
HybridLinearKVPool,
|
||||
@@ -61,7 +58,6 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
ReqToTokenPool,
|
||||
SWAKVPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import get_int_env_var, require_mlp_sync
|
||||
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ Life cycle of a request in the prefill server
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from http import HTTPStatus
|
||||
@@ -54,7 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
NSATokenToKVPool,
|
||||
SWAKVPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||
from sglang.srt.utils import (
|
||||
DynamicGradMode,
|
||||
broadcast_pyobj,
|
||||
|
||||
@@ -32,7 +32,7 @@ try:
|
||||
ops.meta_size()
|
||||
else:
|
||||
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
|
||||
import sgl_kernel
|
||||
import sgl_kernel # noqa: F401
|
||||
custom_ar = True
|
||||
except Exception:
|
||||
# For CPUs
|
||||
|
||||
@@ -4,7 +4,7 @@ import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, List, Optional, TypeVar, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -24,7 +24,7 @@ if _is_hip:
|
||||
mscclpp_is_available = False
|
||||
if _is_cuda:
|
||||
try:
|
||||
import sgl_kernel
|
||||
import sgl_kernel # noqa: F401
|
||||
|
||||
mscclpp_is_available = True
|
||||
except:
|
||||
|
||||
@@ -9,7 +9,7 @@ from torch.distributed import ProcessGroup
|
||||
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
||||
)
|
||||
from sglang.srt.utils import get_device_capability, is_cuda, is_hip
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
try:
|
||||
import torch.distributed._symmetric_memory as torch_symm_mem
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import base64
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copied from vLLM
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
# Adapted from vLLM: https://github.com/vllm-project/vllm/blob/1b9902806915040ac9b3029f2ab7522ec505afc3/vllm/entrypoints/harmony_utils.py
|
||||
# Slight differences in processing chat messages
|
||||
import datetime
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ This file implements HTTP APIs for the inference engine via fastapi.
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing as multiprocessing
|
||||
import os
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import multiprocessing
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pybase64
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.entrypoints.EngineBase import EngineBase
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
|
||||
@@ -3,8 +3,6 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
|
||||
def balanced_packing(
|
||||
weight: torch.Tensor, num_packs: int
|
||||
|
||||
@@ -6,11 +6,7 @@ from typing import List
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import (
|
||||
StreamingParseResult,
|
||||
StructureInfo,
|
||||
_GetInfoFunc,
|
||||
)
|
||||
from sglang.srt.function_call.core_types import StreamingParseResult, _GetInfoFunc
|
||||
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import json
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from json.decoder import WHITESPACE
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||
|
||||
@@ -70,7 +70,7 @@ def compile_proto(proto_file: Path, output_dir: Path, verbose: bool = True) -> b
|
||||
|
||||
# Check if grpc_tools is available
|
||||
try:
|
||||
import grpc_tools.protoc
|
||||
import grpc_tools.protoc # noqa: F401
|
||||
except ImportError:
|
||||
print("Error: grpcio-tools not installed")
|
||||
print(
|
||||
|
||||
@@ -27,7 +27,6 @@ from sglang.srt.managers.io_struct import (
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
@@ -380,4 +380,7 @@ if not (
|
||||
logger.info(
|
||||
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
||||
)
|
||||
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
||||
from vllm.model_executor.layers.activation import ( # noqa: F401
|
||||
GeluAndMul,
|
||||
SiluAndMul,
|
||||
)
|
||||
|
||||
@@ -20,7 +20,6 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.fla.utils import tensor_cache
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -9,8 +9,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
||||
from sglang.srt.layers.attention.fla.op import safe_exp
|
||||
from sglang.srt.layers.attention.fla.utils import check_shared_mem
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
|
||||
@@ -50,7 +50,6 @@ if is_flashinfer_available():
|
||||
fast_decode_plan,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
from flashinfer.decode import _get_range_buf, get_seq_lens
|
||||
|
||||
|
||||
class WrapperDispatch(Enum):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from dataclasses import astuple, dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
|
||||
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class IntelAMXAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
import sgl_kernel
|
||||
import sgl_kernel # noqa: F401
|
||||
|
||||
super().__init__()
|
||||
self.forward_metadata = None
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ def is_mla_preprocess_enabled() -> bool:
|
||||
|
||||
|
||||
if is_mla_preprocess_enabled():
|
||||
import sgl_kernel_npu
|
||||
import sgl_kernel_npu # noqa: F401
|
||||
import torch_npu
|
||||
|
||||
torch.npu.config.allow_internal_format = True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -547,7 +547,7 @@ class Indexer(CustomOp):
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
) -> torch.Tensor:
|
||||
import custom_ops
|
||||
import custom_ops # noqa: F401
|
||||
import torch_npu
|
||||
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
|
||||
|
||||
@@ -34,18 +33,18 @@ _is_hip = is_hip()
|
||||
|
||||
if _is_hip:
|
||||
try:
|
||||
from aiter import (
|
||||
from aiter import ( # noqa: F401
|
||||
flash_attn_varlen_func,
|
||||
mha_batch_prefill_func,
|
||||
paged_attention_ragged,
|
||||
)
|
||||
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
|
||||
from aiter.mla import mla_decode_fwd, mla_prefill_fwd # noqa: F401
|
||||
except ImportError:
|
||||
print(
|
||||
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
||||
)
|
||||
else:
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -372,4 +372,4 @@ if not (
|
||||
logger.info(
|
||||
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm # noqa: F401
|
||||
|
||||
@@ -116,8 +116,6 @@ def cutlass_fused_experts_fp8(
|
||||
|
||||
if is_cuda:
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_group_transpose,
|
||||
per_token_group_quant_fp8_hopper_moe_mn_major,
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Cutlass W4A8 MoE kernel."""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
|
||||
from sglang.utils import is_in_ci
|
||||
from sglang.srt.utils import ceil_div, is_cuda
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
|
||||
|
||||
@@ -43,13 +43,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import (
|
||||
RoutingMethodType,
|
||||
fp4_quantize,
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
)
|
||||
from flashinfer import RoutingMethodType, fp4_quantize
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
|
||||
@@ -51,7 +51,9 @@ elif _is_hip:
|
||||
|
||||
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
from sgl_kernel import ( # noqa: F401
|
||||
moe_align_block_size as sgl_moe_align_block_size,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import IntEnum
|
||||
from functools import cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
|
||||
@@ -22,7 +22,7 @@ try:
|
||||
except ImportError:
|
||||
use_mooncake_ep = False
|
||||
|
||||
from enum import Enum, IntEnum, auto
|
||||
from enum import Enum, auto
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import enum
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
@@ -21,14 +21,7 @@ from sglang.srt.layers.quantization.utils import (
|
||||
per_tensor_dequantize,
|
||||
replace_parameter,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
set_weight_attrs,
|
||||
)
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
@@ -49,7 +42,7 @@ if _use_aiter:
|
||||
from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
|
||||
|
||||
try:
|
||||
import vllm
|
||||
import vllm # noqa: F401
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
|
||||
@@ -12,7 +12,7 @@ def _compute_enable_deep_gemm():
|
||||
return False
|
||||
|
||||
try:
|
||||
import deep_gemm
|
||||
import deep_gemm # noqa: F401
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Tuple
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
|
||||
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
||||
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( # noqa: F401
|
||||
DEEPGEMM_BLACKWELL,
|
||||
DEEPGEMM_SCALE_UE8M0,
|
||||
ENABLE_JIT_DEEPGEMM,
|
||||
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
if ENABLE_JIT_DEEPGEMM:
|
||||
import deep_gemm
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor # noqa: F401
|
||||
|
||||
_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ if _is_hip:
|
||||
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||
else:
|
||||
try:
|
||||
import vllm._C
|
||||
import vllm._C # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError("vllm is required when SGLANG_USE_AITER is set to False")
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from torch.nn.parameter import Parameter
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@@ -28,7 +27,7 @@ from sglang.srt.layers.quantization.marlin_utils_fp8 import (
|
||||
prepare_fp8_layer_for_marlin,
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
@@ -199,7 +199,6 @@ class GPTQConfig(QuantizationConfig):
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[LinearMethodBase]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_name, is_cuda
|
||||
from sglang.srt.utils import get_device_name, is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
|
||||
@@ -1059,16 +1059,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size,
|
||||
num_experts,
|
||||
):
|
||||
from flashinfer import (
|
||||
RoutingMethodType,
|
||||
e2m1_and_ufp8sf_scale_to_float,
|
||||
fp4_quantize,
|
||||
next_positive_power_of_2,
|
||||
nvfp4_block_scale_interleave,
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
)
|
||||
from flashinfer import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import (
|
||||
_maybe_get_cached_w2_permute_indices,
|
||||
_maybe_get_cached_w3_w1_permute_indices,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
@@ -3,16 +3,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from aiter import ActivationType, QuantType, biased_grouped_topk
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, mxfp_supported, set_weight_attrs
|
||||
from sglang.srt.utils import is_hip, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
|
||||
@@ -2,20 +2,13 @@
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import aiter
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||
from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
from aiter.utility import dtypes
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
__all__ = ["QuarkW4A4MXFP4"]
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ import numpy
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
@@ -17,11 +16,11 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import is_npu, set_weight_attrs
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
DeepEPNormalOutput,
|
||||
|
||||
@@ -1,28 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from types import MappingProxyType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from sglang.srt.lora.triton_ops import (
|
||||
)
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
class TritonLoRABackend(BaseLoRABackend):
|
||||
|
||||
@@ -20,7 +20,7 @@ import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# and "Punica: Multi-Tenant LoRA Serving"
|
||||
|
||||
import logging
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -14,11 +14,10 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from queue import Empty, Full, PriorityQueue, Queue
|
||||
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, List, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -41,7 +40,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
||||
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -59,11 +59,10 @@ from sglang.srt.mem_cache.allocator import (
|
||||
SWATokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.chunk_cache import SWAChunkCache
|
||||
from sglang.srt.mem_cache.common import (
|
||||
alloc_for_decode,
|
||||
alloc_for_extend,
|
||||
alloc_token_slots,
|
||||
evict_from_tree_cache,
|
||||
)
|
||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||
@@ -76,7 +75,6 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs, get_global_server_args
|
||||
from sglang.srt.utils import flatten_nested_list
|
||||
from sglang.srt.utils.common import next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
|
||||
@@ -3,13 +3,10 @@ from __future__ import annotations
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
||||
from sglang.srt.managers.schedule_policy import PrefillAdder
|
||||
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
@@ -92,7 +92,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
)
|
||||
|
||||
if num_new_pages_item < 200:
|
||||
import sgl_kernel_npu
|
||||
import sgl_kernel_npu # noqa: F401
|
||||
|
||||
torch.ops.npu.alloc_extend(
|
||||
prefix_lens,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.radix_cache import TreeNode
|
||||
|
||||
@@ -22,7 +22,6 @@ The radix tree data structure for managing the hybrid (full and Mamba) KV cache.
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -33,7 +32,6 @@ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import (
|
||||
RadixKey,
|
||||
_key_match_page_size1,
|
||||
_key_match_paged,
|
||||
get_child_key,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import abc
|
||||
import logging
|
||||
import threading
|
||||
from enum import IntEnum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -3,20 +3,8 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from aibrix_kvcache import (
|
||||
BaseKVCacheManager,
|
||||
GroupAwareKVCacheManager,
|
||||
KVCacheBlockLayout,
|
||||
KVCacheBlockSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheMetrics,
|
||||
KVCacheTensorSpec,
|
||||
ModelSpec,
|
||||
TokenListView,
|
||||
)
|
||||
from aibrix_kvcache.common.absl_logging import getLogger, log_every_n_seconds, log_if
|
||||
from aibrix_kvcache.common.absl_logging import log_every_n_seconds
|
||||
from aibrix_kvcache_storage import AibrixKVCacheStorage
|
||||
from torch.distributed import Backend, ProcessGroup
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
||||
|
||||
@@ -2,21 +2,18 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import eic
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||
from sglang.srt.mem_cache.hicache_storage import (
|
||||
HiCacheStorage,
|
||||
HiCacheStorageConfig,
|
||||
HiCacheStorageExtraInfo,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache, MLATokenToKVPoolHost
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
@@ -18,7 +18,7 @@ Records the latency of some functions
|
||||
import asyncio
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from sglang.srt.metrics.utils import exponential_buckets
|
||||
|
||||
|
||||
@@ -104,11 +104,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
)
|
||||
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
||||
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
|
||||
PiecewiseCudaGraphRunner,
|
||||
|
||||
@@ -19,10 +19,9 @@ import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch, is_deepseek_nsa
|
||||
from sglang.srt.configs.model_config import is_deepseek_nsa
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
# limitations under the License.
|
||||
"""SGLang BailingMoE model."""
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
from typing import Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -59,7 +59,6 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -183,9 +183,9 @@ elif _is_hip:
|
||||
awq_dequantize_triton as awq_dequantize,
|
||||
)
|
||||
elif _is_npu:
|
||||
import custom_ops
|
||||
import sgl_kernel_npu
|
||||
import torch_npu
|
||||
import custom_ops # noqa: F401
|
||||
import sgl_kernel_npu # noqa: F401
|
||||
import torch_npu # noqa: F401
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from sglang.srt.configs import DotsOCRConfig
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
@@ -22,7 +21,6 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.dots_vlm_vit import DotsVisionTransformer
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.utils import add_prefix
|
||||
from sglang.srt.utils.hf_transformers_utils import get_processor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.configs.dots_vlm import DotsVLMConfig
|
||||
from sglang.srt.distributed import parallel_state
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import enum
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
|
||||
@@ -14,8 +14,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.models.auto.modeling_auto import AutoModel
|
||||
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.linear import RowParallelLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user