Use dtype to control generate (#1082)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -383,7 +383,7 @@ class ScheduleBatch:
|
||||
|
||||
return out_cache_loc
|
||||
|
||||
def batch_sampling_params(self, vocab_size, int_token_logit_bias):
|
||||
def batch_sampling_params(self, vocab_size):
|
||||
device = "cuda"
|
||||
bs, reqs = self.batch_size(), self.reqs
|
||||
self.temperatures = torch.tensor(
|
||||
@@ -419,15 +419,8 @@ class ScheduleBatch:
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
self.logit_bias = None
|
||||
for i in range(bs):
|
||||
if reqs[i].sampling_params.dtype == "int":
|
||||
if self.logit_bias is None:
|
||||
self.logit_bias = torch.zeros(
|
||||
(bs, vocab_size), dtype=torch.float32, device=device
|
||||
)
|
||||
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
||||
|
||||
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||
def prepare_for_extend(self, vocab_size: int):
|
||||
bs = self.batch_size()
|
||||
reqs = self.reqs
|
||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
@@ -466,7 +459,7 @@ class ScheduleBatch:
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
|
||||
self.batch_sampling_params(vocab_size, int_token_logit_bias)
|
||||
self.batch_sampling_params(vocab_size)
|
||||
|
||||
def check_decode_mem(self):
|
||||
bs = self.batch_size()
|
||||
|
||||
@@ -54,7 +54,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_int_token_logit_bias,
|
||||
is_multimodal_model,
|
||||
set_random_seed,
|
||||
suppress_other_loggers,
|
||||
@@ -132,9 +131,6 @@ class ModelTpServer:
|
||||
),
|
||||
self.model_runner.req_to_token_pool.size - 1,
|
||||
)
|
||||
self.int_token_logit_bias = torch.tensor(
|
||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||
)
|
||||
self.max_req_input_len = min(
|
||||
self.model_config.context_len - 1,
|
||||
self.max_total_num_tokens - 1,
|
||||
@@ -442,9 +438,7 @@ class ModelTpServer:
|
||||
|
||||
def forward_prefill_batch(self, batch: ScheduleBatch):
|
||||
# Build batch tensors
|
||||
batch.prepare_for_extend(
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
)
|
||||
batch.prepare_for_extend(self.model_config.vocab_size)
|
||||
|
||||
if self.model_runner.is_generation:
|
||||
# Forward and sample the next tokens
|
||||
|
||||
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
|
||||
@@ -36,7 +36,6 @@ class SamplingParams:
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
dtype: Optional[str] = None,
|
||||
regex: Optional[str] = None,
|
||||
n: int = 1,
|
||||
) -> None:
|
||||
@@ -53,7 +52,6 @@ class SamplingParams:
|
||||
self.ignore_eos = ignore_eos
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.dtype = dtype
|
||||
self.regex = regex
|
||||
self.n = n
|
||||
|
||||
@@ -63,8 +61,6 @@ class SamplingParams:
|
||||
self.top_k = 1
|
||||
if self.top_k == -1:
|
||||
self.top_k = 1 << 30 # whole vocabulary
|
||||
if self.dtype == "int":
|
||||
self.stop_strs = [" ", "\n"]
|
||||
|
||||
def verify(self):
|
||||
if self.temperature < 0.0:
|
||||
|
||||
Reference in New Issue
Block a user