Fix the chat template for llava-v1.6-34b & format code (#177)
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
"""Public API"""
|
"""Public API"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.auth_token = auth_token
|
self.auth_token = auth_token
|
||||||
|
|
||||||
res = http_request(self.base_url + "/get_model_info", auth_token=self.auth_token)
|
res = http_request(
|
||||||
|
self.base_url + "/get_model_info", auth_token=self.auth_token
|
||||||
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
self.model_info = res.json()
|
self.model_info = res.json()
|
||||||
|
|
||||||
@@ -37,7 +39,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||||
auth_token=self.auth_token
|
auth_token=self.auth_token,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
|
|
||||||
@@ -45,14 +47,16 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
||||||
auth_token=self.auth_token
|
auth_token=self.auth_token,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
|
|
||||||
def fill_image(self, s: StreamExecutor):
|
def fill_image(self, s: StreamExecutor):
|
||||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||||
self._add_images(s, data)
|
self._add_images(s, data)
|
||||||
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
|
res = http_request(
|
||||||
|
self.base_url + "/generate", json=data, auth_token=self.auth_token
|
||||||
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@@ -82,7 +86,9 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
|
|
||||||
self._add_images(s, data)
|
self._add_images(s, data)
|
||||||
|
|
||||||
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
|
res = http_request(
|
||||||
|
self.base_url + "/generate", json=data, auth_token=self.auth_token
|
||||||
|
)
|
||||||
obj = res.json()
|
obj = res.json()
|
||||||
comp = obj["text"]
|
comp = obj["text"]
|
||||||
return comp, obj["meta_info"]
|
return comp, obj["meta_info"]
|
||||||
@@ -115,7 +121,12 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
data["stream"] = True
|
data["stream"] = True
|
||||||
self._add_images(s, data)
|
self._add_images(s, data)
|
||||||
|
|
||||||
response = http_request(self.base_url + "/generate", json=data, stream=True, auth_token=self.auth_token)
|
response = http_request(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json=data,
|
||||||
|
stream=True,
|
||||||
|
auth_token=self.auth_token,
|
||||||
|
)
|
||||||
pos = 0
|
pos = 0
|
||||||
|
|
||||||
incomplete_text = ""
|
incomplete_text = ""
|
||||||
@@ -145,7 +156,9 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
# Cache common prefix
|
# Cache common prefix
|
||||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||||
self._add_images(s, data)
|
self._add_images(s, data)
|
||||||
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
|
res = http_request(
|
||||||
|
self.base_url + "/generate", json=data, auth_token=self.auth_token
|
||||||
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
||||||
|
|
||||||
@@ -157,7 +170,9 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
"logprob_start_len": max(prompt_len - 2, 0),
|
"logprob_start_len": max(prompt_len - 2, 0),
|
||||||
}
|
}
|
||||||
self._add_images(s, data)
|
self._add_images(s, data)
|
||||||
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
|
res = http_request(
|
||||||
|
self.base_url + "/generate", json=data, auth_token=self.auth_token
|
||||||
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
obj = res.json()
|
obj = res.json()
|
||||||
normalized_prompt_logprob = [
|
normalized_prompt_logprob = [
|
||||||
@@ -172,7 +187,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/concate_and_append_request",
|
self.base_url + "/concate_and_append_request",
|
||||||
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
||||||
auth_token=self.auth_token
|
auth_token=self.auth_token,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
|
|
||||||
|
|||||||
@@ -116,6 +116,21 @@ register_chat_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_chat_template(
|
||||||
|
ChatTemplate(
|
||||||
|
name="chatml-llava",
|
||||||
|
default_system_prompt="Answer the questions.",
|
||||||
|
role_prefix_and_suffix={
|
||||||
|
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
|
||||||
|
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
|
||||||
|
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
|
||||||
|
},
|
||||||
|
style=ChatTemplateStyle.PLAIN,
|
||||||
|
stop_str=("<|im_end|>",),
|
||||||
|
image_token=" <image>\n",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
register_chat_template(
|
register_chat_template(
|
||||||
ChatTemplate(
|
ChatTemplate(
|
||||||
name="vicuna_v1.1",
|
name="vicuna_v1.1",
|
||||||
@@ -168,7 +183,7 @@ register_chat_template(
|
|||||||
def match_vicuna(model_path: str):
|
def match_vicuna(model_path: str):
|
||||||
if "vicuna" in model_path.lower():
|
if "vicuna" in model_path.lower():
|
||||||
return get_chat_template("vicuna_v1.1")
|
return get_chat_template("vicuna_v1.1")
|
||||||
if "llava" in model_path.lower():
|
if "llava-v1.5" in model_path.lower():
|
||||||
return get_chat_template("vicuna_v1.1")
|
return get_chat_template("vicuna_v1.1")
|
||||||
|
|
||||||
|
|
||||||
@@ -192,6 +207,8 @@ def match_chat_ml(model_path: str):
|
|||||||
return get_chat_template("chatml")
|
return get_chat_template("chatml")
|
||||||
if "qwen" in model_path and "chat" in model_path:
|
if "qwen" in model_path and "chat" in model_path:
|
||||||
return get_chat_template("chatml")
|
return get_chat_template("chatml")
|
||||||
|
if "llava-v1.6-34b" in model_path:
|
||||||
|
return get_chat_template("chatml-llava")
|
||||||
|
|
||||||
|
|
||||||
@register_chat_template_matching_function
|
@register_chat_template_matching_function
|
||||||
|
|||||||
@@ -74,9 +74,9 @@ class SglSamplingParams:
|
|||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"max_tokens_to_sample": self.max_new_tokens,
|
"max_tokens_to_sample": self.max_new_tokens,
|
||||||
"stop_sequences": self.stop
|
"stop_sequences": (
|
||||||
if isinstance(self.stop, (list, tuple))
|
self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
|
||||||
else [self.stop],
|
),
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Tracing a program."""
|
"""Tracing a program."""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Backend configurations, may vary with different serving platforms.
|
Backend configurations, may vary with different serving platforms.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -366,7 +366,8 @@ def generate_chat_conv(
|
|||||||
if content.type == "text":
|
if content.type == "text":
|
||||||
real_content += content.text
|
real_content += content.text
|
||||||
elif content.type == "image_url":
|
elif content.type == "image_url":
|
||||||
real_content += "<image>"
|
# NOTE: Only works for llava
|
||||||
|
real_content += "<image>\n"
|
||||||
conv.append_image(content.image_url.url)
|
conv.append_image(content.image_url.url)
|
||||||
conv.append_message(conv.roles[0], real_content)
|
conv.append_message(conv.roles[0], real_content)
|
||||||
elif msg_role == "assistant":
|
elif msg_role == "assistant":
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from sglang.srt.utils import (
|
|||||||
is_multimodal_model,
|
is_multimodal_model,
|
||||||
set_random_seed,
|
set_random_seed,
|
||||||
)
|
)
|
||||||
|
from vllm.logger import _default_handler as vllm_default_handler
|
||||||
|
|
||||||
logger = logging.getLogger("model_rpc")
|
logger = logging.getLogger("model_rpc")
|
||||||
|
|
||||||
@@ -50,6 +51,9 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.tp_size = server_args.tp_size
|
self.tp_size = server_args.tp_size
|
||||||
self.schedule_heuristic = server_args.schedule_heuristic
|
self.schedule_heuristic = server_args.schedule_heuristic
|
||||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||||
|
vllm_default_handler.setLevel(
|
||||||
|
level=getattr(logging, server_args.log_level.upper())
|
||||||
|
)
|
||||||
|
|
||||||
# Init model and tokenizer
|
# Init model and tokenizer
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig(
|
||||||
@@ -83,9 +87,11 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.max_num_running_seq = self.max_total_num_token // 2
|
self.max_num_running_seq = self.max_total_num_token // 2
|
||||||
self.max_prefill_num_token = max(
|
self.max_prefill_num_token = max(
|
||||||
self.model_config.context_len,
|
self.model_config.context_len,
|
||||||
self.max_total_num_token // 6
|
(
|
||||||
if server_args.max_prefill_num_token is None
|
self.max_total_num_token // 6
|
||||||
else server_args.max_prefill_num_token,
|
if server_args.max_prefill_num_token is None
|
||||||
|
else server_args.max_prefill_num_token
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.int_token_logit_bias = torch.tensor(
|
self.int_token_logit_bias = torch.tensor(
|
||||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||||
@@ -534,7 +540,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
output_skip_special_tokens.append(
|
output_skip_special_tokens.append(
|
||||||
req.sampling_params.skip_special_tokens
|
req.sampling_params.skip_special_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
# For the length of input_ids, which will be accumulated during jump-forward.
|
# For the length of input_ids, which will be accumulated during jump-forward.
|
||||||
# Use the original length of input_ids to calculate the token usage info.
|
# Use the original length of input_ids to calculate the token usage info.
|
||||||
meta_info = {
|
meta_info = {
|
||||||
|
|||||||
@@ -112,7 +112,9 @@ class InputMetadata:
|
|||||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
workspace_buffer = torch.empty(
|
||||||
|
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
self.forward_mode == ForwardMode.PREFILL
|
self.forward_mode == ForwardMode.PREFILL
|
||||||
or self.forward_mode == ForwardMode.EXTEND
|
or self.forward_mode == ForwardMode.EXTEND
|
||||||
@@ -121,7 +123,9 @@ class InputMetadata:
|
|||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||||
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, "NHD"
|
||||||
|
)
|
||||||
self.prefill_wrapper.begin_forward(
|
self.prefill_wrapper.begin_forward(
|
||||||
self.qo_indptr,
|
self.qo_indptr,
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
@@ -131,7 +135,9 @@ class InputMetadata:
|
|||||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, "NHD"
|
||||||
|
)
|
||||||
self.decode_wrapper.begin_forward(
|
self.decode_wrapper.begin_forward(
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
self.kv_indices,
|
self.kv_indices,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Memory pool."""
|
"""Memory pool."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -269,7 +270,6 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||||
|
|
||||||
# load mm_projector
|
# load mm_projector
|
||||||
# TODO: support TP?
|
|
||||||
projector_weights = {
|
projector_weights = {
|
||||||
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||||
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Inference-only Mistral model."""
|
"""Inference-only Mistral model."""
|
||||||
|
|
||||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -97,14 +97,16 @@ class MixtralMoE(nn.Module):
|
|||||||
|
|
||||||
self.experts = nn.ModuleList(
|
self.experts = nn.ModuleList(
|
||||||
[
|
[
|
||||||
MixtralMLP(
|
(
|
||||||
self.num_total_experts,
|
MixtralMLP(
|
||||||
config.hidden_size,
|
self.num_total_experts,
|
||||||
config.intermediate_size,
|
config.hidden_size,
|
||||||
linear_method=linear_method,
|
config.intermediate_size,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
if idx in self.expert_indicies
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
if idx in self.expert_indicies
|
|
||||||
else None
|
|
||||||
for idx in range(self.num_total_experts)
|
for idx in range(self.num_total_experts)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Inference-only Yi-VL model."""
|
"""Inference-only Yi-VL model."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Sampling parameters for text generation."""
|
"""Sampling parameters for text generation."""
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-6
|
_SAMPLING_EPS = 1e-6
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""SRT: SGLang Runtime"""
|
"""SRT: SGLang Runtime"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
@@ -493,7 +494,7 @@ def launch_server(server_args, pipe_finish_writer):
|
|||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
try:
|
try:
|
||||||
print("Warmup...", flush=True)
|
# print("Warmup...", flush=True)
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
json={
|
json={
|
||||||
@@ -505,8 +506,8 @@ def launch_server(server_args, pipe_finish_writer):
|
|||||||
},
|
},
|
||||||
timeout=60,
|
timeout=60,
|
||||||
)
|
)
|
||||||
print(f"Warmup done. model response: {res.json()['text']}")
|
# print(f"Warmup done. model response: {res.json()['text']}")
|
||||||
print("=" * 20, "Server is ready", "=" * 20, flush=True)
|
# print("=" * 20, "Server is ready", "=" * 20, flush=True)
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
pipe_finish_writer.send(str(e))
|
pipe_finish_writer.send(str(e))
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ def handle_port_init(
|
|||||||
# first check on server port
|
# first check on server port
|
||||||
if not check_port(port):
|
if not check_port(port):
|
||||||
new_port = alloc_usable_network_port(1, used_list=[port])[0]
|
new_port = alloc_usable_network_port(1, used_list=[port])[0]
|
||||||
print(f"Port {port} is not available, using {new_port} instead.")
|
print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
|
||||||
port = new_port
|
port = new_port
|
||||||
|
|
||||||
# then we check on additional ports
|
# then we check on additional ports
|
||||||
@@ -157,8 +157,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
|
|||||||
ss = tokenizer.decode([t_id]).strip()
|
ss = tokenizer.decode([t_id]).strip()
|
||||||
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
|
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
|
||||||
logit_bias[t_id] = -1e5
|
logit_bias[t_id] = -1e5
|
||||||
# else:
|
|
||||||
# print(ss, t_id)
|
|
||||||
|
|
||||||
return logit_bias
|
return logit_bias
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Common utilities for testing and benchmarking"""
|
"""Common utilities for testing and benchmarking"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
from sglang.backend.openai import OpenAI
|
from sglang.backend.openai import OpenAI
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
|
|||||||
|
|
||||||
if torch.cuda.current_device() != gpu_id:
|
if torch.cuda.current_device() != gpu_id:
|
||||||
print(
|
print(
|
||||||
f"WARN: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
||||||
"which may cause useless memory allocation for torch CUDA context.",
|
"which may cause useless memory allocation for torch CUDA context.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ def http_request(url, json=None, stream=False, auth_token=None):
|
|||||||
return requests.post(url, json=json, stream=True)
|
return requests.post(url, json=json, stream=True)
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authentication": f"Bearer {auth_token}"
|
"Authentication": f"Bearer {auth_token}",
|
||||||
}
|
}
|
||||||
return requests.post(url, json=json, stream=True, headers=headers)
|
return requests.post(url, json=json, stream=True, headers=headers)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|||||||
@@ -66,9 +66,9 @@ class BenchBatch:
|
|||||||
p_idx = prefix_req_idx[i // fork_num].item()
|
p_idx = prefix_req_idx[i // fork_num].item()
|
||||||
n_idx = self.req_pool_indices[i].item()
|
n_idx = self.req_pool_indices[i].item()
|
||||||
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
|
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
|
||||||
req_to_token[
|
req_to_token[n_idx, prefix_len : prefix_len + extend_len] = (
|
||||||
n_idx, prefix_len : prefix_len + extend_len
|
self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
|
||||||
] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
|
)
|
||||||
|
|
||||||
def update_decode(self, predict_ids, batch_size):
|
def update_decode(self, predict_ids, batch_size):
|
||||||
assert predict_ids.shape[0] == batch_size
|
assert predict_ids.shape[0] == batch_size
|
||||||
@@ -81,9 +81,9 @@ class BenchBatch:
|
|||||||
self.out_cache_cont_start,
|
self.out_cache_cont_start,
|
||||||
self.out_cache_cont_end,
|
self.out_cache_cont_end,
|
||||||
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
|
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
|
||||||
self.req_to_token_pool.req_to_token[
|
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
|
||||||
self.req_pool_indices, self.seq_lens
|
self.out_cache_loc
|
||||||
] = self.out_cache_loc
|
)
|
||||||
self.seq_lens.add_(1)
|
self.seq_lens.add_(1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from sglang.test.test_utils import (
|
|||||||
add_common_sglang_args_and_parse,
|
add_common_sglang_args_and_parse,
|
||||||
select_sglang_backend,
|
select_sglang_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
|
|
||||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||||
|
|||||||
@@ -155,7 +155,8 @@ def test_chat_completion_stream(args):
|
|||||||
def test_regex(args):
|
def test_regex(args):
|
||||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
|
|
||||||
regex = (r"""\{\n"""
|
regex = (
|
||||||
|
r"""\{\n"""
|
||||||
+ r""" "name": "[\w]+",\n"""
|
+ r""" "name": "[\w]+",\n"""
|
||||||
+ r""" "population": "[\w\d\s]+"\n"""
|
+ r""" "population": "[\w\d\s]+"\n"""
|
||||||
+ r"""\}"""
|
+ r"""\}"""
|
||||||
|
|||||||
Reference in New Issue
Block a user