model_rpc style improvement (#293)
This commit is contained in:
@@ -9,8 +9,9 @@ from sglang.lang.interpreter import StreamExecutor
|
|||||||
from sglang.lang.ir import SglSamplingParams
|
from sglang.lang.ir import SglSamplingParams
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import openai
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
import openai
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
openai = tiktoken = e
|
openai = tiktoken = e
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import warnings
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import rpyc
|
import rpyc
|
||||||
import torch
|
import torch
|
||||||
from rpyc.utils.classic import obtain
|
from rpyc.utils.classic import obtain
|
||||||
@@ -36,8 +35,8 @@ from vllm.logger import _default_handler as vllm_default_handler
|
|||||||
logger = logging.getLogger("model_rpc")
|
logger = logging.getLogger("model_rpc")
|
||||||
|
|
||||||
|
|
||||||
class ModelRpcServer(rpyc.Service):
|
class ModelRpcServer:
|
||||||
def exposed_init_model(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
@@ -608,14 +607,19 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
batch.reqs = []
|
batch.reqs = []
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRpcService(rpyc.Service):
|
||||||
|
exposed_ModelRpcServer = ModelRpcServer
|
||||||
|
|
||||||
|
|
||||||
class ModelRpcClient:
|
class ModelRpcClient:
|
||||||
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
|
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
|
||||||
tp_size = server_args.tp_size
|
tp_size = server_args.tp_size
|
||||||
|
|
||||||
if tp_size == 1:
|
if tp_size == 1:
|
||||||
# Init model
|
# Init model
|
||||||
self.model_server = ModelRpcServer()
|
self.model_server = ModelRpcService().exposed_ModelRpcServer(
|
||||||
self.model_server.exposed_init_model(0, server_args, port_args)
|
0, server_args, port_args
|
||||||
|
)
|
||||||
|
|
||||||
# Wrap functions
|
# Wrap functions
|
||||||
def async_wrap(f):
|
def async_wrap(f):
|
||||||
@@ -629,14 +633,16 @@ class ModelRpcClient:
|
|||||||
with ThreadPoolExecutor(tp_size) as executor:
|
with ThreadPoolExecutor(tp_size) as executor:
|
||||||
# Launch model processes
|
# Launch model processes
|
||||||
rets = executor.map(start_model_process, port_args.model_rpc_ports)
|
rets = executor.map(start_model_process, port_args.model_rpc_ports)
|
||||||
self.model_servers = [x[0] for x in rets]
|
self.remote_services = [x[0] for x in rets]
|
||||||
self.procs = [x[1] for x in rets]
|
self.procs = [x[1] for x in rets]
|
||||||
|
|
||||||
# Init model
|
# Init model
|
||||||
def init_model(i):
|
def init_model(i):
|
||||||
return self.model_servers[i].init_model(i, server_args, port_args)
|
return self.remote_services[i].ModelRpcServer(
|
||||||
|
i, server_args, port_args
|
||||||
|
)
|
||||||
|
|
||||||
rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
|
self.model_servers = executor.map(init_model, range(tp_size))
|
||||||
|
|
||||||
# Wrap functions
|
# Wrap functions
|
||||||
def async_wrap(func_name):
|
def async_wrap(func_name):
|
||||||
@@ -654,7 +660,7 @@ class ModelRpcClient:
|
|||||||
|
|
||||||
def _init_service(port):
|
def _init_service(port):
|
||||||
t = ThreadedServer(
|
t = ThreadedServer(
|
||||||
ModelRpcServer(),
|
ModelRpcService(),
|
||||||
port=port,
|
port=port,
|
||||||
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
|
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import logging
|
import importlib.resources
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
|
import pkgutil
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
|
||||||
import importlib.resources
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -18,11 +18,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
|||||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||||
|
|
||||||
import importlib
|
|
||||||
import pkgutil
|
|
||||||
|
|
||||||
import sglang
|
|
||||||
|
|
||||||
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
|
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
|
||||||
|
|
||||||
logger = logging.getLogger("model_runner")
|
logger = logging.getLogger("model_runner")
|
||||||
@@ -37,7 +32,7 @@ def import_model_classes():
|
|||||||
model_arch_name_to_cls = {}
|
model_arch_name_to_cls = {}
|
||||||
package_name = "sglang.srt.models"
|
package_name = "sglang.srt.models"
|
||||||
package = importlib.import_module(package_name)
|
package = importlib.import_module(package_name)
|
||||||
for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'):
|
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||||||
if not ispkg:
|
if not ispkg:
|
||||||
module = importlib.import_module(name)
|
module = importlib.import_module(name)
|
||||||
if hasattr(module, "EntryClass"):
|
if hasattr(module, "EntryClass"):
|
||||||
@@ -144,9 +139,12 @@ class InputMetadata:
|
|||||||
|
|
||||||
# flashinfer >= 0.0.3
|
# flashinfer >= 0.0.3
|
||||||
# FIXME: Drop this when flashinfer updates to 0.0.4
|
# FIXME: Drop this when flashinfer updates to 0.0.4
|
||||||
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
|
if (
|
||||||
|
len(inspect.signature(self.prefill_wrapper.begin_forward).parameters)
|
||||||
|
== 7
|
||||||
|
):
|
||||||
args.append(self.model_runner.model_config.head_dim)
|
args.append(self.model_runner.model_config.head_dim)
|
||||||
|
|
||||||
self.prefill_wrapper.begin_forward(*args)
|
self.prefill_wrapper.begin_forward(*args)
|
||||||
else:
|
else:
|
||||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
@@ -307,9 +305,11 @@ class ModelRunner:
|
|||||||
hf_quant_method = hf_quant_config["quant_method"]
|
hf_quant_method = hf_quant_config["quant_method"]
|
||||||
|
|
||||||
# compat: autogptq uses is_marlin_format within quant config
|
# compat: autogptq uses is_marlin_format within quant config
|
||||||
if (hf_quant_method == "gptq"
|
if (
|
||||||
and "is_marlin_format" in hf_quant_config
|
hf_quant_method == "gptq"
|
||||||
and hf_quant_config["is_marlin_format"]):
|
and "is_marlin_format" in hf_quant_config
|
||||||
|
and hf_quant_config["is_marlin_format"]
|
||||||
|
):
|
||||||
hf_quant_method = "marlin"
|
hf_quant_method = "marlin"
|
||||||
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
|
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user