Fix logit processor bugs (#427)
This commit is contained in:
@@ -297,7 +297,6 @@ curl http://localhost:30000/generate \
|
||||
Learn more about the argument format [here](docs/sampling_params.md).
|
||||
|
||||
### OpenAI Compatible API
|
||||
|
||||
In addition, the server supports an experimental OpenAI-compatible API.
|
||||
|
||||
```python
|
||||
@@ -386,7 +385,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
||||
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md).
|
||||
|
||||
## Benchmark And Performance
|
||||
|
||||
- Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
|
||||

|
||||
|
||||
@@ -410,7 +408,4 @@ https://github.com/sgl-project/sglang/issues/157
|
||||
}
|
||||
```
|
||||
|
||||
[](https://huggingface.co/papers/2312.07104)
|
||||
|
||||
|
||||
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Some Public API Definitions"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
@@ -31,6 +32,7 @@ def function(
|
||||
|
||||
def Runtime(*args, **kwargs):
|
||||
# Avoid importing unnecessary dependency
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
from sglang.srt.server import Runtime
|
||||
|
||||
return Runtime(*args, **kwargs)
|
||||
|
||||
@@ -14,7 +14,7 @@ except ImportError as e:
|
||||
|
||||
|
||||
class Anthropic(BaseBackend):
|
||||
def __init__(self, model_name):
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(anthropic, Exception):
|
||||
@@ -22,6 +22,7 @@ class Anthropic(BaseBackend):
|
||||
|
||||
self.model_name = model_name
|
||||
self.chat_template = get_chat_template("claude")
|
||||
self.client = anthropic.Anthropic(*args, **kwargs)
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
@@ -41,7 +42,7 @@ class Anthropic(BaseBackend):
|
||||
else:
|
||||
system = ""
|
||||
|
||||
ret = anthropic.Anthropic().messages.create(
|
||||
ret = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
system=system,
|
||||
messages=messages,
|
||||
@@ -66,11 +67,11 @@ class Anthropic(BaseBackend):
|
||||
else:
|
||||
system = ""
|
||||
|
||||
with anthropic.Anthropic().messages.stream(
|
||||
with self.client.messages.stream(
|
||||
model=self.model_name,
|
||||
system=system,
|
||||
messages=messages,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
yield text, {}
|
||||
yield text, {}
|
||||
@@ -228,7 +228,7 @@ class OpenAI(BaseBackend):
|
||||
prompt_tokens.append(ret_token)
|
||||
|
||||
decision = choices[np.argmax(scores)]
|
||||
return decision, scores, scores
|
||||
return decision, scores, None, None
|
||||
|
||||
|
||||
def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
||||
|
||||
@@ -220,7 +220,6 @@ class RuntimeEndpoint(BaseBackend):
|
||||
"sampling_params": {"max_new_tokens": 0},
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": max(prompt_len - 2, 0),
|
||||
"return_text_in_logprobs": True,
|
||||
}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
|
||||
@@ -42,26 +42,29 @@ class LogitsProcessor(nn.Module):
|
||||
for i in range(all_logprobs.shape[0]):
|
||||
k = input_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[i].topk(k)
|
||||
v_cpu = t.values.cpu().tolist()
|
||||
p_cpu = t.indices.cpu().tolist()
|
||||
v_cpu = t.values.tolist()
|
||||
p_cpu = t.indices.tolist()
|
||||
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
||||
return None, decode_top_logprobs
|
||||
else:
|
||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||
pt = 0
|
||||
# NOTE: the GPU-CPU overhead can be reduced
|
||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens
|
||||
for i in range(len(input_metadata.extend_seq_lens)):
|
||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
|
||||
for i in range(len(extend_seq_lens_cpu)):
|
||||
if extend_seq_lens_cpu[i] == 0:
|
||||
prefill_top_logprobs.append([])
|
||||
decode_top_logprobs.append([])
|
||||
continue
|
||||
k = input_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
|
||||
vs_cpu = t.values.cpu().tolist()
|
||||
ps_cpu = t.indices.cpu().tolist()
|
||||
vs_cpu = t.values.tolist()
|
||||
ps_cpu = t.indices.tolist()
|
||||
prefill_top_logprobs.append(
|
||||
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
||||
)
|
||||
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||
pt += extend_seq_lens_cpu[i]
|
||||
return prefill_top_logprobs, decode_top_logprobs
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
||||
@@ -99,20 +102,24 @@ class LogitsProcessor(nn.Module):
|
||||
all_logits = all_logits[:, : self.config.vocab_size]
|
||||
|
||||
all_logprobs = all_logits.float()
|
||||
all_logits = None
|
||||
del all_logits
|
||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||
|
||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||
all_logprobs, input_metadata
|
||||
)
|
||||
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||
if return_top_logprob:
|
||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||
all_logprobs, input_metadata
|
||||
)
|
||||
else:
|
||||
prefill_top_logprobs = decode_top_logprobs = None
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_logprobs = all_logprobs
|
||||
return last_logits, (
|
||||
None,
|
||||
None,
|
||||
decode_top_logprobs,
|
||||
None,
|
||||
decode_top_logprobs,
|
||||
last_logprobs,
|
||||
)
|
||||
else:
|
||||
@@ -131,9 +138,9 @@ class LogitsProcessor(nn.Module):
|
||||
)
|
||||
return last_logits, (
|
||||
prefill_token_logprobs,
|
||||
normalized_prompt_logprobs,
|
||||
prefill_top_logprobs,
|
||||
decode_top_logprobs,
|
||||
normalized_prompt_logprobs,
|
||||
last_logprobs,
|
||||
)
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ class GenerateReqInput:
|
||||
return_text_in_logprobs: bool = False
|
||||
# Whether to stream output
|
||||
stream: bool = False
|
||||
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
||||
|
||||
def post_init(self):
|
||||
is_single = isinstance(self.text, str)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from enum import IntEnum, auto
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
@@ -9,15 +9,15 @@ from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
|
||||
|
||||
class ForwardMode(Enum):
|
||||
class ForwardMode(IntEnum):
|
||||
PREFILL = auto()
|
||||
EXTEND = auto()
|
||||
DECODE = auto()
|
||||
|
||||
|
||||
class FinishReason(Enum):
|
||||
LENGTH = auto()
|
||||
class FinishReason(IntEnum):
|
||||
EOS_TOKEN = auto()
|
||||
LENGTH = auto()
|
||||
STOP_STR = auto()
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ class Req:
|
||||
# Since jump forward may retokenize the prompt with partial outputs,
|
||||
# we maintain the original prompt length to report the correct usage.
|
||||
self.prompt_tokens = len(input_ids)
|
||||
|
||||
# The number of decoded tokens for token usage report. Note that
|
||||
# this does not include the jump forward tokens.
|
||||
self.completion_tokens_wo_jump_forward = 0
|
||||
@@ -41,12 +42,11 @@ class Req:
|
||||
self.image_offset = 0
|
||||
self.pad_value = None
|
||||
|
||||
# Sampling parameters
|
||||
self.sampling_params = None
|
||||
self.return_logprob = False
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = 0
|
||||
self.stream = False
|
||||
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
self.finished = False
|
||||
self.finish_reason = None
|
||||
@@ -56,13 +56,17 @@ class Req:
|
||||
self.prefix_indices = []
|
||||
self.last_node = None
|
||||
|
||||
# Logprobs
|
||||
self.return_logprob = False
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = 0
|
||||
self.normalized_prompt_logprob = None
|
||||
self.prefill_token_logprobs = None
|
||||
self.decode_token_logprobs = None
|
||||
self.normalized_prompt_logprob = None
|
||||
self.prefill_top_logprobs = None
|
||||
self.decode_top_logprobs = None
|
||||
|
||||
# For constrained decoding
|
||||
# Constrained decoding
|
||||
self.regex_fsm = None
|
||||
self.regex_fsm_state = 0
|
||||
self.jump_forward_map = None
|
||||
@@ -165,8 +169,8 @@ class Batch:
|
||||
out_cache_cont_end: torch.Tensor = None
|
||||
|
||||
# for processing logprobs
|
||||
top_logprobs_nums: List[int] = None
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
# for multimodal
|
||||
pixel_values: List[torch.Tensor] = None
|
||||
@@ -321,8 +325,8 @@ class Batch:
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_np = self.seq_lens.cpu().numpy()
|
||||
req_pool_indices_np = self.req_pool_indices.cpu().numpy()
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||
while self.token_to_kv_pool.available_size() < len(self.reqs):
|
||||
idx = sorted_indices.pop()
|
||||
req = self.reqs[idx]
|
||||
@@ -338,8 +342,8 @@ class Batch:
|
||||
# TODO: apply more fine-grained retraction
|
||||
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_np[idx]
|
||||
][: seq_lens_np[idx]]
|
||||
req_pool_indices_cpu[idx]
|
||||
][: seq_lens_cpu[idx]]
|
||||
self.token_to_kv_pool.dec_refs(token_indices)
|
||||
|
||||
self.filter_batch(sorted_indices)
|
||||
@@ -363,7 +367,7 @@ class Batch:
|
||||
# insert the old request into tree_cache
|
||||
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
|
||||
if req_pool_indices_cpu is None:
|
||||
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
|
||||
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
||||
req_pool_idx = req_pool_indices_cpu[i]
|
||||
indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_idx, : len(token_ids_in_memory)
|
||||
|
||||
@@ -36,7 +36,9 @@ from sglang.srt.utils import (
|
||||
set_random_seed,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("model_rpc")
|
||||
vllm_default_logger.setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||
|
||||
|
||||
@@ -54,9 +56,6 @@ class ModelRpcServer:
|
||||
self.tp_size = server_args.tp_size
|
||||
self.schedule_heuristic = server_args.schedule_heuristic
|
||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||
vllm_default_logger.setLevel(
|
||||
level=getattr(logging, server_args.log_level.upper())
|
||||
)
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
@@ -65,7 +64,7 @@ class ModelRpcServer:
|
||||
context_length=server_args.context_length,
|
||||
)
|
||||
|
||||
# for model end global settings
|
||||
# For model end global settings
|
||||
server_args_dict = {
|
||||
"enable_flashinfer": server_args.enable_flashinfer,
|
||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||
@@ -164,7 +163,7 @@ class ModelRpcServer:
|
||||
logger.info("Cache flushed successfully!")
|
||||
else:
|
||||
warnings.warn(
|
||||
"Cache not flushed because there are pending requests. "
|
||||
f"Cache not flushed because there are pending requests. "
|
||||
f"#queue-req: {len(self.forward_queue)}, "
|
||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
)
|
||||
@@ -386,12 +385,12 @@ class ModelRpcServer:
|
||||
f"#running_req: {running_req}. "
|
||||
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
|
||||
)
|
||||
logger.debug(
|
||||
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
||||
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
||||
f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
||||
f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
||||
)
|
||||
#logger.debug(
|
||||
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
||||
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
||||
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
||||
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
||||
#)
|
||||
|
||||
new_batch = Batch.init_new(
|
||||
can_run_list,
|
||||
@@ -408,47 +407,41 @@ class ModelRpcServer:
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
)
|
||||
|
||||
prefill_token_logprobs = None
|
||||
if batch.extend_num_tokens != 0:
|
||||
# Forward
|
||||
logits, (
|
||||
prefill_token_logprobs,
|
||||
normalized_prompt_logprobs,
|
||||
prefill_top_logprobs,
|
||||
decode_top_logprobs,
|
||||
normalized_prompt_logprobs,
|
||||
last_logprobs,
|
||||
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
if prefill_token_logprobs is not None:
|
||||
prefill_token_logprobs = prefill_token_logprobs.cpu().tolist()
|
||||
normalized_prompt_logprobs = normalized_prompt_logprobs.cpu().tolist()
|
||||
prefill_token_logprobs = prefill_token_logprobs.tolist()
|
||||
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
|
||||
|
||||
next_token_ids, _ = batch.sample(logits)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
|
||||
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||
if last_logprobs is not None:
|
||||
last_token_logprobs = (
|
||||
last_logprobs[torch.arange(len(batch.reqs)), next_token_ids].tolist()
|
||||
)
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
(
|
||||
logits,
|
||||
prefill_token_logprobs,
|
||||
normalized_prompt_logprobs,
|
||||
last_logprobs,
|
||||
) = (None,) * 4
|
||||
|
||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||
reqs = batch.reqs
|
||||
last_token_logprobs = None
|
||||
if last_logprobs is not None:
|
||||
last_token_logprobs = (
|
||||
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
||||
)
|
||||
|
||||
# Check finish condition
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
for i, req in enumerate(batch.reqs):
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids = [next_token_ids[i]]
|
||||
req.check_finished()
|
||||
|
||||
if prefill_token_logprobs is not None:
|
||||
if req.return_logprob:
|
||||
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
||||
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||
req.prefill_token_logprobs = list(
|
||||
zip(
|
||||
@@ -463,12 +456,14 @@ class ModelRpcServer:
|
||||
req.decode_token_logprobs = [
|
||||
(last_token_logprobs[i], next_token_ids[i])
|
||||
]
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
||||
req.decode_top_logprobs = [decode_top_logprobs[i]]
|
||||
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
||||
pt += req.extend_input_len
|
||||
|
||||
pt += req.extend_input_len
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
@@ -520,29 +515,29 @@ class ModelRpcServer:
|
||||
logits, (
|
||||
_,
|
||||
_,
|
||||
decode_top_logprobs,
|
||||
_,
|
||||
decode_top_logprobs,
|
||||
last_logprobs,
|
||||
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, _ = batch.sample(logits)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
|
||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||
reqs = batch.reqs
|
||||
new_token_logprobs = None
|
||||
if last_logprobs is not None:
|
||||
new_token_logprobs = last_logprobs[
|
||||
torch.arange(len(reqs)), next_token_ids
|
||||
torch.arange(len(batch.reqs)), next_token_ids
|
||||
].tolist()
|
||||
|
||||
# Check finish condition
|
||||
for i, (req, next_token_id) in enumerate(zip(reqs, next_token_ids)):
|
||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_id)
|
||||
req.check_finished()
|
||||
|
||||
if new_token_logprobs is not None:
|
||||
if req.return_logprob:
|
||||
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
@@ -590,8 +585,7 @@ class ModelRpcServer:
|
||||
+ len(req.output_ids)
|
||||
- req.prompt_tokens,
|
||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||
"finish_reason": str(req.finish_reason),
|
||||
"hit_stop_str": req.hit_stop_str,
|
||||
"finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
|
||||
}
|
||||
if req.return_logprob:
|
||||
(
|
||||
@@ -628,7 +622,7 @@ class ModelRpcServer:
|
||||
# Remove finished reqs
|
||||
if finished_indices:
|
||||
# Update radix cache
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
|
||||
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
||||
for i in finished_indices:
|
||||
req = batch.reqs[i]
|
||||
req_pool_idx = req_pool_indices_cpu[i]
|
||||
|
||||
@@ -29,7 +29,7 @@ QUANTIZATION_CONFIG_MAPPING = {
|
||||
logger = logging.getLogger("model_runner")
|
||||
|
||||
# for server args in model endpoints
|
||||
global_server_args_dict: dict = None
|
||||
global_server_args_dict = {}
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@@ -86,8 +86,8 @@ class InputMetadata:
|
||||
out_cache_cont_end: torch.Tensor = None
|
||||
|
||||
other_kv_index: torch.Tensor = None
|
||||
top_logprobs_nums: List[int] = None
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
# for flashinfer
|
||||
qo_indptr: torch.Tensor = None
|
||||
@@ -107,18 +107,20 @@ class InputMetadata:
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
self.kv_indices = torch.cat(
|
||||
[
|
||||
self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices[i].item(), : self.seq_lens[i].item()
|
||||
req_pool_indices_cpu[i]: seq_lens_cpu[i]
|
||||
]
|
||||
for i in range(self.batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
workspace_buffer = torch.empty(
|
||||
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
||||
@@ -195,15 +197,15 @@ class InputMetadata:
|
||||
req_pool_indices[0], seq_lens[0] - 1
|
||||
].item()
|
||||
else:
|
||||
seq_lens_np = seq_lens.cpu().numpy()
|
||||
prefix_lens_np = prefix_lens.cpu().numpy()
|
||||
position_ids_offsets_np = position_ids_offsets.cpu().numpy()
|
||||
seq_lens_cpu = seq_lens.cpu().numpy()
|
||||
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
||||
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
||||
positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(
|
||||
prefix_lens_np[i] + position_ids_offsets_np[i],
|
||||
seq_lens_np[i] + position_ids_offsets_np[i],
|
||||
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
)
|
||||
for i in range(batch_size)
|
||||
],
|
||||
@@ -229,9 +231,9 @@ class InputMetadata:
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=out_cache_cont_start,
|
||||
out_cache_cont_end=out_cache_cont_end,
|
||||
top_logprobs_nums=top_logprobs_nums,
|
||||
return_logprob=return_logprob,
|
||||
other_kv_index=other_kv_index,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_nums=top_logprobs_nums,
|
||||
)
|
||||
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
|
||||
@@ -185,7 +185,10 @@ class TokenizerManager:
|
||||
|
||||
while True:
|
||||
await event.wait()
|
||||
yield state.out_list[-1]
|
||||
yield self.convert_logprob_style(state.out_list[-1],
|
||||
obj.return_logprob,
|
||||
obj.top_logprobs_num,
|
||||
obj.return_text_in_logprobs)
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
del self.rid_to_state[rid]
|
||||
@@ -231,16 +234,16 @@ class TokenizerManager:
|
||||
rid = obj.rid[i]
|
||||
state = self.rid_to_state[rid]
|
||||
await state.event.wait()
|
||||
output_list.append(state.out_list[-1])
|
||||
output_list.append(
|
||||
self.convert_logprob_style(state.out_list[-1],
|
||||
obj.return_logprob[i],
|
||||
obj.top_logprobs_num[i],
|
||||
obj.return_text_in_logprobs))
|
||||
assert state.finished
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
yield output_list
|
||||
|
||||
async def detokenize(self, obj: DetokenizeReqInput):
|
||||
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
|
||||
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
|
||||
|
||||
async def flush_cache(self):
|
||||
flush_cache_req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(flush_cache_req)
|
||||
@@ -267,3 +270,37 @@ class TokenizerManager:
|
||||
state.event.set()
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj}")
|
||||
|
||||
def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
|
||||
if return_logprob:
|
||||
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
if top_logprobs_num > 0:
|
||||
ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
return ret
|
||||
|
||||
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
||||
if not decode_to_text:
|
||||
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
||||
|
||||
token_ids = [tid for _, tid in token_logprobs]
|
||||
token_texts = self.tokenizer.batch_decode(token_ids)
|
||||
return [
|
||||
(logprob, token_id, token_text)
|
||||
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
||||
]
|
||||
|
||||
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
|
||||
for i, t in enumerate(top_logprobs):
|
||||
if t:
|
||||
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
||||
return top_logprobs
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
"""pydantic models for OpenAI API protocol"""
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import threading
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
|
||||
# Fix a Python bug
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
import aiohttp
|
||||
@@ -53,10 +53,10 @@ from sglang.srt.managers.router.manager import start_router_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
allocate_init_ports,
|
||||
jsonify_pydantic_model,
|
||||
assert_pkg_version,
|
||||
enable_show_time_cost,
|
||||
jsonify_pydantic_model,
|
||||
get_exception_traceback,
|
||||
API_KEY_HEADER_NAME,
|
||||
APIKeyValidatorMiddleware
|
||||
@@ -99,12 +99,6 @@ async def flush_cache():
|
||||
)
|
||||
|
||||
|
||||
async def stream_generator(obj: GenerateReqInput):
|
||||
async for out in tokenizer_manager.generate_request(obj):
|
||||
await handle_token_logprobs_results(obj, out)
|
||||
yield out
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_request(obj: GenerateReqInput):
|
||||
obj.post_init()
|
||||
@@ -112,69 +106,16 @@ async def generate_request(obj: GenerateReqInput):
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results():
|
||||
async for out in stream_generator(obj):
|
||||
async for out in tokenizer_manager.generate_request(obj):
|
||||
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||
|
||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||
await handle_token_logprobs_results(obj, ret)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
async def detokenize_logprob_tokens(token_logprobs, decode_to_text):
|
||||
if not decode_to_text:
|
||||
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
||||
|
||||
token_ids = [tid for _, tid in token_logprobs]
|
||||
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
||||
return [
|
||||
(logprob, token_id, token_text)
|
||||
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
||||
]
|
||||
|
||||
|
||||
async def detokenize_top_logprobs_tokens(top_logprobs, decode_to_text):
|
||||
for i, t in enumerate(top_logprobs):
|
||||
if top_logprobs[i] is not None:
|
||||
top_logprobs[i] = await detokenize_logprob_tokens(t, decode_to_text)
|
||||
return top_logprobs
|
||||
|
||||
|
||||
async def handle_token_logprobs_results(obj: GenerateReqInput, ret):
|
||||
"""Handle the token logprobs results, convert token ids to text if needed.
|
||||
|
||||
Args:
|
||||
obj (GenerateReqInput): The request object.
|
||||
ret (Union[Dict, List[Dict]]): The response object.
|
||||
"""
|
||||
# NOTE: This is because the multiple requests in one http request.
|
||||
|
||||
async def convert_style(r, return_text):
|
||||
r["meta_info"]["prefill_token_logprobs"] = await detokenize_logprob_tokens(
|
||||
r["meta_info"]["prefill_token_logprobs"], return_text
|
||||
)
|
||||
r["meta_info"]["decode_token_logprobs"] = await detokenize_logprob_tokens(
|
||||
r["meta_info"]["decode_token_logprobs"], return_text
|
||||
)
|
||||
r["meta_info"]["prefill_top_logprobs"] = await detokenize_top_logprobs_tokens(
|
||||
r["meta_info"]["prefill_top_logprobs"], return_text
|
||||
)
|
||||
r["meta_info"]["decode_top_logprobs"] = await detokenize_top_logprobs_tokens(
|
||||
r["meta_info"]["decode_top_logprobs"], return_text
|
||||
)
|
||||
|
||||
if isinstance(obj.text, str):
|
||||
if obj.return_logprob:
|
||||
await convert_style(ret, obj.return_text_in_logprobs)
|
||||
else:
|
||||
for i, r in enumerate(ret):
|
||||
if obj.return_logprob[i]:
|
||||
await convert_style(r, obj.return_text_in_logprobs)
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def v1_completions(raw_request: Request):
|
||||
request_json = await raw_request.json()
|
||||
@@ -203,10 +144,10 @@ async def v1_completions(raw_request: Request):
|
||||
|
||||
if adapted_request.stream:
|
||||
|
||||
async def gnerate_stream_resp():
|
||||
async def generate_stream_resp():
|
||||
stream_buffer = ""
|
||||
n_prev_token = 0
|
||||
async for content in stream_generator(adapted_request):
|
||||
async for content in tokenizer_manager.generate_request(adapted_request):
|
||||
text = content["text"]
|
||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||
@@ -266,7 +207,7 @@ async def v1_completions(raw_request: Request):
|
||||
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
|
||||
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
|
||||
|
||||
# Non-streaming response.
|
||||
ret = await generate_request(adapted_request)
|
||||
@@ -384,7 +325,7 @@ async def v1_chat_completions(raw_request: Request):
|
||||
is_first = True
|
||||
|
||||
stream_buffer = ""
|
||||
async for content in stream_generator(adapted_request):
|
||||
async for content in tokenizer_manager.generate_request(adapted_request):
|
||||
if is_first:
|
||||
# First chunk with role
|
||||
is_first = False
|
||||
|
||||
@@ -241,7 +241,7 @@ class ServerArgs:
|
||||
def print_mode_args(self):
|
||||
return (
|
||||
f"enable_flashinfer={self.enable_flashinfer}, "
|
||||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
|
||||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
|
||||
f"disable_radix_cache={self.disable_radix_cache}, "
|
||||
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
||||
f"disable_disk_cache={self.disable_disk_cache}, "
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Common utilities."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import random
|
||||
@@ -13,6 +15,7 @@ import numpy as np
|
||||
import pydantic
|
||||
import requests
|
||||
import torch
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
@@ -303,6 +306,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
# FIXME: Remove this once we drop support for pydantic 1.x
|
||||
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
||||
|
||||
@@ -310,4 +314,4 @@ IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
||||
def jsonify_pydantic_model(obj: BaseModel):
|
||||
if IS_PYDANTIC_1:
|
||||
return obj.json(ensure_ascii=False)
|
||||
return obj.model_dump_json()
|
||||
return obj.model_dump_json()
|
||||
@@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True):
|
||||
def test_image_qa():
|
||||
@sgl.function
|
||||
def image_qa(s, question):
|
||||
s += sgl.user(sgl.image("test_image.png") + question)
|
||||
s += sgl.user(sgl.image("example_image.png") + question)
|
||||
s += sgl.assistant(sgl.gen("answer"))
|
||||
|
||||
state = image_qa.run(
|
||||
|
||||
BIN
test/lang/example_image.png
Normal file
BIN
test/lang/example_image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 56 KiB |
@@ -28,7 +28,7 @@ class TestOpenAIBackend(unittest.TestCase):
|
||||
if cls.backend is None:
|
||||
cls.backend = OpenAI("gpt-3.5-turbo-instruct")
|
||||
cls.chat_backend = OpenAI("gpt-3.5-turbo")
|
||||
cls.chat_vision_backend = OpenAI("gpt-4-vision-preview")
|
||||
cls.chat_vision_backend = OpenAI("gpt-4-turbo")
|
||||
|
||||
def test_few_shot_qa(self):
|
||||
set_default_backend(self.backend)
|
||||
@@ -88,14 +88,3 @@ if __name__ == "__main__":
|
||||
# t = TestOpenAIBackend()
|
||||
# t.setUp()
|
||||
# t.test_few_shot_qa()
|
||||
# t.test_mt_bench()
|
||||
# t.test_select()
|
||||
# t.test_decode_int()
|
||||
# t.test_decode_json()
|
||||
# t.test_expert_answer()
|
||||
# t.test_tool_use()
|
||||
# t.test_react()
|
||||
# t.test_parallel_decoding()
|
||||
# t.test_parallel_encoding()
|
||||
# t.test_image_qa()
|
||||
# t.test_stream()
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
from sglang import OpenAI, function, gen, set_default_backend
|
||||
|
||||
|
||||
@function()
|
||||
def gen_character_default(s):
|
||||
s += "Construct a character within the following format:\n"
|
||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
|
||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
||||
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
|
||||
|
||||
|
||||
@function(api_num_spec_tokens=512)
|
||||
def gen_character_spec(s):
|
||||
s += "Construct a character within the following format:\n"
|
||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
|
||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
||||
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
|
||||
|
||||
|
||||
@function(api_num_spec_tokens=512)
|
||||
def gen_character_no_stop(s):
|
||||
s += "Construct a character within the following format:\n"
|
||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
|
||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||
s += "Name:" + gen("name") + "\nBirthday:" + gen("birthday")
|
||||
s += "\nJob:" + gen("job") + "\nWelcome.\n"
|
||||
|
||||
|
||||
@function(api_num_spec_tokens=512)
|
||||
def gen_character_multi_stop(s):
|
||||
s += "Construct a character within the following format:\n"
|
||||
s += (
|
||||
"Name: Steve Jobs.###Birthday: February 24, 1955.###Job: Apple CEO.\nWelcome.\n"
|
||||
)
|
||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||
s += "Name:" + gen("name", stop=["\n", "###"])
|
||||
s += "###Birthday:" + gen("birthday", stop=["\n", "###"])
|
||||
s += "###Job:" + gen("job", stop=["\n", "###"]) + "\nWelcome.\n"
|
||||
|
||||
|
||||
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
|
||||
|
||||
state = gen_character_default.run()
|
||||
print(state.text())
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
state = gen_character_no_stop.run()
|
||||
|
||||
print("name###", state["name"])
|
||||
print("birthday###:", state["birthday"])
|
||||
print("job###", state["job"])
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
state = gen_character_multi_stop.run()
|
||||
print(state.text())
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
state = gen_character_spec.run()
|
||||
print(state.text())
|
||||
|
||||
print("name###", state["name"])
|
||||
print("birthday###", state["birthday"])
|
||||
print("job###", state["job"])
|
||||
1
test/srt/example_image.png
Symbolic link
1
test/srt/example_image.png
Symbolic link
@@ -0,0 +1 @@
|
||||
../lang/example_image.png
|
||||
@@ -3,7 +3,6 @@ Usage:
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
python3 test_httpserver_decode.py
|
||||
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
@@ -23,6 +22,7 @@ def test_decode(url, return_logprob, top_logprobs_num, return_text):
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": return_text,
|
||||
|
||||
@@ -26,6 +26,7 @@ def test_decode_stream(url, return_logprob, top_logprobs_num):
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": True,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@@ -34,7 +34,7 @@ async def test_concurrent(args):
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
||||
"image_data": "test_image.png",
|
||||
"image_data": "example_image.png",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
@@ -55,7 +55,7 @@ def test_streaming(args):
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
||||
"image_data": "test_image.png",
|
||||
"image_data": "example_image.png",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 128,
|
||||
|
||||
@@ -6,10 +6,10 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
|
||||
@@ -163,7 +163,7 @@ def test_regex(args):
|
||||
regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "name": "[\w]+",\n"""
|
||||
+ r""" "population": "[\w\d\s]+"\n"""
|
||||
+ r""" "population": [\w\d\s]+\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user