Fix the chat template for llava-v1.6-34b & format code (#177)

This commit is contained in:
Lianmin Zheng
2024-02-11 05:50:13 -08:00
committed by GitHub
parent 50afed4eaa
commit c51020cf0c
23 changed files with 101 additions and 44 deletions

View File

@@ -1,4 +1,5 @@
"""Public API""" """Public API"""
import re import re
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union

View File

@@ -19,7 +19,9 @@ class RuntimeEndpoint(BaseBackend):
self.base_url = base_url self.base_url = base_url
self.auth_token = auth_token self.auth_token = auth_token
res = http_request(self.base_url + "/get_model_info", auth_token=self.auth_token) res = http_request(
self.base_url + "/get_model_info", auth_token=self.auth_token
)
assert res.status_code == 200 assert res.status_code == 200
self.model_info = res.json() self.model_info = res.json()
@@ -37,7 +39,7 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token auth_token=self.auth_token,
) )
assert res.status_code == 200 assert res.status_code == 200
@@ -45,14 +47,16 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}}, json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token auth_token=self.auth_token,
) )
assert res.status_code == 200 assert res.status_code == 200
def fill_image(self, s: StreamExecutor): def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data) self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token) res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200 assert res.status_code == 200
def generate( def generate(
@@ -82,7 +86,9 @@ class RuntimeEndpoint(BaseBackend):
self._add_images(s, data) self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token) res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
obj = res.json() obj = res.json()
comp = obj["text"] comp = obj["text"]
return comp, obj["meta_info"] return comp, obj["meta_info"]
@@ -115,7 +121,12 @@ class RuntimeEndpoint(BaseBackend):
data["stream"] = True data["stream"] = True
self._add_images(s, data) self._add_images(s, data)
response = http_request(self.base_url + "/generate", json=data, stream=True, auth_token=self.auth_token) response = http_request(
self.base_url + "/generate",
json=data,
stream=True,
auth_token=self.auth_token,
)
pos = 0 pos = 0
incomplete_text = "" incomplete_text = ""
@@ -145,7 +156,9 @@ class RuntimeEndpoint(BaseBackend):
# Cache common prefix # Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data) self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token) res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200 assert res.status_code == 200
prompt_len = res.json()["meta_info"]["prompt_tokens"] prompt_len = res.json()["meta_info"]["prompt_tokens"]
@@ -157,7 +170,9 @@ class RuntimeEndpoint(BaseBackend):
"logprob_start_len": max(prompt_len - 2, 0), "logprob_start_len": max(prompt_len - 2, 0),
} }
self._add_images(s, data) self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token) res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200 assert res.status_code == 200
obj = res.json() obj = res.json()
normalized_prompt_logprob = [ normalized_prompt_logprob = [
@@ -172,7 +187,7 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/concate_and_append_request", self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid}, json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token auth_token=self.auth_token,
) )
assert res.status_code == 200 assert res.status_code == 200

View File

