model_rpc style improvement (#293)

This commit is contained in:
Liangsheng Yin
2024-03-24 15:41:24 +08:00
committed by GitHub
parent 64ee9c030e
commit 7523541962
3 changed files with 31 additions and 24 deletions

View File

@@ -9,8 +9,9 @@ from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams from sglang.lang.ir import SglSamplingParams
try: try:
import openai
import tiktoken import tiktoken
import openai
except ImportError as e: except ImportError as e:
openai = tiktoken = e openai = tiktoken = e

View File

@@ -6,7 +6,6 @@ import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import List from typing import List
import numpy as np
import rpyc import rpyc
import torch import torch
from rpyc.utils.classic import obtain from rpyc.utils.classic import obtain
@@ -36,8 +35,8 @@ from vllm.logger import _default_handler as vllm_default_handler
logger = logging.getLogger("model_rpc") logger = logging.getLogger("model_rpc")
class ModelRpcServer(rpyc.Service): class ModelRpcServer:
def exposed_init_model( def __init__(
self, self,
tp_rank: int, tp_rank: int,
server_args: ServerArgs, server_args: ServerArgs,
@@ -608,14 +607,19 @@ class ModelRpcServer(rpyc.Service):
batch.reqs = [] batch.reqs = []
class ModelRpcService(rpyc.Service):
exposed_ModelRpcServer = ModelRpcServer
class ModelRpcClient: class ModelRpcClient:
def __init__(self, server_args: ServerArgs, port_args: PortArgs): def __init__(self, server_args: ServerArgs, port_args: PortArgs):
tp_size = server_args.tp_size tp_size = server_args.tp_size
if tp_size == 1: if tp_size == 1:
# Init model # Init model
self.model_server = ModelRpcServer() self.model_server = ModelRpcService().exposed_ModelRpcServer(
self.model_server.exposed_init_model(0, server_args, port_args) 0, server_args, port_args
)
# Wrap functions # Wrap functions
def async_wrap(f): def async_wrap(f):
@@ -629,14 +633,16 @@ class ModelRpcClient:
with ThreadPoolExecutor(tp_size) as executor: with ThreadPoolExecutor(tp_size) as executor:
# Launch model processes # Launch model processes
rets = executor.map(start_model_process, port_args.model_rpc_ports) rets = executor.map(start_model_process, port_args.model_rpc_ports)
self.model_servers = [x[0] for x in rets] self.remote_services = [x[0] for x in rets]
self.procs = [x[1] for x in rets] self.procs = [x[1] for x in rets]
# Init model # Init model
def init_model(i): def init_model(i):
return self.model_servers[i].init_model(i, server_args, port_args) return self.remote_services[i].ModelRpcServer(
i, server_args, port_args
)
rets = [obtain(x) for x in executor.map(init_model, range(tp_size))] self.model_servers = executor.map(init_model, range(tp_size))
# Wrap functions # Wrap functions
def async_wrap(func_name): def async_wrap(func_name):
@@ -654,7 +660,7 @@ class ModelRpcClient:
def _init_service(port): def _init_service(port):
t = ThreadedServer( t = ThreadedServer(
ModelRpcServer(), ModelRpcService(),
port=port, port=port,
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800}, protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
) )

View File

@@ -1,10 +1,10 @@
import importlib import importlib
import logging import importlib.resources
import inspect import inspect
import logging
import pkgutil
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from pathlib import Path
import importlib.resources
import numpy as np import numpy as np
import torch import torch
@@ -18,11 +18,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
import importlib
import pkgutil
import sglang
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig} QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
logger = logging.getLogger("model_runner") logger = logging.getLogger("model_runner")
@@ -37,7 +32,7 @@ def import_model_classes():
model_arch_name_to_cls = {} model_arch_name_to_cls = {}
package_name = "sglang.srt.models" package_name = "sglang.srt.models"
package = importlib.import_module(package_name) package = importlib.import_module(package_name)
for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'): for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg: if not ispkg:
module = importlib.import_module(name) module = importlib.import_module(name)
if hasattr(module, "EntryClass"): if hasattr(module, "EntryClass"):
@@ -144,9 +139,12 @@ class InputMetadata:
# flashinfer >= 0.0.3 # flashinfer >= 0.0.3
# FIXME: Drop this when flashinfer updates to 0.0.4 # FIXME: Drop this when flashinfer updates to 0.0.4
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7: if (
len(inspect.signature(self.prefill_wrapper.begin_forward).parameters)
== 7
):
args.append(self.model_runner.model_config.head_dim) args.append(self.model_runner.model_config.head_dim)
self.prefill_wrapper.begin_forward(*args) self.prefill_wrapper.begin_forward(*args)
else: else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
@@ -307,9 +305,11 @@ class ModelRunner:
hf_quant_method = hf_quant_config["quant_method"] hf_quant_method = hf_quant_config["quant_method"]
# compat: autogptq uses is_marlin_format within quant config # compat: autogptq uses is_marlin_format within quant config
if (hf_quant_method == "gptq" if (
and "is_marlin_format" in hf_quant_config hf_quant_method == "gptq"
and hf_quant_config["is_marlin_format"]): and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]
):
hf_quant_method = "marlin" hf_quant_method = "marlin"
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method) quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)