Implement served_model_name to customize model id when use local mode… (#749)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -79,6 +79,7 @@ class TokenizerManager:
|
|||||||
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||||
|
|
||||||
self.model_path = server_args.model_path
|
self.model_path = server_args.model_path
|
||||||
|
self.served_model_name = server_args.served_model_name
|
||||||
self.hf_config = get_config(
|
self.hf_config = get_config(
|
||||||
self.model_path,
|
self.model_path,
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
|
|||||||
@@ -190,10 +190,10 @@ async def retrieve_file_content(file_id: str):
|
|||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
def available_models():
|
def available_models():
|
||||||
"""Show available models."""
|
"""Show available models."""
|
||||||
model_names = [tokenizer_manager.model_path]
|
served_model_names = [tokenizer_manager.served_model_name]
|
||||||
model_cards = []
|
model_cards = []
|
||||||
for model_name in model_names:
|
for served_model_name in served_model_names:
|
||||||
model_cards.append(ModelCard(id=model_name, root=model_name))
|
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
||||||
return ModelList(data=model_cards)
|
return ModelList(data=model_cards)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ class ServerArgs:
|
|||||||
trust_remote_code: bool = True
|
trust_remote_code: bool = True
|
||||||
context_length: Optional[int] = None
|
context_length: Optional[int] = None
|
||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
|
served_model_name: Optional[str] = None
|
||||||
chat_template: Optional[str] = None
|
chat_template: Optional[str] = None
|
||||||
|
|
||||||
# Port
|
# Port
|
||||||
@@ -90,6 +91,10 @@ class ServerArgs:
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
self.tokenizer_path = self.model_path
|
self.tokenizer_path = self.model_path
|
||||||
|
|
||||||
|
if self.served_model_name is None:
|
||||||
|
self.served_model_name = self.model_path
|
||||||
|
|
||||||
if self.mem_fraction_static is None:
|
if self.mem_fraction_static is None:
|
||||||
if self.tp_size >= 16:
|
if self.tp_size >= 16:
|
||||||
self.mem_fraction_static = 0.79
|
self.mem_fraction_static = 0.79
|
||||||
@@ -202,6 +207,12 @@ class ServerArgs:
|
|||||||
],
|
],
|
||||||
help="The quantization method.",
|
help="The quantization method.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--served-model-name",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.served_model_name,
|
||||||
|
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--chat-template",
|
"--chat-template",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user