Fix logit processor bugs (#427)

This commit is contained in:
Lianmin Zheng
2024-05-12 04:54:07 -07:00
committed by GitHub
parent 7023f413c6
commit aee4f523cf
26 changed files with 166 additions and 257 deletions

View File

@@ -297,7 +297,6 @@ curl http://localhost:30000/generate \
Learn more about the argument format [here](docs/sampling_params.md). Learn more about the argument format [here](docs/sampling_params.md).
### OpenAI Compatible API ### OpenAI Compatible API
In addition, the server supports an experimental OpenAI-compatible API. In addition, the server supports an experimental OpenAI-compatible API.
```python ```python
@@ -386,7 +385,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md). Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md).
## Benchmark And Performance ## Benchmark And Performance
- Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1 - Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
![llama_7b](assets/llama_7b.jpg) ![llama_7b](assets/llama_7b.jpg)
@@ -410,7 +408,4 @@ https://github.com/sgl-project/sglang/issues/157
} }
``` ```
[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md.svg)](https://huggingface.co/papers/2312.07104)
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql). We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).

View File

@@ -1,5 +1,6 @@
"""Some Public API Definitions""" """Some Public API Definitions"""
import os
import re import re
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
@@ -31,6 +32,7 @@ def function(
def Runtime(*args, **kwargs): def Runtime(*args, **kwargs):
# Avoid importing unnecessary dependency # Avoid importing unnecessary dependency
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from sglang.srt.server import Runtime from sglang.srt.server import Runtime
return Runtime(*args, **kwargs) return Runtime(*args, **kwargs)

View File

@@ -14,7 +14,7 @@ except ImportError as e:
class Anthropic(BaseBackend): class Anthropic(BaseBackend):
def __init__(self, model_name): def __init__(self, model_name, *args, **kwargs):
super().__init__() super().__init__()
if isinstance(anthropic, Exception): if isinstance(anthropic, Exception):
@@ -22,6 +22,7 @@ class Anthropic(BaseBackend):
self.model_name = model_name self.model_name = model_name
self.chat_template = get_chat_template("claude") self.chat_template = get_chat_template("claude")
self.client = anthropic.Anthropic(*args, **kwargs)
def get_chat_template(self): def get_chat_template(self):
return self.chat_template return self.chat_template
@@ -41,7 +42,7 @@ class Anthropic(BaseBackend):
else: else:
system = "" system = ""
ret = anthropic.Anthropic().messages.create( ret = self.client.messages.create(
model=self.model_name, model=self.model_name,
system=system, system=system,
messages=messages, messages=messages,
@@ -66,7 +67,7 @@ class Anthropic(BaseBackend):
else: else:
system = "" system = ""
with anthropic.Anthropic().messages.stream( with self.client.messages.stream(
model=self.model_name, model=self.model_name,
system=system, system=system,
messages=messages, messages=messages,

View File

@@ -228,7 +228,7 @@ class OpenAI(BaseBackend):
prompt_tokens.append(ret_token) prompt_tokens.append(ret_token)
decision = choices[np.argmax(scores)] decision = choices[np.argmax(scores)]
return decision, scores, scores return decision, scores, None, None
def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs): def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):

View File

@@ -220,7 +220,6 @@ class RuntimeEndpoint(BaseBackend):
"sampling_params": {"max_new_tokens": 0}, "sampling_params": {"max_new_tokens": 0},
"return_logprob": True, "return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0), "logprob_start_len": max(prompt_len - 2, 0),
"return_text_in_logprobs": True,
} }
self._add_images(s, data) self._add_images(s, data)
res = http_request( res = http_request(

View File

@@ -42,26 +42,29 @@ class LogitsProcessor(nn.Module):
for i in range(all_logprobs.shape[0]): for i in range(all_logprobs.shape[0]):
k = input_metadata.top_logprobs_nums[i] k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k) t = all_logprobs[i].topk(k)
v_cpu = t.values.cpu().tolist() v_cpu = t.values.tolist()
p_cpu = t.indices.cpu().tolist() p_cpu = t.indices.tolist()
decode_top_logprobs.append(list(zip(v_cpu, p_cpu))) decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, decode_top_logprobs return None, decode_top_logprobs
else: else:
prefill_top_logprobs, decode_top_logprobs = [], [] prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0 pt = 0
# NOTE: the GPU-CPU overhead can be reduced # NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu = input_metadata.extend_seq_lens extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
for i in range(len(input_metadata.extend_seq_lens)): for i in range(len(extend_seq_lens_cpu)):
if extend_seq_lens_cpu[i] == 0: if extend_seq_lens_cpu[i] == 0:
prefill_top_logprobs.append([])
decode_top_logprobs.append([])
continue continue
k = input_metadata.top_logprobs_nums[i] k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k) t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
vs_cpu = t.values.cpu().tolist() vs_cpu = t.values.tolist()
ps_cpu = t.indices.cpu().tolist() ps_cpu = t.indices.tolist()
prefill_top_logprobs.append( prefill_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
) )
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1]))) decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_lens_cpu[i]
return prefill_top_logprobs, decode_top_logprobs return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata): def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
@@ -99,20 +102,24 @@ class LogitsProcessor(nn.Module):
all_logits = all_logits[:, : self.config.vocab_size] all_logits = all_logits[:, : self.config.vocab_size]
all_logprobs = all_logits.float() all_logprobs = all_logits.float()
all_logits = None del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata all_logprobs, input_metadata
) )
else:
prefill_top_logprobs = decode_top_logprobs = None
if input_metadata.forward_mode == ForwardMode.DECODE: if input_metadata.forward_mode == ForwardMode.DECODE:
last_logprobs = all_logprobs last_logprobs = all_logprobs
return last_logits, ( return last_logits, (
None, None,
None, None,
decode_top_logprobs,
None, None,
decode_top_logprobs,
last_logprobs, last_logprobs,
) )
else: else:
@@ -131,9 +138,9 @@ class LogitsProcessor(nn.Module):
) )
return last_logits, ( return last_logits, (
prefill_token_logprobs, prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs, prefill_top_logprobs,
decode_top_logprobs, decode_top_logprobs,
normalized_prompt_logprobs,
last_logprobs, last_logprobs,
) )

