Fix return_log_probs with cuda graph (#775)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Logits processing."""
|
||||
|
||||
import dataclasses
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -34,11 +34,11 @@ class LogitProcessorOutput:
|
||||
@dataclasses.dataclass
|
||||
class LogitsMetadata:
|
||||
forward_mode: ForwardMode
|
||||
return_logprob: bool
|
||||
return_logprob: bool = False
|
||||
|
||||
extend_seq_lens: torch.Tensor = None
|
||||
extend_start_loc: torch.Tensor = None
|
||||
top_logprobs_nums: List[int] = None
|
||||
extend_seq_lens: Optional[torch.Tensor] = None
|
||||
extend_start_loc: Optional[torch.Tensor] = None
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
|
||||
@classmethod
|
||||
def from_input_metadata(cls, input_metadata: InputMetadata):
|
||||
@@ -79,7 +79,8 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
return normalized_prompt_logprobs
|
||||
|
||||
def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
|
||||
@staticmethod
|
||||
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
||||
# TODO: vectorize the code below
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
decode_top_logprobs = []
|
||||
@@ -156,36 +157,48 @@ class LogitsProcessor(nn.Module):
|
||||
else:
|
||||
# When logprob is requested, compute the logits for all tokens.
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
all_logits = last_logits
|
||||
else:
|
||||
all_logits = torch.matmul(hidden_states, weight.T)
|
||||
if self.tp_size > 1:
|
||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||
|
||||
all_logprobs = all_logits
|
||||
del all_logits
|
||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||
|
||||
# Get the logprob of top-k tokens
|
||||
return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
|
||||
if return_top_logprob:
|
||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||
all_logprobs, logits_metadata
|
||||
# Get the logprob of top-k tokens
|
||||
return_top_logprob = any(
|
||||
x > 0 for x in logits_metadata.top_logprobs_nums
|
||||
)
|
||||
else:
|
||||
prefill_top_logprobs = decode_top_logprobs = None
|
||||
if return_top_logprob:
|
||||
decode_top_logprobs = self.get_top_logprobs(
|
||||
last_logprobs, logits_metadata
|
||||
)[1]
|
||||
else:
|
||||
decode_top_logprobs = None
|
||||
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=all_logprobs,
|
||||
next_token_logprobs=last_logprobs,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=decode_top_logprobs,
|
||||
)
|
||||
else:
|
||||
all_logits = torch.matmul(hidden_states, weight.T)
|
||||
if self.tp_size > 1:
|
||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||
|
||||
all_logprobs = all_logits
|
||||
del all_logits
|
||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||
|
||||
# Get the logprob of top-k tokens
|
||||
return_top_logprob = any(
|
||||
x > 0 for x in logits_metadata.top_logprobs_nums
|
||||
)
|
||||
if return_top_logprob:
|
||||
prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs(
|
||||
all_logprobs, logits_metadata
|
||||
)
|
||||
else:
|
||||
prefill_top_logprobs = decode_top_logprobs = None
|
||||
|
||||
last_logprobs = all_logprobs[last_index]
|
||||
|
||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||
|
||||
@@ -9,7 +9,11 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
||||
from sglang.srt.layers.logits_processor import (
|
||||
LogitProcessorOutput,
|
||||
LogitsMetadata,
|
||||
LogitsProcessor,
|
||||
)
|
||||
from sglang.srt.managers.controller.infer_batch import (
|
||||
Batch,
|
||||
ForwardMode,
|
||||
@@ -185,7 +189,6 @@ class CudaGraphRunner:
|
||||
|
||||
def replay(self, batch: Batch):
|
||||
assert batch.out_cache_loc is not None
|
||||
assert not batch.return_logprob
|
||||
raw_bs = len(batch.reqs)
|
||||
|
||||
# Pad
|
||||
@@ -218,23 +221,29 @@ class CudaGraphRunner:
|
||||
output = self.output_buffers[bs]
|
||||
|
||||
# Unpad
|
||||
if bs == raw_bs:
|
||||
return output
|
||||
else:
|
||||
if bs != raw_bs:
|
||||
output = LogitProcessorOutput(
|
||||
next_token_logits=output.next_token_logits[:raw_bs],
|
||||
next_token_logprobs=(
|
||||
output.next_token_logprobs[:raw_bs]
|
||||
if output.next_token_logprobs is not None
|
||||
else None
|
||||
),
|
||||
next_token_logprobs=None,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=(
|
||||
output.decode_top_logprobs[:raw_bs]
|
||||
if output.decode_top_logprobs is not None
|
||||
else None
|
||||
),
|
||||
decode_top_logprobs=None,
|
||||
)
|
||||
|
||||
# Extract logprobs
|
||||
if batch.return_logprob:
|
||||
output.next_token_logprobs = torch.nn.functional.log_softmax(
|
||||
output.next_token_logits, dim=-1
|
||||
)
|
||||
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
||||
if return_top_logprob:
|
||||
logits_metadata = LogitsMetadata(
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
)
|
||||
output.decode_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||
output.next_token_logprobs, logits_metadata
|
||||
)[1]
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user