model_rpc style improvement (#293)
This commit is contained in:
@@ -9,8 +9,9 @@ from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import openai
|
||||
import tiktoken
|
||||
|
||||
import openai
|
||||
except ImportError as e:
|
||||
openai = tiktoken = e
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import rpyc
|
||||
import torch
|
||||
from rpyc.utils.classic import obtain
|
||||
@@ -36,8 +35,8 @@ from vllm.logger import _default_handler as vllm_default_handler
|
||||
logger = logging.getLogger("model_rpc")
|
||||
|
||||
|
||||
class ModelRpcServer(rpyc.Service):
|
||||
def exposed_init_model(
|
||||
class ModelRpcServer:
|
||||
def __init__(
|
||||
self,
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
@@ -608,14 +607,19 @@ class ModelRpcServer(rpyc.Service):
|
||||
batch.reqs = []
|
||||
|
||||
|
||||
class ModelRpcService(rpyc.Service):
|
||||
exposed_ModelRpcServer = ModelRpcServer
|
||||
|
||||
|
||||
class ModelRpcClient:
|
||||
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
|
||||
tp_size = server_args.tp_size
|
||||
|
||||
if tp_size == 1:
|
||||
# Init model
|
||||
self.model_server = ModelRpcServer()
|
||||
self.model_server.exposed_init_model(0, server_args, port_args)
|
||||
self.model_server = ModelRpcService().exposed_ModelRpcServer(
|
||||
0, server_args, port_args
|
||||
)
|
||||
|
||||
# Wrap functions
|
||||
def async_wrap(f):
|
||||
@@ -629,14 +633,16 @@ class ModelRpcClient:
|
||||
with ThreadPoolExecutor(tp_size) as executor:
|
||||
# Launch model processes
|
||||
rets = executor.map(start_model_process, port_args.model_rpc_ports)
|
||||
self.model_servers = [x[0] for x in rets]
|
||||
self.remote_services = [x[0] for x in rets]
|
||||
self.procs = [x[1] for x in rets]
|
||||
|
||||
# Init model
|
||||
def init_model(i):
|
||||
return self.model_servers[i].init_model(i, server_args, port_args)
|
||||
return self.remote_services[i].ModelRpcServer(
|
||||
i, server_args, port_args
|
||||
)
|
||||
|
||||
rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
|
||||
self.model_servers = executor.map(init_model, range(tp_size))
|
||||
|
||||
# Wrap functions
|
||||
def async_wrap(func_name):
|
||||
@@ -654,7 +660,7 @@ class ModelRpcClient:
|
||||
|
||||
def _init_service(port):
|
||||
t = ThreadedServer(
|
||||
ModelRpcServer(),
|
||||
ModelRpcService(),
|
||||
port=port,
|
||||
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import importlib
|
||||
import logging
|
||||
import importlib.resources
|
||||
import inspect
|
||||
import logging
|
||||
import pkgutil
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
import importlib.resources
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -18,11 +18,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
import sglang
|
||||
|
||||
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
|
||||
|
||||
logger = logging.getLogger("model_runner")
|
||||
@@ -37,7 +32,7 @@ def import_model_classes():
|
||||
model_arch_name_to_cls = {}
|
||||
package_name = "sglang.srt.models"
|
||||
package = importlib.import_module(package_name)
|
||||
for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'):
|
||||
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||||
if not ispkg:
|
||||
module = importlib.import_module(name)
|
||||
if hasattr(module, "EntryClass"):
|
||||
@@ -144,9 +139,12 @@ class InputMetadata:
|
||||
|
||||
# flashinfer >= 0.0.3
|
||||
# FIXME: Drop this when flashinfer updates to 0.0.4
|
||||
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
|
||||
if (
|
||||
len(inspect.signature(self.prefill_wrapper.begin_forward).parameters)
|
||||
== 7
|
||||
):
|
||||
args.append(self.model_runner.model_config.head_dim)
|
||||
|
||||
|
||||
self.prefill_wrapper.begin_forward(*args)
|
||||
else:
|
||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
@@ -307,9 +305,11 @@ class ModelRunner:
|
||||
hf_quant_method = hf_quant_config["quant_method"]
|
||||
|
||||
# compat: autogptq uses is_marlin_format within quant config
|
||||
if (hf_quant_method == "gptq"
|
||||
and "is_marlin_format" in hf_quant_config
|
||||
and hf_quant_config["is_marlin_format"]):
|
||||
if (
|
||||
hf_quant_method == "gptq"
|
||||
and "is_marlin_format" in hf_quant_config
|
||||
and hf_quant_config["is_marlin_format"]
|
||||
):
|
||||
hf_quant_method = "marlin"
|
||||
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user