243 lines
7.1 KiB
Python
243 lines
7.1 KiB
Python
import os
|
|
import sys
|
|
# import copy
|
|
import subprocess
|
|
from abc import ABC, abstractmethod
|
|
from typing import List
|
|
from llm_utils import ModelConfig
|
|
# from xtrt_llm.vllm.entrypoints.openai.api_server import parse_args
|
|
|
|
|
|
class Config:
|
|
def __init__(self):
|
|
self.model_path = os.getenv("MODEL_PATH", "/model")
|
|
self.model_name = os.getenv("MODEL_NAME")
|
|
self.num_gpus = int(os.getenv("NUM_GPUs", "1"))
|
|
self.port = os.getenv("PORT", "80")
|
|
self.script_root = os.getenv("BUILD_SCRIPT_ROOT", "examples")
|
|
self.weight_only_precision = os.getenv("WEIGHT_ONLY_PRECISION")
|
|
self.engine_dir = os.getenv("ENGINE_DIR", "./xtrt_engine")
|
|
self.build_extra = os.getenv("BUILD_EXTRA")
|
|
# self.parallel_build = os.getenv("PARALLEL_BUILD")
|
|
self.model_config = ModelConfig(self.model_path)
|
|
|
|
|
|
class ModelRunner(ABC):
|
|
def __init__(self, config):
|
|
self.config = config
|
|
|
|
@abstractmethod
|
|
def build_script(self) -> str:
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def build_args(self) -> List[str]:
|
|
pass
|
|
|
|
def build_command(self) -> List[str]:
|
|
cmd = [
|
|
sys.executable,
|
|
self.build_script()
|
|
] + self.build_args()
|
|
if self.config.build_extra:
|
|
cmd.extend(self.config.build_extra.split(' '))
|
|
return cmd
|
|
|
|
def build(self):
|
|
if os.path.exists(self.config.engine_dir):
|
|
print(f"engine path {self.config.engine_dir} exists")
|
|
return
|
|
cmd = self.build_command()
|
|
print(f"build command: {cmd}")
|
|
p = subprocess.Popen(cmd)
|
|
p.wait()
|
|
if p.returncode != 0:
|
|
raise RuntimeError(f"build failed, exit code {p.returncode}")
|
|
print("build success")
|
|
|
|
@staticmethod
|
|
def serve_module():
|
|
return 'xtrt_llm.vllm.entrypoints.openai.api_server'
|
|
|
|
def serve_command(self) -> List[str]:
|
|
cmd = [
|
|
sys.executable,
|
|
'-m',
|
|
self.serve_module(),
|
|
'--port',
|
|
self.config.port,
|
|
'--model',
|
|
self.config.model_path,
|
|
'--engine_dir',
|
|
self.config.engine_dir,
|
|
'--trust-remote-code',
|
|
'--tensor-parallel-size',
|
|
str(self.config.num_gpus),
|
|
'--dtype',
|
|
'float16',
|
|
]
|
|
if self.config.model_name:
|
|
cmd.extend(['--served-model-name', self.config.model_name])
|
|
cmd.extend(sys.argv[1:])
|
|
return cmd
|
|
|
|
def serve(self):
|
|
cmd = self.serve_command()
|
|
print(f"serve command: {cmd}")
|
|
p = subprocess.Popen(cmd)
|
|
p.wait()
|
|
if p.returncode != 0:
|
|
raise RuntimeError(f"serve failed, exit code {p.returncode}")
|
|
|
|
|
|
class ChatGLMRunner(ModelRunner):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
# used in build args
|
|
self.build_model_name = os.getenv("BUILD_MODEL_NAME", "chatglm3_6b")
|
|
|
|
def build_script(self):
|
|
return f"{self.config.script_root}/chatglm/build.py"
|
|
|
|
def build_args(self):
|
|
args = [
|
|
'--model_dir',
|
|
self.config.model_path,
|
|
'--output_dir',
|
|
self.config.engine_dir,
|
|
'--model_name',
|
|
self.build_model_name,
|
|
'--dtype',
|
|
'float16',
|
|
'--use_gpt_attention_plugin',
|
|
'float16',
|
|
'--remove_input_padding',
|
|
'--paged_kv_cache',
|
|
'--world_size',
|
|
str(self.config.num_gpus),
|
|
'--tp_size',
|
|
str(self.config.num_gpus),
|
|
# '--parallel_build',
|
|
]
|
|
return args
|
|
|
|
|
|
class LlamaRunner(ModelRunner):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
# model_names = ["llama2-7b", "llama2-13b", "llama2-70b"]
|
|
# self.build_model_name = os.getenv("BUILD_MODEL_NAME")
|
|
# if self.build_model_name not in model_names:
|
|
# raise RuntimeError(f"BUILD_MODEL_NAME not in {model_names}")
|
|
|
|
def build_script(self):
|
|
return f"{self.config.script_root}/llama/build.py"
|
|
|
|
def build_args(self):
|
|
args = [
|
|
'--model_dir',
|
|
self.config.model_path,
|
|
'--output_dir',
|
|
self.config.engine_dir,
|
|
'--dtype',
|
|
'float16',
|
|
'--use_gpt_attention_plugin',
|
|
'float16',
|
|
'--world_size',
|
|
str(self.config.num_gpus),
|
|
'--tp_size',
|
|
str(self.config.num_gpus),
|
|
# '--parallel_build',
|
|
'--use_parallel_embedding',
|
|
'--remove_input_padding',
|
|
'--opt_memory_use',
|
|
'--paged_kv_cache',
|
|
'--tokens_per_block',
|
|
'64',
|
|
]
|
|
if self.config.weight_only_precision:
|
|
args.extend([
|
|
'--use_weight_only',
|
|
'--weight_only_precision',
|
|
self.config.weight_only_precision
|
|
])
|
|
return args
|
|
|
|
|
|
class QWenRunner(ModelRunner):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
# model_names = ["qwen1.5-7b", "qwen1.5-14b", "qwen1.5-72b"]
|
|
# self.build_model_name = os.getenv("BUILD_MODEL_NAME")
|
|
# if self.build_model_name not in model_names:
|
|
# raise RuntimeError(f"BUILD_MODEL_NAME not in {model_names}")
|
|
|
|
def build_script(self):
|
|
return f"{self.config.script_root}/qwen/build.py"
|
|
|
|
def build_args(self):
|
|
args = [
|
|
'--hf_model_dir',
|
|
self.config.model_path,
|
|
'--output_dir',
|
|
self.config.engine_dir,
|
|
'--dtype',
|
|
'float16',
|
|
'--use_gpt_attention_plugin',
|
|
'float16',
|
|
'--world_size',
|
|
str(self.config.num_gpus),
|
|
'--tp_size',
|
|
str(self.config.num_gpus),
|
|
# '--parallel_build',
|
|
# '--use_parallel_embedding',
|
|
'--remove_input_padding',
|
|
'--opt_memory_use',
|
|
'--paged_kv_cache',
|
|
'--tokens_per_block',
|
|
'64',
|
|
]
|
|
model_type = self.config.model_config.model_type()
|
|
if model_type == "qwen2":
|
|
# only support 1.5
|
|
args.extend(["--version", "1.5"])
|
|
if self.config.weight_only_precision:
|
|
args.extend([
|
|
'--use_weight_only',
|
|
'--weight_only_precision',
|
|
self.config.weight_only_precision
|
|
])
|
|
return args
|
|
|
|
|
|
runners = {
|
|
"chatglm": ChatGLMRunner,
|
|
"llama": LlamaRunner,
|
|
"qwen": QWenRunner,
|
|
"qwen2": QWenRunner,
|
|
}
|
|
|
|
|
|
def new_runner() -> ModelRunner:
|
|
config = Config()
|
|
model_type = config.model_config.model_type()
|
|
runner_cls = runners.get(model_type)
|
|
if runner_cls is None:
|
|
raise RuntimeError(f"model type {model_type} unsupported")
|
|
return runner_cls(config)
|
|
|
|
|
|
def check_args():
|
|
if '-h' in sys.argv[1:] or '--help' in sys.argv[1:]:
|
|
cmd = [sys.executable, '-m', ModelRunner.serve_module(), '--help']
|
|
p = subprocess.Popen(cmd)
|
|
p.wait()
|
|
sys.exit(p.returncode)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
check_args()
|
|
runner = new_runner()
|
|
runner.build()
|
|
runner.serve()
|