Fix the chat template for llava-v1.6-34b & format code (#177)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Public API"""
|
||||
|
||||
import re
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
|
||||
@@ -19,7 +19,9 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.base_url = base_url
|
||||
self.auth_token = auth_token
|
||||
|
||||
res = http_request(self.base_url + "/get_model_info", auth_token=self.auth_token)
|
||||
res = http_request(
|
||||
self.base_url + "/get_model_info", auth_token=self.auth_token
|
||||
)
|
||||
assert res.status_code == 200
|
||||
self.model_info = res.json()
|
||||
|
||||
@@ -37,7 +39,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||
auth_token=self.auth_token
|
||||
auth_token=self.auth_token,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
@@ -45,14 +47,16 @@ class RuntimeEndpoint(BaseBackend):
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
||||
auth_token=self.auth_token
|
||||
auth_token=self.auth_token,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
def fill_image(self, s: StreamExecutor):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
|
||||
res = http_request(
|
||||
self.base_url + "/generate", json=data, auth_token=self.auth_token
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
def generate(
|
||||
@@ -82,7 +86,9 @@ class RuntimeEndpoint(BaseBackend):
|
||||
|
||||
self._add_images(s, data)
|
||||
|
||||
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
|
||||
res = http_request(
|
||||
self.base_url + "/generate", json=data, auth_token=self.auth_token
|
||||
)
|
||||
obj = res.json()
|
||||
comp = obj["text"]
|
||||
return comp, obj["meta_info"]
|
||||
@@ -115,7 +121,12 @@ class RuntimeEndpoint(BaseBackend):
|
||||
data["stream"] = True
|
||||
self._add_images(s, data)
|
||||
|
||||
response = http_request(self.base_url + "/generate", json=data, stream=True, auth_token=self.auth_token)
|
||||
response = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
stream=True,
|
||||
auth_token=self.auth_token,
|
||||
)
|
||||
pos = 0
|
||||
|
||||
incomplete_text = ""
|
||||
@@ -145,7 +156,9 @@ class RuntimeEndpoint(BaseBackend):
|
||||
# Cache common prefix
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
|
||||
res = http_request(
|
||||
self.base_url + "/generate", json=data, auth_token=self.auth_token
|
||||
)
|
||||
assert res.status_code == 200
|
||||
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
||||
|
||||
@@ -157,7 +170,9 @@ class RuntimeEndpoint(BaseBackend):
|
||||
"logprob_start_len": max(prompt_len - 2, 0),
|
||||
}
|
||||
self._add_images(s, data)
|
||||
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
|
||||
res = http_request(
|
||||
self.base_url + "/generate", json=data, auth_token=self.auth_token
|
||||
)
|
||||
assert res.status_code == 200
|
||||
obj = res.json()
|
||||
normalized_prompt_logprob = [
|
||||
@@ -172,7 +187,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
res = http_request(
|
||||
self.base_url + "/concate_and_append_request",
|
||||
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
||||
auth_token=self.auth_token
|
||||
auth_token=self.auth_token,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
@@ -116,6 +116,21 @@ register_chat_template(
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml-llava",
|
||||
default_system_prompt="Answer the questions.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
image_token=" <image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="vicuna_v1.1",
|
||||
@@ -168,7 +183,7 @@ register_chat_template(
|
||||
def match_vicuna(model_path: str):
|
||||
if "vicuna" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
if "llava" in model_path.lower():
|
||||
if "llava-v1.5" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
|
||||
|
||||
@@ -192,6 +207,8 @@ def match_chat_ml(model_path: str):
|
||||
return get_chat_template("chatml")
|
||||
if "qwen" in model_path and "chat" in model_path:
|
||||
return get_chat_template("chatml")
|
||||
if "llava-v1.6-34b" in model_path:
|
||||
return get_chat_template("chatml-llava")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
|
||||
@@ -74,9 +74,9 @@ class SglSamplingParams:
|
||||
)
|
||||
return {
|
||||
"max_tokens_to_sample": self.max_new_tokens,
|
||||
"stop_sequences": self.stop
|
||||
if isinstance(self.stop, (list, tuple))
|
||||
else [self.stop],
|
||||
"stop_sequences": (
|
||||
self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
|
||||
),
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tracing a program."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Backend configurations, may vary with different serving platforms.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
|
||||
@@ -366,7 +366,8 @@ def generate_chat_conv(
|
||||
if content.type == "text":
|
||||
real_content += content.text
|
||||
elif content.type == "image_url":
|
||||
real_content += "<image>"
|
||||
# NOTE: Only works for llava
|
||||
real_content += "<image>\n"
|
||||
conv.append_image(content.image_url.url)
|
||||
conv.append_message(conv.roles[0], real_content)
|
||||
elif msg_role == "assistant":
|
||||
|
||||
@@ -31,6 +31,7 @@ from sglang.srt.utils import (
|
||||
is_multimodal_model,
|
||||
set_random_seed,
|
||||
)
|
||||
from vllm.logger import _default_handler as vllm_default_handler
|
||||
|
||||
logger = logging.getLogger("model_rpc")
|
||||
|
||||
@@ -50,6 +51,9 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.tp_size = server_args.tp_size
|
||||
self.schedule_heuristic = server_args.schedule_heuristic
|
||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||
vllm_default_handler.setLevel(
|
||||
level=getattr(logging, server_args.log_level.upper())
|
||||
)
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
@@ -83,9 +87,11 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.max_num_running_seq = self.max_total_num_token // 2
|
||||
self.max_prefill_num_token = max(
|
||||
self.model_config.context_len,
|
||||
self.max_total_num_token // 6
|
||||
if server_args.max_prefill_num_token is None
|
||||
else server_args.max_prefill_num_token,
|
||||
(
|
||||
self.max_total_num_token // 6
|
||||
if server_args.max_prefill_num_token is None
|
||||
else server_args.max_prefill_num_token
|
||||
),
|
||||
)
|
||||
self.int_token_logit_bias = torch.tensor(
|
||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||
@@ -534,7 +540,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
)
|
||||
|
||||
|
||||
# For the length of input_ids, which will be accumulated during jump-forward.
|
||||
# Use the original length of input_ids to calculate the token usage info.
|
||||
meta_info = {
|
||||
|
||||
@@ -112,7 +112,9 @@ class InputMetadata:
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
workspace_buffer = torch.empty(
|
||||
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
||||
)
|
||||
if (
|
||||
self.forward_mode == ForwardMode.PREFILL
|
||||
or self.forward_mode == ForwardMode.EXTEND
|
||||
@@ -121,7 +123,9 @@ class InputMetadata:
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
||||
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
)
|
||||
self.prefill_wrapper.begin_forward(
|
||||
self.qo_indptr,
|
||||
self.kv_indptr,
|
||||
@@ -131,7 +135,9 @@ class InputMetadata:
|
||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||
)
|
||||
else:
|
||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
)
|
||||
self.decode_wrapper.begin_forward(
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Memory pool."""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -269,7 +270,6 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||
|
||||
# load mm_projector
|
||||
# TODO: support TP?
|
||||
projector_weights = {
|
||||
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Inference-only Mistral model."""
|
||||
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
|
||||
|
||||
|
||||
@@ -97,14 +97,16 @@ class MixtralMoE(nn.Module):
|
||||
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
MixtralMLP(
|
||||
self.num_total_experts,
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
linear_method=linear_method,
|
||||
(
|
||||
MixtralMLP(
|
||||
self.num_total_experts,
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
if idx in self.expert_indicies
|
||||
else None
|
||||
)
|
||||
if idx in self.expert_indicies
|
||||
else None
|
||||
for idx in range(self.num_total_experts)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Inference-only Yi-VL model."""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-6
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""SRT: SGLang Runtime"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
@@ -493,7 +494,7 @@ def launch_server(server_args, pipe_finish_writer):
|
||||
|
||||
# Warmup
|
||||
try:
|
||||
print("Warmup...", flush=True)
|
||||
# print("Warmup...", flush=True)
|
||||
res = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
@@ -505,8 +506,8 @@ def launch_server(server_args, pipe_finish_writer):
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
print(f"Warmup done. model response: {res.json()['text']}")
|
||||
print("=" * 20, "Server is ready", "=" * 20, flush=True)
|
||||
# print(f"Warmup done. model response: {res.json()['text']}")
|
||||
# print("=" * 20, "Server is ready", "=" * 20, flush=True)
|
||||
except requests.exceptions.RequestException as e:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(str(e))
|
||||
|
||||
@@ -122,7 +122,7 @@ def handle_port_init(
|
||||
# first check on server port
|
||||
if not check_port(port):
|
||||
new_port = alloc_usable_network_port(1, used_list=[port])[0]
|
||||
print(f"Port {port} is not available, using {new_port} instead.")
|
||||
print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
|
||||
port = new_port
|
||||
|
||||
# then we check on additional ports
|
||||
@@ -157,8 +157,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
ss = tokenizer.decode([t_id]).strip()
|
||||
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
|
||||
logit_bias[t_id] = -1e5
|
||||
# else:
|
||||
# print(ss, t_id)
|
||||
|
||||
return logit_bias
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Common utilities for testing and benchmarking"""
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from sglang.backend.openai import OpenAI
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
|
||||
|
||||
if torch.cuda.current_device() != gpu_id:
|
||||
print(
|
||||
f"WARN: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
||||
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
||||
"which may cause useless memory allocation for torch CUDA context.",
|
||||
)
|
||||
|
||||
@@ -95,7 +95,7 @@ def http_request(url, json=None, stream=False, auth_token=None):
|
||||
return requests.post(url, json=json, stream=True)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authentication": f"Bearer {auth_token}"
|
||||
"Authentication": f"Bearer {auth_token}",
|
||||
}
|
||||
return requests.post(url, json=json, stream=True, headers=headers)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user