From 3bc01ac1377001540b38fd8ccb470b29c0e74804 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 3 Jun 2024 18:11:34 -0700 Subject: [PATCH] [Minor] improve code style --- .../latency_throughput/bench_throughput.py | 8 +-- benchmark/latency_throughput/test_latency.py | 6 +- python/sglang/srt/hf_transformers_utils.py | 58 +++++++++++++++++-- python/sglang/test/test_utils.py | 12 ++-- 4 files changed, 67 insertions(+), 17 deletions(-) diff --git a/benchmark/latency_throughput/bench_throughput.py b/benchmark/latency_throughput/bench_throughput.py index 323a3f449..845cb0acb 100644 --- a/benchmark/latency_throughput/bench_throughput.py +++ b/benchmark/latency_throughput/bench_throughput.py @@ -149,12 +149,12 @@ async def send_request( "inputs": prompt, "parameters": params, } - elif backend == "xinfer": + elif backend == "ginfer": pass else: raise ValueError(f"Unknown backend: {backend}") - if backend != "xinfer": + if backend != "ginfer": timeout = aiohttp.ClientTimeout(total=3 * 3600) async with aiohttp.ClientSession(timeout=timeout) as session: while True: @@ -172,7 +172,7 @@ async def send_request( print(output) else: import grpc - from xlm.proto import sampler_pb2, sampler_pb2_grpc + from ginfer import sampler_pb2, sampler_pb2_grpc api_url = api_url.replace("http://", "").replace("/generate", "") sampler_channel = grpc.aio.insecure_channel(api_url) @@ -283,7 +283,7 @@ if __name__ == "__main__": "--backend", type=str, default="srt", - choices=["vllm", "tgi", "srt", "lightllm", "xinfer"], + choices=["vllm", "tgi", "srt", "lightllm", "ginfer"], ) parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=30000) diff --git a/benchmark/latency_throughput/test_latency.py b/benchmark/latency_throughput/test_latency.py index 37ab6aef6..593df054c 100644 --- a/benchmark/latency_throughput/test_latency.py +++ b/benchmark/latency_throughput/test_latency.py @@ -18,7 +18,7 @@ if __name__ == "__main__": args.port = 21000 elif args.backend == "lightllm": args.port = 22000 - elif args.backend == "xinfer": + elif args.backend == "ginfer": args.port = 9988 else: raise ValueError(f"Invalid backend: {args.backend}") @@ -60,9 +60,9 @@ if __name__ == "__main__": "max_tokens": max_new_tokens, }, ) - elif args.backend == "xinfer": + elif args.backend == "ginfer": import grpc - from xlm.proto import sampler_pb2, sampler_pb2_grpc + from ginfer import sampler_pb2, sampler_pb2_grpc sampler_channel = grpc.insecure_channel(url.replace("http://", "")) sampler = sampler_pb2_grpc.SamplerStub(sampler_channel) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index f5c1654aa..ec9b1dccd 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -3,7 +3,8 @@ import json import os import warnings -from typing import List, Optional, Tuple, Union +import functools +from typing import Optional, Union, AbstractSet, Collection, Literal from huggingface_hub import snapshot_download from transformers import ( @@ -177,10 +178,57 @@ def get_processor( class TiktokenTokenizer: def __init__(self, tokenizer_path): - import xlm.tokenizers.tiktoken_wrapper as tiktoken_wrapper - tokenizer = tiktoken_wrapper.Encoding.from_xtok_json("xtok-json", tokenizer_path) + import tiktoken + PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" + + # Read JSON + name = "tmp-json" + with open(tokenizer_path, "rb") as fin: + tok_dict = json.load(fin) + + mergeable_ranks = { + bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"] + } + special_tokens = { + bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"] + } + assert tok_dict["word_split"] == "V1" + + kwargs = { + "name": name, + "pat_str": tok_dict.get("pat_str", PAT_STR_B), + "mergeable_ranks": mergeable_ranks, + "special_tokens": special_tokens, + } + if "default_allowed_special" in tok_dict: + default_allowed_special = set( + [bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]] + ) + else: + default_allowed_special = None + if "vocab_size" in tok_dict: + kwargs["explicit_n_vocab"] = tok_dict["vocab_size"] + + tokenizer = tiktoken.Encoding(**kwargs) + tokenizer._default_allowed_special = default_allowed_special or set() + + def encode_patched( + self, + text: str, + *, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + ) -> list[int]: + if isinstance(allowed_special, set): + allowed_special |= self._default_allowed_special + return tiktoken.Encoding.encode( + self, text, allowed_special=allowed_special, disallowed_special=disallowed_special + ) + tokenizer.encode = functools.partial(encode_patched, tokenizer) + + # Convert to HF interface self.tokenizer = tokenizer - self.eos_token_id = tokenizer.eos_token + self.eos_token_id = tokenizer._special_tokens["<|eos|>"] self.vocab_size = tokenizer.n_vocab def encode(self, x, add_special_tokens=False): @@ -190,6 +238,8 @@ class TiktokenTokenizer: return self.tokenizer.decode(x) def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False): + if isinstance(batch[0], int): + batch = [[x] for x in batch] return self.tokenizer.decode_batch(batch) def convert_ids_to_tokens(self, index): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 18d0f6c32..693bade6f 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -88,9 +88,9 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None): return pred -def call_generate_xinfer(prompt, temperature, max_tokens, stop=None, url=None): +def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None): import grpc - from xlm.proto import sampler_pb2, sampler_pb2_grpc + from ginfer import sampler_pb2, sampler_pb2_grpc sampler_channel = grpc.insecure_channel(url.replace("http://", "")) sampler = sampler_pb2_grpc.SamplerStub(sampler_channel) @@ -255,7 +255,7 @@ def add_common_other_args_and_parse(parser): "vllm", "outlines", "lightllm", - "xinfer", + "ginfer", "guidance", "lmql", "srt-raw", @@ -276,7 +276,7 @@ def add_common_other_args_and_parse(parser): "lightllm": 22000, "lmql": 23000, "srt-raw": 30000, - "xinfer": 9988, + "ginfer": 9988, } args.port = default_port.get(args.backend, None) return args @@ -312,8 +312,8 @@ def _get_call_generate(args): return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate") elif args.backend == "srt-raw": return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate") - elif args.backend == "xinfer": - return partial(call_generate_xinfer, url=f"{args.host}:{args.port}") + elif args.backend == "ginfer": + return partial(call_generate_ginfer, url=f"{args.host}:{args.port}") elif args.backend == "outlines": return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate") elif args.backend == "guidance":