Deprecate global_server_args_dict (#11331)
This commit is contained in:
@@ -7,7 +7,11 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.server_args import (
|
||||
ServerArgs,
|
||||
get_global_server_args,
|
||||
set_global_server_args_for_scheduler,
|
||||
)
|
||||
|
||||
|
||||
class LMHeadStub(nn.Module):
|
||||
@@ -32,8 +36,10 @@ class TestLMHeadFP32(unittest.TestCase):
|
||||
raise unittest.SkipTest("needs CUDA GPU")
|
||||
|
||||
def _make_logprocessor(self, vocab_size, enable_fp32):
|
||||
global_server_args_dict["enable_dp_lm_head"] = False
|
||||
global_server_args_dict["enable_fp32_lm_head"] = enable_fp32
|
||||
ServerArgs.__post_init__ = lambda self: None # disable validation
|
||||
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
|
||||
get_global_server_args().enable_dp_lm_head = False
|
||||
get_global_server_args().enable_fp32_lm_head = enable_fp32
|
||||
cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None)
|
||||
return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import unittest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from sglang.srt.server_args import set_global_server_args_for_scheduler
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
@@ -16,17 +17,15 @@ from sglang.test.test_utils import (
|
||||
def check_quant_method(model_path: str, use_marlin_kernel: bool):
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.distributed import (
|
||||
get_tp_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||
from sglang.srt.layers.quantization.utils import get_dynamic_override
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
try:
|
||||
init_distributed_environment(
|
||||
@@ -43,6 +42,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
|
||||
pass
|
||||
|
||||
server_args = ServerArgs(model_path=model_path, dtype=torch.float16)
|
||||
set_global_server_args_for_scheduler(server_args)
|
||||
model_config = ModelConfig.from_server_args(server_args)
|
||||
|
||||
load_config = LoadConfig()
|
||||
|
||||
Reference in New Issue
Block a user