[Fix] Fix logprob and normalized_logprob (#1428)
This commit is contained in:
@@ -164,6 +164,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
||||
req.prefix_indices = []
|
||||
req.sampling_params = sampling_params
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
reqs.append(req)
|
||||
|
||||
return input_ids, reqs
|
||||
@@ -178,6 +179,7 @@ def prepare_extend_inputs_for_correctness_test(
|
||||
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
||||
i, : bench_args.cut_len
|
||||
]
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
return reqs
|
||||
|
||||
|
||||
@@ -194,6 +196,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
||||
req.prefix_indices = []
|
||||
req.sampling_params = sampling_params
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
reqs.append(req)
|
||||
|
||||
return reqs
|
||||
|
||||
@@ -239,9 +239,12 @@ class RuntimeEndpoint(BaseBackend):
|
||||
# Compute logprob
|
||||
data = {
|
||||
"text": [s.text_ + c for c in choices],
|
||||
"sampling_params": {"max_new_tokens": 0},
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 0,
|
||||
"temperature": 0,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": max(prompt_len - 2, 0),
|
||||
"logprob_start_len": max(prompt_len - 2, 0), # for token healing
|
||||
}
|
||||
obj = self._generate_http_request(s, data)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import uuid
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import tqdm
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ class AttentionBackend(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
@@ -66,9 +67,11 @@ class AttentionBackend(ABC):
|
||||
return self.forward_extend(q, k, v, layer, input_metadata)
|
||||
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
"""Run a forward for decode."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
"""Run a forward for extend."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -299,6 +302,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
if total_num_tokens >= global_config.layer_sync_threshold:
|
||||
# TODO: Revisit this. Why is this synchronize needed?
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
@@ -37,7 +37,7 @@ class LogitsProcessorOutput:
|
||||
|
||||
# The normlaized logprobs of prompts. shape: [#seq]
|
||||
normalized_prompt_logprobs: torch.Tensor
|
||||
# The logprobs of input tokens. shape: [#token, vocab_size]
|
||||
# The logprobs of input tokens. shape: [#token, vocab_size]
|
||||
input_token_logprobs: torch.Tensor
|
||||
|
||||
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
@@ -49,25 +49,39 @@ class LogitsProcessorOutput:
|
||||
@dataclasses.dataclass
|
||||
class LogitsMetadata:
|
||||
forward_mode: ForwardMode
|
||||
top_logprobs_nums: Optional[List[int]]
|
||||
|
||||
return_logprob: bool = False
|
||||
return_top_logprob: bool = False
|
||||
|
||||
extend_seq_lens: Optional[torch.Tensor] = None
|
||||
extend_start_loc: Optional[torch.Tensor] = None
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||
|
||||
extend_seq_lens_cpu: List[int] = None
|
||||
logprob_start_lens_cpu: List[int] = None
|
||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
||||
|
||||
@classmethod
|
||||
def from_input_metadata(cls, input_metadata: InputMetadata):
|
||||
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
extend_logprob_pruned_lens_cpu = [
|
||||
extend_len - start_len
|
||||
for extend_len, start_len in zip(
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.extend_logprob_start_lens_cpu,
|
||||
)
|
||||
]
|
||||
else:
|
||||
extend_logprob_pruned_lens_cpu = None
|
||||
return cls(
|
||||
forward_mode=input_metadata.forward_mode,
|
||||
extend_seq_lens=input_metadata.extend_seq_lens,
|
||||
extend_start_loc=input_metadata.extend_start_loc,
|
||||
return_logprob=input_metadata.return_logprob,
|
||||
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
||||
return_logprob=input_metadata.return_logprob,
|
||||
return_top_logprob=return_top_logprob,
|
||||
extend_seq_lens=input_metadata.extend_seq_lens,
|
||||
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
|
||||
logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu,
|
||||
extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu,
|
||||
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
||||
)
|
||||
|
||||
|
||||
@@ -82,57 +96,49 @@ class LogitsProcessor(nn.Module):
|
||||
def _get_normalized_prompt_logprobs(
|
||||
self,
|
||||
input_token_logprobs: torch.Tensor,
|
||||
cum_start_len0: torch.Tensor,
|
||||
cum_start_len1: torch.Tensor,
|
||||
logits_metadata: LogitsMetadata,
|
||||
):
|
||||
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
||||
pruned_lens = torch.tensor(
|
||||
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
||||
)
|
||||
|
||||
start = logits_metadata.extend_start_loc.clone() - cum_start_len0
|
||||
end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1
|
||||
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
||||
start = torch.zeros_like(pruned_lens)
|
||||
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
||||
end = torch.clamp(
|
||||
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
||||
)
|
||||
sum_logp = (
|
||||
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
||||
)
|
||||
normalized_prompt_logprobs = sum_logp / (
|
||||
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
)
|
||||
|
||||
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
||||
return normalized_prompt_logprobs
|
||||
|
||||
@staticmethod
|
||||
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
||||
max_k = max(logits_metadata.top_logprobs_nums)
|
||||
ret = all_logprobs.topk(max_k, dim=1)
|
||||
values = ret.values.tolist()
|
||||
indices = ret.indices.tolist()
|
||||
|
||||
if logits_metadata.forward_mode.is_decode():
|
||||
output_top_logprobs = []
|
||||
max_k = max(logits_metadata.top_logprobs_nums)
|
||||
ret = all_logprobs.topk(max_k, dim=1)
|
||||
values = ret.values.tolist()
|
||||
indices = ret.indices.tolist()
|
||||
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
||||
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
|
||||
return None, output_top_logprobs
|
||||
else:
|
||||
# TODO: vectorize the code below
|
||||
input_top_logprobs, output_top_logprobs = [], []
|
||||
|
||||
pt = 0
|
||||
extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
|
||||
|
||||
max_k = max(logits_metadata.top_logprobs_nums)
|
||||
ret = all_logprobs.topk(max_k, dim=1)
|
||||
values = ret.values.tolist()
|
||||
indices = ret.indices.tolist()
|
||||
|
||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||
start_len = logits_metadata.logprob_start_lens_cpu[i]
|
||||
pruned_len = extend_seq_len - start_len
|
||||
|
||||
if extend_seq_len == 0:
|
||||
for k, pruned_len in zip(
|
||||
logits_metadata.top_logprobs_nums,
|
||||
logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||
):
|
||||
if pruned_len <= 0:
|
||||
input_top_logprobs.append([])
|
||||
output_top_logprobs.append([])
|
||||
continue
|
||||
|
||||
k = logits_metadata.top_logprobs_nums[i]
|
||||
input_top_logprobs.append(
|
||||
[
|
||||
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
||||
@@ -167,10 +173,7 @@ class LogitsProcessor(nn.Module):
|
||||
last_index = None
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_index = (
|
||||
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
||||
- 1
|
||||
)
|
||||
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||
last_hidden = hidden_states[last_index]
|
||||
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
@@ -194,21 +197,15 @@ class LogitsProcessor(nn.Module):
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
else:
|
||||
# When logprob is requested, compute the logits for all tokens.
|
||||
if logits_metadata.forward_mode.is_decode():
|
||||
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||
|
||||
# Get the logprob of top-k tokens
|
||||
return_top_logprob = any(
|
||||
x > 0 for x in logits_metadata.top_logprobs_nums
|
||||
)
|
||||
if return_top_logprob:
|
||||
if logits_metadata.forward_mode.is_decode():
|
||||
if logits_metadata.return_top_logprob:
|
||||
output_top_logprobs = self.get_top_logprobs(
|
||||
last_logprobs, logits_metadata
|
||||
)[1]
|
||||
else:
|
||||
output_top_logprobs = None
|
||||
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=last_logprobs,
|
||||
@@ -218,22 +215,18 @@ class LogitsProcessor(nn.Module):
|
||||
output_top_logprobs=output_top_logprobs,
|
||||
)
|
||||
else:
|
||||
# Slice the requested tokens to compute logprob
|
||||
pt, states, pruned_input_ids = 0, [], []
|
||||
for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu):
|
||||
start_len = logits_metadata.logprob_start_lens_cpu[i]
|
||||
for start_len, extend_len in zip(
|
||||
logits_metadata.extend_logprob_start_lens_cpu,
|
||||
logits_metadata.extend_seq_lens_cpu,
|
||||
):
|
||||
states.append(hidden_states[pt + start_len : pt + extend_len])
|
||||
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
||||
pt += extend_len
|
||||
|
||||
# Compute the logits and logprobs for all required tokens
|
||||
states = torch.cat(states, dim=0)
|
||||
pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
|
||||
|
||||
cum_start_len1 = torch.tensor(
|
||||
logits_metadata.logprob_start_lens_cpu, device="cuda"
|
||||
).cumsum(0)
|
||||
cum_start_len0 = torch.zeros_like(cum_start_len1)
|
||||
cum_start_len0[1:] = cum_start_len1[:-1]
|
||||
|
||||
all_logits = torch.matmul(states, weight.T)
|
||||
if self.do_tensor_parallel_all_gather:
|
||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||
@@ -249,35 +242,29 @@ class LogitsProcessor(nn.Module):
|
||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||
|
||||
# Get the logprob of top-k tokens
|
||||
return_top_logprob = any(
|
||||
x > 0 for x in logits_metadata.top_logprobs_nums
|
||||
)
|
||||
if return_top_logprob:
|
||||
if logits_metadata.return_top_logprob:
|
||||
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
|
||||
all_logprobs, logits_metadata
|
||||
)
|
||||
else:
|
||||
input_top_logprobs = output_top_logprobs = None
|
||||
|
||||
last_logprobs = all_logprobs[last_index - cum_start_len1]
|
||||
|
||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||
# Note that we pad a zero at the end of each sequence for easy computation.
|
||||
# Compute the normalized logprobs for the requested tokens.
|
||||
# Note that we pad a zero at the end for easy batching.
|
||||
input_token_logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
torch.cat(
|
||||
[
|
||||
torch.cat(pruned_input_ids)[1:],
|
||||
torch.tensor([0], device="cuda"),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||
input_token_logprobs,
|
||||
cum_start_len0,
|
||||
cum_start_len1,
|
||||
logits_metadata,
|
||||
)
|
||||
|
||||
# Remove the last token logprob for the prefill tokens.
|
||||
input_token_logprobs = input_token_logprobs[:-1]
|
||||
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=last_logprobs,
|
||||
|
||||
@@ -20,7 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
||||
|
||||
import copy
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
@@ -43,6 +43,7 @@ class GenerateReqInput:
|
||||
# Whether to return logprobs.
|
||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||
# If return logprobs, the start location in the prompt for returning logprobs.
|
||||
# By default, this value is "-1", which means it will only return logprobs for output tokens.
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||
# If return logprobs, the number of top logprobs to return at each position.
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||
|
||||
@@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -53,7 +53,7 @@ class BaseFinishReason:
|
||||
self.is_error = is_error
|
||||
|
||||
def to_json(self):
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
||||
@@ -105,7 +105,13 @@ class FINISH_ABORT(BaseFinishReason):
|
||||
class Req:
|
||||
"""Store all inforamtion of a request."""
|
||||
|
||||
def __init__(self, rid, origin_input_text, origin_input_ids, lora_path=None):
|
||||
def __init__(
|
||||
self,
|
||||
rid: str,
|
||||
origin_input_text: str,
|
||||
origin_input_ids: Tuple[int],
|
||||
lora_path: Optional[str] = None,
|
||||
):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
self.origin_input_text = origin_input_text
|
||||
@@ -118,6 +124,10 @@ class Req:
|
||||
# Memory info
|
||||
self.req_pool_idx = None
|
||||
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
self.finished_reason = None
|
||||
|
||||
# For incremental decoding
|
||||
# ----- | --------- read_ids -------|
|
||||
# ----- | surr_ids |
|
||||
@@ -136,7 +146,7 @@ class Req:
|
||||
# this does not include the jump forward tokens.
|
||||
self.completion_tokens_wo_jump_forward = 0
|
||||
|
||||
# For vision input
|
||||
# For vision inputs
|
||||
self.pixel_values = None
|
||||
self.image_sizes = None
|
||||
self.image_offsets = None
|
||||
@@ -144,31 +154,35 @@ class Req:
|
||||
self.modalities = None
|
||||
|
||||
# Prefix info
|
||||
self.extend_input_len = 0
|
||||
self.prefix_indices = []
|
||||
self.extend_input_len = 0
|
||||
self.last_node = None
|
||||
|
||||
# Sampling parameters
|
||||
self.sampling_params = None
|
||||
self.stream = False
|
||||
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
self.finished_reason = None
|
||||
|
||||
# Logprobs
|
||||
# Logprobs (arguments)
|
||||
self.return_logprob = False
|
||||
self.embedding = None
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = 0
|
||||
|
||||
# Logprobs (return value)
|
||||
self.normalized_prompt_logprob = None
|
||||
self.input_token_logprobs = None
|
||||
self.input_top_logprobs = None
|
||||
self.output_token_logprobs = []
|
||||
self.output_top_logprobs = []
|
||||
|
||||
# Logprobs (internal values)
|
||||
# The tokens is prefilled but need to be considered as decode tokens
|
||||
# and should be updated for the decode logprobs
|
||||
self.last_update_decode_tokens = 0
|
||||
# The relative logprob_start_len in an extend batch
|
||||
self.extend_logprob_start_len = 0
|
||||
|
||||
# Embedding
|
||||
self.embedding = None
|
||||
|
||||
# Constrained decoding
|
||||
self.regex_fsm: RegexGuide = None
|
||||
@@ -363,9 +377,13 @@ class ScheduleBatch:
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
# Stream
|
||||
has_stream: bool = False
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||
return_logprob = any(req.return_logprob for req in reqs)
|
||||
has_stream = any(req.stream for req in reqs)
|
||||
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -373,18 +391,15 @@ class ScheduleBatch:
|
||||
token_to_kv_pool=token_to_kv_pool,
|
||||
tree_cache=tree_cache,
|
||||
return_logprob=return_logprob,
|
||||
has_stream=has_stream,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
return len(self.reqs) if self.reqs else 0
|
||||
return len(self.reqs)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def has_stream(self) -> bool:
|
||||
# Return whether batch has at least 1 streaming request
|
||||
return any(r.stream for r in self.reqs)
|
||||
|
||||
def alloc_req_slots(self, num_reqs):
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||
if req_pool_indices is None:
|
||||
@@ -427,8 +442,8 @@ class ScheduleBatch:
|
||||
for i, req in enumerate(reqs):
|
||||
req.req_pool_idx = req_pool_indices_cpu[i]
|
||||
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||
ext_len = seq_len - pre_len
|
||||
seq_lens.append(seq_len)
|
||||
assert seq_len - pre_len == req.extend_input_len
|
||||
|
||||
if pre_len > 0:
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
@@ -436,9 +451,19 @@ class ScheduleBatch:
|
||||
] = req.prefix_indices
|
||||
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
||||
out_cache_loc[pt : pt + ext_len]
|
||||
out_cache_loc[pt : pt + req.extend_input_len]
|
||||
)
|
||||
pt += ext_len
|
||||
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
extend_logprob_start_len = min(
|
||||
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
||||
)
|
||||
else:
|
||||
extend_logprob_start_len = req.extend_input_len - 1
|
||||
|
||||
req.extend_logprob_start_len = extend_logprob_start_len
|
||||
pt += req.extend_input_len
|
||||
|
||||
# Set fields
|
||||
with torch.device("cuda"):
|
||||
@@ -451,21 +476,13 @@ class ScheduleBatch:
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
|
||||
|
||||
self.extend_lens_cpu = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
self.forward_mode = ForwardMode.MIXED
|
||||
self.running_bs = running_batch.batch_size()
|
||||
|
||||
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
||||
prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
|
||||
prefix_lens_cpu.extend(
|
||||
[
|
||||
len(r.origin_input_ids) + len(r.output_ids) - 1
|
||||
for r in running_batch.reqs
|
||||
]
|
||||
)
|
||||
running_bs = running_batch.batch_size()
|
||||
|
||||
for req in running_batch.reqs:
|
||||
req.fill_ids = req.origin_input_ids + req.output_ids
|
||||
@@ -473,12 +490,22 @@ class ScheduleBatch:
|
||||
|
||||
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
||||
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
||||
extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
|
||||
extend_num_tokens = self.extend_num_tokens + running_bs
|
||||
|
||||
self.merge(running_batch)
|
||||
self.input_ids = input_ids
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens_cpu = prefix_lens_cpu
|
||||
|
||||
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
||||
self.prefix_lens_cpu.extend(
|
||||
[
|
||||
len(r.origin_input_ids) + len(r.output_ids) - 1
|
||||
for r in running_batch.reqs
|
||||
]
|
||||
)
|
||||
self.extend_lens_cpu.extend([1] * running_bs)
|
||||
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
|
||||
|
||||
def check_decode_mem(self):
|
||||
bs = self.batch_size()
|
||||
@@ -685,6 +712,7 @@ class ScheduleBatch:
|
||||
self.out_cache_loc = None
|
||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
self.has_stream = any(req.stream for req in self.reqs)
|
||||
|
||||
self.sampling_info.filter(unfinished_indices, new_indices)
|
||||
|
||||
@@ -695,7 +723,6 @@ class ScheduleBatch:
|
||||
self.sampling_info.merge(other.sampling_info)
|
||||
|
||||
self.reqs.extend(other.reqs)
|
||||
|
||||
self.req_pool_indices = torch.concat(
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
@@ -706,3 +733,4 @@ class ScheduleBatch:
|
||||
self.out_cache_loc = None
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
self.has_stream = any(req.stream for req in self.reqs)
|
||||
|
||||
@@ -197,8 +197,6 @@ class TokenizerManager:
|
||||
if not_use_index
|
||||
else obj.logprob_start_len[index]
|
||||
)
|
||||
if return_logprob and logprob_start_len == -1:
|
||||
logprob_start_len = len(input_ids) - 1
|
||||
top_logprobs_num = (
|
||||
obj.top_logprobs_num
|
||||
if not_use_index
|
||||
@@ -251,8 +249,6 @@ class TokenizerManager:
|
||||
|
||||
# Send to the controller
|
||||
if self.is_generation:
|
||||
if return_logprob and logprob_start_len == -1:
|
||||
logprob_start_len = len(input_ids) - 1
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
@@ -349,8 +345,6 @@ class TokenizerManager:
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||
|
||||
if self.is_generation:
|
||||
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
|
||||
obj.logprob_start_len[index] = len(input_ids) - 1
|
||||
pixel_values, image_hashes, image_sizes = (
|
||||
await self._get_pixel_values(obj.image_data[index])
|
||||
)
|
||||
|
||||
@@ -278,7 +278,7 @@ class ModelTpServer:
|
||||
self.running_batch = None
|
||||
break
|
||||
|
||||
if self.out_pyobjs and self.running_batch.has_stream():
|
||||
if self.out_pyobjs and self.running_batch.has_stream:
|
||||
break
|
||||
else:
|
||||
self.check_memory()
|
||||
@@ -360,9 +360,13 @@ class ModelTpServer:
|
||||
# Only when pixel values is not None we have modalities
|
||||
req.modalities = recv_req.modalites
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||
req.stream = recv_req.stream
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
|
||||
if req.logprob_start_len == -1:
|
||||
# By default, only return the logprobs for output tokens
|
||||
req.logprob_start_len = len(recv_req.input_ids) - 1
|
||||
|
||||
# Init regex FSM
|
||||
if (
|
||||
@@ -384,7 +388,7 @@ class ModelTpServer:
|
||||
|
||||
# Truncate prompts that are too long
|
||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"Request length is longer than the KV cache pool size or "
|
||||
"the max context length. Truncated!!!"
|
||||
)
|
||||
@@ -583,7 +587,7 @@ class ModelTpServer:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
|
||||
# Check finish conditions
|
||||
pt = 0
|
||||
logprob_pt = 0
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req is not self.current_inflight_req:
|
||||
# Inflight reqs' prefill is not finished
|
||||
@@ -607,10 +611,9 @@ class ModelTpServer:
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
if req.return_logprob:
|
||||
self.add_logprob_return_values(
|
||||
i, req, pt, next_token_ids, logits_output
|
||||
logprob_pt += self.add_logprob_return_values(
|
||||
i, req, logprob_pt, next_token_ids, logits_output
|
||||
)
|
||||
pt += req.extend_input_len
|
||||
else:
|
||||
assert batch.extend_num_tokens != 0
|
||||
logits_output = self.model_runner.forward(batch)
|
||||
@@ -638,48 +641,63 @@ class ModelTpServer:
|
||||
|
||||
def add_logprob_return_values(
|
||||
self,
|
||||
i,
|
||||
i: int,
|
||||
req: Req,
|
||||
pt: int,
|
||||
next_token_ids: List[int],
|
||||
output: LogitsProcessorOutput,
|
||||
):
|
||||
"""Attach logprobs to the return values."""
|
||||
req.output_token_logprobs.append(
|
||||
(output.next_token_logprobs[i], next_token_ids[i])
|
||||
)
|
||||
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
||||
|
||||
if req.normalized_prompt_logprob is None:
|
||||
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
||||
|
||||
if req.input_token_logprobs is None:
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||
req.input_token_logprobs = list(
|
||||
zip(
|
||||
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||
req.fill_ids[-req.extend_input_len + 1 :],
|
||||
)
|
||||
)
|
||||
if req.logprob_start_len == 0:
|
||||
input_token_logprobs = output.input_token_logprobs[
|
||||
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
||||
]
|
||||
input_token_ids = req.fill_ids[
|
||||
len(req.fill_ids)
|
||||
- num_input_logprobs
|
||||
+ 1 : len(req.fill_ids)
|
||||
- req.last_update_decode_tokens
|
||||
]
|
||||
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
|
||||
|
||||
if (
|
||||
req.logprob_start_len == 0
|
||||
): # The first token does not have logprob, pad it.
|
||||
req.input_token_logprobs = [
|
||||
(None, req.fill_ids[0])
|
||||
] + req.input_token_logprobs
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
# Some decode tokens are re-computed in an extend batch
|
||||
req.output_token_logprobs.extend(
|
||||
list(
|
||||
zip(
|
||||
output.input_token_logprobs[
|
||||
pt
|
||||
+ req.extend_input_len
|
||||
+ num_input_logprobs
|
||||
- 1
|
||||
- req.last_update_decode_tokens : pt
|
||||
+ req.extend_input_len
|
||||
+ num_input_logprobs
|
||||
- 1
|
||||
],
|
||||
req.fill_ids[-req.last_update_decode_tokens + 1 :],
|
||||
req.fill_ids[
|
||||
len(req.fill_ids)
|
||||
- req.last_update_decode_tokens : len(req.fill_ids)
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
req.output_token_logprobs.append(
|
||||
(output.next_token_logprobs[i], next_token_ids[i])
|
||||
)
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
if req.input_top_logprobs is None:
|
||||
req.input_top_logprobs = output.input_top_logprobs[i]
|
||||
@@ -688,10 +706,12 @@ class ModelTpServer:
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.output_top_logprobs.extend(
|
||||
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
||||
output.input_top_logprobs[i][-req.last_update_decode_tokens :]
|
||||
)
|
||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||
|
||||
return num_input_logprobs
|
||||
|
||||
def forward_decode_batch(self, batch: ScheduleBatch):
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
|
||||
@@ -193,7 +193,7 @@ class CudaGraphRunner:
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
out_cache_loc=out_cache_loc,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=0,
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
||||
)
|
||||
return forward(input_ids, input_metadata.positions, input_metadata)
|
||||
|
||||
@@ -81,7 +81,7 @@ class InputMetadata:
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
extend_seq_lens_cpu: List[int] = None
|
||||
logprob_start_lens_cpu: List[int] = None
|
||||
extend_logprob_start_lens_cpu: List[int] = None
|
||||
|
||||
# For multimodal
|
||||
pixel_values: List[torch.Tensor] = None
|
||||
@@ -138,27 +138,13 @@ class InputMetadata:
|
||||
self.positions = self.positions.to(torch.int64)
|
||||
|
||||
def compute_extend_infos(self, batch: ScheduleBatch):
|
||||
extend_lens_cpu = [
|
||||
len(r.fill_ids) - batch.prefix_lens_cpu[i] for i, r in enumerate(batch.reqs)
|
||||
]
|
||||
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
||||
self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
|
||||
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||
self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
||||
|
||||
self.extend_seq_lens_cpu = extend_lens_cpu
|
||||
self.logprob_start_lens_cpu = [
|
||||
(
|
||||
min(
|
||||
req.logprob_start_len - batch.prefix_lens_cpu[i],
|
||||
extend_lens_cpu[i] - 1,
|
||||
)
|
||||
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
|
||||
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
|
||||
)
|
||||
for i, req in enumerate(batch.reqs)
|
||||
]
|
||||
self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
|
||||
self.extend_seq_lens_cpu = batch.extend_lens_cpu
|
||||
self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(
|
||||
|
||||
@@ -22,7 +22,7 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import HTTPException, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
@@ -472,7 +472,7 @@ def v1_generate_request(
|
||||
first_prompt_type = type(all_requests[0].prompt)
|
||||
for request in all_requests:
|
||||
assert (
|
||||
type(request.prompt) == first_prompt_type
|
||||
type(request.prompt) is first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
if len(all_requests) > 1 and request.n > 1:
|
||||
raise ValueError(
|
||||
@@ -887,7 +887,7 @@ def v1_chat_generate_request(
|
||||
input_ids.append(prompt_ids)
|
||||
return_logprobs.append(request.logprobs)
|
||||
logprob_start_lens.append(-1)
|
||||
top_logprobs_nums.append(request.top_logprobs)
|
||||
top_logprobs_nums.append(request.top_logprobs or 0)
|
||||
|
||||
sampling_params = {
|
||||
"temperature": request.temperature,
|
||||
|
||||
@@ -86,24 +86,24 @@ class SamplingBatchInfo:
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
device = "cuda"
|
||||
reqs = batch.reqs
|
||||
ret = cls(vocab_size=vocab_size)
|
||||
|
||||
ret.temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
).view(-1, 1)
|
||||
ret.top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
ret.top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||
)
|
||||
ret.min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
with torch.device("cuda"):
|
||||
ret.temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
).view(-1, 1)
|
||||
ret.top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
||||
)
|
||||
ret.top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int
|
||||
)
|
||||
ret.min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
)
|
||||
|
||||
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
@@ -116,7 +116,7 @@ class SamplingBatchInfo:
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device=device,
|
||||
device="cuda",
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
|
||||
Reference in New Issue
Block a user