[Minor] Many cleanup (#1357)
This commit is contained in:
@@ -298,34 +298,41 @@ class BenchmarkMetrics:
|
||||
median_e2e_latency_ms: float
|
||||
|
||||
|
||||
default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
|
||||
|
||||
def download_sharegpt_dataset(path):
|
||||
url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
def download_and_cache_file(url: str, filename: Optional[str] = None):
|
||||
"""Read and cache a file from a url."""
|
||||
if filename is None:
|
||||
filename = os.path.join("/tmp", url.split("/")[-1])
|
||||
|
||||
print(f"Downloading dataset from {url}")
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
# Check if the cache file already exists
|
||||
if os.path.exists(filename):
|
||||
return filename
|
||||
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 8192
|
||||
print(f"Downloading from {url} to {filename}")
|
||||
|
||||
with open(path, "wb") as f, tqdm(
|
||||
desc="Downloading",
|
||||
total=total_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as progress_bar:
|
||||
for data in response.iter_content(block_size):
|
||||
size = f.write(data)
|
||||
progress_bar.update(size)
|
||||
# Stream the response to show the progress bar
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status() # Check for request errors
|
||||
|
||||
print(f"Dataset downloaded and saved to {path}")
|
||||
except requests.RequestException as e:
|
||||
raise Exception(f"Failed to download dataset: {e}")
|
||||
# Total size of the file in bytes
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
chunk_size = 1024 # Download in chunks of 1KB
|
||||
|
||||
# Use tqdm to display the progress bar
|
||||
with open(filename, "wb") as f, tqdm(
|
||||
desc=filename,
|
||||
total=total_size,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||
f.write(chunk)
|
||||
bar.update(len(chunk))
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def sample_sharegpt_requests(
|
||||
@@ -338,13 +345,8 @@ def sample_sharegpt_requests(
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
# Download sharegpt if necessary
|
||||
if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
|
||||
download_sharegpt_dataset(default_sharegpt_path)
|
||||
dataset_path = default_sharegpt_path
|
||||
else:
|
||||
dataset_path = (
|
||||
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
|
||||
)
|
||||
if not os.path.isfile(dataset_path):
|
||||
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
@@ -412,15 +414,8 @@ def sample_random_requests(
|
||||
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
||||
|
||||
# Download sharegpt if necessary
|
||||
if not os.path.isfile(dataset_path) and not os.path.isfile(
|
||||
default_sharegpt_path
|
||||
):
|
||||
download_sharegpt_dataset(default_sharegpt_path)
|
||||
dataset_path = default_sharegpt_path
|
||||
else:
|
||||
dataset_path = (
|
||||
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
|
||||
)
|
||||
if not os.path.isfile(dataset_path):
|
||||
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
|
||||
@@ -9,10 +9,9 @@ from sglang.srt.utils import kill_child_process
|
||||
|
||||
if __name__ == "__main__":
|
||||
server_args = prepare_server_args(sys.argv[1:])
|
||||
model_override_args = server_args.json_model_override_args
|
||||
|
||||
try:
|
||||
launch_server(server_args, model_override_args=model_override_args)
|
||||
launch_server(server_args)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Launch the inference server for Llava-video model."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from sglang.srt.server import launch_server, prepare_server_args
|
||||
@@ -19,5 +20,6 @@ if __name__ == "__main__":
|
||||
model_override_args["model_max_length"] = 4096 * 2
|
||||
if "34b" in server_args.model_path.lower():
|
||||
model_override_args["image_token_index"] = 64002
|
||||
server_args.json_model_override_args = json.dumps(model_override_args)
|
||||
|
||||
launch_server(server_args, model_override_args, None)
|
||||
launch_server(server_args)
|
||||
|
||||
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
"""Cache for the compressed finite state machine."""
|
||||
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
@@ -28,12 +29,9 @@ class FSMCache(BaseToolCache):
|
||||
tokenizer_args_dict,
|
||||
enable=True,
|
||||
skip_tokenizer_init=False,
|
||||
json_schema_mode=False,
|
||||
):
|
||||
super().__init__(enable=enable)
|
||||
|
||||
self.json_schema_mode = json_schema_mode
|
||||
|
||||
if (
|
||||
skip_tokenizer_init
|
||||
or tokenizer_path.endswith(".json")
|
||||
@@ -42,44 +40,37 @@ class FSMCache(BaseToolCache):
|
||||
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
||||
return
|
||||
|
||||
from importlib.metadata import version
|
||||
tokenizer_args_dict.setdefault("padding_side", "left")
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
|
||||
try:
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
except AttributeError:
|
||||
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
||||
origin_pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
if version("outlines") >= "0.0.35":
|
||||
from transformers import AutoTokenizer
|
||||
def fset(self, value):
|
||||
self._value = value
|
||||
|
||||
tokenizer_args_dict.setdefault("padding_side", "left")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, **tokenizer_args_dict
|
||||
type(tokenizer).pad_token_id = property(
|
||||
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
||||
)
|
||||
try:
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
except AttributeError:
|
||||
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
||||
origin_pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
def fset(self, value):
|
||||
self._value = value
|
||||
|
||||
type(tokenizer).pad_token_id = property(
|
||||
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
||||
)
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token = (
|
||||
self.outlines_tokenizer.tokenizer.pad_token
|
||||
)
|
||||
self.outlines_tokenizer.vocabulary = (
|
||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||
)
|
||||
else:
|
||||
self.outlines_tokenizer = TransformerTokenizer(
|
||||
tokenizer_path, **tokenizer_args_dict
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token = (
|
||||
self.outlines_tokenizer.tokenizer.pad_token
|
||||
)
|
||||
self.outlines_tokenizer.vocabulary = (
|
||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||
)
|
||||
|
||||
def init_value(self, value):
|
||||
if self.json_schema_mode:
|
||||
regex = build_regex_from_schema(value, whitespace_pattern=r"[\n\t ]*")
|
||||
return RegexGuide(regex, self.outlines_tokenizer), regex
|
||||
def init_value(self, key):
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
|
||||
elif key_type == "regex":
|
||||
regex = key_string
|
||||
else:
|
||||
return RegexGuide(value, self.outlines_tokenizer)
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
|
||||
return RegexGuide(regex, self.outlines_tokenizer), regex
|
||||
|
||||
@@ -71,12 +71,10 @@ class ControllerMulti:
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
model_override_args,
|
||||
):
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
self.port_args = port_args
|
||||
self.model_override_args = model_override_args
|
||||
self.load_balance_method = LoadBalanceMethod.from_str(
|
||||
server_args.load_balance_method
|
||||
)
|
||||
@@ -114,7 +112,6 @@ class ControllerMulti:
|
||||
self.server_args,
|
||||
self.port_args,
|
||||
pipe_controller_writer,
|
||||
self.model_override_args,
|
||||
True,
|
||||
gpu_ids,
|
||||
dp_worker_id,
|
||||
@@ -189,14 +186,13 @@ def start_controller_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
model_override_args: dict,
|
||||
):
|
||||
"""Start a controller process."""
|
||||
|
||||
configure_logger(server_args)
|
||||
|
||||
try:
|
||||
controller = ControllerMulti(server_args, port_args, model_override_args)
|
||||
controller = ControllerMulti(server_args, port_args)
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
|
||||
@@ -40,7 +40,6 @@ class ControllerSingle:
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
model_override_args: dict,
|
||||
gpu_ids: List[int],
|
||||
is_data_parallel_worker: bool,
|
||||
dp_worker_id: int,
|
||||
@@ -76,7 +75,6 @@ class ControllerSingle:
|
||||
tp_rank_range,
|
||||
server_args,
|
||||
port_args.nccl_ports[dp_worker_id],
|
||||
model_override_args,
|
||||
)
|
||||
|
||||
# Launch tp rank 0
|
||||
@@ -85,7 +83,6 @@ class ControllerSingle:
|
||||
0,
|
||||
server_args,
|
||||
port_args.nccl_ports[dp_worker_id],
|
||||
model_override_args,
|
||||
)
|
||||
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
||||
|
||||
@@ -126,7 +123,6 @@ def start_controller_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer: multiprocessing.connection.Connection,
|
||||
model_override_args: dict,
|
||||
is_data_parallel_worker: bool = False,
|
||||
gpu_ids: List[int] = None,
|
||||
dp_worker_id: int = None,
|
||||
@@ -149,7 +145,6 @@ def start_controller_process(
|
||||
controller = ControllerSingle(
|
||||
server_args,
|
||||
port_args,
|
||||
model_override_args,
|
||||
gpu_ids,
|
||||
is_data_parallel_worker,
|
||||
dp_worker_id,
|
||||
|
||||
@@ -18,6 +18,7 @@ limitations under the License.
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
@@ -77,7 +78,6 @@ class TokenizerManager:
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
model_override_args: dict = None,
|
||||
):
|
||||
self.server_args = server_args
|
||||
|
||||
@@ -95,7 +95,7 @@ class TokenizerManager:
|
||||
self.hf_config = get_config(
|
||||
self.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
model_override_args=model_override_args,
|
||||
model_override_args=json.loads(server_args.json_model_override_args),
|
||||
)
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, self.server_args.is_embedding
|
||||
|
||||
@@ -15,13 +15,14 @@ limitations under the License.
|
||||
|
||||
"""A tensor parallel worker."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@@ -66,6 +67,7 @@ from sglang.utils import get_exception_traceback
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Crash on warning if we are running CI tests
|
||||
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||
|
||||
|
||||
@@ -76,11 +78,10 @@ class ModelTpServer:
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
nccl_port: int,
|
||||
model_override_args: dict,
|
||||
):
|
||||
suppress_other_loggers()
|
||||
|
||||
# Copy arguments
|
||||
# Parse arguments
|
||||
self.gpu_id = gpu_id
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = server_args.tp_size
|
||||
@@ -93,9 +94,8 @@ class ModelTpServer:
|
||||
server_args.model_path,
|
||||
server_args.trust_remote_code,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=model_override_args,
|
||||
model_override_args=json.loads(server_args.json_model_override_args),
|
||||
)
|
||||
|
||||
self.model_runner = ModelRunner(
|
||||
model_config=self.model_config,
|
||||
mem_fraction_static=server_args.mem_fraction_static,
|
||||
@@ -136,7 +136,7 @@ class ModelTpServer:
|
||||
self.max_total_num_tokens - 1,
|
||||
)
|
||||
|
||||
# Sync random seed
|
||||
# Sync random seed across TP workers
|
||||
server_args.random_seed = broadcast_recv_input(
|
||||
[server_args.random_seed],
|
||||
self.tp_rank,
|
||||
@@ -144,7 +144,7 @@ class ModelTpServer:
|
||||
)[0]
|
||||
set_random_seed(server_args.random_seed)
|
||||
|
||||
# Print info
|
||||
# Print debug info
|
||||
logger.info(
|
||||
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
||||
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
||||
@@ -181,7 +181,7 @@ class ModelTpServer:
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
|
||||
# Chunked prefill
|
||||
# Init chunked prefill
|
||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||
self.current_inflight_req = None
|
||||
self.is_mixed_chunk = (
|
||||
@@ -197,16 +197,6 @@ class ModelTpServer:
|
||||
"trust_remote_code": server_args.trust_remote_code,
|
||||
},
|
||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||
json_schema_mode=False,
|
||||
)
|
||||
self.json_fsm_cache = FSMCache(
|
||||
server_args.tokenizer_path,
|
||||
{
|
||||
"tokenizer_mode": server_args.tokenizer_mode,
|
||||
"trust_remote_code": server_args.trust_remote_code,
|
||||
},
|
||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||
json_schema_mode=True,
|
||||
)
|
||||
self.jump_forward_cache = JumpForwardCache()
|
||||
|
||||
@@ -227,11 +217,12 @@ class ModelTpServer:
|
||||
try:
|
||||
# Recv requests
|
||||
for recv_req in recv_reqs:
|
||||
if isinstance(
|
||||
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
||||
):
|
||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
self.handle_generate_request(recv_req)
|
||||
self.do_not_get_new_batch = False
|
||||
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
|
||||
self.handle_embedding_request(recv_req)
|
||||
self.do_not_get_new_batch = False
|
||||
elif isinstance(recv_req, FlushCacheReq):
|
||||
self.flush_cache()
|
||||
elif isinstance(recv_req, AbortReq):
|
||||
@@ -331,57 +322,56 @@ class ModelTpServer:
|
||||
|
||||
def handle_generate_request(
|
||||
self,
|
||||
recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||
req.tokenizer = self.tokenizer
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
if self.model_runner.is_generation:
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
if req.pixel_values is not None:
|
||||
# Use image hash as fake token_ids, which is then used
|
||||
# for prefix matching
|
||||
image_hash = hash(tuple(recv_req.image_hashes))
|
||||
req.pad_value = [
|
||||
(image_hash) % self.model_config.vocab_size,
|
||||
(image_hash >> 16) % self.model_config.vocab_size,
|
||||
(image_hash >> 32) % self.model_config.vocab_size,
|
||||
(image_hash >> 64) % self.model_config.vocab_size,
|
||||
]
|
||||
req.image_sizes = recv_req.image_sizes
|
||||
(
|
||||
req.origin_input_ids,
|
||||
req.image_offsets,
|
||||
) = self.model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded,
|
||||
req.pad_value,
|
||||
req.pixel_values,
|
||||
req.image_sizes,
|
||||
)
|
||||
# Only when pixel values is not None we have modalities
|
||||
req.modalities = recv_req.modalites
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||
req.stream = recv_req.stream
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
if req.pixel_values is not None:
|
||||
# Use image hash as fake token_ids, which is then used
|
||||
# for prefix matching
|
||||
image_hash = hash(tuple(recv_req.image_hashes))
|
||||
req.pad_value = [
|
||||
(image_hash) % self.model_config.vocab_size,
|
||||
(image_hash >> 16) % self.model_config.vocab_size,
|
||||
(image_hash >> 32) % self.model_config.vocab_size,
|
||||
(image_hash >> 64) % self.model_config.vocab_size,
|
||||
]
|
||||
req.image_sizes = recv_req.image_sizes
|
||||
(
|
||||
req.origin_input_ids,
|
||||
req.image_offsets,
|
||||
) = self.model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded,
|
||||
req.pad_value,
|
||||
req.pixel_values,
|
||||
req.image_sizes,
|
||||
)
|
||||
# Only when pixel values is not None we have modalities
|
||||
req.modalities = recv_req.modalites
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||
req.stream = recv_req.stream
|
||||
|
||||
# Init regex fsm fron json
|
||||
# Init regex FSM
|
||||
if (
|
||||
req.sampling_params.json_schema is not None
|
||||
or req.sampling_params.regex is not None
|
||||
):
|
||||
if req.sampling_params.json_schema is not None:
|
||||
req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
|
||||
req.sampling_params.json_schema
|
||||
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
||||
("json", req.sampling_params.json_schema)
|
||||
)
|
||||
if not self.disable_regex_jump_forward:
|
||||
req.jump_forward_map = self.jump_forward_cache.query(
|
||||
computed_regex_string
|
||||
)
|
||||
|
||||
# Init regex fsm
|
||||
elif req.sampling_params.regex is not None:
|
||||
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
||||
if not self.disable_regex_jump_forward:
|
||||
req.jump_forward_map = self.jump_forward_cache.query(
|
||||
req.sampling_params.regex
|
||||
)
|
||||
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
||||
("regex", req.sampling_params.regex)
|
||||
)
|
||||
if not self.disable_regex_jump_forward:
|
||||
req.jump_forward_map = self.jump_forward_cache.query(
|
||||
computed_regex_string
|
||||
)
|
||||
|
||||
# Truncate prompts that are too long
|
||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||
@@ -390,16 +380,32 @@ class ModelTpServer:
|
||||
"the max context length. Truncated!!!"
|
||||
)
|
||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
(
|
||||
req.sampling_params.max_new_tokens
|
||||
if req.sampling_params.max_new_tokens is not None
|
||||
else 1 << 30
|
||||
),
|
||||
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
||||
)
|
||||
|
||||
if self.model_runner.is_generation:
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
(
|
||||
req.sampling_params.max_new_tokens
|
||||
if req.sampling_params.max_new_tokens is not None
|
||||
else 1 << 30
|
||||
),
|
||||
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def handle_embedding_request(
|
||||
self,
|
||||
recv_req: TokenizedEmbeddingReqInput,
|
||||
):
|
||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||
req.tokenizer = self.tokenizer
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
|
||||
# Truncate prompts that are too long
|
||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||
logger.warn(
|
||||
"Request length is longer than the KV cache pool size or "
|
||||
"the max context length. Truncated!!!"
|
||||
)
|
||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
@@ -892,7 +898,6 @@ def run_tp_server(
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
nccl_port: int,
|
||||
model_override_args: dict,
|
||||
):
|
||||
"""Run a tensor parallel model server."""
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
@@ -903,7 +908,6 @@ def run_tp_server(
|
||||
tp_rank,
|
||||
server_args,
|
||||
nccl_port,
|
||||
model_override_args,
|
||||
)
|
||||
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
||||
|
||||
@@ -920,14 +924,13 @@ def launch_tp_servers(
|
||||
tp_rank_range: List[int],
|
||||
server_args: ServerArgs,
|
||||
nccl_port: int,
|
||||
model_override_args: dict,
|
||||
):
|
||||
"""Launch multiple tensor parallel servers."""
|
||||
procs = []
|
||||
for i in tp_rank_range:
|
||||
proc = multiprocessing.Process(
|
||||
target=run_tp_server,
|
||||
args=(gpu_ids[i], i, server_args, nccl_port, model_override_args),
|
||||
args=(gpu_ids[i], i, server_args, nccl_port),
|
||||
)
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
@@ -18,6 +18,7 @@ limitations under the License.
|
||||
import gc
|
||||
import importlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
|
||||
@@ -272,7 +272,6 @@ async def retrieve_file_content(file_id: str):
|
||||
|
||||
def launch_server(
|
||||
server_args: ServerArgs,
|
||||
model_override_args: Optional[dict] = None,
|
||||
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
||||
):
|
||||
"""Launch an HTTP server."""
|
||||
@@ -317,7 +316,6 @@ def launch_server(
|
||||
tp_rank_range,
|
||||
server_args,
|
||||
ports[3],
|
||||
model_override_args,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -328,7 +326,7 @@ def launch_server(
|
||||
return
|
||||
|
||||
# Launch processes
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args)
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
if server_args.chat_template:
|
||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
||||
@@ -341,7 +339,7 @@ def launch_server(
|
||||
|
||||
proc_controller = mp.Process(
|
||||
target=start_controller_process,
|
||||
args=(server_args, port_args, pipe_controller_writer, model_override_args),
|
||||
args=(server_args, port_args, pipe_controller_writer),
|
||||
)
|
||||
proc_controller.start()
|
||||
|
||||
@@ -501,7 +499,6 @@ class Runtime:
|
||||
def __init__(
|
||||
self,
|
||||
log_level: str = "error",
|
||||
model_override_args: Optional[dict] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -525,7 +522,7 @@ class Runtime:
|
||||
|
||||
proc = mp.Process(
|
||||
target=launch_server,
|
||||
args=(self.server_args, model_override_args, pipe_writer),
|
||||
args=(self.server_args, pipe_writer),
|
||||
)
|
||||
proc.start()
|
||||
pipe_writer.close()
|
||||
|
||||
@@ -76,6 +76,14 @@ class ServerArgs:
|
||||
dp_size: int = 1
|
||||
load_balance_method: str = "round_robin"
|
||||
|
||||
# Distributed args
|
||||
nccl_init_addr: Optional[str] = None
|
||||
nnodes: int = 1
|
||||
node_rank: Optional[int] = None
|
||||
|
||||
# Model override args in JSON
|
||||
json_model_override_args: str = "{}"
|
||||
|
||||
# Optimization/debug options
|
||||
disable_flashinfer: bool = False
|
||||
disable_flashinfer_sampling: bool = False
|
||||
@@ -91,14 +99,6 @@ class ServerArgs:
|
||||
enable_mla: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
|
||||
# Distributed args
|
||||
nccl_init_addr: Optional[str] = None
|
||||
nnodes: int = 1
|
||||
node_rank: Optional[int] = None
|
||||
|
||||
# Model override args in JSON
|
||||
json_model_override_args: Optional[dict] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
@@ -385,6 +385,14 @@ class ServerArgs:
|
||||
)
|
||||
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
||||
|
||||
# Model override args
|
||||
parser.add_argument(
|
||||
"--json-model-override-args",
|
||||
type=str,
|
||||
help="A dictionary in JSON string format used to override default model configurations.",
|
||||
default=ServerArgs.json_model_override_args,
|
||||
)
|
||||
|
||||
# Optimization/debug options
|
||||
parser.add_argument(
|
||||
"--disable-flashinfer",
|
||||
@@ -459,22 +467,10 @@ class ServerArgs:
|
||||
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
||||
)
|
||||
|
||||
# Model override args
|
||||
parser.add_argument(
|
||||
"--json-model-override-args",
|
||||
type=str,
|
||||
help="A dictionary in JSON string format used to override default model configurations.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
args.dp_size = args.data_parallel_size
|
||||
args.json_model_override_args = (
|
||||
json.loads(args.json_model_override_args)
|
||||
if args.json_model_override_args
|
||||
else None
|
||||
)
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
|
||||
@@ -498,7 +494,7 @@ class ServerArgs:
|
||||
self.disable_flashinfer = False
|
||||
|
||||
|
||||
def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
|
||||
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
"""
|
||||
Prepare the server arguments from the command line arguments.
|
||||
|
||||
@@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
raw_args = parser.parse_args(args)
|
||||
raw_args = parser.parse_args(argv)
|
||||
server_args = ServerArgs.from_cli_args(raw_args)
|
||||
return server_args
|
||||
|
||||
|
||||
132
python/sglang/test/few_shot_gsm8k.py
Normal file
132
python/sglang/test/few_shot_gsm8k.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Run few-shot GSM-8K evaluation.
|
||||
|
||||
Usage:
|
||||
python3 -m sglang.test.few_shot_gsm8k --num-questions 200
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.api import set_default_backend
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def get_one_example(lines, i, include_answer):
|
||||
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
||||
if include_answer:
|
||||
ret += " " + lines[i]["answer"]
|
||||
return ret
|
||||
|
||||
|
||||
def get_few_shot_examples(lines, k):
|
||||
ret = ""
|
||||
for i in range(k):
|
||||
ret += get_one_example(lines, i, True) + "\n\n"
|
||||
return ret
|
||||
|
||||
|
||||
def get_answer_value(answer_str):
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r"\d+", answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
def main(args):
|
||||
# Select backend
|
||||
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
|
||||
|
||||
# Read data
|
||||
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||
filename = download_and_cache_file(url)
|
||||
lines = list(read_jsonl(filename))
|
||||
|
||||
# Construct prompts
|
||||
num_questions = args.num_questions
|
||||
num_shots = args.num_shots
|
||||
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(len(lines[:num_questions])):
|
||||
questions.append(get_one_example(lines, i, False))
|
||||
labels.append(get_answer_value(lines[i]["answer"]))
|
||||
assert all(l != INVALID for l in labels)
|
||||
arguments = [{"question": q} for q in questions]
|
||||
|
||||
#####################################
|
||||
######### SGL Program Begin #########
|
||||
#####################################
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
@sgl.function
|
||||
def few_shot_gsm8k(s, question):
|
||||
s += few_shot_examples + question
|
||||
s += sgl.gen(
|
||||
"answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
|
||||
)
|
||||
|
||||
#####################################
|
||||
########## SGL Program End ##########
|
||||
#####################################
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
states = few_shot_gsm8k.run_batch(
|
||||
arguments,
|
||||
temperature=0,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True,
|
||||
)
|
||||
latency = time.time() - tic
|
||||
|
||||
preds = []
|
||||
for i in range(len(states)):
|
||||
preds.append(get_answer_value(states[i]["answer"]))
|
||||
|
||||
# print(f"{preds=}")
|
||||
# print(f"{labels=}")
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid = np.mean(np.array(preds) == INVALID)
|
||||
|
||||
# Compute speed
|
||||
num_output_tokens = sum(
|
||||
s.get_meta_info("answer")["completion_tokens"] for s in states
|
||||
)
|
||||
output_throughput = num_output_tokens / latency
|
||||
|
||||
# Print results
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
print(f"Invalid: {invalid:.3f}")
|
||||
print(f"Latency: {latency:.3f} s")
|
||||
print(f"Output throughput: {output_throughput:.3f} token/s")
|
||||
|
||||
# Dump results
|
||||
dump_state_text("tmp_output_gsm8k.txt", states)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-shots", type=int, default=5)
|
||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=200)
|
||||
parser.add_argument("--parallel", type=int, default=128)
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
import numpy as np
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.utils import fetch_and_cache_jsonl
|
||||
from sglang.utils import download_and_cache_file, read_jsonl
|
||||
|
||||
|
||||
def test_few_shot_qa():
|
||||
@@ -456,10 +456,6 @@ def test_chat_completion_speculative():
|
||||
def test_hellaswag_select():
|
||||
"""Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
|
||||
|
||||
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
||||
lines = fetch_and_cache_jsonl(url)
|
||||
|
||||
# Construct prompts
|
||||
def get_one_example(lines, i, include_answer):
|
||||
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||
if include_answer:
|
||||
@@ -472,6 +468,12 @@ def test_hellaswag_select():
|
||||
ret += get_one_example(lines, i, True) + "\n\n"
|
||||
return ret
|
||||
|
||||
# Read data
|
||||
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
||||
filename = download_and_cache_file(url)
|
||||
lines = list(read_jsonl(filename))
|
||||
|
||||
# Construct prompts
|
||||
num_questions = 200
|
||||
num_shots = 20
|
||||
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||
|
||||
@@ -12,7 +12,7 @@ import urllib.request
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from json import dumps
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -38,13 +38,11 @@ def is_same_type(values: list):
|
||||
|
||||
def read_jsonl(filename: str):
|
||||
"""Read a JSONL file."""
|
||||
rets = []
|
||||
with open(filename) as fin:
|
||||
for line in fin:
|
||||
if line.startswith("#"):
|
||||
continue
|
||||
rets.append(json.loads(line))
|
||||
return rets
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
def dump_state_text(filename: str, states: list, mode: str = "w"):
|
||||
@@ -264,38 +262,35 @@ class LazyImport:
|
||||
return module(*args, **kwargs)
|
||||
|
||||
|
||||
def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"):
|
||||
"""Read and cache a jsonl file from a url."""
|
||||
def download_and_cache_file(url: str, filename: Optional[str] = None):
|
||||
"""Read and cache a file from a url."""
|
||||
if filename is None:
|
||||
filename = os.path.join("/tmp", url.split("/")[-1])
|
||||
|
||||
# Check if the cache file already exists
|
||||
if os.path.exists(cache_file):
|
||||
print("Loading data from cache...")
|
||||
with open(cache_file, "r") as f:
|
||||
data = [json.loads(line) for line in f]
|
||||
else:
|
||||
print("Downloading data from URL...")
|
||||
# Stream the response to show the progress bar
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status() # Check for request errors
|
||||
if os.path.exists(filename):
|
||||
return filename
|
||||
|
||||
# Total size of the file in bytes
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
chunk_size = 1024 # Download in chunks of 1KB
|
||||
print(f"Downloading from {url} to {filename}")
|
||||
|
||||
# Use tqdm to display the progress bar
|
||||
with open(cache_file, "wb") as f, tqdm(
|
||||
desc=cache_file,
|
||||
total=total_size,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||
f.write(chunk)
|
||||
bar.update(len(chunk))
|
||||
# Stream the response to show the progress bar
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status() # Check for request errors
|
||||
|
||||
# Convert the data to a list of dictionaries
|
||||
with open(cache_file, "r") as f:
|
||||
data = [json.loads(line) for line in f]
|
||||
# Total size of the file in bytes
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
chunk_size = 1024 # Download in chunks of 1KB
|
||||
|
||||
return data
|
||||
# Use tqdm to display the progress bar
|
||||
with open(filename, "wb") as f, tqdm(
|
||||
desc=filename,
|
||||
total=total_size,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||
f.write(chunk)
|
||||
bar.update(len(chunk))
|
||||
|
||||
return filename
|
||||
|
||||
Reference in New Issue
Block a user