View File

@@ -25,7 +25,6 @@ class GenerateReqInput:
return_text_in_logprobs: bool = False return_text_in_logprobs: bool = False
# Whether to stream output # Whether to stream output
stream: bool = False stream: bool = False
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
def post_init(self): def post_init(self):
is_single = isinstance(self.text, str) is_single = isinstance(self.text, str)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import IntEnum, auto
from typing import List from typing import List
import numpy as np import numpy as np
@@ -9,15 +9,15 @@ from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
class ForwardMode(Enum): class ForwardMode(IntEnum):
PREFILL = auto() PREFILL = auto()
EXTEND = auto() EXTEND = auto()
DECODE = auto() DECODE = auto()
class FinishReason(Enum): class FinishReason(IntEnum):
LENGTH = auto()
EOS_TOKEN = auto() EOS_TOKEN = auto()
LENGTH = auto()
STOP_STR = auto() STOP_STR = auto()
@@ -31,6 +31,7 @@ class Req:
# Since jump forward may retokenize the prompt with partial outputs, # Since jump forward may retokenize the prompt with partial outputs,
# we maintain the original prompt length to report the correct usage. # we maintain the original prompt length to report the correct usage.
self.prompt_tokens = len(input_ids) self.prompt_tokens = len(input_ids)
# The number of decoded tokens for token usage report. Note that # The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens. # this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0 self.completion_tokens_wo_jump_forward = 0
@@ -41,12 +42,11 @@ class Req:
self.image_offset = 0 self.image_offset = 0
self.pad_value = None self.pad_value = None
# Sampling parameters
self.sampling_params = None self.sampling_params = None
self.return_logprob = False
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.stream = False self.stream = False
# Check finish
self.tokenizer = None self.tokenizer = None
self.finished = False self.finished = False
self.finish_reason = None self.finish_reason = None
@@ -56,13 +56,17 @@ class Req:
self.prefix_indices = [] self.prefix_indices = []
self.last_node = None self.last_node = None
# Logprobs
self.return_logprob = False
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.normalized_prompt_logprob = None
self.prefill_token_logprobs = None self.prefill_token_logprobs = None
self.decode_token_logprobs = None self.decode_token_logprobs = None
self.normalized_prompt_logprob = None
self.prefill_top_logprobs = None self.prefill_top_logprobs = None
self.decode_top_logprobs = None self.decode_top_logprobs = None
# For constrained decoding # Constrained decoding
self.regex_fsm = None self.regex_fsm = None
self.regex_fsm_state = 0 self.regex_fsm_state = 0
self.jump_forward_map = None self.jump_forward_map = None
@@ -165,8 +169,8 @@ class Batch:
out_cache_cont_end: torch.Tensor = None out_cache_cont_end: torch.Tensor = None
# for processing logprobs # for processing logprobs
top_logprobs_nums: List[int] = None
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None
# for multimodal # for multimodal
pixel_values: List[torch.Tensor] = None pixel_values: List[torch.Tensor] = None
@@ -321,8 +325,8 @@ class Batch:
) )
retracted_reqs = [] retracted_reqs = []
seq_lens_np = self.seq_lens.cpu().numpy() seq_lens_cpu = self.seq_lens.cpu().numpy()
req_pool_indices_np = self.req_pool_indices.cpu().numpy() req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
while self.token_to_kv_pool.available_size() < len(self.reqs): while self.token_to_kv_pool.available_size() < len(self.reqs):
idx = sorted_indices.pop() idx = sorted_indices.pop()
req = self.reqs[idx] req = self.reqs[idx]
@@ -338,8 +342,8 @@ class Batch:
# TODO: apply more fine-grained retraction # TODO: apply more fine-grained retraction
token_indices = self.req_to_token_pool.req_to_token[ token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_np[idx] req_pool_indices_cpu[idx]
][: seq_lens_np[idx]] ][: seq_lens_cpu[idx]]
self.token_to_kv_pool.dec_refs(token_indices) self.token_to_kv_pool.dec_refs(token_indices)
self.filter_batch(sorted_indices) self.filter_batch(sorted_indices)
@@ -363,7 +367,7 @@ class Batch:
# insert the old request into tree_cache # insert the old request into tree_cache
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1] token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
if req_pool_indices_cpu is None: if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist() req_pool_indices_cpu = self.req_pool_indices.tolist()
req_pool_idx = req_pool_indices_cpu[i] req_pool_idx = req_pool_indices_cpu[i]
indices = self.req_to_token_pool.req_to_token[ indices = self.req_to_token_pool.req_to_token[
req_pool_idx, : len(token_ids_in_memory) req_pool_idx, : len(token_ids_in_memory)

View File

@@ -36,7 +36,9 @@ from sglang.srt.utils import (
set_random_seed, set_random_seed,
) )
logger = logging.getLogger("model_rpc") logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN)
@@ -54,9 +56,6 @@ class ModelRpcServer:
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
vllm_default_logger.setLevel(
level=getattr(logging, server_args.log_level.upper())
)
# Init model and tokenizer # Init model and tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
@@ -65,7 +64,7 @@ class ModelRpcServer:
context_length=server_args.context_length, context_length=server_args.context_length,
) )
# for model end global settings # For model end global settings
server_args_dict = { server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer, "enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
@@ -164,7 +163,7 @@ class ModelRpcServer:
logger.info("Cache flushed successfully!") logger.info("Cache flushed successfully!")
else: else:
warnings.warn( warnings.warn(
"Cache not flushed because there are pending requests. " f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.forward_queue)}, " f"#queue-req: {len(self.forward_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
) )
@@ -386,12 +385,12 @@ class ModelRpcServer:
f"#running_req: {running_req}. " f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%." f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
) )
logger.debug( #logger.debug(
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. " # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. " # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
) #)
new_batch = Batch.init_new( new_batch = Batch.init_new(
can_run_list, can_run_list,
@@ -408,47 +407,41 @@ class ModelRpcServer:
self.model_config.vocab_size, self.int_token_logit_bias self.model_config.vocab_size, self.int_token_logit_bias
) )
prefill_token_logprobs = None
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
# Forward # Forward
logits, ( logits, (
prefill_token_logprobs, prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs, prefill_top_logprobs,
decode_top_logprobs, decode_top_logprobs,
normalized_prompt_logprobs,
last_logprobs, last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.EXTEND) ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
if prefill_token_logprobs is not None: if prefill_token_logprobs is not None:
prefill_token_logprobs = prefill_token_logprobs.cpu().tolist() prefill_token_logprobs = prefill_token_logprobs.tolist()
normalized_prompt_logprobs = normalized_prompt_logprobs.cpu().tolist() normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
next_token_ids, _ = batch.sample(logits) next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
(
logits,
prefill_token_logprobs,
normalized_prompt_logprobs,
last_logprobs,
) = (None,) * 4
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
last_token_logprobs = None
if last_logprobs is not None: if last_logprobs is not None:
last_token_logprobs = ( last_token_logprobs = (
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist() last_logprobs[torch.arange(len(batch.reqs)), next_token_ids].tolist()
) )
next_token_ids = next_token_ids.tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
# Check finish condition # Check finish condition
pt = 0 pt = 0
for i, req in enumerate(reqs): for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
req.output_ids = [next_token_ids[i]] req.output_ids = [next_token_ids[i]]
req.check_finished() req.check_finished()
if prefill_token_logprobs is not None: if req.return_logprob:
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list( req.prefill_token_logprobs = list(
zip( zip(
@@ -463,11 +456,13 @@ class ModelRpcServer:
req.decode_token_logprobs = [ req.decode_token_logprobs = [
(last_token_logprobs[i], next_token_ids[i]) (last_token_logprobs[i], next_token_ids[i])
] ]
if req.top_logprobs_num > 0:
req.prefill_top_logprobs = prefill_top_logprobs[i] req.prefill_top_logprobs = prefill_top_logprobs[i]
if req.logprob_start_len == 0: if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
req.decode_top_logprobs = [decode_top_logprobs[i]] req.decode_top_logprobs = [decode_top_logprobs[i]]
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
pt += req.extend_input_len pt += req.extend_input_len
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
@@ -520,29 +515,29 @@ class ModelRpcServer:
logits, ( logits, (
_, _,
_, _,
decode_top_logprobs,
_, _,
decode_top_logprobs,
last_logprobs, last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.DECODE) ) = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(logits) next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist() next_token_ids = next_token_ids.tolist()
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
new_token_logprobs = None
if last_logprobs is not None: if last_logprobs is not None:
new_token_logprobs = last_logprobs[ new_token_logprobs = last_logprobs[
torch.arange(len(reqs)), next_token_ids torch.arange(len(batch.reqs)), next_token_ids
].tolist() ].tolist()
# Check finish condition # Check finish condition
for i, (req, next_token_id) in enumerate(zip(reqs, next_token_ids)): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
req.check_finished() req.check_finished()
if new_token_logprobs is not None: if req.return_logprob:
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id)) req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(decode_top_logprobs[i]) req.decode_top_logprobs.append(decode_top_logprobs[i])
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
@@ -590,8 +585,7 @@ class ModelRpcServer:
+ len(req.output_ids) + len(req.output_ids)
- req.prompt_tokens, - req.prompt_tokens,
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finish_reason), "finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
"hit_stop_str": req.hit_stop_str,
} }
if req.return_logprob: if req.return_logprob:
( (
@@ -628,7 +622,7 @@ class ModelRpcServer:
# Remove finished reqs # Remove finished reqs
if finished_indices: if finished_indices:
# Update radix cache # Update radix cache
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist() req_pool_indices_cpu = batch.req_pool_indices.tolist()
for i in finished_indices: for i in finished_indices:
req = batch.reqs[i] req = batch.reqs[i]
req_pool_idx = req_pool_indices_cpu[i] req_pool_idx = req_pool_indices_cpu[i]

View File

@@ -29,7 +29,7 @@ QUANTIZATION_CONFIG_MAPPING = {
logger = logging.getLogger("model_runner") logger = logging.getLogger("model_runner")
# for server args in model endpoints # for server args in model endpoints
global_server_args_dict: dict = None global_server_args_dict = {}
@lru_cache() @lru_cache()
@@ -86,8 +86,8 @@ class InputMetadata:
out_cache_cont_end: torch.Tensor = None out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None other_kv_index: torch.Tensor = None
top_logprobs_nums: List[int] = None
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None
# for flashinfer # for flashinfer
qo_indptr: torch.Tensor = None qo_indptr: torch.Tensor = None
@@ -107,18 +107,20 @@ class InputMetadata:
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0) self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
seq_lens_cpu = self.seq_lens.cpu().numpy()
self.kv_indices = torch.cat( self.kv_indices = torch.cat(
[ [
self.req_to_token_pool.req_to_token[ self.req_to_token_pool.req_to_token[
self.req_pool_indices[i].item(), : self.seq_lens[i].item() req_pool_indices_cpu[i]: seq_lens_cpu[i]
] ]
for i in range(self.batch_size) for i in range(self.batch_size)
], ],
dim=0, dim=0,
).contiguous() ).contiguous()
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
workspace_buffer = torch.empty( workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda" 32 * 1024 * 1024, dtype=torch.int8, device="cuda"
@@ -195,15 +197,15 @@ class InputMetadata:
req_pool_indices[0], seq_lens[0] - 1 req_pool_indices[0], seq_lens[0] - 1
].item() ].item()
else: else:
seq_lens_np = seq_lens.cpu().numpy() seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_np = prefix_lens.cpu().numpy() prefix_lens_cpu = prefix_lens.cpu().numpy()
position_ids_offsets_np = position_ids_offsets.cpu().numpy() position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
positions = torch.tensor( positions = torch.tensor(
np.concatenate( np.concatenate(
[ [
np.arange( np.arange(
prefix_lens_np[i] + position_ids_offsets_np[i], prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
seq_lens_np[i] + position_ids_offsets_np[i], seq_lens_cpu[i] + position_ids_offsets_cpu[i],
) )
for i in range(batch_size) for i in range(batch_size)
], ],
@@ -229,9 +231,9 @@ class InputMetadata:
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start, out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end, out_cache_cont_end=out_cache_cont_end,
top_logprobs_nums=top_logprobs_nums,
return_logprob=return_logprob,
other_kv_index=other_kv_index, other_kv_index=other_kv_index,
return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums,
) )
if forward_mode == ForwardMode.EXTEND: if forward_mode == ForwardMode.EXTEND:

View File

@@ -185,7 +185,10 @@ class TokenizerManager:
while True: while True:
await event.wait() await event.wait()
yield state.out_list[-1] yield self.convert_logprob_style(state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs)
state.out_list = [] state.out_list = []
if state.finished: if state.finished:
del self.rid_to_state[rid] del self.rid_to_state[rid]
@@ -231,16 +234,16 @@ class TokenizerManager:
rid = obj.rid[i] rid = obj.rid[i]
state = self.rid_to_state[rid] state = self.rid_to_state[rid]
await state.event.wait() await state.event.wait()
output_list.append(state.out_list[-1]) output_list.append(
self.convert_logprob_style(state.out_list[-1],
obj.return_logprob[i],
obj.top_logprobs_num[i],
obj.return_text_in_logprobs))
assert state.finished assert state.finished
del self.rid_to_state[rid] del self.rid_to_state[rid]
yield output_list yield output_list
async def detokenize(self, obj: DetokenizeReqInput):
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
async def flush_cache(self): async def flush_cache(self):
flush_cache_req = FlushCacheReq() flush_cache_req = FlushCacheReq()
self.send_to_router.send_pyobj(flush_cache_req) self.send_to_router.send_pyobj(flush_cache_req)
@@ -267,3 +270,37 @@ class TokenizerManager:
state.event.set() state.event.set()
else: else:
raise ValueError(f"Invalid object: {recv_obj}") raise ValueError(f"Invalid object: {recv_obj}")
def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
if return_logprob:
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
return ret
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
token_ids = [tid for _, tid in token_logprobs]
token_texts = self.tokenizer.batch_decode(token_ids)
return [
(logprob, token_id, token_text)
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
]
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
for i, t in enumerate(top_logprobs):
if t:
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
return top_logprobs

View File

@@ -1,3 +1,4 @@
"""pydantic models for OpenAI API protocol"""
import time import time
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union

View File

@@ -10,7 +10,7 @@ import threading
import time import time
from typing import List, Optional, Union from typing import List, Optional, Union
# Fix a Python bug # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import aiohttp import aiohttp
@@ -53,10 +53,10 @@ from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
enable_show_time_cost,
allocate_init_ports, allocate_init_ports,
jsonify_pydantic_model,
assert_pkg_version, assert_pkg_version,
enable_show_time_cost,
jsonify_pydantic_model,
get_exception_traceback, get_exception_traceback,
API_KEY_HEADER_NAME, API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware APIKeyValidatorMiddleware
@@ -99,12 +99,6 @@ async def flush_cache():
) )
async def stream_generator(obj: GenerateReqInput):
async for out in tokenizer_manager.generate_request(obj):
await handle_token_logprobs_results(obj, out)
yield out
@app.post("/generate") @app.post("/generate")
async def generate_request(obj: GenerateReqInput): async def generate_request(obj: GenerateReqInput):
obj.post_init() obj.post_init()
@@ -112,69 +106,16 @@ async def generate_request(obj: GenerateReqInput):
if obj.stream: if obj.stream:
async def stream_results(): async def stream_results():
async for out in stream_generator(obj): async for out in tokenizer_manager.generate_request(obj):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream") return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__() ret = await tokenizer_manager.generate_request(obj).__anext__()
await handle_token_logprobs_results(obj, ret)
return ret return ret
async def detokenize_logprob_tokens(token_logprobs, decode_to_text):
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
token_ids = [tid for _, tid in token_logprobs]
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
return [
(logprob, token_id, token_text)
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
]
async def detokenize_top_logprobs_tokens(top_logprobs, decode_to_text):
for i, t in enumerate(top_logprobs):
if top_logprobs[i] is not None:
top_logprobs[i] = await detokenize_logprob_tokens(t, decode_to_text)
return top_logprobs
async def handle_token_logprobs_results(obj: GenerateReqInput, ret):
"""Handle the token logprobs results, convert token ids to text if needed.
Args:
obj (GenerateReqInput): The request object.
ret (Union[Dict, List[Dict]]): The response object.
"""
# NOTE: This is because the multiple requests in one http request.
async def convert_style(r, return_text):
r["meta_info"]["prefill_token_logprobs"] = await detokenize_logprob_tokens(
r["meta_info"]["prefill_token_logprobs"], return_text
)
r["meta_info"]["decode_token_logprobs"] = await detokenize_logprob_tokens(
r["meta_info"]["decode_token_logprobs"], return_text
)
r["meta_info"]["prefill_top_logprobs"] = await detokenize_top_logprobs_tokens(
r["meta_info"]["prefill_top_logprobs"], return_text
)
r["meta_info"]["decode_top_logprobs"] = await detokenize_top_logprobs_tokens(
r["meta_info"]["decode_top_logprobs"], return_text
)
if isinstance(obj.text, str):
if obj.return_logprob:
await convert_style(ret, obj.return_text_in_logprobs)
else:
for i, r in enumerate(ret):
if obj.return_logprob[i]:
await convert_style(r, obj.return_text_in_logprobs)
@app.post("/v1/completions") @app.post("/v1/completions")
async def v1_completions(raw_request: Request): async def v1_completions(raw_request: Request):
request_json = await raw_request.json() request_json = await raw_request.json()
@@ -203,10 +144,10 @@ async def v1_completions(raw_request: Request):
if adapted_request.stream: if adapted_request.stream:
async def gnerate_stream_resp(): async def generate_stream_resp():
stream_buffer = "" stream_buffer = ""
n_prev_token = 0 n_prev_token = 0
async for content in stream_generator(adapted_request): async for content in tokenizer_manager.generate_request(adapted_request):
text = content["text"] text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"] prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"] completion_tokens = content["meta_info"]["completion_tokens"]
@@ -266,7 +207,7 @@ async def v1_completions(raw_request: Request):
yield f"data: {jsonify_pydantic_model(chunk)}\n\n" yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream") return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
# Non-streaming response. # Non-streaming response.
ret = await generate_request(adapted_request) ret = await generate_request(adapted_request)
@@ -384,7 +325,7 @@ async def v1_chat_completions(raw_request: Request):
is_first = True is_first = True
stream_buffer = "" stream_buffer = ""
async for content in stream_generator(adapted_request): async for content in tokenizer_manager.generate_request(adapted_request):
if is_first: if is_first:
# First chunk with role # First chunk with role
is_first = False is_first = False

