Use dtype to control generate (#1082)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Liangsheng Yin
2024-08-14 08:58:07 -07:00
committed by GitHub
parent 67c0d832a6
commit a34dd86a7d
12 changed files with 110 additions and 88 deletions

View File

@@ -383,7 +383,7 @@ class ScheduleBatch:
return out_cache_loc
def batch_sampling_params(self, vocab_size, int_token_logit_bias):
def batch_sampling_params(self, vocab_size):
device = "cuda"
bs, reqs = self.batch_size(), self.reqs
self.temperatures = torch.tensor(
@@ -419,15 +419,8 @@ class ScheduleBatch:
# Handle logit bias but only allocate when needed
self.logit_bias = None
for i in range(bs):
if reqs[i].sampling_params.dtype == "int":
if self.logit_bias is None:
self.logit_bias = torch.zeros(
(bs, vocab_size), dtype=torch.float32, device=device
)
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
def prepare_for_extend(self, vocab_size: int):
bs = self.batch_size()
reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
@@ -466,7 +459,7 @@ class ScheduleBatch:
self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.batch_sampling_params(vocab_size, int_token_logit_bias)
self.batch_sampling_params(vocab_size)
def check_decode_mem(self):
bs = self.batch_size()

View File

@@ -54,7 +54,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
suppress_other_loggers,
@@ -132,9 +131,6 @@ class ModelTpServer:
),
self.model_runner.req_to_token_pool.size - 1,
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
self.max_req_input_len = min(
self.model_config.context_len - 1,
self.max_total_num_tokens - 1,
@@ -442,9 +438,7 @@ class ModelTpServer:
def forward_prefill_batch(self, batch: ScheduleBatch):
# Build batch tensors
batch.prepare_for_extend(
self.model_config.vocab_size, self.int_token_logit_bias
)
batch.prepare_for_extend(self.model_config.vocab_size)
if self.model_runner.is_generation:
# Forward and sample the next tokens

View File

@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)

View File

@@ -36,7 +36,6 @@ class SamplingParams:
ignore_eos: bool = False,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
dtype: Optional[str] = None,
regex: Optional[str] = None,
n: int = 1,
) -> None:
@@ -53,7 +52,6 @@ class SamplingParams:
self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.dtype = dtype
self.regex = regex
self.n = n
@@ -63,8 +61,6 @@ class SamplingParams:
self.top_k = 1
if self.top_k == -1:
self.top_k = 1 << 30 # whole vocabulary
if self.dtype == "int":
self.stop_strs = [" ", "\n"]
def verify(self):
if self.temperature < 0.0: