Clean up logits processor (#558)
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
"""Logits processing."""
|
||||
|
||||
import dataclasses
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.distributed import (
|
||||
@@ -10,6 +13,24 @@ from vllm.distributed import (
|
||||
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LogitProcessorOutput:
|
||||
# The logits of the next tokens. shape: [#seq, vocab_size]
|
||||
next_token_logits: torch.Tensor
|
||||
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
||||
next_token_logprobs: torch.Tensor
|
||||
|
||||
# The normlaized logprobs of prompts. shape: [#seq]
|
||||
normalized_prompt_logprobs: torch.Tensor
|
||||
# The logprobs of prefill tokens. shape: [#token, vocab_size]
|
||||
prefill_token_logprobs: torch.Tensor
|
||||
|
||||
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
prefill_top_logprobs: List
|
||||
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
decode_top_logprobs: List
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -39,6 +60,7 @@ class LogitsProcessor(nn.Module):
|
||||
return normalized_prompt_logprobs
|
||||
|
||||
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
|
||||
# TODO: vectorize the code below
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
decode_top_logprobs = []
|
||||
for i in range(all_logprobs.shape[0]):
|
||||
@@ -51,7 +73,6 @@ class LogitsProcessor(nn.Module):
|
||||
else:
|
||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||
pt = 0
|
||||
# NOTE: the GPU-CPU overhead can be reduced
|
||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
|
||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||
if extend_seq_len == 0:
|
||||
@@ -71,18 +92,15 @@ class LogitsProcessor(nn.Module):
|
||||
return prefill_top_logprobs, decode_top_logprobs
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
||||
# Get last index for next token prediction, except for DECODE mode.
|
||||
last_index = None
|
||||
if input_metadata.forward_mode != ForwardMode.DECODE:
|
||||
# Get the last hidden states and last logits for the next token prediction
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_index = None
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_index = (
|
||||
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
||||
- 1
|
||||
)
|
||||
|
||||
# Get the last hidden states and last logits
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_hidden = hidden_states[last_index]
|
||||
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
@@ -92,8 +110,14 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
# Return only last_logits if logprob is not requested
|
||||
if not input_metadata.return_logprob:
|
||||
hidden_states = None
|
||||
return last_logits, (None, None, None, None, None)
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=None,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=None,
|
||||
)
|
||||
else:
|
||||
# When logprob is requested, compute the logits for all tokens.
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
@@ -108,6 +132,7 @@ class LogitsProcessor(nn.Module):
|
||||
del all_logits
|
||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||
|
||||
# Get the logprob of top-k tokens
|
||||
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||
if return_top_logprob:
|
||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||
@@ -117,16 +142,15 @@ class LogitsProcessor(nn.Module):
|
||||
prefill_top_logprobs = decode_top_logprobs = None
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_logprobs = all_logprobs
|
||||
return last_logits, (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
decode_top_logprobs,
|
||||
last_logprobs,
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=all_logprobs,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=decode_top_logprobs,
|
||||
)
|
||||
else:
|
||||
# Compute the logprobs for the last token of each request.
|
||||
last_logprobs = all_logprobs[last_index]
|
||||
|
||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||
@@ -139,12 +163,14 @@ class LogitsProcessor(nn.Module):
|
||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||
prefill_token_logprobs, input_metadata
|
||||
)
|
||||
return last_logits, (
|
||||
prefill_token_logprobs,
|
||||
normalized_prompt_logprobs,
|
||||
prefill_top_logprobs,
|
||||
decode_top_logprobs,
|
||||
last_logprobs,
|
||||
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=last_logprobs,
|
||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||
prefill_token_logprobs=prefill_token_logprobs,
|
||||
prefill_top_logprobs=prefill_top_logprobs,
|
||||
decode_top_logprobs=decode_top_logprobs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -441,33 +441,25 @@ class ModelTpServer:
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
)
|
||||
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
# Forward
|
||||
logits, (
|
||||
prefill_token_logprobs,
|
||||
normalized_prompt_logprobs,
|
||||
prefill_top_logprobs,
|
||||
decode_top_logprobs,
|
||||
last_logprobs,
|
||||
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
if prefill_token_logprobs is not None:
|
||||
prefill_token_logprobs = prefill_token_logprobs.tolist()
|
||||
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
|
||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
|
||||
next_token_ids, _ = batch.sample(logits)
|
||||
|
||||
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||
if last_logprobs is not None:
|
||||
last_token_logprobs = last_logprobs[
|
||||
torch.arange(len(batch.reqs), device=next_token_ids.device),
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
output.next_token_logprobs = output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
|
||||
output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist()
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
|
||||
# Check finish condition
|
||||
# Check finish conditions
|
||||
pt = 0
|
||||
for i, req in enumerate(batch.reqs):
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
@@ -475,58 +467,60 @@ class ModelTpServer:
|
||||
req.check_finished()
|
||||
|
||||
if req.return_logprob:
|
||||
if req.normalized_prompt_logprob is None:
|
||||
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
||||
|
||||
if req.prefill_token_logprobs is None:
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||
req.prefill_token_logprobs = list(
|
||||
zip(
|
||||
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||
req.input_ids[-req.extend_input_len + 1 :],
|
||||
)
|
||||
)
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_token_logprobs = [
|
||||
(None, req.input_ids[0])
|
||||
] + req.prefill_token_logprobs
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.decode_token_logprobs.extend(
|
||||
list(
|
||||
zip(
|
||||
prefill_token_logprobs[
|
||||
pt
|
||||
+ req.extend_input_len
|
||||
- req.last_update_decode_tokens : pt
|
||||
+ req.extend_input_len
|
||||
- 1
|
||||
],
|
||||
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
req.decode_token_logprobs.append(
|
||||
(last_token_logprobs[i], next_token_ids[i])
|
||||
)
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
if req.prefill_top_logprobs is None:
|
||||
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.decode_top_logprobs.extend(
|
||||
prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
||||
)
|
||||
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
||||
|
||||
pt += req.extend_input_len
|
||||
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
||||
pt += req.extend_input_len
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
|
||||
if req.normalized_prompt_logprob is None:
|
||||
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
||||
|
||||
if req.prefill_token_logprobs is None:
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||
req.prefill_token_logprobs = list(
|
||||
zip(
|
||||
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||
req.input_ids[-req.extend_input_len + 1 :],
|
||||
)
|
||||
)
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_token_logprobs = [
|
||||
(None, req.input_ids[0])
|
||||
] + req.prefill_token_logprobs
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.decode_token_logprobs.extend(
|
||||
list(
|
||||
zip(
|
||||
output.prefill_token_logprobs[
|
||||
pt
|
||||
+ req.extend_input_len
|
||||
- req.last_update_decode_tokens : pt
|
||||
+ req.extend_input_len
|
||||
- 1
|
||||
],
|
||||
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
req.decode_token_logprobs.append(
|
||||
(output.next_token_logprobs[i], next_token_ids[i])
|
||||
)
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
if req.prefill_top_logprobs is None:
|
||||
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.decode_top_logprobs.extend(
|
||||
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
||||
)
|
||||
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
||||
|
||||
def cache_filled_batch(self, batch: Batch):
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||
for i, req in enumerate(batch.reqs):
|
||||
@@ -540,7 +534,7 @@ class ModelTpServer:
|
||||
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
||||
|
||||
def forward_decode_batch(self, batch: Batch):
|
||||
# check if decode out of memory
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
old_ratio = self.new_token_ratio
|
||||
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
||||
@@ -559,9 +553,8 @@ class ModelTpServer:
|
||||
)
|
||||
|
||||
if not self.disable_regex_jump_forward:
|
||||
# check for jump-forward
|
||||
# Check for jump-forward
|
||||
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
||||
|
||||
self.forward_queue.extend(jump_forward_reqs)
|
||||
if batch.is_empty():
|
||||
return
|
||||
@@ -570,23 +563,19 @@ class ModelTpServer:
|
||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward
|
||||
logits, (
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
decode_top_logprobs,
|
||||
last_logprobs,
|
||||
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, _ = batch.sample(logits)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
# Forward and sample the next tokens
|
||||
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
|
||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||
if last_logprobs is not None:
|
||||
new_token_logprobs = last_logprobs[
|
||||
torch.arange(len(batch.reqs)), next_token_ids
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
next_token_logprobs = output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
|
||||
# Check finish condition
|
||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
@@ -594,10 +583,9 @@ class ModelTpServer:
|
||||
req.check_finished()
|
||||
|
||||
if req.return_logprob:
|
||||
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
||||
req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id))
|
||||
if req.top_logprobs_num > 0:
|
||||
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
|
||||
@@ -253,7 +253,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
try:
|
||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
|
||||
# Send a warmup request
|
||||
@@ -265,14 +265,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
"text": "The capital city of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
"max_new_tokens": 8,
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(get_exception_traceback())
|
||||
print(f"Initialization failed. warmup error: {e}")
|
||||
|
||||
Reference in New Issue
Block a user