Files
engnex-r_series-llm/main.py
2025-08-06 10:30:31 +08:00

243 lines
7.1 KiB
Python

import os
import sys
# import copy
import subprocess
from abc import ABC, abstractmethod
from typing import List
from llm_utils import ModelConfig
# from xtrt_llm.vllm.entrypoints.openai.api_server import parse_args
class Config:
def __init__(self):
self.model_path = os.getenv("MODEL_PATH", "/model")
self.model_name = os.getenv("MODEL_NAME")
self.num_gpus = int(os.getenv("NUM_GPUs", "1"))
self.port = os.getenv("PORT", "80")
self.script_root = os.getenv("BUILD_SCRIPT_ROOT", "examples")
self.weight_only_precision = os.getenv("WEIGHT_ONLY_PRECISION")
self.engine_dir = os.getenv("ENGINE_DIR", "./xtrt_engine")
self.build_extra = os.getenv("BUILD_EXTRA")
# self.parallel_build = os.getenv("PARALLEL_BUILD")
self.model_config = ModelConfig(self.model_path)
class ModelRunner(ABC):
def __init__(self, config):
self.config = config
@abstractmethod
def build_script(self) -> str:
raise NotImplementedError()
@abstractmethod
def build_args(self) -> List[str]:
pass
def build_command(self) -> List[str]:
cmd = [
sys.executable,
self.build_script()
] + self.build_args()
if self.config.build_extra:
cmd.extend(self.config.build_extra.split(' '))
return cmd
def build(self):
if os.path.exists(self.config.engine_dir):
print(f"engine path {self.config.engine_dir} exists")
return
cmd = self.build_command()
print(f"build command: {cmd}")
p = subprocess.Popen(cmd)
p.wait()
if p.returncode != 0:
raise RuntimeError(f"build failed, exit code {p.returncode}")
print("build success")
@staticmethod
def serve_module():
return 'xtrt_llm.vllm.entrypoints.openai.api_server'
def serve_command(self) -> List[str]:
cmd = [
sys.executable,
'-m',
self.serve_module(),
'--port',
self.config.port,
'--model',
self.config.model_path,
'--engine_dir',
self.config.engine_dir,
'--trust-remote-code',
'--tensor-parallel-size',
str(self.config.num_gpus),
'--dtype',
'float16',
]
if self.config.model_name:
cmd.extend(['--served-model-name', self.config.model_name])
cmd.extend(sys.argv[1:])
return cmd
def serve(self):
cmd = self.serve_command()
print(f"serve command: {cmd}")
p = subprocess.Popen(cmd)
p.wait()
if p.returncode != 0:
raise RuntimeError(f"serve failed, exit code {p.returncode}")
class ChatGLMRunner(ModelRunner):
def __init__(self, config):
super().__init__(config)
# used in build args
self.build_model_name = os.getenv("BUILD_MODEL_NAME", "chatglm3_6b")
def build_script(self):
return f"{self.config.script_root}/chatglm/build.py"
def build_args(self):
args = [
'--model_dir',
self.config.model_path,
'--output_dir',
self.config.engine_dir,
'--model_name',
self.build_model_name,
'--dtype',
'float16',
'--use_gpt_attention_plugin',
'float16',
'--remove_input_padding',
'--paged_kv_cache',
'--world_size',
str(self.config.num_gpus),
'--tp_size',
str(self.config.num_gpus),
# '--parallel_build',
]
return args
class LlamaRunner(ModelRunner):
def __init__(self, config):
super().__init__(config)
# model_names = ["llama2-7b", "llama2-13b", "llama2-70b"]
# self.build_model_name = os.getenv("BUILD_MODEL_NAME")
# if self.build_model_name not in model_names:
# raise RuntimeError(f"BUILD_MODEL_NAME not in {model_names}")
def build_script(self):
return f"{self.config.script_root}/llama/build.py"
def build_args(self):
args = [
'--model_dir',
self.config.model_path,
'--output_dir',
self.config.engine_dir,
'--dtype',
'float16',
'--use_gpt_attention_plugin',
'float16',
'--world_size',
str(self.config.num_gpus),
'--tp_size',
str(self.config.num_gpus),
# '--parallel_build',
'--use_parallel_embedding',
'--remove_input_padding',
'--opt_memory_use',
'--paged_kv_cache',
'--tokens_per_block',
'64',
]
if self.config.weight_only_precision:
args.extend([
'--use_weight_only',
'--weight_only_precision',
self.config.weight_only_precision
])
return args
class QWenRunner(ModelRunner):
def __init__(self, config):
super().__init__(config)
# model_names = ["qwen1.5-7b", "qwen1.5-14b", "qwen1.5-72b"]
# self.build_model_name = os.getenv("BUILD_MODEL_NAME")
# if self.build_model_name not in model_names:
# raise RuntimeError(f"BUILD_MODEL_NAME not in {model_names}")
def build_script(self):
return f"{self.config.script_root}/qwen/build.py"
def build_args(self):
args = [
'--hf_model_dir',
self.config.model_path,
'--output_dir',
self.config.engine_dir,
'--dtype',
'float16',
'--use_gpt_attention_plugin',
'float16',
'--world_size',
str(self.config.num_gpus),
'--tp_size',
str(self.config.num_gpus),
# '--parallel_build',
# '--use_parallel_embedding',
'--remove_input_padding',
'--opt_memory_use',
'--paged_kv_cache',
'--tokens_per_block',
'64',
]
model_type = self.config.model_config.model_type()
if model_type == "qwen2":
# only support 1.5
args.extend(["--version", "1.5"])
if self.config.weight_only_precision:
args.extend([
'--use_weight_only',
'--weight_only_precision',
self.config.weight_only_precision
])
return args
runners = {
"chatglm": ChatGLMRunner,
"llama": LlamaRunner,
"qwen": QWenRunner,
"qwen2": QWenRunner,
}
def new_runner() -> ModelRunner:
config = Config()
model_type = config.model_config.model_type()
runner_cls = runners.get(model_type)
if runner_cls is None:
raise RuntimeError(f"model type {model_type} unsupported")
return runner_cls(config)
def check_args():
if '-h' in sys.argv[1:] or '--help' in sys.argv[1:]:
cmd = [sys.executable, '-m', ModelRunner.serve_module(), '--help']
p = subprocess.Popen(cmd)
p.wait()
sys.exit(p.returncode)
if __name__ == '__main__':
check_args()
runner = new_runner()
runner.build()
runner.serve()