View File

@@ -241,7 +241,7 @@ class ServerArgs:
def print_mode_args(self): def print_mode_args(self):
return ( return (
f"enable_flashinfer={self.enable_flashinfer}, " f"enable_flashinfer={self.enable_flashinfer}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}" f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
f"disable_radix_cache={self.disable_radix_cache}, " f"disable_radix_cache={self.disable_radix_cache}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, " f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
f"disable_disk_cache={self.disable_disk_cache}, " f"disable_disk_cache={self.disable_disk_cache}, "

View File

@@ -1,3 +1,5 @@
"""Common utilities."""
import base64 import base64
import os import os
import random import random
@@ -13,6 +15,7 @@ import numpy as np
import pydantic import pydantic
import requests import requests
import torch import torch
from fastapi.responses import JSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from pydantic import BaseModel from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
@@ -303,6 +306,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
response = await call_next(request) response = await call_next(request)
return response return response
# FIXME: Remove this once we drop support for pydantic 1.x # FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1

View File

@@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True):
def test_image_qa(): def test_image_qa():
@sgl.function @sgl.function
def image_qa(s, question): def image_qa(s, question):
s += sgl.user(sgl.image("test_image.png") + question) s += sgl.user(sgl.image("example_image.png") + question)
s += sgl.assistant(sgl.gen("answer")) s += sgl.assistant(sgl.gen("answer"))
state = image_qa.run( state = image_qa.run(

BIN
test/lang/example_image.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

View File

@@ -28,7 +28,7 @@ class TestOpenAIBackend(unittest.TestCase):
if cls.backend is None: if cls.backend is None:
cls.backend = OpenAI("gpt-3.5-turbo-instruct") cls.backend = OpenAI("gpt-3.5-turbo-instruct")
cls.chat_backend = OpenAI("gpt-3.5-turbo") cls.chat_backend = OpenAI("gpt-3.5-turbo")
cls.chat_vision_backend = OpenAI("gpt-4-vision-preview") cls.chat_vision_backend = OpenAI("gpt-4-turbo")
def test_few_shot_qa(self): def test_few_shot_qa(self):
set_default_backend(self.backend) set_default_backend(self.backend)
@@ -88,14 +88,3 @@ if __name__ == "__main__":
# t = TestOpenAIBackend() # t = TestOpenAIBackend()
# t.setUp() # t.setUp()
# t.test_few_shot_qa() # t.test_few_shot_qa()
# t.test_mt_bench()
# t.test_select()
# t.test_decode_int()
# t.test_decode_json()
# t.test_expert_answer()
# t.test_tool_use()
# t.test_react()
# t.test_parallel_decoding()
# t.test_parallel_encoding()
# t.test_image_qa()
# t.test_stream()

View File

@@ -1,68 +0,0 @@
from sglang import OpenAI, function, gen, set_default_backend
@function()
def gen_character_default(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
@function(api_num_spec_tokens=512)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
@function(api_num_spec_tokens=512)
def gen_character_no_stop(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name") + "\nBirthday:" + gen("birthday")
s += "\nJob:" + gen("job") + "\nWelcome.\n"
@function(api_num_spec_tokens=512)
def gen_character_multi_stop(s):
s += "Construct a character within the following format:\n"
s += (
"Name: Steve Jobs.###Birthday: February 24, 1955.###Job: Apple CEO.\nWelcome.\n"
)
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + gen("name", stop=["\n", "###"])
s += "###Birthday:" + gen("birthday", stop=["\n", "###"])
s += "###Job:" + gen("job", stop=["\n", "###"]) + "\nWelcome.\n"
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
state = gen_character_default.run()
print(state.text())
print("=" * 60)
state = gen_character_no_stop.run()
print("name###", state["name"])
print("birthday###:", state["birthday"])
print("job###", state["job"])
print("=" * 60)
state = gen_character_multi_stop.run()
print(state.text())
print("=" * 60)
state = gen_character_spec.run()
print(state.text())
print("name###", state["name"])
print("birthday###", state["birthday"])
print("job###", state["job"])

1
test/srt/example_image.png Symbolic link
View File

@@ -0,0 +1 @@
../lang/example_image.png

View File

@@ -3,7 +3,6 @@ Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode.py python3 test_httpserver_decode.py
Output: Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
""" """
@@ -23,6 +22,7 @@ def test_decode(url, return_logprob, top_logprobs_num, return_text):
"temperature": 0, "temperature": 0,
"max_new_tokens": 32, "max_new_tokens": 32,
}, },
"stream": False,
"return_logprob": return_logprob, "return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num, "top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text, "return_text_in_logprobs": return_text,

View File

@@ -26,6 +26,7 @@ def test_decode_stream(url, return_logprob, top_logprobs_num):
"return_logprob": return_logprob, "return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num, "top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": True, "return_text_in_logprobs": True,
"logprob_start_len": 0,
}, },
stream=True, stream=True,
) )

View File

@@ -34,7 +34,7 @@ async def test_concurrent(args):
url + "/generate", url + "/generate",
{ {
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:", "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "test_image.png", "image_data": "example_image.png",
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 16,
@@ -55,7 +55,7 @@ def test_streaming(args):
url + "/generate", url + "/generate",
json={ json={
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:", "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "test_image.png", "image_data": "example_image.png",
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 128, "max_new_tokens": 128,

View File

@@ -6,10 +6,10 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
""" """
import argparse import argparse
import time
import requests import requests
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--host", type=str, default="http://127.0.0.1")

View File

@@ -163,7 +163,7 @@ def test_regex(args):
regex = ( regex = (
r"""\{\n""" r"""\{\n"""
+ r""" "name": "[\w]+",\n""" + r""" "name": "[\w]+",\n"""
+ r""" "population": "[\w\d\s]+"\n""" + r""" "population": [\w\d\s]+\n"""
+ r"""\}""" + r"""\}"""
) )