@@ -116,6 +116,21 @@ register_chat_template(
) )
register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="Answer the questions.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token=" <image>\n",
)
)
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="vicuna_v1.1", name="vicuna_v1.1",
@@ -168,7 +183,7 @@ register_chat_template(
def match_vicuna(model_path: str): def match_vicuna(model_path: str):
if "vicuna" in model_path.lower(): if "vicuna" in model_path.lower():
return get_chat_template("vicuna_v1.1") return get_chat_template("vicuna_v1.1")
if "llava" in model_path.lower(): if "llava-v1.5" in model_path.lower():
return get_chat_template("vicuna_v1.1") return get_chat_template("vicuna_v1.1")
@@ -192,6 +207,8 @@ def match_chat_ml(model_path: str):
return get_chat_template("chatml") return get_chat_template("chatml")
if "qwen" in model_path and "chat" in model_path: if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml") return get_chat_template("chatml")
if "llava-v1.6-34b" in model_path:
return get_chat_template("chatml-llava")
@register_chat_template_matching_function @register_chat_template_matching_function

View File

@@ -74,9 +74,9 @@ class SglSamplingParams:
) )
return { return {
"max_tokens_to_sample": self.max_new_tokens, "max_tokens_to_sample": self.max_new_tokens,
"stop_sequences": self.stop "stop_sequences": (
if isinstance(self.stop, (list, tuple)) self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
else [self.stop], ),
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p, "top_p": self.top_p,
"top_k": self.top_k, "top_k": self.top_k,

View File

@@ -1,4 +1,5 @@
"""Tracing a program.""" """Tracing a program."""
import uuid import uuid
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union

View File

@@ -1,6 +1,7 @@
""" """
Backend configurations, may vary with different serving platforms. Backend configurations, may vary with different serving platforms.
""" """
from dataclasses import dataclass from dataclasses import dataclass

View File

@@ -366,7 +366,8 @@ def generate_chat_conv(
if content.type == "text": if content.type == "text":
real_content += content.text real_content += content.text
elif content.type == "image_url": elif content.type == "image_url":
real_content += "<image>" # NOTE: Only works for llava
real_content += "<image>\n"
conv.append_image(content.image_url.url) conv.append_image(content.image_url.url)
conv.append_message(conv.roles[0], real_content) conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant": elif msg_role == "assistant":

View File

@@ -31,6 +31,7 @@ from sglang.srt.utils import (
is_multimodal_model, is_multimodal_model,
set_random_seed, set_random_seed,
) )
from vllm.logger import _default_handler as vllm_default_handler
logger = logging.getLogger("model_rpc") logger = logging.getLogger("model_rpc")
@@ -50,6 +51,9 @@ class ModelRpcServer(rpyc.Service):
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
vllm_default_handler.setLevel(
level=getattr(logging, server_args.log_level.upper())
)
# Init model and tokenizer # Init model and tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
@@ -83,9 +87,11 @@ class ModelRpcServer(rpyc.Service):
self.max_num_running_seq = self.max_total_num_token // 2 self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max( self.max_prefill_num_token = max(
self.model_config.context_len, self.model_config.context_len,
self.max_total_num_token // 6 (
if server_args.max_prefill_num_token is None self.max_total_num_token // 6
else server_args.max_prefill_num_token, if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token
),
) )
self.int_token_logit_bias = torch.tensor( self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
@@ -534,7 +540,7 @@ class ModelRpcServer(rpyc.Service):
output_skip_special_tokens.append( output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens req.sampling_params.skip_special_tokens
) )
# For the length of input_ids, which will be accumulated during jump-forward. # For the length of input_ids, which will be accumulated during jump-forward.
# Use the original length of input_ids to calculate the token usage info. # Use the original length of input_ids to calculate the token usage info.
meta_info = { meta_info = {

View File

@@ -112,7 +112,9 @@ class InputMetadata:
(self.batch_size,), dtype=torch.int32, device="cuda" (self.batch_size,), dtype=torch.int32, device="cuda"
) )
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda") workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
)
if ( if (
self.forward_mode == ForwardMode.PREFILL self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND or self.forward_mode == ForwardMode.EXTEND
@@ -121,7 +123,9 @@ class InputMetadata:
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.prefill_wrapper.begin_forward( self.prefill_wrapper.begin_forward(
self.qo_indptr, self.qo_indptr,
self.kv_indptr, self.kv_indptr,
@@ -131,7 +135,9 @@ class InputMetadata:
self.model_runner.model_config.num_key_value_heads // tp_size, self.model_runner.model_config.num_key_value_heads // tp_size,
) )
else: else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.decode_wrapper.begin_forward( self.decode_wrapper.begin_forward(
self.kv_indptr, self.kv_indptr,
self.kv_indices, self.kv_indices,

View File

@@ -1,4 +1,5 @@
"""Memory pool.""" """Memory pool."""
import logging import logging
import torch import torch

View File

@@ -1,4 +1,5 @@
"""Inference-only LLaVa model compatible with HuggingFace weights.""" """Inference-only LLaVa model compatible with HuggingFace weights."""
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
@@ -269,7 +270,6 @@ class LlavaLlamaForCausalLM(nn.Module):
raise ValueError(f"Unexpected select feature: {self.select_feature}") raise ValueError(f"Unexpected select feature: {self.select_feature}")
# load mm_projector # load mm_projector
# TODO: support TP?
projector_weights = { projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1", "model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2", "model.mm_projector.2": "multi_modal_projector.linear_2",

View File

@@ -1,4 +1,5 @@
"""Inference-only Mistral model.""" """Inference-only Mistral model."""
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama2 import LlamaForCausalLM

View File

@@ -97,14 +97,16 @@ class MixtralMoE(nn.Module):
self.experts = nn.ModuleList( self.experts = nn.ModuleList(
[ [
MixtralMLP( (
self.num_total_experts, MixtralMLP(
config.hidden_size, self.num_total_experts,
config.intermediate_size, config.hidden_size,
linear_method=linear_method, config.intermediate_size,
linear_method=linear_method,
)
if idx in self.expert_indicies
else None
) )
if idx in self.expert_indicies
else None
for idx in range(self.num_total_experts) for idx in range(self.num_total_experts)
] ]
) )

View File

@@ -1,4 +1,5 @@
"""Inference-only Yi-VL model.""" """Inference-only Yi-VL model."""
import os import os
from typing import List, Optional from typing import List, Optional

View File

@@ -1,4 +1,5 @@
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
from typing import List, Optional, Union from typing import List, Optional, Union
_SAMPLING_EPS = 1e-6 _SAMPLING_EPS = 1e-6

View File

@@ -1,4 +1,5 @@
"""SRT: SGLang Runtime""" """SRT: SGLang Runtime"""
import asyncio import asyncio
import json import json
import multiprocessing as mp import multiprocessing as mp
@@ -493,7 +494,7 @@ def launch_server(server_args, pipe_finish_writer):
# Warmup # Warmup
try: try:
print("Warmup...", flush=True) # print("Warmup...", flush=True)
res = requests.post( res = requests.post(
url + "/generate", url + "/generate",
json={ json={
@@ -505,8 +506,8 @@ def launch_server(server_args, pipe_finish_writer):
}, },
timeout=60, timeout=60,
) )
print(f"Warmup done. model response: {res.json()['text']}") # print(f"Warmup done. model response: {res.json()['text']}")
print("=" * 20, "Server is ready", "=" * 20, flush=True) # print("=" * 20, "Server is ready", "=" * 20, flush=True)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(str(e)) pipe_finish_writer.send(str(e))

View File

@@ -122,7 +122,7 @@ def handle_port_init(
# first check on server port # first check on server port
if not check_port(port): if not check_port(port):
new_port = alloc_usable_network_port(1, used_list=[port])[0] new_port = alloc_usable_network_port(1, used_list=[port])[0]
print(f"Port {port} is not available, using {new_port} instead.") print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
port = new_port port = new_port
# then we check on additional ports # then we check on additional ports
@@ -157,8 +157,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
ss = tokenizer.decode([t_id]).strip() ss = tokenizer.decode([t_id]).strip()
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id): if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
logit_bias[t_id] = -1e5 logit_bias[t_id] = -1e5
# else:
# print(ss, t_id)
return logit_bias return logit_bias

View File

@@ -1,4 +1,5 @@
"""Common utilities for testing and benchmarking""" """Common utilities for testing and benchmarking"""
import numpy as np import numpy as np
import requests import requests
from sglang.backend.openai import OpenAI from sglang.backend.openai import OpenAI

View File

@@ -22,7 +22,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
if torch.cuda.current_device() != gpu_id: if torch.cuda.current_device() != gpu_id:
print( print(
f"WARN: current device is not {gpu_id}, but {torch.cuda.current_device()}, ", f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.", "which may cause useless memory allocation for torch CUDA context.",
) )
@@ -95,7 +95,7 @@ def http_request(url, json=None, stream=False, auth_token=None):
return requests.post(url, json=json, stream=True) return requests.post(url, json=json, stream=True)
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authentication": f"Bearer {auth_token}" "Authentication": f"Bearer {auth_token}",
} }
return requests.post(url, json=json, stream=True, headers=headers) return requests.post(url, json=json, stream=True, headers=headers)
else: else:

View File

@@ -1,6 +1,7 @@
""" """
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
""" """
import json import json
import unittest import unittest

View File

@@ -66,9 +66,9 @@ class BenchBatch:
p_idx = prefix_req_idx[i // fork_num].item() p_idx = prefix_req_idx[i // fork_num].item()
n_idx = self.req_pool_indices[i].item() n_idx = self.req_pool_indices[i].item()
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len] req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
req_to_token[ req_to_token[n_idx, prefix_len : prefix_len + extend_len] = (
n_idx, prefix_len : prefix_len + extend_len self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len] )
def update_decode(self, predict_ids, batch_size): def update_decode(self, predict_ids, batch_size):
assert predict_ids.shape[0] == batch_size assert predict_ids.shape[0] == batch_size
@@ -81,9 +81,9 @@ class BenchBatch:
self.out_cache_cont_start, self.out_cache_cont_start,
self.out_cache_cont_end, self.out_cache_cont_end,
) = self.token_to_kv_pool.alloc_contiguous(batch_size) ) = self.token_to_kv_pool.alloc_contiguous(batch_size)
self.req_to_token_pool.req_to_token[ self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
self.req_pool_indices, self.seq_lens self.out_cache_loc
] = self.out_cache_loc )
self.seq_lens.add_(1) self.seq_lens.add_(1)

View File

@@ -7,6 +7,7 @@ from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
) )
import sglang as sgl import sglang as sgl
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"

View File

@@ -155,7 +155,8 @@ def test_chat_completion_stream(args):
def test_regex(args): def test_regex(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url) client = openai.Client(api_key="EMPTY", base_url=args.base_url)
regex = (r"""\{\n""" regex = (
r"""\{\n"""
+ r""" "name": "[\w]+",\n""" + r""" "name": "[\w]+",\n"""
+ r""" "population": "[\w\d\s]+"\n""" + r""" "population": "[\w\d\s]+"\n"""
+ r"""\}""" + r"""\}"""