Revert "Deprecate global_server_args_dict" (#11520)
This commit is contained in:
@@ -6,6 +6,9 @@
|
||||
class GlobalConfig:
|
||||
"""
|
||||
Store some global constants.
|
||||
|
||||
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
|
||||
many global runtime arguments as well.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -5,7 +5,7 @@ from packaging import version
|
||||
from torch.cuda.memory import CUDAPluggableAllocator
|
||||
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
nccl_allocator_source = """
|
||||
#include <nccl.h>
|
||||
@@ -32,7 +32,7 @@ _graph_pool_id = None
|
||||
|
||||
|
||||
def is_symmetric_memory_enabled():
|
||||
return get_global_server_args().enable_symm_mem
|
||||
return global_server_args_dict["enable_symm_mem"]
|
||||
|
||||
|
||||
def set_graph_pool_id(graph_pool_id):
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Literal, Optional
|
||||
import torch
|
||||
|
||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,7 +34,7 @@ class ExpertLocationDispatchInfo:
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, layer_id: int):
|
||||
ep_dispatch_algorithm = get_global_server_args().ep_dispatch_algorithm
|
||||
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
|
||||
expert_location_metadata = get_global_expert_location_metadata()
|
||||
assert expert_location_metadata is not None
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from sglang.srt.eplb.expert_location import (
|
||||
ExpertLocationMetadata,
|
||||
get_global_expert_location_metadata,
|
||||
)
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -97,7 +97,7 @@ def _update_expert_weights_with_canary(
|
||||
canary_tensor = (
|
||||
_get_canary_value(old_expert_location_metadata, layer_id)
|
||||
.clone()
|
||||
.to(device=get_global_server_args().device, non_blocking=True)
|
||||
.to(device=global_server_args_dict["device"], non_blocking=True)
|
||||
)
|
||||
routed_experts_weights_of_layer[layer_id].append(canary_tensor)
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
||||
# TODO: Change the hard-coded block_seq_num
|
||||
self.BLOCK_SEQ = 128
|
||||
|
||||
if get_global_server_args().triton_attention_reduce_in_fp32:
|
||||
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
||||
self.reduce_dtype = torch.float32
|
||||
else:
|
||||
self.reduce_dtype = torch.float16
|
||||
|
||||
@@ -11,8 +11,8 @@ import triton.language as tl
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -830,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
):
|
||||
# Do multi-head attention with chunked prefix cache
|
||||
if forward_batch.attn_attend_prefix_cache:
|
||||
assert not get_global_server_args().disable_chunked_prefix_cache
|
||||
assert not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||
# MHA for chunked prefix kv cache when running model with MLA
|
||||
assert forward_batch.prefix_chunk_idx is not None
|
||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||
|
||||
@@ -28,8 +28,8 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
create_flashinfer_kv_indices_triton,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import (
|
||||
is_flashinfer_available,
|
||||
@@ -193,9 +193,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
self.skip_prefill = skip_prefill
|
||||
self.enable_chunk_kv = (
|
||||
not skip_prefill
|
||||
and get_global_server_args().disaggregation_mode != "decode"
|
||||
and not get_global_server_args().disable_chunked_prefix_cache
|
||||
and not get_global_server_args().flashinfer_mla_disable_ragged
|
||||
and global_server_args_dict["disaggregation_mode"] != "decode"
|
||||
and not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
||||
)
|
||||
self.page_size = model_runner.page_size
|
||||
|
||||
@@ -306,7 +306,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||
use_ragged = (
|
||||
not get_global_server_args().flashinfer_mla_disable_ragged
|
||||
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
||||
and extend_no_prefix
|
||||
)
|
||||
|
||||
|
||||
@@ -23,9 +23,9 @@ from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
|
||||
@@ -162,7 +162,7 @@ class Indexer(CustomOp):
|
||||
base=rope_theta, # type: ignore
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False,
|
||||
device=get_global_server_args().device,
|
||||
device=global_server_args_dict["device"],
|
||||
)
|
||||
self.block_size = block_size
|
||||
self.scale_fmt = scale_fmt
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -11,7 +11,7 @@ if _is_cuda:
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
if get_global_server_args().triton_attention_reduce_in_fp32:
|
||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||
REDUCE_TRITON_TYPE = tl.float32
|
||||
REDUCE_TORCH_TYPE = torch.float32
|
||||
else:
|
||||
|
||||
@@ -20,8 +20,8 @@ from sglang.srt.layers.attention.utils import (
|
||||
create_flashmla_kv_indices_triton,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
||||
|
||||
if is_flashinfer_available():
|
||||
@@ -123,9 +123,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
|
||||
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
||||
|
||||
self.disable_chunked_prefix_cache = (
|
||||
get_global_server_args().disable_chunked_prefix_cache
|
||||
)
|
||||
self.disable_chunked_prefix_cache = global_server_args_dict[
|
||||
"disable_chunked_prefix_cache"
|
||||
]
|
||||
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
ROTARY_EMBED_CLASSES = {
|
||||
@@ -468,7 +468,7 @@ class VisionAttention(nn.Module):
|
||||
_passed_backend = qkv_backend
|
||||
qkv_backend = self._determine_attention_backend(_passed_backend)
|
||||
if (
|
||||
get_global_server_args().mm_attention_backend is None
|
||||
global_server_args_dict["mm_attention_backend"] is None
|
||||
and _passed_backend is None
|
||||
):
|
||||
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
||||
@@ -528,7 +528,7 @@ class VisionAttention(nn.Module):
|
||||
- CUDA: "triton_attn"
|
||||
- Non-CUDA: "sdpa"
|
||||
"""
|
||||
override_backend = get_global_server_args().mm_attention_backend
|
||||
override_backend = global_server_args_dict["mm_attention_backend"]
|
||||
if override_backend is not None:
|
||||
backend = override_backend
|
||||
elif passed_backend is not None:
|
||||
|
||||
@@ -40,9 +40,8 @@ from sglang.srt.layers.moe import (
|
||||
get_moe_a2a_backend,
|
||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_cuda,
|
||||
@@ -169,7 +168,7 @@ class LayerScatterModes:
|
||||
|
||||
|
||||
def enable_moe_dense_fully_dp():
|
||||
return get_global_server_args().moe_dense_tp_size == 1
|
||||
return global_server_args_dict["moe_dense_tp_size"] == 1
|
||||
|
||||
|
||||
class LayerCommunicator:
|
||||
@@ -315,9 +314,7 @@ class LayerCommunicator:
|
||||
def should_fuse_mlp_allreduce_with_next_layer(
|
||||
self, forward_batch: ForwardBatch
|
||||
) -> bool:
|
||||
speculative_algo = SpeculativeAlgorithm.from_string(
|
||||
get_global_server_args().speculative_algorithm
|
||||
)
|
||||
speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
|
||||
if (
|
||||
is_dp_attention_enabled()
|
||||
and speculative_algo is not None
|
||||
@@ -336,7 +333,7 @@ class LayerCommunicator:
|
||||
static_conditions_met = (
|
||||
(not self.is_last_layer)
|
||||
and (self._context.tp_size > 1)
|
||||
and get_global_server_args().enable_flashinfer_allreduce_fusion
|
||||
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
||||
and _is_flashinfer_available
|
||||
)
|
||||
|
||||
@@ -534,7 +531,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
(_is_sm100_supported or _is_sm90_supported)
|
||||
and _is_flashinfer_available
|
||||
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
||||
and get_global_server_args().enable_flashinfer_allreduce_fusion
|
||||
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
||||
and hidden_states.shape[0] <= 4096
|
||||
):
|
||||
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
||||
|
||||
@@ -38,15 +38,17 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_dp_device,
|
||||
get_dp_dtype,
|
||||
get_dp_hidden_size,
|
||||
get_global_dp_buffer,
|
||||
get_local_attention_dp_size,
|
||||
set_dp_buffer_len,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -228,8 +230,8 @@ class LogitsProcessor(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logit_scale = logit_scale
|
||||
self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head
|
||||
self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head
|
||||
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
||||
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
|
||||
if self.use_attn_tp_group:
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.do_tensor_parallel_all_gather = (
|
||||
@@ -252,8 +254,8 @@ class LogitsProcessor(nn.Module):
|
||||
):
|
||||
self.final_logit_softcapping = None
|
||||
|
||||
self.debug_tensor_dump_output_folder = (
|
||||
get_global_server_args().debug_tensor_dump_output_folder
|
||||
self.debug_tensor_dump_output_folder = global_server_args_dict.get(
|
||||
"debug_tensor_dump_output_folder", None
|
||||
)
|
||||
|
||||
def compute_logprobs_for_multi_item_scoring(
|
||||
@@ -370,7 +372,9 @@ class LogitsProcessor(nn.Module):
|
||||
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
||||
|
||||
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
|
||||
multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
|
||||
multi_item_delimiter = global_server_args_dict.get(
|
||||
"multi_item_scoring_delimiter"
|
||||
)
|
||||
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
|
||||
return self.compute_logprobs_for_multi_item_scoring(
|
||||
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
|
||||
|
||||
@@ -27,10 +27,12 @@ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
|
||||
@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
is_cuda,
|
||||
@@ -265,9 +265,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||
self.with_bias = False
|
||||
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||
self.flashinfer_mxfp4_moe_precision = (
|
||||
get_global_server_args().flashinfer_mxfp4_moe_precision
|
||||
)
|
||||
self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
|
||||
"flashinfer_mxfp4_moe_precision"
|
||||
]
|
||||
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
|
||||
@@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import (
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
|
||||
|
||||
if is_cuda():
|
||||
@@ -33,7 +33,7 @@ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
||||
class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_nan_detection = get_global_server_args().enable_nan_detection
|
||||
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
|
||||
self.tp_sync_group = get_tp_group().device_group
|
||||
|
||||
if is_dp_attention_enabled():
|
||||
@@ -103,7 +103,7 @@ class Sampler(nn.Module):
|
||||
del logits
|
||||
|
||||
if True: # Keep this redundant check to simplify some internal code sync
|
||||
if get_global_server_args().sampling_backend == "flashinfer":
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
if sampling_info.need_min_p_sampling:
|
||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||
@@ -118,7 +118,7 @@ class Sampler(nn.Module):
|
||||
filter_apply_order="joint",
|
||||
check_nan=self.use_nan_detection,
|
||||
)
|
||||
elif get_global_server_args().sampling_backend == "pytorch":
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# A slower fallback implementation with torch native operations.
|
||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs,
|
||||
@@ -131,7 +131,7 @@ class Sampler(nn.Module):
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid sampling backend: {get_global_server_args().sampling_backend}"
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
)
|
||||
|
||||
if return_logprob:
|
||||
|
||||
@@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
|
||||
from sglang.utils import logger
|
||||
|
||||
@@ -428,7 +428,7 @@ def _adjust_embedding_length(
|
||||
f"tokens from multimodal embeddings."
|
||||
)
|
||||
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
||||
chunked_prefill_size = get_global_server_args().chunked_prefill_size
|
||||
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
||||
if chunked_prefill_size != -1:
|
||||
logger.warning(
|
||||
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
|
||||
|
||||
@@ -72,7 +72,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs, get_global_server_args
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import flatten_nested_list
|
||||
from sglang.srt.utils.common import next_power_of_2
|
||||
|
||||
@@ -82,6 +82,47 @@ if TYPE_CHECKING:
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"attention_backend",
|
||||
"mm_attention_backend",
|
||||
"debug_tensor_dump_inject",
|
||||
"debug_tensor_dump_output_folder",
|
||||
"chunked_prefill_size",
|
||||
"device",
|
||||
"disable_chunked_prefix_cache",
|
||||
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
||||
"disable_radix_cache",
|
||||
"enable_dp_lm_head",
|
||||
"enable_fp32_lm_head",
|
||||
"flashinfer_mxfp4_moe_precision",
|
||||
"enable_flashinfer_allreduce_fusion",
|
||||
"moe_dense_tp_size",
|
||||
"ep_dispatch_algorithm",
|
||||
"ep_num_redundant_experts",
|
||||
"enable_nan_detection",
|
||||
"flashinfer_mla_disable_ragged",
|
||||
"pp_max_micro_batch_size",
|
||||
"disable_shared_experts_fusion",
|
||||
"sampling_backend",
|
||||
"speculative_accept_threshold_single",
|
||||
"speculative_accept_threshold_acc",
|
||||
"speculative_attention_mode",
|
||||
"torchao_config",
|
||||
"triton_attention_reduce_in_fp32",
|
||||
"num_reserved_decode_tokens",
|
||||
"weight_loader_disable_mmap",
|
||||
"enable_multimodal",
|
||||
"enable_symm_mem",
|
||||
"enable_custom_logit_processor",
|
||||
"disaggregation_mode",
|
||||
"enable_deterministic_inference",
|
||||
"nsa_prefill",
|
||||
"nsa_decode",
|
||||
"multi_item_scoring_delimiter",
|
||||
]
|
||||
|
||||
# Put some global args for easy access
|
||||
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -642,9 +683,12 @@ class Req:
|
||||
def is_prefill_only(self) -> bool:
|
||||
"""Check if this request is prefill-only (no token generation needed)."""
|
||||
# NOTE: when spec is enabled, prefill_only optimizations are disabled
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
spec_alg = get_global_server_args().speculative_algorithm
|
||||
return self.sampling_params.max_new_tokens == 0 and spec_alg is None
|
||||
spec_alg = global_server_args_dict["speculative_algorithm"]
|
||||
return self.sampling_params.max_new_tokens == 0 and (
|
||||
spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
|
||||
)
|
||||
|
||||
def add_latency(self, stage: RequestStage):
|
||||
if self.metrics_collector is None:
|
||||
|
||||
@@ -122,6 +122,7 @@ from sglang.srt.managers.schedule_batch import (
|
||||
Req,
|
||||
RequestStage,
|
||||
ScheduleBatch,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.managers.schedule_policy import (
|
||||
AddReqResult,
|
||||
@@ -149,7 +150,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.tracing.trace import (
|
||||
@@ -446,12 +447,13 @@ class Scheduler(
|
||||
self.max_req_input_len,
|
||||
self.random_seed,
|
||||
self.device,
|
||||
worker_global_server_args_dict,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self.tp_worker.get_worker_info()
|
||||
if get_global_server_args().pp_max_micro_batch_size is None:
|
||||
get_global_server_args().pp_max_micro_batch_size = max(
|
||||
if global_server_args_dict["pp_max_micro_batch_size"] is None:
|
||||
global_server_args_dict["pp_max_micro_batch_size"] = max(
|
||||
self.max_running_requests // server_args.pp_size, 1
|
||||
)
|
||||
|
||||
@@ -463,6 +465,7 @@ class Scheduler(
|
||||
self.world_group = get_world_group()
|
||||
|
||||
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
||||
global_server_args_dict.update(worker_global_server_args_dict)
|
||||
set_random_seed(self.random_seed)
|
||||
|
||||
# Hybrid memory pool
|
||||
@@ -1863,7 +1866,7 @@ class Scheduler(
|
||||
return ret
|
||||
|
||||
def get_num_allocatable_reqs(self, running_bs):
|
||||
res = get_global_server_args().pp_max_micro_batch_size - running_bs
|
||||
res = global_server_args_dict["pp_max_micro_batch_size"] - running_bs
|
||||
if self.pp_size > 1:
|
||||
res = min(res, self.req_to_token_pool.available_size())
|
||||
return res
|
||||
@@ -2607,7 +2610,7 @@ class Scheduler(
|
||||
)
|
||||
|
||||
def get_internal_state(self, recv_req: GetInternalStateReq):
|
||||
ret = vars(get_global_server_args())
|
||||
ret = dict(global_server_args_dict)
|
||||
ret["last_gen_throughput"] = self.last_gen_throughput
|
||||
ret["memory_usage"] = {
|
||||
"weight": round(
|
||||
@@ -2663,11 +2666,11 @@ class Scheduler(
|
||||
logger.info(f"{avg_spec_accept_length=}")
|
||||
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
|
||||
for k, v in server_args_dict.items():
|
||||
setattr(get_global_server_args(), k, v)
|
||||
logger.info(f"Global server args updated! {get_global_server_args()=}")
|
||||
global_server_args_dict[k] = v
|
||||
logger.info(f"Global server args updated! {global_server_args_dict=}")
|
||||
return SetInternalStateReqOutput(
|
||||
updated=True,
|
||||
server_args=vars(get_global_server_args()),
|
||||
server_args=global_server_args_dict,
|
||||
)
|
||||
|
||||
def handle_rpc_request(self, recv_req: RpcReqInput):
|
||||
|
||||
@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
@@ -190,6 +190,7 @@ class TpModelWorker:
|
||||
self.max_req_input_len,
|
||||
self.random_seed,
|
||||
self.device,
|
||||
global_server_args_dict,
|
||||
self.model_runner.req_to_token_pool.size,
|
||||
self.model_runner.req_to_token_pool.max_context_len,
|
||||
self.model_runner.token_to_kv_pool.size,
|
||||
|
||||
@@ -11,7 +11,7 @@ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import support_triton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -19,6 +19,10 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GLOBAL_SERVER_ARGS_KEYS = ["attention_backend"]
|
||||
|
||||
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton(
|
||||
@@ -84,7 +88,7 @@ def write_cache_indices(
|
||||
prefix_tensors: list[torch.Tensor],
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
):
|
||||
if support_triton(get_global_server_args().attention_backend):
|
||||
if support_triton(global_server_args_dict.get("attention_backend")):
|
||||
prefix_pointers = torch.tensor(
|
||||
[t.data_ptr() for t in prefix_tensors],
|
||||
device=req_to_token_pool.device,
|
||||
@@ -125,8 +129,8 @@ def get_last_loc(
|
||||
prefix_lens_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
get_global_server_args().attention_backend != "ascend"
|
||||
and get_global_server_args().attention_backend != "torch_native"
|
||||
global_server_args_dict["attention_backend"] != "ascend"
|
||||
and global_server_args_dict["attention_backend"] != "torch_native"
|
||||
):
|
||||
impl = get_last_loc_triton
|
||||
else:
|
||||
|
||||
@@ -83,6 +83,10 @@ from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.lora.lora_registry import LoRARef
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
GLOBAL_SERVER_ARGS_KEYS,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
BaseTokenToKVPoolAllocator,
|
||||
PagedTokenToKVPoolAllocator,
|
||||
@@ -121,11 +125,7 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import (
|
||||
ServerArgs,
|
||||
get_global_server_args,
|
||||
set_global_server_args_for_scheduler,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
@@ -275,12 +275,15 @@ class ModelRunner:
|
||||
# Model-specific adjustment
|
||||
self.model_specific_adjustment()
|
||||
|
||||
# Set the global server_args in the scheduler process
|
||||
set_global_server_args_for_scheduler(server_args)
|
||||
global_server_args = get_global_server_args()
|
||||
|
||||
# FIXME: hacky set `use_mla_backend`
|
||||
global_server_args.use_mla_backend = self.use_mla_backend
|
||||
# Global vars
|
||||
global_server_args_dict.update(
|
||||
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
|
||||
| {
|
||||
# TODO it is indeed not a "server args"
|
||||
"use_mla_backend": self.use_mla_backend,
|
||||
"speculative_algorithm": self.spec_algorithm,
|
||||
}
|
||||
)
|
||||
|
||||
# Init OpenMP threads binding for CPU
|
||||
if self.device == "cpu":
|
||||
@@ -429,7 +432,7 @@ class ModelRunner:
|
||||
# In layered loading, torchao may have been applied
|
||||
if not torchao_applied:
|
||||
apply_torchao_config_to_model(
|
||||
self.model, get_global_server_args().torchao_config
|
||||
self.model, global_server_args_dict["torchao_config"]
|
||||
)
|
||||
|
||||
# Apply torch TP if the model supports it
|
||||
@@ -1835,10 +1838,12 @@ class ModelRunner:
|
||||
self.server_args.attention_backend
|
||||
)
|
||||
|
||||
(
|
||||
get_global_server_args().prefill_attention_backend,
|
||||
get_global_server_args().decode_attention_backend,
|
||||
) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
|
||||
global_server_args_dict.update(
|
||||
{
|
||||
"decode_attention_backend": self.decode_attention_backend_str,
|
||||
"prefill_attention_backend": self.prefill_attention_backend_str,
|
||||
}
|
||||
)
|
||||
return attn_backend
|
||||
|
||||
def _get_attention_backend_from_str(self, backend_str: str):
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
# ruff: noqa: SIM117
|
||||
import collections
|
||||
import concurrent
|
||||
import dataclasses
|
||||
import fnmatch
|
||||
import glob
|
||||
@@ -11,10 +12,12 @@ import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -30,10 +33,10 @@ from typing import (
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import requests
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
|
||||
# Try to import accelerate (optional dependency)
|
||||
try:
|
||||
from accelerate import infer_auto_device_map, init_empty_weights
|
||||
@@ -78,6 +81,8 @@ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
|
||||
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
|
||||
)
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
_BAR_FORMAT,
|
||||
default_weight_loader,
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files,
|
||||
@@ -440,8 +445,10 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
hf_weights_files,
|
||||
)
|
||||
elif use_safetensors:
|
||||
weight_loader_disable_mmap = (
|
||||
get_global_server_args().weight_loader_disable_mmap
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
weight_loader_disable_mmap = global_server_args_dict.get(
|
||||
"weight_loader_disable_mmap"
|
||||
)
|
||||
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
@@ -609,9 +616,9 @@ class LayeredModelLoader(DefaultModelLoader):
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
torchao_config = get_global_server_args().torchao_config
|
||||
torchao_config = global_server_args_dict.get("torchao_config")
|
||||
target_device = torch.device(device_config.device)
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
|
||||
@@ -46,14 +46,15 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
kv_cache_scales_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -446,7 +447,7 @@ class ApertusForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
kv_cache_scales_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SGLang BailingMoE model."""
|
||||
""" SGLang BailingMoE model."""
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
@@ -68,6 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -75,7 +76,6 @@ from sglang.srt.models.utils import (
|
||||
create_fused_set_kv_buffer_arg,
|
||||
enable_fused_set_kv_buffer,
|
||||
)
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
|
||||
|
||||
LoraConfig = None
|
||||
@@ -204,8 +204,8 @@ class BailingMoESparseMoeBlock(nn.Module):
|
||||
else:
|
||||
self.router_dtype = torch.bfloat16
|
||||
|
||||
# TODO global_server_args.ep_num_redundant_experts is used for eplb, not supported now
|
||||
assert get_global_server_args().ep_num_redundant_experts == 0
|
||||
# TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now
|
||||
assert global_server_args_dict["ep_num_redundant_experts"] == 0
|
||||
# check group topk
|
||||
self.num_expert_group = getattr(config, "n_group", 0)
|
||||
self.topk_group = getattr(config, "topk_group", 0)
|
||||
@@ -220,7 +220,7 @@ class BailingMoESparseMoeBlock(nn.Module):
|
||||
self.use_grouped_topk = False
|
||||
|
||||
self.num_experts = (
|
||||
config.num_experts + get_global_server_args().ep_num_redundant_experts
|
||||
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
||||
)
|
||||
|
||||
self.gate = BailingMoEGate(
|
||||
@@ -824,7 +824,7 @@ class BailingMoEForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SGLang BailingMoENextN model."""
|
||||
""" SGLang BailingMoENextN model."""
|
||||
import logging
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
@@ -29,14 +29,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
LoraConfig = None
|
||||
@@ -144,7 +145,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("model.shared_head.head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -152,7 +152,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("model.shared_head.head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from sglang.srt.configs.model_config import (
|
||||
get_nsa_index_topk,
|
||||
is_deepseek_nsa,
|
||||
)
|
||||
from sglang.srt.debug_utils.dumper import dumper
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_pp_group,
|
||||
@@ -107,11 +108,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.single_batch_overlap import SboFlags
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.two_batch_overlap import (
|
||||
MaybeTboDeepEPDispatcher,
|
||||
model_forward_maybe_tbo,
|
||||
@@ -520,7 +520,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.num_fused_shared_experts = (
|
||||
0
|
||||
if get_global_server_args().disable_shared_experts_fusion
|
||||
if global_server_args_dict["disable_shared_experts_fusion"]
|
||||
else config.n_shared_experts
|
||||
)
|
||||
self.config = config
|
||||
@@ -549,7 +549,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
num_experts=config.n_routed_experts
|
||||
+ self.num_fused_shared_experts
|
||||
+ get_global_server_args().ep_num_redundant_experts,
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -627,7 +627,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.num_experts = (
|
||||
config.n_routed_experts
|
||||
+ get_global_server_args().ep_num_redundant_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"]
|
||||
)
|
||||
self.renormalize = config.norm_topk_prob
|
||||
self.topk_group = config.topk_group
|
||||
@@ -1125,7 +1125,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False,
|
||||
device=get_global_server_args().device,
|
||||
device=global_server_args_dict["device"],
|
||||
)
|
||||
|
||||
if rope_scaling:
|
||||
@@ -1169,12 +1169,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_scale_v = None
|
||||
self.use_deep_gemm_bmm = False
|
||||
|
||||
self.flashinfer_mla_disable_ragged = (
|
||||
get_global_server_args().flashinfer_mla_disable_ragged,
|
||||
)
|
||||
self.disable_chunked_prefix_cache = (
|
||||
get_global_server_args().disable_chunked_prefix_cache
|
||||
)
|
||||
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
||||
"flashinfer_mla_disable_ragged"
|
||||
]
|
||||
self.disable_chunked_prefix_cache = global_server_args_dict[
|
||||
"disable_chunked_prefix_cache"
|
||||
]
|
||||
|
||||
self.current_attention_backend = (
|
||||
None # Attention backend used by current forward batch
|
||||
@@ -1253,18 +1253,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
) -> AttnForwardMethod:
|
||||
# Determine attention backend used by current forward batch
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
attention_backend = get_global_server_args().decode_attention_backend
|
||||
attention_backend = global_server_args_dict["decode_attention_backend"]
|
||||
elif (
|
||||
forward_batch.forward_mode.is_target_verify()
|
||||
or forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
# Use the specified backend for speculative operations (both verify and draft extend)
|
||||
if get_global_server_args().speculative_attention_mode == "decode":
|
||||
attention_backend = get_global_server_args().decode_attention_backend
|
||||
if global_server_args_dict["speculative_attention_mode"] == "decode":
|
||||
attention_backend = global_server_args_dict["decode_attention_backend"]
|
||||
else: # default to prefill
|
||||
attention_backend = get_global_server_args().prefill_attention_backend
|
||||
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
||||
else:
|
||||
attention_backend = get_global_server_args().prefill_attention_backend
|
||||
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
||||
self.current_attention_backend = attention_backend
|
||||
|
||||
handler = AttentionBackendRegistry.get_handler(attention_backend)
|
||||
@@ -2365,9 +2365,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
get_global_server_args().speculative_algorithm
|
||||
)
|
||||
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
|
||||
self.layer_id = layer_id
|
||||
self.is_nextn = is_nextn
|
||||
self.self_attn = DeepseekV2AttentionMLA(
|
||||
@@ -2819,7 +2817,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@@ -2839,7 +2837,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self, architecture: str = "DeepseekV3ForCausalLM"
|
||||
):
|
||||
self.num_fused_shared_experts = 0
|
||||
if get_global_server_args().disable_shared_experts_fusion:
|
||||
if global_server_args_dict["disable_shared_experts_fusion"]:
|
||||
return
|
||||
|
||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||
@@ -2858,7 +2856,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
|
||||
|
||||
if disable_reason is not None:
|
||||
get_global_server_args().disable_shared_experts_fusion = True
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = True
|
||||
self.num_fused_shared_experts = 0
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
|
||||
@@ -33,9 +33,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import add_prefix, is_cuda, make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -483,7 +483,7 @@ class FalconH1ForCausalLM(nn.Module):
|
||||
quant_config=quant_config,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.lm_head = self.lm_head.float()
|
||||
self.lm_head_multiplier = config.lm_head_multiplier
|
||||
|
||||
@@ -56,13 +56,18 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
per_tensor_quant_mla_fp8,
|
||||
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -72,7 +77,6 @@ from sglang.srt.models.deepseek_v2 import (
|
||||
DeepseekV2Model,
|
||||
DeepseekV2MoE,
|
||||
)
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||
from sglang.srt.utils import (
|
||||
BumpAllocator,
|
||||
@@ -391,7 +395,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.num_fused_shared_experts = (
|
||||
0
|
||||
if get_global_server_args().disable_shared_experts_fusion
|
||||
if global_server_args_dict["disable_shared_experts_fusion"]
|
||||
else config.n_shared_experts
|
||||
)
|
||||
self.config = config
|
||||
@@ -428,7 +432,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
num_experts=config.n_routed_experts
|
||||
+ self.num_fused_shared_experts
|
||||
+ get_global_server_args().ep_num_redundant_experts,
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -472,7 +476,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.num_experts = (
|
||||
config.n_routed_experts
|
||||
+ get_global_server_args().ep_num_redundant_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"]
|
||||
)
|
||||
self.renormalize = config.norm_topk_prob
|
||||
self.topk_group = config.topk_group
|
||||
@@ -754,7 +758,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@@ -770,7 +774,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
||||
self, architecture: str = "Glm4MoeForCausalLM"
|
||||
):
|
||||
self.num_fused_shared_experts = 0
|
||||
if get_global_server_args().disable_shared_experts_fusion:
|
||||
if global_server_args_dict["disable_shared_experts_fusion"]:
|
||||
return
|
||||
|
||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||
@@ -786,7 +790,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
||||
disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
|
||||
|
||||
if disable_reason is not None:
|
||||
get_global_server_args().disable_shared_experts_fusion = True
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = True
|
||||
self.num_fused_shared_experts = 0
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
|
||||
@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import BumpAllocator, add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -145,7 +145,7 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("model.shared_head.head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -16,10 +16,10 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.glm4_moe import Glm4MoeModel
|
||||
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
|
||||
from sglang.srt.utils.hf_transformers_utils import get_processor
|
||||
|
||||
@@ -47,7 +47,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
||||
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
||||
self.num_fused_shared_experts = (
|
||||
0
|
||||
if get_global_server_args().disable_shared_experts_fusion
|
||||
if global_server_args_dict["disable_shared_experts_fusion"]
|
||||
else config.n_shared_experts
|
||||
)
|
||||
|
||||
@@ -68,7 +68,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
@@ -81,7 +81,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
||||
self, architecture: str = "Glm4MoeForCausalLM"
|
||||
):
|
||||
self.num_fused_shared_experts = 0
|
||||
if get_global_server_args().disable_shared_experts_fusion:
|
||||
if global_server_args_dict["disable_shared_experts_fusion"]:
|
||||
return
|
||||
|
||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||
@@ -97,7 +97,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
||||
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
|
||||
|
||||
if disable_reason is not None:
|
||||
get_global_server_args().disable_shared_experts_fusion = True
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = True
|
||||
self.num_fused_shared_experts = 0
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
|
||||
@@ -63,13 +63,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.utils import (
|
||||
create_fused_set_kv_buffer_arg,
|
||||
enable_fused_set_kv_buffer,
|
||||
)
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import (
|
||||
LazyValue,
|
||||
add_prefix,
|
||||
@@ -138,7 +138,7 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
}
|
||||
self.experts = experts_type(
|
||||
num_experts=config.num_local_experts
|
||||
+ get_global_server_args().ep_num_redundant_experts,
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
top_k=config.num_experts_per_tok,
|
||||
layer_id=layer_id,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -259,7 +259,7 @@ class GptOssAttention(nn.Module):
|
||||
|
||||
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
|
||||
# others can use bfloat16
|
||||
attn_backend = get_global_server_args().attention_backend
|
||||
attn_backend = global_server_args_dict.get("attention_backend")
|
||||
sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
|
||||
self.sinks = nn.Parameter(
|
||||
torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
|
||||
@@ -591,7 +591,7 @@ class GptOssForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
# quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
@@ -28,6 +28,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
@@ -35,6 +36,7 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.elementwise import (
|
||||
experts_combine_triton,
|
||||
fused_dual_residual_rmsnorm,
|
||||
fused_rmsnorm,
|
||||
gelu_and_mul_triton,
|
||||
@@ -62,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -864,10 +866,10 @@ class Grok1ForCausalLM(nn.Module):
|
||||
|
||||
# Dump tensors for debugging
|
||||
global debug_tensor_dump_output_folder, debug_tensor_dump_inject
|
||||
debug_tensor_dump_output_folder = (
|
||||
get_global_server_args().debug_tensor_dump_output_folder
|
||||
)
|
||||
debug_tensor_dump_inject = get_global_server_args().debug_tensor_dump_inject
|
||||
debug_tensor_dump_output_folder = global_server_args_dict[
|
||||
"debug_tensor_dump_output_folder"
|
||||
]
|
||||
debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"]
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
|
||||
@@ -45,13 +45,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
kv_cache_scales_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
@@ -433,7 +433,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -32,10 +32,14 @@
|
||||
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from typing import Iterable, Optional, Tuple
|
||||
import os
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.configs import LongcatFlashConfig
|
||||
from sglang.srt.distributed import (
|
||||
@@ -81,10 +85,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import (
|
||||
BumpAllocator,
|
||||
LazyValue,
|
||||
@@ -591,7 +595,7 @@ class LongcatFlashForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -31,9 +31,9 @@ from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import is_cpu
|
||||
|
||||
_is_cpu = is_cpu()
|
||||
@@ -448,7 +448,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
)
|
||||
|
||||
self.has_vision = (
|
||||
self.has_vision_weights and get_global_server_args().enable_multimodal
|
||||
self.has_vision_weights and global_server_args_dict["enable_multimodal"]
|
||||
)
|
||||
|
||||
if self.has_vision:
|
||||
|
||||
@@ -64,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
||||
from sglang.srt.utils import add_prefix, is_cuda, make_layers
|
||||
|
||||
@@ -156,7 +156,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
layer_id=self.layer_id,
|
||||
top_k=config.num_experts_per_tok,
|
||||
num_experts=config.num_experts
|
||||
+ get_global_server_args().ep_num_redundant_experts,
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
quant_config=quant_config,
|
||||
@@ -192,7 +192,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.num_experts = (
|
||||
config.num_experts + get_global_server_args().ep_num_redundant_experts
|
||||
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
||||
)
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
@@ -643,7 +643,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
# For EAGLE3 support
|
||||
|
||||
@@ -54,6 +54,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
|
||||
from sglang.srt.layers.utils import get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -63,7 +64,6 @@ from sglang.srt.models.utils import (
|
||||
create_fused_set_kv_buffer_arg,
|
||||
enable_fused_set_kv_buffer,
|
||||
)
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import (
|
||||
add_prefix,
|
||||
is_cuda,
|
||||
@@ -104,7 +104,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
num_experts=config.num_experts
|
||||
+ get_global_server_args().ep_num_redundant_experts,
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
top_k=config.num_experts_per_tok,
|
||||
layer_id=layer_id,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -125,7 +125,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.num_experts = (
|
||||
config.num_experts + get_global_server_args().ep_num_redundant_experts
|
||||
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
||||
)
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
@@ -693,7 +693,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
@@ -46,7 +47,6 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
sharded_weight_loader,
|
||||
)
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import (
|
||||
LazyValue,
|
||||
add_prefix,
|
||||
@@ -905,7 +905,7 @@ class Qwen3NextForCausalLM(nn.Module):
|
||||
quant_config=quant_config,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.lm_head = self.lm_head.float()
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@@ -21,13 +21,14 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
||||
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
|
||||
from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -68,7 +69,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("model.shared_head.head", prefix),
|
||||
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -38,12 +38,20 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.utils import get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.mm_utils import general_mm_embed_routine
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
|
||||
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
|
||||
from sglang.srt.models.qwen3_vl import (
|
||||
Qwen3_VisionTransformer,
|
||||
Qwen3VLForConditionalGeneration,
|
||||
|
||||
@@ -57,6 +57,7 @@ from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -299,7 +300,7 @@ class Step3TextDecoderLayer(nn.Module):
|
||||
# self.n_shared_experts = 1
|
||||
# self.num_fused_shared_experts = (
|
||||
# 0
|
||||
# if global_server_args.disable_shared_experts_fusion
|
||||
# if global_server_args_dict["disable_shared_experts_fusion"]
|
||||
# else self.n_shared_experts
|
||||
# )
|
||||
self.num_fused_shared_experts = 0
|
||||
@@ -773,7 +774,7 @@ class Step3VLForConditionalGeneration(nn.Module):
|
||||
# self.n_shared_experts = 1
|
||||
# self.num_fused_shared_experts = (
|
||||
# 0
|
||||
# if global_server_args.disable_shared_experts_fusion
|
||||
# if global_server_args_dict["disable_shared_experts_fusion"]
|
||||
# else self.n_shared_experts
|
||||
# )
|
||||
self.num_fused_shared_experts = 0
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -9,7 +10,6 @@ import torch
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.sampling.sampling_params import TOP_K_ALL
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
@@ -66,10 +66,16 @@ class SamplingBatchInfo:
|
||||
# Handle logit bias
|
||||
logit_bias: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def _get_global_server_args_dict(cls):
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
return global_server_args_dict
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
global_server_args = get_global_server_args()
|
||||
enable_deterministic = global_server_args.enable_deterministic_inference
|
||||
global_server_args_dict = cls._get_global_server_args_dict()
|
||||
enable_deterministic = global_server_args_dict["enable_deterministic_inference"]
|
||||
|
||||
reqs = batch.reqs
|
||||
device = batch.device
|
||||
@@ -106,9 +112,10 @@ class SamplingBatchInfo:
|
||||
logit_bias[i, int(key)] = value
|
||||
|
||||
# Check if any request has custom logit processor
|
||||
has_custom_logit_processor = (
|
||||
global_server_args.enable_custom_logit_processor
|
||||
and any(r.custom_logit_processor for r in reqs) # check the flag first.
|
||||
has_custom_logit_processor = global_server_args_dict[
|
||||
"enable_custom_logit_processor"
|
||||
] and any( # check the flag first.
|
||||
r.custom_logit_processor for r in reqs
|
||||
) # then check the requests.
|
||||
|
||||
if has_custom_logit_processor:
|
||||
|
||||
@@ -53,7 +53,6 @@ from sglang.utils import is_in_ci
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define constants
|
||||
LOAD_FORMAT_CHOICES = [
|
||||
"auto",
|
||||
@@ -3324,22 +3323,6 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
|
||||
# NOTE: This is a global variable to hold the server args for scheduler.
|
||||
_global_server_args: Optional[ServerArgs] = None
|
||||
|
||||
|
||||
def set_global_server_args_for_scheduler(server_args: ServerArgs):
|
||||
global _global_server_args
|
||||
_global_server_args = server_args
|
||||
|
||||
|
||||
def get_global_server_args() -> ServerArgs:
|
||||
if _global_server_args is None:
|
||||
raise ValueError("Global server args is not set yet!")
|
||||
|
||||
return _global_server_args
|
||||
|
||||
|
||||
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
"""
|
||||
Prepare the server arguments from the command line arguments.
|
||||
@@ -3374,8 +3357,8 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
raw_args = parser.parse_args(argv)
|
||||
|
||||
return ServerArgs.from_cli_args(raw_args)
|
||||
server_args = ServerArgs.from_cli_args(raw_args)
|
||||
return server_args
|
||||
|
||||
|
||||
ZMQ_TCP_PORT_DELTA = 233
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
from sglang.srt.layers.moe import get_moe_runner_backend
|
||||
from sglang.srt.layers.moe.utils import is_sbo_enabled
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import get_int_env_var
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
||||
from sglang.srt.managers.overlap_utils import FutureIndices
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.common import (
|
||||
alloc_paged_token_slots_extend,
|
||||
@@ -19,7 +19,6 @@ from sglang.srt.mem_cache.common import (
|
||||
get_last_loc,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.speculative.eagle_info_v2 import (
|
||||
EagleDraftInputV2Mixin,
|
||||
EagleVerifyInputV2Mixin,
|
||||
@@ -333,8 +332,12 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
||||
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
threshold_single=get_global_server_args().speculative_accept_threshold_single,
|
||||
threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
|
||||
threshold_single=global_server_args_dict[
|
||||
"speculative_accept_threshold_single"
|
||||
],
|
||||
threshold_acc=global_server_args_dict[
|
||||
"speculative_accept_threshold_acc"
|
||||
],
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.scheduler import global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
@@ -18,7 +19,6 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.speculative.build_eagle_tree import TreeMaskMode
|
||||
from sglang.srt.speculative.spec_utils import (
|
||||
SIMULATE_ACC_LEN,
|
||||
@@ -265,8 +265,12 @@ class EagleVerifyInputV2Mixin:
|
||||
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
threshold_single=get_global_server_args().speculative_accept_threshold_single,
|
||||
threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
|
||||
threshold_single=global_server_args_dict[
|
||||
"speculative_accept_threshold_single"
|
||||
],
|
||||
threshold_acc=global_server_args_dict[
|
||||
"speculative_accept_threshold_acc"
|
||||
],
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.mem_cache.common import (
|
||||
@@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs, get_global_server_args
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||
EAGLEDraftCudaGraphRunner,
|
||||
@@ -261,7 +261,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
def _create_flashinfer_decode_backend(self):
|
||||
if not get_global_server_args().use_mla_backend:
|
||||
if not global_server_args_dict["use_mla_backend"]:
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
@@ -325,7 +325,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
def _create_trtllm_mla_decode_backend(self):
|
||||
if not get_global_server_args().use_mla_backend:
|
||||
if not global_server_args_dict["use_mla_backend"]:
|
||||
raise ValueError(
|
||||
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
||||
)
|
||||
@@ -340,7 +340,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
def _create_flashinfer_prefill_backend(self):
|
||||
if not get_global_server_args().use_mla_backend:
|
||||
if not global_server_args_dict["use_mla_backend"]:
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferAttnBackend,
|
||||
)
|
||||
@@ -376,7 +376,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
||||
|
||||
def _create_trtllm_mla_prefill_backend(self):
|
||||
if not get_global_server_args().use_mla_backend:
|
||||
if not global_server_args_dict["use_mla_backend"]:
|
||||
raise ValueError(
|
||||
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
||||
)
|
||||
|
||||
@@ -7,8 +7,6 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from dataclasses import dataclass
|
||||
@@ -18,7 +16,7 @@ import torch.nn.functional as F
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.common import (
|
||||
alloc_paged_token_slots_extend,
|
||||
alloc_token_slots,
|
||||
@@ -352,8 +350,10 @@ class NgramVerifyInput(SpecInput):
|
||||
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
threshold_single=get_global_server_args().speculative_accept_threshold_single,
|
||||
threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
|
||||
threshold_single=global_server_args_dict[
|
||||
"speculative_accept_threshold_single"
|
||||
],
|
||||
threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"],
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from sglang.srt.layers.moe import (
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
@@ -30,7 +30,6 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
)
|
||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||
from sglang.srt.operations_strategy import OperationsStrategy
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
|
||||
|
||||
@@ -154,7 +153,7 @@ def _update_device_and_sum_field_from_cpu_field(
|
||||
cpu_value
|
||||
if isinstance(cpu_value, torch.Tensor)
|
||||
else torch.tensor(cpu_value, dtype=old_device_value.dtype)
|
||||
).to(device=get_global_server_args().device, non_blocking=True)
|
||||
).to(device=global_server_args_dict["device"], non_blocking=True)
|
||||
setattr(batch, device_field, new_device_value)
|
||||
|
||||
if sum_field is not None:
|
||||
@@ -583,7 +582,7 @@ class TboForwardBatchPreparer:
|
||||
sum_field=None,
|
||||
)
|
||||
_, child_b.extend_start_loc = compute_position(
|
||||
get_global_server_args().attention_backend,
|
||||
global_server_args_dict["attention_backend"],
|
||||
child_b.extend_prefix_lens,
|
||||
child_b.extend_seq_lens,
|
||||
child_b.extend_num_tokens,
|
||||
@@ -688,7 +687,7 @@ class TboForwardBatchPreparer:
|
||||
|
||||
# TODO improve, e.g. unify w/ `init_raw`
|
||||
if (
|
||||
get_global_server_args().moe_dense_tp_size == 1
|
||||
global_server_args_dict["moe_dense_tp_size"] == 1
|
||||
and batch.global_dp_buffer_len is not None
|
||||
):
|
||||
sum_len = end_token_index - start_token_index
|
||||
@@ -756,7 +755,7 @@ class TboForwardBatchPreparer:
|
||||
value_a = min(tbo_split_token_index, num_token_non_padded)
|
||||
value_b = max(0, num_token_non_padded - tbo_split_token_index)
|
||||
return torch.tensor([value_a, value_b], dtype=torch.int32).to(
|
||||
device=get_global_server_args().device, non_blocking=True
|
||||
device=global_server_args_dict["device"], non_blocking=True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user