move more files under srt/utils (#11285)
This commit is contained in:
4
.github/workflows/pr-test.yml
vendored
4
.github/workflows/pr-test.yml
vendored
@@ -292,7 +292,7 @@ jobs:
|
|||||||
needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels]
|
needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels]
|
||||||
if: always() && !failure() && !cancelled() &&
|
if: always() && !failure() && !cancelled() &&
|
||||||
((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))
|
((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))
|
||||||
runs-on: 4-gpu-runner
|
runs-on: 4-gpu-h100
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -614,7 +614,7 @@ jobs:
|
|||||||
needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels]
|
needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels]
|
||||||
if: always() && !failure() && !cancelled() &&
|
if: always() && !failure() && !cancelled() &&
|
||||||
((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))
|
((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))
|
||||||
runs-on: 4-gpu-runner
|
runs-on: 4-gpu-h100
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
||||||
from sglang.srt.utils import get_int_env_var, require_mlp_sync
|
from sglang.srt.utils import get_int_env_var, require_mlp_sync
|
||||||
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,6 @@ from sglang.srt.managers.scheduler import run_scheduler_process
|
|||||||
from sglang.srt.managers.template_manager import TemplateManager
|
from sglang.srt.managers.template_manager import TemplateManager
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
MultiprocessingSerializer,
|
MultiprocessingSerializer,
|
||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
@@ -82,6 +81,7 @@ from sglang.srt.utils import (
|
|||||||
set_prometheus_multiproc_dir,
|
set_prometheus_multiproc_dir,
|
||||||
set_ulimit,
|
set_ulimit,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.version import __version__
|
from sglang.version import __version__
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -35,8 +35,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
||||||
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
|
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
|
||||||
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -12,7 +12,10 @@ from sglang.srt.custom_op import CustomOp
|
|||||||
from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
|
from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
|
||||||
|
|
||||||
if is_cuda():
|
if is_cuda():
|
||||||
|
try:
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
|
except ImportError as e:
|
||||||
|
deep_gemm = e
|
||||||
|
|
||||||
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
|
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||||
|
|||||||
@@ -30,9 +30,9 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|||||||
ModelOptNvFp4FusedMoEMethod,
|
ModelOptNvFp4FusedMoEMethod,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.offloader import get_offloader
|
|
||||||
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||||
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
||||||
|
from sglang.srt.utils.offloader import get_offloader
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ _is_hip = is_hip()
|
|||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def fused_moe_router_kernel(
|
def fused_moe_router_cudacore_kernel(
|
||||||
input_ptr, # input (bs, hidden_dim)
|
input_ptr, # input (bs, hidden_dim)
|
||||||
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
||||||
topk_weights_ptr, # output (bs, topk)
|
topk_weights_ptr, # output (bs, topk)
|
||||||
@@ -114,7 +114,7 @@ def fused_moe_router_kernel(
|
|||||||
# assert not moe_renormalize, "moe weight renormalization not implemented"
|
# assert not moe_renormalize, "moe weight renormalization not implemented"
|
||||||
|
|
||||||
|
|
||||||
def fused_moe_router_impl(
|
def fused_moe_router_cudacore(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_weight: torch.Tensor,
|
router_weight: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
@@ -138,7 +138,7 @@ def fused_moe_router_impl(
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
fused_moe_router_kernel[(bs,)](
|
fused_moe_router_cudacore_kernel[(bs,)](
|
||||||
x,
|
x,
|
||||||
router_weight,
|
router_weight,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
@@ -157,7 +157,7 @@ def fused_moe_router_impl(
|
|||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def fused_moe_router_large_bs_kernel(
|
def fused_moe_router_tensorcore_kernel(
|
||||||
a_ptr, # input (bs, hidden_dim)
|
a_ptr, # input (bs, hidden_dim)
|
||||||
b_ptr, # input (num_experts, hidden_dim)
|
b_ptr, # input (num_experts, hidden_dim)
|
||||||
topk_weights_ptr, # output (bs, topk)
|
topk_weights_ptr, # output (bs, topk)
|
||||||
@@ -167,12 +167,15 @@ def fused_moe_router_large_bs_kernel(
|
|||||||
topk: tl.constexpr, # only support topk <= 2
|
topk: tl.constexpr, # only support topk <= 2
|
||||||
moe_softcapping: tl.constexpr,
|
moe_softcapping: tl.constexpr,
|
||||||
moe_renormalize: tl.constexpr, # not supported
|
moe_renormalize: tl.constexpr, # not supported
|
||||||
|
correction_bias_ptr,
|
||||||
|
is_correction_bias: tl.constexpr,
|
||||||
K: tl.constexpr,
|
K: tl.constexpr,
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
stride_am: tl.constexpr,
|
stride_am: tl.constexpr,
|
||||||
stride_bn: tl.constexpr,
|
stride_bn: tl.constexpr,
|
||||||
|
dp_attn_workaround_flag: tl.constexpr,
|
||||||
):
|
):
|
||||||
|
|
||||||
# 1. get block id
|
# 1. get block id
|
||||||
@@ -217,6 +220,20 @@ def fused_moe_router_large_bs_kernel(
|
|||||||
exped = tl.exp(2 * logits_scaled)
|
exped = tl.exp(2 * logits_scaled)
|
||||||
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
||||||
|
|
||||||
|
# Add bias after softcapping
|
||||||
|
if is_correction_bias:
|
||||||
|
bias = tl.load(
|
||||||
|
correction_bias_ptr + tl.arange(0, BLOCK_SIZE_N)[None, :],
|
||||||
|
mask=expert_mask.T,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
logits_softcapped = logits_softcapped + bias
|
||||||
|
|
||||||
|
if dp_attn_workaround_flag:
|
||||||
|
logits_softcapped = tl.where(
|
||||||
|
logits_softcapped != logits_softcapped, -1e9, logits_softcapped
|
||||||
|
)
|
||||||
|
|
||||||
# 5. top1
|
# 5. top1
|
||||||
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
|
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
|
||||||
cond_top1 = arange_block_size_n < num_experts
|
cond_top1 = arange_block_size_n < num_experts
|
||||||
@@ -266,7 +283,7 @@ def fused_moe_router_large_bs_kernel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def fused_moe_router_large_bs_impl(
|
def fused_moe_router_tensorcore(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_weight: torch.Tensor,
|
router_weight: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
@@ -274,6 +291,7 @@ def fused_moe_router_large_bs_impl(
|
|||||||
BLOCK_SIZE_M: int,
|
BLOCK_SIZE_M: int,
|
||||||
BLOCK_SIZE_N: int,
|
BLOCK_SIZE_N: int,
|
||||||
BLOCK_SIZE_K: int,
|
BLOCK_SIZE_K: int,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
||||||
bs, hidden_dim = x.shape
|
bs, hidden_dim = x.shape
|
||||||
@@ -285,10 +303,17 @@ def fused_moe_router_large_bs_impl(
|
|||||||
|
|
||||||
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
||||||
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
||||||
|
is_correction_bias = correction_bias is not None
|
||||||
|
|
||||||
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
|
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
|
||||||
|
|
||||||
fused_moe_router_large_bs_kernel[grid](
|
# TODO(ch-wan): temporary workaround for dp attention. We should support masked
|
||||||
|
# router to skip padded tokens.
|
||||||
|
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
||||||
|
|
||||||
|
dp_attn_workaround_flag = is_dp_attention_enabled()
|
||||||
|
|
||||||
|
fused_moe_router_tensorcore_kernel[grid](
|
||||||
a_ptr=x,
|
a_ptr=x,
|
||||||
b_ptr=router_weight,
|
b_ptr=router_weight,
|
||||||
topk_weights_ptr=topk_weights,
|
topk_weights_ptr=topk_weights,
|
||||||
@@ -299,11 +324,14 @@ def fused_moe_router_large_bs_impl(
|
|||||||
moe_softcapping=moe_softcapping,
|
moe_softcapping=moe_softcapping,
|
||||||
moe_renormalize=False,
|
moe_renormalize=False,
|
||||||
K=hidden_dim,
|
K=hidden_dim,
|
||||||
|
correction_bias_ptr=correction_bias,
|
||||||
|
is_correction_bias=is_correction_bias,
|
||||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||||
stride_am=hidden_dim,
|
stride_am=hidden_dim,
|
||||||
stride_bn=hidden_dim,
|
stride_bn=hidden_dim,
|
||||||
|
dp_attn_workaround_flag=dp_attn_workaround_flag,
|
||||||
)
|
)
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
@@ -316,6 +344,7 @@ def fused_moe_router_shim(
|
|||||||
topk,
|
topk,
|
||||||
renormalize,
|
renormalize,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
enable_deterministic_inference: bool = False,
|
||||||
):
|
):
|
||||||
assert not renormalize
|
assert not renormalize
|
||||||
assert (
|
assert (
|
||||||
@@ -324,16 +353,22 @@ def fused_moe_router_shim(
|
|||||||
)
|
)
|
||||||
bs, hidden_dim = hidden_states.shape
|
bs, hidden_dim = hidden_states.shape
|
||||||
num_experts = gating_output.shape[0]
|
num_experts = gating_output.shape[0]
|
||||||
|
|
||||||
BLOCK_SIZE_M = 32
|
BLOCK_SIZE_M = 32
|
||||||
BLOCK_SIZE_N = 16
|
|
||||||
BLOCK_SIZE_K = 256
|
BLOCK_SIZE_N = max(num_experts, 16)
|
||||||
|
BLOCK_SIZE_K = (
|
||||||
|
256 if num_experts < 256 else 64
|
||||||
|
) # if experts are large, need to use smaller k block or shared memory OOM
|
||||||
|
|
||||||
if (
|
if (
|
||||||
bs >= 512
|
(bs >= 512 or num_experts > 8)
|
||||||
and topk <= 2
|
|
||||||
and num_experts <= BLOCK_SIZE_N
|
|
||||||
and hidden_dim % BLOCK_SIZE_K == 0
|
and hidden_dim % BLOCK_SIZE_K == 0
|
||||||
|
# we keep using single kernel to avoid non-deterministic behavior
|
||||||
|
and not enable_deterministic_inference
|
||||||
):
|
):
|
||||||
return fused_moe_router_large_bs_impl(
|
# if large batch size or large expert, use kernel that uses tensorcore in matmul
|
||||||
|
return fused_moe_router_tensorcore(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
router_weight=gating_output,
|
router_weight=gating_output,
|
||||||
topk=topk,
|
topk=topk,
|
||||||
@@ -341,9 +376,11 @@ def fused_moe_router_shim(
|
|||||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||||
|
correction_bias=correction_bias,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return fused_moe_router_impl(
|
# if smaller, use kernel that does not use tensorcore in matmul
|
||||||
|
return fused_moe_router_cudacore(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
router_weight=gating_output,
|
router_weight=gating_output,
|
||||||
topk=topk,
|
topk=topk,
|
||||||
@@ -380,11 +417,10 @@ class FusedMoeRouter:
|
|||||||
renormalize=False,
|
renormalize=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_vllm(
|
def forward_torch(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# g, _ = self.router_linear.forward(x)
|
|
||||||
g = x.float() @ self.router_linear.weight.T.float()
|
g = x.float() @ self.router_linear.weight.T.float()
|
||||||
|
|
||||||
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
|
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
|
||||||
|
|||||||
@@ -2,11 +2,10 @@ from typing import Callable, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt import offloader
|
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||||
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
||||||
from sglang.srt.utils import is_sm100_supported
|
from sglang.srt.utils import is_sm100_supported, offloader
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
@@ -29,7 +28,6 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
align,
|
align,
|
||||||
ceil_div,
|
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_cuda_version,
|
get_cuda_version,
|
||||||
get_device_capability,
|
get_device_capability,
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ from dataclasses import dataclass, field, fields
|
|||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from sglang.srt.aio_rwlock import RWLock
|
|
||||||
from sglang.srt.utils import ConcurrentCounter
|
from sglang.srt.utils import ConcurrentCounter
|
||||||
|
from sglang.srt.utils.aio_rwlock import RWLock
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|||||||
@@ -37,13 +37,13 @@ from sglang.srt.managers.io_struct import (
|
|||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
bind_port,
|
bind_port,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_itself_when_parent_died,
|
kill_itself_when_parent_died,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ else:
|
|||||||
Image = Any
|
Image = Any
|
||||||
|
|
||||||
|
|
||||||
# Parameters for a session
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseReq(ABC):
|
class BaseReq(ABC):
|
||||||
rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
|
rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
|
||||||
@@ -60,9 +59,11 @@ class BaseBatchReq(ABC):
|
|||||||
return self.rids
|
return self.rids
|
||||||
|
|
||||||
|
|
||||||
|
# Parameters for a session
|
||||||
@dataclass
|
@dataclass
|
||||||
class SessionParams:
|
class SessionParams:
|
||||||
id: Optional[str] = None
|
id: Optional[str] = None
|
||||||
|
rid: Optional[str] = None
|
||||||
offset: Optional[int] = None
|
offset: Optional[int] = None
|
||||||
replace: Optional[bool] = None
|
replace: Optional[bool] = None
|
||||||
drop_previous_output: Optional[bool] = None
|
drop_previous_output: Optional[bool] = None
|
||||||
|
|||||||
@@ -156,7 +156,6 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
||||||
from sglang.srt.tracing.trace import (
|
from sglang.srt.tracing.trace import (
|
||||||
process_tracing_init,
|
process_tracing_init,
|
||||||
trace_set_proc_propagate_context,
|
trace_set_proc_propagate_context,
|
||||||
@@ -192,6 +191,7 @@ from sglang.srt.utils.hf_transformers_utils import (
|
|||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
get_tokenizer_from_processor,
|
get_tokenizer_from_processor,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ import zmq
|
|||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
from fastapi import BackgroundTasks
|
from fastapi import BackgroundTasks
|
||||||
|
|
||||||
from sglang.srt.aio_rwlock import RWLock
|
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.lora.lora_registry import LoRARegistry
|
from sglang.srt.lora.lora_registry import LoRARegistry
|
||||||
@@ -94,6 +93,7 @@ from sglang.srt.utils import (
|
|||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.utils.aio_rwlock import RWLock
|
||||||
from sglang.srt.utils.hf_transformers_utils import (
|
from sglang.srt.utils.hf_transformers_utils import (
|
||||||
get_processor,
|
get_processor,
|
||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from dataclasses import dataclass
|
|||||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
|
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
|
||||||
from sglang.srt.layers.attention.nsa import index_buf_accessor
|
from sglang.srt.layers.attention.nsa import index_buf_accessor
|
||||||
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Memory pool.
|
Memory pool.
|
||||||
|
|||||||
@@ -117,15 +117,9 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.offloader import (
|
|
||||||
create_offloader_from_server_args,
|
|
||||||
get_offloader,
|
|
||||||
set_offloader,
|
|
||||||
)
|
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
MultiprocessingSerializer,
|
MultiprocessingSerializer,
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
@@ -148,7 +142,13 @@ from sglang.srt.utils import (
|
|||||||
set_cuda_arch,
|
set_cuda_arch,
|
||||||
slow_rank_detector,
|
slow_rank_detector,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.utils.offloader import (
|
||||||
|
create_offloader_from_server_args,
|
||||||
|
get_offloader,
|
||||||
|
set_offloader,
|
||||||
|
)
|
||||||
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
|
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
|
||||||
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.srt.weight_sync.tensor_bucket import (
|
from sglang.srt.weight_sync.tensor_bucket import (
|
||||||
FlattenedTensorBucket,
|
FlattenedTensorBucket,
|
||||||
FlattenedTensorMetadata,
|
FlattenedTensorMetadata,
|
||||||
|
|||||||
@@ -43,10 +43,8 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Iterable, List, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -56,10 +54,6 @@ from sglang.srt.configs import KimiVLConfig
|
|||||||
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
|
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
|
||||||
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
||||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||||
from sglang.srt.distributed import (
|
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.activation import QuickGELU
|
from sglang.srt.layers.activation import QuickGELU
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ from typing import List, Optional, Sequence, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers.activations import ACT2FN, GELUTanh
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -596,6 +596,8 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
|||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
|
||||||
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
|
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
|
||||||
|
from transformers.activations import GELUTanh
|
||||||
|
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
config = deepcopy(config)
|
config = deepcopy(config)
|
||||||
self.merge_kernel_size = config.merge_kernel_size
|
self.merge_kernel_size = config.merge_kernel_size
|
||||||
|
|||||||
@@ -238,6 +238,7 @@ class ServerArgs:
|
|||||||
log_requests: bool = False
|
log_requests: bool = False
|
||||||
log_requests_level: int = 2
|
log_requests_level: int = 2
|
||||||
crash_dump_folder: Optional[str] = None
|
crash_dump_folder: Optional[str] = None
|
||||||
|
crash_on_nan: bool = False
|
||||||
show_time_cost: bool = False
|
show_time_cost: bool = False
|
||||||
enable_metrics: bool = False
|
enable_metrics: bool = False
|
||||||
enable_metrics_for_all_schedulers: bool = False
|
enable_metrics_for_all_schedulers: bool = False
|
||||||
@@ -1733,6 +1734,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.crash_dump_folder,
|
default=ServerArgs.crash_dump_folder,
|
||||||
help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
|
help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--crash-on-nan",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.crash_on_nan,
|
||||||
|
help="Crash the server on nan logprobs.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--show-time-cost",
|
"--show-time-cost",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -133,9 +133,9 @@ class TiktokenTokenizer:
|
|||||||
)
|
)
|
||||||
return self.encode(ret) if tokenize else ret
|
return self.encode(ret) if tokenize else ret
|
||||||
|
|
||||||
def __call__(self, text, **kwargs):
|
def __call__(self, text: List[str], **kwargs):
|
||||||
return {
|
return {
|
||||||
"input_ids": self.encode(text),
|
"input_ids": [self.encode(x) for x in text],
|
||||||
}
|
}
|
||||||
|
|
||||||
def init_xgrammar(self):
|
def init_xgrammar(self):
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
# Temporarily do this to avoid changing all imports in the repo
|
# Temporarily do this to avoid changing all imports in the repo
|
||||||
from .common import *
|
from sglang.srt.utils.common import *
|
||||||
|
|||||||
@@ -487,7 +487,7 @@ def make_layers(
|
|||||||
# circula imports
|
# circula imports
|
||||||
from sglang.srt.distributed import get_pp_indices
|
from sglang.srt.distributed import get_pp_indices
|
||||||
from sglang.srt.layers.utils import PPMissingLayer
|
from sglang.srt.layers.utils import PPMissingLayer
|
||||||
from sglang.srt.offloader import get_offloader
|
from sglang.srt.utils.offloader import get_offloader
|
||||||
|
|
||||||
assert not pp_size or num_hidden_layers >= pp_size
|
assert not pp_size or num_hidden_layers >= pp_size
|
||||||
start_layer, end_layer = (
|
start_layer, end_layer = (
|
||||||
|
|||||||
@@ -11,14 +11,14 @@ from sglang.srt.distributed.naive_distributed import (
|
|||||||
get_naive_distributed,
|
get_naive_distributed,
|
||||||
set_naive_distributed,
|
set_naive_distributed,
|
||||||
)
|
)
|
||||||
from sglang.srt.host_shared_memory import (
|
from sglang.srt.layers.parameter import ModelWeightParameter
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import MultiprocessingSerializer, is_pin_memory_available
|
||||||
|
from sglang.srt.utils.host_shared_memory import (
|
||||||
HostSharedMemoryManager,
|
HostSharedMemoryManager,
|
||||||
get_host_shared_memory_manager,
|
get_host_shared_memory_manager,
|
||||||
set_host_shared_memory_manager,
|
set_host_shared_memory_manager,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.parameter import ModelWeightParameter
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
|
||||||
from sglang.srt.utils import MultiprocessingSerializer, is_pin_memory_available
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -7,7 +7,6 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
from sglang.srt.bench_utils import bench_kineto
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
create_per_token_group_quant_fp8_output_scale,
|
create_per_token_group_quant_fp8_output_scale,
|
||||||
)
|
)
|
||||||
@@ -16,6 +15,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
from sglang.srt.utils.bench_utils import bench_kineto
|
||||||
|
|
||||||
# CI environment detection
|
# CI environment detection
|
||||||
IS_CI = (
|
IS_CI = (
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from sgl_kernel.testing.rotary_embedding import (
|
|||||||
create_inputs,
|
create_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.bench_utils import bench_kineto
|
from sglang.srt.utils.bench_utils import bench_kineto
|
||||||
|
|
||||||
# CI environment detection
|
# CI environment detection
|
||||||
IS_CI = (
|
IS_CI = (
|
||||||
|
|||||||
Reference in New Issue
Block a user