Clean up logits processor (#558)
This commit is contained in:
@@ -1,5 +1,8 @@
|
|||||||
"""Logits processing."""
|
"""Logits processing."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
@@ -10,6 +13,24 @@ from vllm.distributed import (
|
|||||||
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class LogitProcessorOutput:
|
||||||
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
||||||
|
next_token_logits: torch.Tensor
|
||||||
|
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
||||||
|
next_token_logprobs: torch.Tensor
|
||||||
|
|
||||||
|
# The normlaized logprobs of prompts. shape: [#seq]
|
||||||
|
normalized_prompt_logprobs: torch.Tensor
|
||||||
|
# The logprobs of prefill tokens. shape: [#token, vocab_size]
|
||||||
|
prefill_token_logprobs: torch.Tensor
|
||||||
|
|
||||||
|
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||||
|
prefill_top_logprobs: List
|
||||||
|
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||||
|
decode_top_logprobs: List
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessor(nn.Module):
|
class LogitsProcessor(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -39,6 +60,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
return normalized_prompt_logprobs
|
return normalized_prompt_logprobs
|
||||||
|
|
||||||
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
|
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
|
||||||
|
# TODO: vectorize the code below
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
decode_top_logprobs = []
|
decode_top_logprobs = []
|
||||||
for i in range(all_logprobs.shape[0]):
|
for i in range(all_logprobs.shape[0]):
|
||||||
@@ -51,7 +73,6 @@ class LogitsProcessor(nn.Module):
|
|||||||
else:
|
else:
|
||||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||||
pt = 0
|
pt = 0
|
||||||
# NOTE: the GPU-CPU overhead can be reduced
|
|
||||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
|
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
|
||||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||||
if extend_seq_len == 0:
|
if extend_seq_len == 0:
|
||||||
@@ -71,18 +92,15 @@ class LogitsProcessor(nn.Module):
|
|||||||
return prefill_top_logprobs, decode_top_logprobs
|
return prefill_top_logprobs, decode_top_logprobs
|
||||||
|
|
||||||
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
||||||
# Get last index for next token prediction, except for DECODE mode.
|
# Get the last hidden states and last logits for the next token prediction
|
||||||
last_index = None
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
if input_metadata.forward_mode != ForwardMode.DECODE:
|
last_index = None
|
||||||
|
last_hidden = hidden_states
|
||||||
|
else:
|
||||||
last_index = (
|
last_index = (
|
||||||
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the last hidden states and last logits
|
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
|
||||||
last_hidden = hidden_states
|
|
||||||
else:
|
|
||||||
last_hidden = hidden_states[last_index]
|
last_hidden = hidden_states[last_index]
|
||||||
|
|
||||||
last_logits = torch.matmul(last_hidden, weight.T)
|
last_logits = torch.matmul(last_hidden, weight.T)
|
||||||
@@ -92,8 +110,14 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
# Return only last_logits if logprob is not requested
|
# Return only last_logits if logprob is not requested
|
||||||
if not input_metadata.return_logprob:
|
if not input_metadata.return_logprob:
|
||||||
hidden_states = None
|
return LogitProcessorOutput(
|
||||||
return last_logits, (None, None, None, None, None)
|
next_token_logits=last_logits,
|
||||||
|
next_token_logprobs=None,
|
||||||
|
normalized_prompt_logprobs=None,
|
||||||
|
prefill_token_logprobs=None,
|
||||||
|
prefill_top_logprobs=None,
|
||||||
|
decode_top_logprobs=None,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# When logprob is requested, compute the logits for all tokens.
|
# When logprob is requested, compute the logits for all tokens.
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
@@ -108,6 +132,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
del all_logits
|
del all_logits
|
||||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||||
|
|
||||||
|
# Get the logprob of top-k tokens
|
||||||
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||||
if return_top_logprob:
|
if return_top_logprob:
|
||||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||||
@@ -117,16 +142,15 @@ class LogitsProcessor(nn.Module):
|
|||||||
prefill_top_logprobs = decode_top_logprobs = None
|
prefill_top_logprobs = decode_top_logprobs = None
|
||||||
|
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
last_logprobs = all_logprobs
|
return LogitProcessorOutput(
|
||||||
return last_logits, (
|
next_token_logits=last_logits,
|
||||||
None,
|
next_token_logprobs=all_logprobs,
|
||||||
None,
|
normalized_prompt_logprobs=None,
|
||||||
None,
|
prefill_token_logprobs=None,
|
||||||
decode_top_logprobs,
|
prefill_top_logprobs=None,
|
||||||
last_logprobs,
|
decode_top_logprobs=decode_top_logprobs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Compute the logprobs for the last token of each request.
|
|
||||||
last_logprobs = all_logprobs[last_index]
|
last_logprobs = all_logprobs[last_index]
|
||||||
|
|
||||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||||
@@ -139,12 +163,14 @@ class LogitsProcessor(nn.Module):
|
|||||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||||
prefill_token_logprobs, input_metadata
|
prefill_token_logprobs, input_metadata
|
||||||
)
|
)
|
||||||
return last_logits, (
|
|
||||||
prefill_token_logprobs,
|
return LogitProcessorOutput(
|
||||||
normalized_prompt_logprobs,
|
next_token_logits=last_logits,
|
||||||
prefill_top_logprobs,
|
next_token_logprobs=last_logprobs,
|
||||||
decode_top_logprobs,
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||||
last_logprobs,
|
prefill_token_logprobs=prefill_token_logprobs,
|
||||||
|
prefill_top_logprobs=prefill_top_logprobs,
|
||||||
|
decode_top_logprobs=decode_top_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -441,33 +441,25 @@ class ModelTpServer:
|
|||||||
self.model_config.vocab_size, self.int_token_logit_bias
|
self.model_config.vocab_size, self.int_token_logit_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Forward and sample the next tokens
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
# Forward
|
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||||
logits, (
|
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||||
prefill_token_logprobs,
|
|
||||||
normalized_prompt_logprobs,
|
|
||||||
prefill_top_logprobs,
|
|
||||||
decode_top_logprobs,
|
|
||||||
last_logprobs,
|
|
||||||
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
|
||||||
if prefill_token_logprobs is not None:
|
|
||||||
prefill_token_logprobs = prefill_token_logprobs.tolist()
|
|
||||||
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
|
|
||||||
|
|
||||||
next_token_ids, _ = batch.sample(logits)
|
# Move logprobs to cpu
|
||||||
|
if output.next_token_logprobs is not None:
|
||||||
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
output.next_token_logprobs = output.next_token_logprobs[
|
||||||
if last_logprobs is not None:
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||||
last_token_logprobs = last_logprobs[
|
|
||||||
torch.arange(len(batch.reqs), device=next_token_ids.device),
|
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
].tolist()
|
].tolist()
|
||||||
|
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
|
||||||
|
output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist()
|
||||||
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
else:
|
else:
|
||||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish conditions
|
||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
@@ -475,58 +467,60 @@ class ModelTpServer:
|
|||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
if req.normalized_prompt_logprob is None:
|
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
||||||
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
pt += req.extend_input_len
|
||||||
|
|
||||||
if req.prefill_token_logprobs is None:
|
|
||||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
|
||||||
req.prefill_token_logprobs = list(
|
|
||||||
zip(
|
|
||||||
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
|
||||||
req.input_ids[-req.extend_input_len + 1 :],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if req.logprob_start_len == 0:
|
|
||||||
req.prefill_token_logprobs = [
|
|
||||||
(None, req.input_ids[0])
|
|
||||||
] + req.prefill_token_logprobs
|
|
||||||
|
|
||||||
if req.last_update_decode_tokens != 0:
|
|
||||||
req.decode_token_logprobs.extend(
|
|
||||||
list(
|
|
||||||
zip(
|
|
||||||
prefill_token_logprobs[
|
|
||||||
pt
|
|
||||||
+ req.extend_input_len
|
|
||||||
- req.last_update_decode_tokens : pt
|
|
||||||
+ req.extend_input_len
|
|
||||||
- 1
|
|
||||||
],
|
|
||||||
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
req.decode_token_logprobs.append(
|
|
||||||
(last_token_logprobs[i], next_token_ids[i])
|
|
||||||
)
|
|
||||||
|
|
||||||
if req.top_logprobs_num > 0:
|
|
||||||
if req.prefill_top_logprobs is None:
|
|
||||||
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
|
||||||
if req.logprob_start_len == 0:
|
|
||||||
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
|
||||||
|
|
||||||
if req.last_update_decode_tokens != 0:
|
|
||||||
req.decode_top_logprobs.extend(
|
|
||||||
prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
|
||||||
)
|
|
||||||
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
|
||||||
|
|
||||||
pt += req.extend_input_len
|
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
|
def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
|
||||||
|
if req.normalized_prompt_logprob is None:
|
||||||
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
||||||
|
|
||||||
|
if req.prefill_token_logprobs is None:
|
||||||
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||||
|
req.prefill_token_logprobs = list(
|
||||||
|
zip(
|
||||||
|
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||||
|
req.input_ids[-req.extend_input_len + 1 :],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if req.logprob_start_len == 0:
|
||||||
|
req.prefill_token_logprobs = [
|
||||||
|
(None, req.input_ids[0])
|
||||||
|
] + req.prefill_token_logprobs
|
||||||
|
|
||||||
|
if req.last_update_decode_tokens != 0:
|
||||||
|
req.decode_token_logprobs.extend(
|
||||||
|
list(
|
||||||
|
zip(
|
||||||
|
output.prefill_token_logprobs[
|
||||||
|
pt
|
||||||
|
+ req.extend_input_len
|
||||||
|
- req.last_update_decode_tokens : pt
|
||||||
|
+ req.extend_input_len
|
||||||
|
- 1
|
||||||
|
],
|
||||||
|
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
req.decode_token_logprobs.append(
|
||||||
|
(output.next_token_logprobs[i], next_token_ids[i])
|
||||||
|
)
|
||||||
|
|
||||||
|
if req.top_logprobs_num > 0:
|
||||||
|
if req.prefill_top_logprobs is None:
|
||||||
|
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
|
||||||
|
if req.logprob_start_len == 0:
|
||||||
|
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
||||||
|
|
||||||
|
if req.last_update_decode_tokens != 0:
|
||||||
|
req.decode_top_logprobs.extend(
|
||||||
|
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
||||||
|
)
|
||||||
|
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
||||||
|
|
||||||
def cache_filled_batch(self, batch: Batch):
|
def cache_filled_batch(self, batch: Batch):
|
||||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
@@ -540,7 +534,7 @@ class ModelTpServer:
|
|||||||
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
||||||
|
|
||||||
def forward_decode_batch(self, batch: Batch):
|
def forward_decode_batch(self, batch: Batch):
|
||||||
# check if decode out of memory
|
# Check if decode out of memory
|
||||||
if not batch.check_decode_mem():
|
if not batch.check_decode_mem():
|
||||||
old_ratio = self.new_token_ratio
|
old_ratio = self.new_token_ratio
|
||||||
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
||||||
@@ -559,9 +553,8 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not self.disable_regex_jump_forward:
|
if not self.disable_regex_jump_forward:
|
||||||
# check for jump-forward
|
# Check for jump-forward
|
||||||
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
||||||
|
|
||||||
self.forward_queue.extend(jump_forward_reqs)
|
self.forward_queue.extend(jump_forward_reqs)
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
return
|
return
|
||||||
@@ -570,23 +563,19 @@ class ModelTpServer:
|
|||||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
|
|
||||||
# Forward
|
# Forward and sample the next tokens
|
||||||
logits, (
|
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||||
_,
|
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||||
_,
|
|
||||||
_,
|
|
||||||
decode_top_logprobs,
|
|
||||||
last_logprobs,
|
|
||||||
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
|
||||||
next_token_ids, _ = batch.sample(logits)
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
|
||||||
|
|
||||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
# Move logprobs to cpu
|
||||||
if last_logprobs is not None:
|
if output.next_token_logprobs is not None:
|
||||||
new_token_logprobs = last_logprobs[
|
next_token_logprobs = output.next_token_logprobs[
|
||||||
torch.arange(len(batch.reqs)), next_token_ids
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||||
|
next_token_ids,
|
||||||
].tolist()
|
].tolist()
|
||||||
|
|
||||||
|
next_token_ids = next_token_ids.tolist()
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
@@ -594,10 +583,9 @@ class ModelTpServer:
|
|||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
|
req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id))
|
||||||
|
if req.top_logprobs_num > 0:
|
||||||
if req.top_logprobs_num > 0:
|
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
||||||
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
try:
|
try:
|
||||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||||
break
|
break
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
@@ -265,14 +265,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
"text": "The capital city of France is",
|
"text": "The capital city of France is",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 16,
|
"max_new_tokens": 8,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=600,
|
timeout=600,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
except Exception:
|
except Exception as e:
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
pipe_finish_writer.send(get_exception_traceback())
|
pipe_finish_writer.send(get_exception_traceback())
|
||||||
print(f"Initialization failed. warmup error: {e}")
|
print(f"Initialization failed. warmup error: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user