Fix logit processor bugs (#427)
This commit is contained in:
@@ -297,7 +297,6 @@ curl http://localhost:30000/generate \
|
|||||||
Learn more about the argument format [here](docs/sampling_params.md).
|
Learn more about the argument format [here](docs/sampling_params.md).
|
||||||
|
|
||||||
### OpenAI Compatible API
|
### OpenAI Compatible API
|
||||||
|
|
||||||
In addition, the server supports an experimental OpenAI-compatible API.
|
In addition, the server supports an experimental OpenAI-compatible API.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -386,7 +385,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|||||||
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md).
|
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md).
|
||||||
|
|
||||||
## Benchmark And Performance
|
## Benchmark And Performance
|
||||||
|
|
||||||
- Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
|
- Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
|
||||||

|

|
||||||
|
|
||||||
@@ -410,7 +408,4 @@ https://github.com/sgl-project/sglang/issues/157
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
[](https://huggingface.co/papers/2312.07104)
|
|
||||||
|
|
||||||
|
|
||||||
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).
|
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Some Public API Definitions"""
|
"""Some Public API Definitions"""
|
||||||
|
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
@@ -31,6 +32,7 @@ def function(
|
|||||||
|
|
||||||
def Runtime(*args, **kwargs):
|
def Runtime(*args, **kwargs):
|
||||||
# Avoid importing unnecessary dependency
|
# Avoid importing unnecessary dependency
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
from sglang.srt.server import Runtime
|
from sglang.srt.server import Runtime
|
||||||
|
|
||||||
return Runtime(*args, **kwargs)
|
return Runtime(*args, **kwargs)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ except ImportError as e:
|
|||||||
|
|
||||||
|
|
||||||
class Anthropic(BaseBackend):
|
class Anthropic(BaseBackend):
|
||||||
def __init__(self, model_name):
|
def __init__(self, model_name, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if isinstance(anthropic, Exception):
|
if isinstance(anthropic, Exception):
|
||||||
@@ -22,6 +22,7 @@ class Anthropic(BaseBackend):
|
|||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.chat_template = get_chat_template("claude")
|
self.chat_template = get_chat_template("claude")
|
||||||
|
self.client = anthropic.Anthropic(*args, **kwargs)
|
||||||
|
|
||||||
def get_chat_template(self):
|
def get_chat_template(self):
|
||||||
return self.chat_template
|
return self.chat_template
|
||||||
@@ -41,7 +42,7 @@ class Anthropic(BaseBackend):
|
|||||||
else:
|
else:
|
||||||
system = ""
|
system = ""
|
||||||
|
|
||||||
ret = anthropic.Anthropic().messages.create(
|
ret = self.client.messages.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
system=system,
|
system=system,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@@ -66,7 +67,7 @@ class Anthropic(BaseBackend):
|
|||||||
else:
|
else:
|
||||||
system = ""
|
system = ""
|
||||||
|
|
||||||
with anthropic.Anthropic().messages.stream(
|
with self.client.messages.stream(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
system=system,
|
system=system,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|||||||
@@ -228,7 +228,7 @@ class OpenAI(BaseBackend):
|
|||||||
prompt_tokens.append(ret_token)
|
prompt_tokens.append(ret_token)
|
||||||
|
|
||||||
decision = choices[np.argmax(scores)]
|
decision = choices[np.argmax(scores)]
|
||||||
return decision, scores, scores
|
return decision, scores, None, None
|
||||||
|
|
||||||
|
|
||||||
def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
|
||||||
|
|||||||
@@ -220,7 +220,6 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
"sampling_params": {"max_new_tokens": 0},
|
"sampling_params": {"max_new_tokens": 0},
|
||||||
"return_logprob": True,
|
"return_logprob": True,
|
||||||
"logprob_start_len": max(prompt_len - 2, 0),
|
"logprob_start_len": max(prompt_len - 2, 0),
|
||||||
"return_text_in_logprobs": True,
|
|
||||||
}
|
}
|
||||||
self._add_images(s, data)
|
self._add_images(s, data)
|
||||||
res = http_request(
|
res = http_request(
|
||||||
|
|||||||
@@ -42,26 +42,29 @@ class LogitsProcessor(nn.Module):
|
|||||||
for i in range(all_logprobs.shape[0]):
|
for i in range(all_logprobs.shape[0]):
|
||||||
k = input_metadata.top_logprobs_nums[i]
|
k = input_metadata.top_logprobs_nums[i]
|
||||||
t = all_logprobs[i].topk(k)
|
t = all_logprobs[i].topk(k)
|
||||||
v_cpu = t.values.cpu().tolist()
|
v_cpu = t.values.tolist()
|
||||||
p_cpu = t.indices.cpu().tolist()
|
p_cpu = t.indices.tolist()
|
||||||
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
||||||
return None, decode_top_logprobs
|
return None, decode_top_logprobs
|
||||||
else:
|
else:
|
||||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||||
pt = 0
|
pt = 0
|
||||||
# NOTE: the GPU-CPU overhead can be reduced
|
# NOTE: the GPU-CPU overhead can be reduced
|
||||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens
|
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
|
||||||
for i in range(len(input_metadata.extend_seq_lens)):
|
for i in range(len(extend_seq_lens_cpu)):
|
||||||
if extend_seq_lens_cpu[i] == 0:
|
if extend_seq_lens_cpu[i] == 0:
|
||||||
|
prefill_top_logprobs.append([])
|
||||||
|
decode_top_logprobs.append([])
|
||||||
continue
|
continue
|
||||||
k = input_metadata.top_logprobs_nums[i]
|
k = input_metadata.top_logprobs_nums[i]
|
||||||
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
|
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
|
||||||
vs_cpu = t.values.cpu().tolist()
|
vs_cpu = t.values.tolist()
|
||||||
ps_cpu = t.indices.cpu().tolist()
|
ps_cpu = t.indices.tolist()
|
||||||
prefill_top_logprobs.append(
|
prefill_top_logprobs.append(
|
||||||
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
||||||
)
|
)
|
||||||
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||||
|
pt += extend_seq_lens_cpu[i]
|
||||||
return prefill_top_logprobs, decode_top_logprobs
|
return prefill_top_logprobs, decode_top_logprobs
|
||||||
|
|
||||||
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
||||||
@@ -99,20 +102,24 @@ class LogitsProcessor(nn.Module):
|
|||||||
all_logits = all_logits[:, : self.config.vocab_size]
|
all_logits = all_logits[:, : self.config.vocab_size]
|
||||||
|
|
||||||
all_logprobs = all_logits.float()
|
all_logprobs = all_logits.float()
|
||||||
all_logits = None
|
del all_logits
|
||||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||||
|
|
||||||
|
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||||
|
if return_top_logprob:
|
||||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||||
all_logprobs, input_metadata
|
all_logprobs, input_metadata
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
prefill_top_logprobs = decode_top_logprobs = None
|
||||||
|
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
last_logprobs = all_logprobs
|
last_logprobs = all_logprobs
|
||||||
return last_logits, (
|
return last_logits, (
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
decode_top_logprobs,
|
|
||||||
None,
|
None,
|
||||||
|
decode_top_logprobs,
|
||||||
last_logprobs,
|
last_logprobs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -131,9 +138,9 @@ class LogitsProcessor(nn.Module):
|
|||||||
)
|
)
|
||||||
return last_logits, (
|
return last_logits, (
|
||||||
prefill_token_logprobs,
|
prefill_token_logprobs,
|
||||||
|
normalized_prompt_logprobs,
|
||||||
prefill_top_logprobs,
|
prefill_top_logprobs,
|
||||||
decode_top_logprobs,
|
decode_top_logprobs,
|
||||||
normalized_prompt_logprobs,
|
|
||||||
last_logprobs,
|
last_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ class GenerateReqInput:
|
|||||||
return_text_in_logprobs: bool = False
|
return_text_in_logprobs: bool = False
|
||||||
# Whether to stream output
|
# Whether to stream output
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
is_single = isinstance(self.text, str)
|
is_single = isinstance(self.text, str)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -9,15 +9,15 @@ from sglang.srt.managers.router.radix_cache import RadixCache
|
|||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(Enum):
|
class ForwardMode(IntEnum):
|
||||||
PREFILL = auto()
|
PREFILL = auto()
|
||||||
EXTEND = auto()
|
EXTEND = auto()
|
||||||
DECODE = auto()
|
DECODE = auto()
|
||||||
|
|
||||||
|
|
||||||
class FinishReason(Enum):
|
class FinishReason(IntEnum):
|
||||||
LENGTH = auto()
|
|
||||||
EOS_TOKEN = auto()
|
EOS_TOKEN = auto()
|
||||||
|
LENGTH = auto()
|
||||||
STOP_STR = auto()
|
STOP_STR = auto()
|
||||||
|
|
||||||
|
|
||||||
@@ -31,6 +31,7 @@ class Req:
|
|||||||
# Since jump forward may retokenize the prompt with partial outputs,
|
# Since jump forward may retokenize the prompt with partial outputs,
|
||||||
# we maintain the original prompt length to report the correct usage.
|
# we maintain the original prompt length to report the correct usage.
|
||||||
self.prompt_tokens = len(input_ids)
|
self.prompt_tokens = len(input_ids)
|
||||||
|
|
||||||
# The number of decoded tokens for token usage report. Note that
|
# The number of decoded tokens for token usage report. Note that
|
||||||
# this does not include the jump forward tokens.
|
# this does not include the jump forward tokens.
|
||||||
self.completion_tokens_wo_jump_forward = 0
|
self.completion_tokens_wo_jump_forward = 0
|
||||||
@@ -41,12 +42,11 @@ class Req:
|
|||||||
self.image_offset = 0
|
self.image_offset = 0
|
||||||
self.pad_value = None
|
self.pad_value = None
|
||||||
|
|
||||||
|
# Sampling parameters
|
||||||
self.sampling_params = None
|
self.sampling_params = None
|
||||||
self.return_logprob = False
|
|
||||||
self.logprob_start_len = 0
|
|
||||||
self.top_logprobs_num = 0
|
|
||||||
self.stream = False
|
self.stream = False
|
||||||
|
|
||||||
|
# Check finish
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.finished = False
|
self.finished = False
|
||||||
self.finish_reason = None
|
self.finish_reason = None
|
||||||
@@ -56,13 +56,17 @@ class Req:
|
|||||||
self.prefix_indices = []
|
self.prefix_indices = []
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
|
|
||||||
|
# Logprobs
|
||||||
|
self.return_logprob = False
|
||||||
|
self.logprob_start_len = 0
|
||||||
|
self.top_logprobs_num = 0
|
||||||
|
self.normalized_prompt_logprob = None
|
||||||
self.prefill_token_logprobs = None
|
self.prefill_token_logprobs = None
|
||||||
self.decode_token_logprobs = None
|
self.decode_token_logprobs = None
|
||||||
self.normalized_prompt_logprob = None
|
|
||||||
self.prefill_top_logprobs = None
|
self.prefill_top_logprobs = None
|
||||||
self.decode_top_logprobs = None
|
self.decode_top_logprobs = None
|
||||||
|
|
||||||
# For constrained decoding
|
# Constrained decoding
|
||||||
self.regex_fsm = None
|
self.regex_fsm = None
|
||||||
self.regex_fsm_state = 0
|
self.regex_fsm_state = 0
|
||||||
self.jump_forward_map = None
|
self.jump_forward_map = None
|
||||||
@@ -165,8 +169,8 @@ class Batch:
|
|||||||
out_cache_cont_end: torch.Tensor = None
|
out_cache_cont_end: torch.Tensor = None
|
||||||
|
|
||||||
# for processing logprobs
|
# for processing logprobs
|
||||||
top_logprobs_nums: List[int] = None
|
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
|
top_logprobs_nums: List[int] = None
|
||||||
|
|
||||||
# for multimodal
|
# for multimodal
|
||||||
pixel_values: List[torch.Tensor] = None
|
pixel_values: List[torch.Tensor] = None
|
||||||
@@ -321,8 +325,8 @@ class Batch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
retracted_reqs = []
|
retracted_reqs = []
|
||||||
seq_lens_np = self.seq_lens.cpu().numpy()
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||||
req_pool_indices_np = self.req_pool_indices.cpu().numpy()
|
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||||
while self.token_to_kv_pool.available_size() < len(self.reqs):
|
while self.token_to_kv_pool.available_size() < len(self.reqs):
|
||||||
idx = sorted_indices.pop()
|
idx = sorted_indices.pop()
|
||||||
req = self.reqs[idx]
|
req = self.reqs[idx]
|
||||||
@@ -338,8 +342,8 @@ class Batch:
|
|||||||
# TODO: apply more fine-grained retraction
|
# TODO: apply more fine-grained retraction
|
||||||
|
|
||||||
token_indices = self.req_to_token_pool.req_to_token[
|
token_indices = self.req_to_token_pool.req_to_token[
|
||||||
req_pool_indices_np[idx]
|
req_pool_indices_cpu[idx]
|
||||||
][: seq_lens_np[idx]]
|
][: seq_lens_cpu[idx]]
|
||||||
self.token_to_kv_pool.dec_refs(token_indices)
|
self.token_to_kv_pool.dec_refs(token_indices)
|
||||||
|
|
||||||
self.filter_batch(sorted_indices)
|
self.filter_batch(sorted_indices)
|
||||||
@@ -363,7 +367,7 @@ class Batch:
|
|||||||
# insert the old request into tree_cache
|
# insert the old request into tree_cache
|
||||||
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
|
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
|
||||||
if req_pool_indices_cpu is None:
|
if req_pool_indices_cpu is None:
|
||||||
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
|
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
||||||
req_pool_idx = req_pool_indices_cpu[i]
|
req_pool_idx = req_pool_indices_cpu[i]
|
||||||
indices = self.req_to_token_pool.req_to_token[
|
indices = self.req_to_token_pool.req_to_token[
|
||||||
req_pool_idx, : len(token_ids_in_memory)
|
req_pool_idx, : len(token_ids_in_memory)
|
||||||
|
|||||||
@@ -36,7 +36,9 @@ from sglang.srt.utils import (
|
|||||||
set_random_seed,
|
set_random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("model_rpc")
|
logger = logging.getLogger("model_rpc")
|
||||||
|
vllm_default_logger.setLevel(logging.WARN)
|
||||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||||
|
|
||||||
|
|
||||||
@@ -54,9 +56,6 @@ class ModelRpcServer:
|
|||||||
self.tp_size = server_args.tp_size
|
self.tp_size = server_args.tp_size
|
||||||
self.schedule_heuristic = server_args.schedule_heuristic
|
self.schedule_heuristic = server_args.schedule_heuristic
|
||||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||||
vllm_default_logger.setLevel(
|
|
||||||
level=getattr(logging, server_args.log_level.upper())
|
|
||||||
)
|
|
||||||
|
|
||||||
# Init model and tokenizer
|
# Init model and tokenizer
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig(
|
||||||
@@ -65,7 +64,7 @@ class ModelRpcServer:
|
|||||||
context_length=server_args.context_length,
|
context_length=server_args.context_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
# for model end global settings
|
# For model end global settings
|
||||||
server_args_dict = {
|
server_args_dict = {
|
||||||
"enable_flashinfer": server_args.enable_flashinfer,
|
"enable_flashinfer": server_args.enable_flashinfer,
|
||||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||||
@@ -164,7 +163,7 @@ class ModelRpcServer:
|
|||||||
logger.info("Cache flushed successfully!")
|
logger.info("Cache flushed successfully!")
|
||||||
else:
|
else:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Cache not flushed because there are pending requests. "
|
f"Cache not flushed because there are pending requests. "
|
||||||
f"#queue-req: {len(self.forward_queue)}, "
|
f"#queue-req: {len(self.forward_queue)}, "
|
||||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||||
)
|
)
|
||||||
@@ -386,12 +385,12 @@ class ModelRpcServer:
|
|||||||
f"#running_req: {running_req}. "
|
f"#running_req: {running_req}. "
|
||||||
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
|
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
|
||||||
)
|
)
|
||||||
logger.debug(
|
#logger.debug(
|
||||||
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
||||||
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
||||||
f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
||||||
f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
||||||
)
|
#)
|
||||||
|
|
||||||
new_batch = Batch.init_new(
|
new_batch = Batch.init_new(
|
||||||
can_run_list,
|
can_run_list,
|
||||||
@@ -408,47 +407,41 @@ class ModelRpcServer:
|
|||||||
self.model_config.vocab_size, self.int_token_logit_bias
|
self.model_config.vocab_size, self.int_token_logit_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill_token_logprobs = None
|
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
# Forward
|
# Forward
|
||||||
logits, (
|
logits, (
|
||||||
prefill_token_logprobs,
|
prefill_token_logprobs,
|
||||||
|
normalized_prompt_logprobs,
|
||||||
prefill_top_logprobs,
|
prefill_top_logprobs,
|
||||||
decode_top_logprobs,
|
decode_top_logprobs,
|
||||||
normalized_prompt_logprobs,
|
|
||||||
last_logprobs,
|
last_logprobs,
|
||||||
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||||
if prefill_token_logprobs is not None:
|
if prefill_token_logprobs is not None:
|
||||||
prefill_token_logprobs = prefill_token_logprobs.cpu().tolist()
|
prefill_token_logprobs = prefill_token_logprobs.tolist()
|
||||||
normalized_prompt_logprobs = normalized_prompt_logprobs.cpu().tolist()
|
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
|
||||||
|
|
||||||
next_token_ids, _ = batch.sample(logits)
|
next_token_ids, _ = batch.sample(logits)
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
|
||||||
else:
|
|
||||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
|
||||||
(
|
|
||||||
logits,
|
|
||||||
prefill_token_logprobs,
|
|
||||||
normalized_prompt_logprobs,
|
|
||||||
last_logprobs,
|
|
||||||
) = (None,) * 4
|
|
||||||
|
|
||||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||||
reqs = batch.reqs
|
|
||||||
last_token_logprobs = None
|
|
||||||
if last_logprobs is not None:
|
if last_logprobs is not None:
|
||||||
last_token_logprobs = (
|
last_token_logprobs = (
|
||||||
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
last_logprobs[torch.arange(len(batch.reqs)), next_token_ids].tolist()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
next_token_ids = next_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids = [next_token_ids[i]]
|
req.output_ids = [next_token_ids[i]]
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if prefill_token_logprobs is not None:
|
if req.return_logprob:
|
||||||
|
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
||||||
|
|
||||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||||
req.prefill_token_logprobs = list(
|
req.prefill_token_logprobs = list(
|
||||||
zip(
|
zip(
|
||||||
@@ -463,11 +456,13 @@ class ModelRpcServer:
|
|||||||
req.decode_token_logprobs = [
|
req.decode_token_logprobs = [
|
||||||
(last_token_logprobs[i], next_token_ids[i])
|
(last_token_logprobs[i], next_token_ids[i])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if req.top_logprobs_num > 0:
|
||||||
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
||||||
if req.logprob_start_len == 0:
|
if req.logprob_start_len == 0:
|
||||||
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
||||||
req.decode_top_logprobs = [decode_top_logprobs[i]]
|
req.decode_top_logprobs = [decode_top_logprobs[i]]
|
||||||
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
|
||||||
pt += req.extend_input_len
|
pt += req.extend_input_len
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
@@ -520,29 +515,29 @@ class ModelRpcServer:
|
|||||||
logits, (
|
logits, (
|
||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
decode_top_logprobs,
|
|
||||||
_,
|
_,
|
||||||
|
decode_top_logprobs,
|
||||||
last_logprobs,
|
last_logprobs,
|
||||||
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||||
next_token_ids, _ = batch.sample(logits)
|
next_token_ids, _ = batch.sample(logits)
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
|
|
||||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||||
reqs = batch.reqs
|
|
||||||
new_token_logprobs = None
|
|
||||||
if last_logprobs is not None:
|
if last_logprobs is not None:
|
||||||
new_token_logprobs = last_logprobs[
|
new_token_logprobs = last_logprobs[
|
||||||
torch.arange(len(reqs)), next_token_ids
|
torch.arange(len(batch.reqs)), next_token_ids
|
||||||
].tolist()
|
].tolist()
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
for i, (req, next_token_id) in enumerate(zip(reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids.append(next_token_id)
|
req.output_ids.append(next_token_id)
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if new_token_logprobs is not None:
|
if req.return_logprob:
|
||||||
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
|
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
|
||||||
|
|
||||||
|
if req.top_logprobs_num > 0:
|
||||||
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
@@ -590,8 +585,7 @@ class ModelRpcServer:
|
|||||||
+ len(req.output_ids)
|
+ len(req.output_ids)
|
||||||
- req.prompt_tokens,
|
- req.prompt_tokens,
|
||||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||||
"finish_reason": str(req.finish_reason),
|
"finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
|
||||||
"hit_stop_str": req.hit_stop_str,
|
|
||||||
}
|
}
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
(
|
(
|
||||||
@@ -628,7 +622,7 @@ class ModelRpcServer:
|
|||||||
# Remove finished reqs
|
# Remove finished reqs
|
||||||
if finished_indices:
|
if finished_indices:
|
||||||
# Update radix cache
|
# Update radix cache
|
||||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
|
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
||||||
for i in finished_indices:
|
for i in finished_indices:
|
||||||
req = batch.reqs[i]
|
req = batch.reqs[i]
|
||||||
req_pool_idx = req_pool_indices_cpu[i]
|
req_pool_idx = req_pool_indices_cpu[i]
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ QUANTIZATION_CONFIG_MAPPING = {
|
|||||||
logger = logging.getLogger("model_runner")
|
logger = logging.getLogger("model_runner")
|
||||||
|
|
||||||
# for server args in model endpoints
|
# for server args in model endpoints
|
||||||
global_server_args_dict: dict = None
|
global_server_args_dict = {}
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@@ -86,8 +86,8 @@ class InputMetadata:
|
|||||||
out_cache_cont_end: torch.Tensor = None
|
out_cache_cont_end: torch.Tensor = None
|
||||||
|
|
||||||
other_kv_index: torch.Tensor = None
|
other_kv_index: torch.Tensor = None
|
||||||
top_logprobs_nums: List[int] = None
|
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
|
top_logprobs_nums: List[int] = None
|
||||||
|
|
||||||
# for flashinfer
|
# for flashinfer
|
||||||
qo_indptr: torch.Tensor = None
|
qo_indptr: torch.Tensor = None
|
||||||
@@ -107,18 +107,20 @@ class InputMetadata:
|
|||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
||||||
|
self.kv_last_page_len = torch.ones(
|
||||||
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||||
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||||
self.kv_indices = torch.cat(
|
self.kv_indices = torch.cat(
|
||||||
[
|
[
|
||||||
self.req_to_token_pool.req_to_token[
|
self.req_to_token_pool.req_to_token[
|
||||||
self.req_pool_indices[i].item(), : self.seq_lens[i].item()
|
req_pool_indices_cpu[i]: seq_lens_cpu[i]
|
||||||
]
|
]
|
||||||
for i in range(self.batch_size)
|
for i in range(self.batch_size)
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
).contiguous()
|
).contiguous()
|
||||||
self.kv_last_page_len = torch.ones(
|
|
||||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
workspace_buffer = torch.empty(
|
workspace_buffer = torch.empty(
|
||||||
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
||||||
@@ -195,15 +197,15 @@ class InputMetadata:
|
|||||||
req_pool_indices[0], seq_lens[0] - 1
|
req_pool_indices[0], seq_lens[0] - 1
|
||||||
].item()
|
].item()
|
||||||
else:
|
else:
|
||||||
seq_lens_np = seq_lens.cpu().numpy()
|
seq_lens_cpu = seq_lens.cpu().numpy()
|
||||||
prefix_lens_np = prefix_lens.cpu().numpy()
|
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
||||||
position_ids_offsets_np = position_ids_offsets.cpu().numpy()
|
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
||||||
positions = torch.tensor(
|
positions = torch.tensor(
|
||||||
np.concatenate(
|
np.concatenate(
|
||||||
[
|
[
|
||||||
np.arange(
|
np.arange(
|
||||||
prefix_lens_np[i] + position_ids_offsets_np[i],
|
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||||
seq_lens_np[i] + position_ids_offsets_np[i],
|
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||||
)
|
)
|
||||||
for i in range(batch_size)
|
for i in range(batch_size)
|
||||||
],
|
],
|
||||||
@@ -229,9 +231,9 @@ class InputMetadata:
|
|||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
out_cache_cont_start=out_cache_cont_start,
|
out_cache_cont_start=out_cache_cont_start,
|
||||||
out_cache_cont_end=out_cache_cont_end,
|
out_cache_cont_end=out_cache_cont_end,
|
||||||
top_logprobs_nums=top_logprobs_nums,
|
|
||||||
return_logprob=return_logprob,
|
|
||||||
other_kv_index=other_kv_index,
|
other_kv_index=other_kv_index,
|
||||||
|
return_logprob=return_logprob,
|
||||||
|
top_logprobs_nums=top_logprobs_nums,
|
||||||
)
|
)
|
||||||
|
|
||||||
if forward_mode == ForwardMode.EXTEND:
|
if forward_mode == ForwardMode.EXTEND:
|
||||||
|
|||||||
@@ -185,7 +185,10 @@ class TokenizerManager:
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
await event.wait()
|
await event.wait()
|
||||||
yield state.out_list[-1]
|
yield self.convert_logprob_style(state.out_list[-1],
|
||||||
|
obj.return_logprob,
|
||||||
|
obj.top_logprobs_num,
|
||||||
|
obj.return_text_in_logprobs)
|
||||||
state.out_list = []
|
state.out_list = []
|
||||||
if state.finished:
|
if state.finished:
|
||||||
del self.rid_to_state[rid]
|
del self.rid_to_state[rid]
|
||||||
@@ -231,16 +234,16 @@ class TokenizerManager:
|
|||||||
rid = obj.rid[i]
|
rid = obj.rid[i]
|
||||||
state = self.rid_to_state[rid]
|
state = self.rid_to_state[rid]
|
||||||
await state.event.wait()
|
await state.event.wait()
|
||||||
output_list.append(state.out_list[-1])
|
output_list.append(
|
||||||
|
self.convert_logprob_style(state.out_list[-1],
|
||||||
|
obj.return_logprob[i],
|
||||||
|
obj.top_logprobs_num[i],
|
||||||
|
obj.return_text_in_logprobs))
|
||||||
assert state.finished
|
assert state.finished
|
||||||
del self.rid_to_state[rid]
|
del self.rid_to_state[rid]
|
||||||
|
|
||||||
yield output_list
|
yield output_list
|
||||||
|
|
||||||
async def detokenize(self, obj: DetokenizeReqInput):
|
|
||||||
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
|
|
||||||
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
|
|
||||||
|
|
||||||
async def flush_cache(self):
|
async def flush_cache(self):
|
||||||
flush_cache_req = FlushCacheReq()
|
flush_cache_req = FlushCacheReq()
|
||||||
self.send_to_router.send_pyobj(flush_cache_req)
|
self.send_to_router.send_pyobj(flush_cache_req)
|
||||||
@@ -267,3 +270,37 @@ class TokenizerManager:
|
|||||||
state.event.set()
|
state.event.set()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid object: {recv_obj}")
|
raise ValueError(f"Invalid object: {recv_obj}")
|
||||||
|
|
||||||
|
def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
|
||||||
|
if return_logprob:
|
||||||
|
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||||
|
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
|
||||||
|
)
|
||||||
|
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||||
|
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
||||||
|
)
|
||||||
|
if top_logprobs_num > 0:
|
||||||
|
ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||||
|
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
||||||
|
)
|
||||||
|
ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||||
|
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
||||||
|
if not decode_to_text:
|
||||||
|
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
||||||
|
|
||||||
|
token_ids = [tid for _, tid in token_logprobs]
|
||||||
|
token_texts = self.tokenizer.batch_decode(token_ids)
|
||||||
|
return [
|
||||||
|
(logprob, token_id, token_text)
|
||||||
|
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
||||||
|
]
|
||||||
|
|
||||||
|
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
|
||||||
|
for i, t in enumerate(top_logprobs):
|
||||||
|
if t:
|
||||||
|
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
||||||
|
return top_logprobs
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
"""pydantic models for OpenAI API protocol"""
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
# Fix a Python bug
|
# Fix a bug of Python threading
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -53,10 +53,10 @@ from sglang.srt.managers.router.manager import start_router_process
|
|||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
enable_show_time_cost,
|
|
||||||
allocate_init_ports,
|
allocate_init_ports,
|
||||||
jsonify_pydantic_model,
|
|
||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
|
enable_show_time_cost,
|
||||||
|
jsonify_pydantic_model,
|
||||||
get_exception_traceback,
|
get_exception_traceback,
|
||||||
API_KEY_HEADER_NAME,
|
API_KEY_HEADER_NAME,
|
||||||
APIKeyValidatorMiddleware
|
APIKeyValidatorMiddleware
|
||||||
@@ -99,12 +99,6 @@ async def flush_cache():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def stream_generator(obj: GenerateReqInput):
|
|
||||||
async for out in tokenizer_manager.generate_request(obj):
|
|
||||||
await handle_token_logprobs_results(obj, out)
|
|
||||||
yield out
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
async def generate_request(obj: GenerateReqInput):
|
async def generate_request(obj: GenerateReqInput):
|
||||||
obj.post_init()
|
obj.post_init()
|
||||||
@@ -112,69 +106,16 @@ async def generate_request(obj: GenerateReqInput):
|
|||||||
if obj.stream:
|
if obj.stream:
|
||||||
|
|
||||||
async def stream_results():
|
async def stream_results():
|
||||||
async for out in stream_generator(obj):
|
async for out in tokenizer_manager.generate_request(obj):
|
||||||
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||||
|
|
||||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||||
await handle_token_logprobs_results(obj, ret)
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
async def detokenize_logprob_tokens(token_logprobs, decode_to_text):
|
|
||||||
if not decode_to_text:
|
|
||||||
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
|
||||||
|
|
||||||
token_ids = [tid for _, tid in token_logprobs]
|
|
||||||
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
|
||||||
return [
|
|
||||||
(logprob, token_id, token_text)
|
|
||||||
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def detokenize_top_logprobs_tokens(top_logprobs, decode_to_text):
|
|
||||||
for i, t in enumerate(top_logprobs):
|
|
||||||
if top_logprobs[i] is not None:
|
|
||||||
top_logprobs[i] = await detokenize_logprob_tokens(t, decode_to_text)
|
|
||||||
return top_logprobs
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_token_logprobs_results(obj: GenerateReqInput, ret):
|
|
||||||
"""Handle the token logprobs results, convert token ids to text if needed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj (GenerateReqInput): The request object.
|
|
||||||
ret (Union[Dict, List[Dict]]): The response object.
|
|
||||||
"""
|
|
||||||
# NOTE: This is because the multiple requests in one http request.
|
|
||||||
|
|
||||||
async def convert_style(r, return_text):
|
|
||||||
r["meta_info"]["prefill_token_logprobs"] = await detokenize_logprob_tokens(
|
|
||||||
r["meta_info"]["prefill_token_logprobs"], return_text
|
|
||||||
)
|
|
||||||
r["meta_info"]["decode_token_logprobs"] = await detokenize_logprob_tokens(
|
|
||||||
r["meta_info"]["decode_token_logprobs"], return_text
|
|
||||||
)
|
|
||||||
r["meta_info"]["prefill_top_logprobs"] = await detokenize_top_logprobs_tokens(
|
|
||||||
r["meta_info"]["prefill_top_logprobs"], return_text
|
|
||||||
)
|
|
||||||
r["meta_info"]["decode_top_logprobs"] = await detokenize_top_logprobs_tokens(
|
|
||||||
r["meta_info"]["decode_top_logprobs"], return_text
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(obj.text, str):
|
|
||||||
if obj.return_logprob:
|
|
||||||
await convert_style(ret, obj.return_text_in_logprobs)
|
|
||||||
else:
|
|
||||||
for i, r in enumerate(ret):
|
|
||||||
if obj.return_logprob[i]:
|
|
||||||
await convert_style(r, obj.return_text_in_logprobs)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def v1_completions(raw_request: Request):
|
async def v1_completions(raw_request: Request):
|
||||||
request_json = await raw_request.json()
|
request_json = await raw_request.json()
|
||||||
@@ -203,10 +144,10 @@ async def v1_completions(raw_request: Request):
|
|||||||
|
|
||||||
if adapted_request.stream:
|
if adapted_request.stream:
|
||||||
|
|
||||||
async def gnerate_stream_resp():
|
async def generate_stream_resp():
|
||||||
stream_buffer = ""
|
stream_buffer = ""
|
||||||
n_prev_token = 0
|
n_prev_token = 0
|
||||||
async for content in stream_generator(adapted_request):
|
async for content in tokenizer_manager.generate_request(adapted_request):
|
||||||
text = content["text"]
|
text = content["text"]
|
||||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||||
@@ -266,7 +207,7 @@ async def v1_completions(raw_request: Request):
|
|||||||
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
|
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
|
||||||
|
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
ret = await generate_request(adapted_request)
|
ret = await generate_request(adapted_request)
|
||||||
@@ -384,7 +325,7 @@ async def v1_chat_completions(raw_request: Request):
|
|||||||
is_first = True
|
is_first = True
|
||||||
|
|
||||||
stream_buffer = ""
|
stream_buffer = ""
|
||||||
async for content in stream_generator(adapted_request):
|
async for content in tokenizer_manager.generate_request(adapted_request):
|
||||||
if is_first:
|
if is_first:
|
||||||
# First chunk with role
|
# First chunk with role
|
||||||
is_first = False
|
is_first = False
|
||||||
|
|||||||
@@ -241,7 +241,7 @@ class ServerArgs:
|
|||||||
def print_mode_args(self):
|
def print_mode_args(self):
|
||||||
return (
|
return (
|
||||||
f"enable_flashinfer={self.enable_flashinfer}, "
|
f"enable_flashinfer={self.enable_flashinfer}, "
|
||||||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
|
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
|
||||||
f"disable_radix_cache={self.disable_radix_cache}, "
|
f"disable_radix_cache={self.disable_radix_cache}, "
|
||||||
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
||||||
f"disable_disk_cache={self.disable_disk_cache}, "
|
f"disable_disk_cache={self.disable_disk_cache}, "
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Common utilities."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -13,6 +15,7 @@ import numpy as np
|
|||||||
import pydantic
|
import pydantic
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
@@ -303,6 +306,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
|||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
# FIXME: Remove this once we drop support for pydantic 1.x
|
# FIXME: Remove this once we drop support for pydantic 1.x
|
||||||
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
||||||
|
|
||||||
|
|||||||
@@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True):
|
|||||||
def test_image_qa():
|
def test_image_qa():
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def image_qa(s, question):
|
def image_qa(s, question):
|
||||||
s += sgl.user(sgl.image("test_image.png") + question)
|
s += sgl.user(sgl.image("example_image.png") + question)
|
||||||
s += sgl.assistant(sgl.gen("answer"))
|
s += sgl.assistant(sgl.gen("answer"))
|
||||||
|
|
||||||
state = image_qa.run(
|
state = image_qa.run(
|
||||||
|
|||||||
BIN
test/lang/example_image.png
Normal file
BIN
test/lang/example_image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 56 KiB |
@@ -28,7 +28,7 @@ class TestOpenAIBackend(unittest.TestCase):
|
|||||||
if cls.backend is None:
|
if cls.backend is None:
|
||||||
cls.backend = OpenAI("gpt-3.5-turbo-instruct")
|
cls.backend = OpenAI("gpt-3.5-turbo-instruct")
|
||||||
cls.chat_backend = OpenAI("gpt-3.5-turbo")
|
cls.chat_backend = OpenAI("gpt-3.5-turbo")
|
||||||
cls.chat_vision_backend = OpenAI("gpt-4-vision-preview")
|
cls.chat_vision_backend = OpenAI("gpt-4-turbo")
|
||||||
|
|
||||||
def test_few_shot_qa(self):
|
def test_few_shot_qa(self):
|
||||||
set_default_backend(self.backend)
|
set_default_backend(self.backend)
|
||||||
@@ -88,14 +88,3 @@ if __name__ == "__main__":
|
|||||||
# t = TestOpenAIBackend()
|
# t = TestOpenAIBackend()
|
||||||
# t.setUp()
|
# t.setUp()
|
||||||
# t.test_few_shot_qa()
|
# t.test_few_shot_qa()
|
||||||
# t.test_mt_bench()
|
|
||||||
# t.test_select()
|
|
||||||
# t.test_decode_int()
|
|
||||||
# t.test_decode_json()
|
|
||||||
# t.test_expert_answer()
|
|
||||||
# t.test_tool_use()
|
|
||||||
# t.test_react()
|
|
||||||
# t.test_parallel_decoding()
|
|
||||||
# t.test_parallel_encoding()
|
|
||||||
# t.test_image_qa()
|
|
||||||
# t.test_stream()
|
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
from sglang import OpenAI, function, gen, set_default_backend
|
|
||||||
|
|
||||||
|
|
||||||
@function()
|
|
||||||
def gen_character_default(s):
|
|
||||||
s += "Construct a character within the following format:\n"
|
|
||||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
|
|
||||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
|
||||||
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
|
||||||
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
|
|
||||||
|
|
||||||
|
|
||||||
@function(api_num_spec_tokens=512)
|
|
||||||
def gen_character_spec(s):
|
|
||||||
s += "Construct a character within the following format:\n"
|
|
||||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
|
|
||||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
|
||||||
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
|
||||||
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
|
|
||||||
|
|
||||||
|
|
||||||
@function(api_num_spec_tokens=512)
|
|
||||||
def gen_character_no_stop(s):
|
|
||||||
s += "Construct a character within the following format:\n"
|
|
||||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
|
|
||||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
|
||||||
s += "Name:" + gen("name") + "\nBirthday:" + gen("birthday")
|
|
||||||
s += "\nJob:" + gen("job") + "\nWelcome.\n"
|
|
||||||
|
|
||||||
|
|
||||||
@function(api_num_spec_tokens=512)
|
|
||||||
def gen_character_multi_stop(s):
|
|
||||||
s += "Construct a character within the following format:\n"
|
|
||||||
s += (
|
|
||||||
"Name: Steve Jobs.###Birthday: February 24, 1955.###Job: Apple CEO.\nWelcome.\n"
|
|
||||||
)
|
|
||||||
s += "\nPlease generate new Name, Birthday and Job.\n"
|
|
||||||
s += "Name:" + gen("name", stop=["\n", "###"])
|
|
||||||
s += "###Birthday:" + gen("birthday", stop=["\n", "###"])
|
|
||||||
s += "###Job:" + gen("job", stop=["\n", "###"]) + "\nWelcome.\n"
|
|
||||||
|
|
||||||
|
|
||||||
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
|
|
||||||
|
|
||||||
state = gen_character_default.run()
|
|
||||||
print(state.text())
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
state = gen_character_no_stop.run()
|
|
||||||
|
|
||||||
print("name###", state["name"])
|
|
||||||
print("birthday###:", state["birthday"])
|
|
||||||
print("job###", state["job"])
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
state = gen_character_multi_stop.run()
|
|
||||||
print(state.text())
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
state = gen_character_spec.run()
|
|
||||||
print(state.text())
|
|
||||||
|
|
||||||
print("name###", state["name"])
|
|
||||||
print("birthday###", state["birthday"])
|
|
||||||
print("job###", state["job"])
|
|
||||||
1
test/srt/example_image.png
Symbolic link
1
test/srt/example_image.png
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../lang/example_image.png
|
||||||
@@ -3,7 +3,6 @@ Usage:
|
|||||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||||
python3 test_httpserver_decode.py
|
python3 test_httpserver_decode.py
|
||||||
|
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||||
"""
|
"""
|
||||||
@@ -23,6 +22,7 @@ def test_decode(url, return_logprob, top_logprobs_num, return_text):
|
|||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
},
|
},
|
||||||
|
"stream": False,
|
||||||
"return_logprob": return_logprob,
|
"return_logprob": return_logprob,
|
||||||
"top_logprobs_num": top_logprobs_num,
|
"top_logprobs_num": top_logprobs_num,
|
||||||
"return_text_in_logprobs": return_text,
|
"return_text_in_logprobs": return_text,
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ def test_decode_stream(url, return_logprob, top_logprobs_num):
|
|||||||
"return_logprob": return_logprob,
|
"return_logprob": return_logprob,
|
||||||
"top_logprobs_num": top_logprobs_num,
|
"top_logprobs_num": top_logprobs_num,
|
||||||
"return_text_in_logprobs": True,
|
"return_text_in_logprobs": True,
|
||||||
|
"logprob_start_len": 0,
|
||||||
},
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ async def test_concurrent(args):
|
|||||||
url + "/generate",
|
url + "/generate",
|
||||||
{
|
{
|
||||||
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
||||||
"image_data": "test_image.png",
|
"image_data": "example_image.png",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 16,
|
"max_new_tokens": 16,
|
||||||
@@ -55,7 +55,7 @@ def test_streaming(args):
|
|||||||
url + "/generate",
|
url + "/generate",
|
||||||
json={
|
json={
|
||||||
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
||||||
"image_data": "test_image.png",
|
"image_data": "example_image.png",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 128,
|
"max_new_tokens": 128,
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ def test_regex(args):
|
|||||||
regex = (
|
regex = (
|
||||||
r"""\{\n"""
|
r"""\{\n"""
|
||||||
+ r""" "name": "[\w]+",\n"""
|
+ r""" "name": "[\w]+",\n"""
|
||||||
+ r""" "population": "[\w\d\s]+"\n"""
|
+ r""" "population": [\w\d\s]+\n"""
|
||||||
+ r"""\}"""
|
+ r"""\}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user