Fix return_log_probs with cuda graph (#775)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
"""Logits processing."""
|
"""Logits processing."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -34,11 +34,11 @@ class LogitProcessorOutput:
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class LogitsMetadata:
|
class LogitsMetadata:
|
||||||
forward_mode: ForwardMode
|
forward_mode: ForwardMode
|
||||||
return_logprob: bool
|
return_logprob: bool = False
|
||||||
|
|
||||||
extend_seq_lens: torch.Tensor = None
|
extend_seq_lens: Optional[torch.Tensor] = None
|
||||||
extend_start_loc: torch.Tensor = None
|
extend_start_loc: Optional[torch.Tensor] = None
|
||||||
top_logprobs_nums: List[int] = None
|
top_logprobs_nums: Optional[List[int]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_input_metadata(cls, input_metadata: InputMetadata):
|
def from_input_metadata(cls, input_metadata: InputMetadata):
|
||||||
@@ -79,7 +79,8 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
return normalized_prompt_logprobs
|
return normalized_prompt_logprobs
|
||||||
|
|
||||||
def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
|
@staticmethod
|
||||||
|
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
||||||
# TODO: vectorize the code below
|
# TODO: vectorize the code below
|
||||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
decode_top_logprobs = []
|
decode_top_logprobs = []
|
||||||
@@ -156,7 +157,27 @@ class LogitsProcessor(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# When logprob is requested, compute the logits for all tokens.
|
# When logprob is requested, compute the logits for all tokens.
|
||||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
all_logits = last_logits
|
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||||
|
|
||||||
|
# Get the logprob of top-k tokens
|
||||||
|
return_top_logprob = any(
|
||||||
|
x > 0 for x in logits_metadata.top_logprobs_nums
|
||||||
|
)
|
||||||
|
if return_top_logprob:
|
||||||
|
decode_top_logprobs = self.get_top_logprobs(
|
||||||
|
last_logprobs, logits_metadata
|
||||||
|
)[1]
|
||||||
|
else:
|
||||||
|
decode_top_logprobs = None
|
||||||
|
|
||||||
|
return LogitProcessorOutput(
|
||||||
|
next_token_logits=last_logits,
|
||||||
|
next_token_logprobs=last_logprobs,
|
||||||
|
normalized_prompt_logprobs=None,
|
||||||
|
prefill_token_logprobs=None,
|
||||||
|
prefill_top_logprobs=None,
|
||||||
|
decode_top_logprobs=decode_top_logprobs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
all_logits = torch.matmul(hidden_states, weight.T)
|
all_logits = torch.matmul(hidden_states, weight.T)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
@@ -168,24 +189,16 @@ class LogitsProcessor(nn.Module):
|
|||||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||||
|
|
||||||
# Get the logprob of top-k tokens
|
# Get the logprob of top-k tokens
|
||||||
return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
|
return_top_logprob = any(
|
||||||
|
x > 0 for x in logits_metadata.top_logprobs_nums
|
||||||
|
)
|
||||||
if return_top_logprob:
|
if return_top_logprob:
|
||||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs(
|
||||||
all_logprobs, logits_metadata
|
all_logprobs, logits_metadata
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_top_logprobs = decode_top_logprobs = None
|
prefill_top_logprobs = decode_top_logprobs = None
|
||||||
|
|
||||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
|
||||||
return LogitProcessorOutput(
|
|
||||||
next_token_logits=last_logits,
|
|
||||||
next_token_logprobs=all_logprobs,
|
|
||||||
normalized_prompt_logprobs=None,
|
|
||||||
prefill_token_logprobs=None,
|
|
||||||
prefill_top_logprobs=None,
|
|
||||||
decode_top_logprobs=decode_top_logprobs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
last_logprobs = all_logprobs[last_index]
|
last_logprobs = all_logprobs[last_index]
|
||||||
|
|
||||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||||
|
|||||||
@@ -9,7 +9,11 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
|||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
from sglang.srt.layers.logits_processor import (
|
||||||
|
LogitProcessorOutput,
|
||||||
|
LogitsMetadata,
|
||||||
|
LogitsProcessor,
|
||||||
|
)
|
||||||
from sglang.srt.managers.controller.infer_batch import (
|
from sglang.srt.managers.controller.infer_batch import (
|
||||||
Batch,
|
Batch,
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
@@ -185,7 +189,6 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
def replay(self, batch: Batch):
|
def replay(self, batch: Batch):
|
||||||
assert batch.out_cache_loc is not None
|
assert batch.out_cache_loc is not None
|
||||||
assert not batch.return_logprob
|
|
||||||
raw_bs = len(batch.reqs)
|
raw_bs = len(batch.reqs)
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
@@ -218,23 +221,29 @@ class CudaGraphRunner:
|
|||||||
output = self.output_buffers[bs]
|
output = self.output_buffers[bs]
|
||||||
|
|
||||||
# Unpad
|
# Unpad
|
||||||
if bs == raw_bs:
|
if bs != raw_bs:
|
||||||
return output
|
|
||||||
else:
|
|
||||||
output = LogitProcessorOutput(
|
output = LogitProcessorOutput(
|
||||||
next_token_logits=output.next_token_logits[:raw_bs],
|
next_token_logits=output.next_token_logits[:raw_bs],
|
||||||
next_token_logprobs=(
|
next_token_logprobs=None,
|
||||||
output.next_token_logprobs[:raw_bs]
|
|
||||||
if output.next_token_logprobs is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
normalized_prompt_logprobs=None,
|
normalized_prompt_logprobs=None,
|
||||||
prefill_token_logprobs=None,
|
prefill_token_logprobs=None,
|
||||||
prefill_top_logprobs=None,
|
prefill_top_logprobs=None,
|
||||||
decode_top_logprobs=(
|
decode_top_logprobs=None,
|
||||||
output.decode_top_logprobs[:raw_bs]
|
|
||||||
if output.decode_top_logprobs is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract logprobs
|
||||||
|
if batch.return_logprob:
|
||||||
|
output.next_token_logprobs = torch.nn.functional.log_softmax(
|
||||||
|
output.next_token_logits, dim=-1
|
||||||
|
)
|
||||||
|
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
||||||
|
if return_top_logprob:
|
||||||
|
logits_metadata = LogitsMetadata(
|
||||||
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
|
)
|
||||||
|
output.decode_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||||
|
output.next_token_logprobs, logits_metadata
|
||||||
|
)[1]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
Reference in New Issue
Block a user