Use dtype to control generate (#1082)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -72,7 +72,7 @@ def gen(
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
dtype: Optional[type] = None,
|
||||
dtype: Optional[Union[type, str]] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
choices_method: Optional[ChoicesSamplingMethod] = None,
|
||||
regex: Optional[str] = None,
|
||||
|
||||
@@ -195,7 +195,7 @@ def extend(reqs, model_runner):
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
tree_cache=None,
|
||||
)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
||||
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
return next_token_ids, output.next_token_logits, batch
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
import json
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.choices import (
|
||||
ChoicesDecision,
|
||||
ChoicesSamplingMethod,
|
||||
token_length_normalized,
|
||||
)
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
from sglang.lang.ir import (
|
||||
REGEX_BOOL,
|
||||
REGEX_FLOAT,
|
||||
REGEX_INT,
|
||||
REGEX_STR,
|
||||
SglSamplingParams,
|
||||
)
|
||||
from sglang.utils import http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
@@ -95,32 +97,52 @@ class RuntimeEndpoint(BaseBackend):
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
|
||||
if sampling_params.dtype is None:
|
||||
return
|
||||
|
||||
if sampling_params.stop == ():
|
||||
sampling_params.stop = []
|
||||
|
||||
dtype_regex = None
|
||||
if sampling_params.dtype in ["int", int]:
|
||||
|
||||
dtype_regex = REGEX_INT
|
||||
sampling_params.stop.extend([" ", "\n"])
|
||||
elif sampling_params.dtype in ["float", float]:
|
||||
|
||||
dtype_regex = REGEX_FLOAT
|
||||
sampling_params.stop.extend([" ", "\n"])
|
||||
elif sampling_params.dtype in ["str", str]:
|
||||
|
||||
dtype_regex = REGEX_STR
|
||||
elif sampling_params.dtype in ["bool", bool]:
|
||||
|
||||
dtype_regex = REGEX_BOOL
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
if dtype_regex is not None and sampling_params.regex is not None:
|
||||
warnings.warn(
|
||||
f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
|
||||
)
|
||||
|
||||
sampling_params.regex = dtype_regex
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
"dtype": "int",
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
self._handle_dtype_to_regex(sampling_params)
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
|
||||
for item in [
|
||||
"return_logprob",
|
||||
@@ -151,27 +173,16 @@ class RuntimeEndpoint(BaseBackend):
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
"dtype": "int",
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
self._handle_dtype_to_regex(sampling_params)
|
||||
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
|
||||
for item in [
|
||||
"return_logprob",
|
||||
|
||||
@@ -8,10 +8,10 @@ from typing import List, Optional, Union
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.choices import ChoicesSamplingMethod
|
||||
|
||||
REGEX_INT = r"[-+]?[0-9]+"
|
||||
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
|
||||
REGEX_INT = r"[-+]?[0-9]+[ \n]*"
|
||||
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
|
||||
REGEX_BOOL = r"(True|False)"
|
||||
REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
@@ -383,7 +383,7 @@ class ScheduleBatch:
|
||||
|
||||
return out_cache_loc
|
||||
|
||||
def batch_sampling_params(self, vocab_size, int_token_logit_bias):
|
||||
def batch_sampling_params(self, vocab_size):
|
||||
device = "cuda"
|
||||
bs, reqs = self.batch_size(), self.reqs
|
||||
self.temperatures = torch.tensor(
|
||||
@@ -419,15 +419,8 @@ class ScheduleBatch:
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
self.logit_bias = None
|
||||
for i in range(bs):
|
||||
if reqs[i].sampling_params.dtype == "int":
|
||||
if self.logit_bias is None:
|
||||
self.logit_bias = torch.zeros(
|
||||
(bs, vocab_size), dtype=torch.float32, device=device
|
||||
)
|
||||
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
||||
|
||||
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||
def prepare_for_extend(self, vocab_size: int):
|
||||
bs = self.batch_size()
|
||||
reqs = self.reqs
|
||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
@@ -466,7 +459,7 @@ class ScheduleBatch:
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
|
||||
self.batch_sampling_params(vocab_size, int_token_logit_bias)
|
||||
self.batch_sampling_params(vocab_size)
|
||||
|
||||
def check_decode_mem(self):
|
||||
bs = self.batch_size()
|
||||
|
||||
@@ -54,7 +54,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_int_token_logit_bias,
|
||||
is_multimodal_model,
|
||||
set_random_seed,
|
||||
suppress_other_loggers,
|
||||
@@ -132,9 +131,6 @@ class ModelTpServer:
|
||||
),
|
||||
self.model_runner.req_to_token_pool.size - 1,
|
||||
)
|
||||
self.int_token_logit_bias = torch.tensor(
|
||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||
)
|
||||
self.max_req_input_len = min(
|
||||
self.model_config.context_len - 1,
|
||||
self.max_total_num_tokens - 1,
|
||||
@@ -442,9 +438,7 @@ class ModelTpServer:
|
||||
|
||||
def forward_prefill_batch(self, batch: ScheduleBatch):
|
||||
# Build batch tensors
|
||||
batch.prepare_for_extend(
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
)
|
||||
batch.prepare_for_extend(self.model_config.vocab_size)
|
||||
|
||||
if self.model_runner.is_generation:
|
||||
# Forward and sample the next tokens
|
||||
|
||||
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
|
||||
@@ -36,7 +36,6 @@ class SamplingParams:
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
dtype: Optional[str] = None,
|
||||
regex: Optional[str] = None,
|
||||
n: int = 1,
|
||||
) -> None:
|
||||
@@ -53,7 +52,6 @@ class SamplingParams:
|
||||
self.ignore_eos = ignore_eos
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.dtype = dtype
|
||||
self.regex = regex
|
||||
self.n = n
|
||||
|
||||
@@ -63,8 +61,6 @@ class SamplingParams:
|
||||
self.top_k = 1
|
||||
if self.top_k == -1:
|
||||
self.top_k = 1 << 30 # whole vocabulary
|
||||
if self.dtype == "int":
|
||||
self.stop_strs = [" ", "\n"]
|
||||
|
||||
def verify(self):
|
||||
if self.temperature < 0.0:
|
||||
|
||||
@@ -103,13 +103,13 @@ def test_decode_int():
|
||||
def test_decode_json_regex():
|
||||
@sgl.function
|
||||
def decode_json(s):
|
||||
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
|
||||
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
||||
|
||||
s += "Generate a JSON object to describe the basic city information of Paris.\n"
|
||||
|
||||
with s.var_scope("json_output"):
|
||||
s += "{\n"
|
||||
s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
|
||||
s += ' "name": ' + sgl.gen(regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
||||
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
||||
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
|
||||
@@ -359,6 +359,30 @@ def test_regex():
|
||||
assert re.match(regex, answer)
|
||||
|
||||
|
||||
def test_dtype_gen():
|
||||
@sgl.function
|
||||
def dtype_gen(s):
|
||||
s += "Q: What is the full name of DNS?\n"
|
||||
s += "A: The full nams is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n"
|
||||
s += "Q: Which year was DNS invented?\n"
|
||||
s += "A: " + sgl.gen("int_res", dtype=int) + "\n"
|
||||
s += "Q: What is the value of pi?\n"
|
||||
s += "A: " + sgl.gen("float_res", dtype=float) + "\n"
|
||||
s += "Q: Is the sky blue?\n"
|
||||
s += "A: " + sgl.gen("bool_res", dtype=bool) + "\n"
|
||||
|
||||
state = dtype_gen.run()
|
||||
|
||||
try:
|
||||
state["int_res"] = int(state["int_res"])
|
||||
state["float_res"] = float(state["float_res"])
|
||||
state["bool_res"] = bool(state["bool_res"])
|
||||
# assert state["str_res"].startswith('"') and state["str_res"].endswith('"')
|
||||
except ValueError:
|
||||
print(state)
|
||||
raise
|
||||
|
||||
|
||||
def test_completion_speculative():
|
||||
@sgl.function(num_api_spec_tokens=64)
|
||||
def gen_character_spec(s):
|
||||
|
||||
Reference in New Issue
Block a user