Move global_server_args_dict (#642)
This commit is contained in:
@@ -7,8 +7,8 @@ from torch import nn
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||||
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
|
|
||||||
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
||||||
|
from sglang.srt.server import global_server_args_dict
|
||||||
|
|
||||||
|
|
||||||
class RadixAttention(nn.Module):
|
class RadixAttention(nn.Module):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
from sglang.srt.server import global_server_args_dict
|
||||||
from sglang.srt.utils import wrap_kernel_launcher
|
from sglang.srt.utils import wrap_kernel_launcher
|
||||||
|
|
||||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||||
|
|||||||
@@ -16,9 +16,6 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
|||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
# Store some global server args
|
|
||||||
global_server_args_dict = {}
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(IntEnum):
|
class ForwardMode(IntEnum):
|
||||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||||
|
|||||||
@@ -20,12 +20,7 @@ from vllm.model_executor.model_loader import get_model
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.managers.controller.infer_batch import (
|
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata
|
||||||
Batch,
|
|
||||||
ForwardMode,
|
|
||||||
InputMetadata,
|
|
||||||
global_server_args_dict,
|
|
||||||
)
|
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -91,12 +86,6 @@ class ModelRunner:
|
|||||||
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set some global args
|
|
||||||
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
|
||||||
global_server_args_dict["attention_reduce_in_fp32"] = (
|
|
||||||
server_args.attention_reduce_in_fp32
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load the model and create memory pool
|
# Load the model and create memory pool
|
||||||
self.load_model()
|
self.load_model()
|
||||||
self.init_memory_pool(total_gpu_memory)
|
self.init_memory_pool(total_gpu_memory)
|
||||||
|
|||||||
@@ -64,6 +64,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
tokenizer_manager = None
|
tokenizer_manager = None
|
||||||
|
|
||||||
|
# Put some args for easily access
|
||||||
|
global_server_args_dict = {}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health() -> Response:
|
async def health() -> Response:
|
||||||
@@ -135,6 +138,14 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|||||||
return await v1_chat_completions(tokenizer_manager, raw_request)
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_global_server_args(server_args: ServerArgs):
|
||||||
|
global global_server_args_dict
|
||||||
|
global_server_args_dict = {
|
||||||
|
"disable_flashinfer": server_args.disable_flashinfer,
|
||||||
|
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
|
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
|
||||||
global tokenizer_manager
|
global tokenizer_manager
|
||||||
|
|
||||||
@@ -163,6 +174,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
# TODO: replace this with huggingface transformers template
|
# TODO: replace this with huggingface transformers template
|
||||||
load_chat_template_for_openai_api(server_args.chat_template)
|
load_chat_template_for_openai_api(server_args.chat_template)
|
||||||
|
|
||||||
|
_set_global_server_args(server_args)
|
||||||
|
|
||||||
# Allocate ports
|
# Allocate ports
|
||||||
assert server_args.tp_size % server_args.nnodes == 0
|
assert server_args.tp_size % server_args.nnodes == 0
|
||||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||||
|
|||||||
Reference in New Issue
Block a user