Fix return_log_probs with cuda graph (#775)

This commit is contained in:
Lianmin Zheng
2024-07-27 19:15:09 -07:00
committed by GitHub
parent e4db4e5ba5
commit 0a409bd438
2 changed files with 62 additions and 40 deletions

View File

@@ -1,7 +1,7 @@
"""Logits processing.""" """Logits processing."""
import dataclasses import dataclasses
from typing import List, Union from typing import List, Optional, Union
import torch import torch
from torch import nn from torch import nn
@@ -34,11 +34,11 @@ class LogitProcessorOutput:
@dataclasses.dataclass @dataclasses.dataclass
class LogitsMetadata: class LogitsMetadata:
forward_mode: ForwardMode forward_mode: ForwardMode
return_logprob: bool return_logprob: bool = False
extend_seq_lens: torch.Tensor = None extend_seq_lens: Optional[torch.Tensor] = None
extend_start_loc: torch.Tensor = None extend_start_loc: Optional[torch.Tensor] = None
top_logprobs_nums: List[int] = None top_logprobs_nums: Optional[List[int]] = None
@classmethod @classmethod
def from_input_metadata(cls, input_metadata: InputMetadata): def from_input_metadata(cls, input_metadata: InputMetadata):
@@ -79,7 +79,8 @@ class LogitsProcessor(nn.Module):
return normalized_prompt_logprobs return normalized_prompt_logprobs
def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata): @staticmethod
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below # TODO: vectorize the code below
if logits_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = [] decode_top_logprobs = []
@@ -156,7 +157,27 @@ class LogitsProcessor(nn.Module):
else: else:
# When logprob is requested, compute the logits for all tokens. # When logprob is requested, compute the logits for all tokens.
if logits_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
all_logits = last_logits last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
# Get the logprob of top-k tokens
return_top_logprob = any(
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
decode_top_logprobs = self.get_top_logprobs(
last_logprobs, logits_metadata
)[1]
else:
decode_top_logprobs = None
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=decode_top_logprobs,
)
else: else:
all_logits = torch.matmul(hidden_states, weight.T) all_logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1: if self.tp_size > 1:
@@ -168,24 +189,16 @@ class LogitsProcessor(nn.Module):
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
# Get the logprob of top-k tokens # Get the logprob of top-k tokens
return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums) return_top_logprob = any(
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob: if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs(
all_logprobs, logits_metadata all_logprobs, logits_metadata
) )
else: else:
prefill_top_logprobs = decode_top_logprobs = None prefill_top_logprobs = decode_top_logprobs = None
if logits_metadata.forward_mode == ForwardMode.DECODE:
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=all_logprobs,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=decode_top_logprobs,
)
else:
last_logprobs = all_logprobs[last_index] last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens. # Compute the logprobs and normalized logprobs for the prefill tokens.

View File

@@ -9,7 +9,11 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.layers.logits_processor import (
LogitProcessorOutput,
LogitsMetadata,
LogitsProcessor,
)
from sglang.srt.managers.controller.infer_batch import ( from sglang.srt.managers.controller.infer_batch import (
Batch, Batch,
ForwardMode, ForwardMode,
@@ -185,7 +189,6 @@ class CudaGraphRunner:
def replay(self, batch: Batch): def replay(self, batch: Batch):
assert batch.out_cache_loc is not None assert batch.out_cache_loc is not None
assert not batch.return_logprob
raw_bs = len(batch.reqs) raw_bs = len(batch.reqs)
# Pad # Pad
@@ -218,23 +221,29 @@ class CudaGraphRunner:
output = self.output_buffers[bs] output = self.output_buffers[bs]
# Unpad # Unpad
if bs == raw_bs: if bs != raw_bs:
return output
else:
output = LogitProcessorOutput( output = LogitProcessorOutput(
next_token_logits=output.next_token_logits[:raw_bs], next_token_logits=output.next_token_logits[:raw_bs],
next_token_logprobs=( next_token_logprobs=None,
output.next_token_logprobs[:raw_bs]
if output.next_token_logprobs is not None
else None
),
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
prefill_token_logprobs=None, prefill_token_logprobs=None,
prefill_top_logprobs=None, prefill_top_logprobs=None,
decode_top_logprobs=( decode_top_logprobs=None,
output.decode_top_logprobs[:raw_bs]
if output.decode_top_logprobs is not None
else None
),
) )
# Extract logprobs
if batch.return_logprob:
output.next_token_logprobs = torch.nn.functional.log_softmax(
output.next_token_logits, dim=-1
)
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
if return_top_logprob:
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=batch.top_logprobs_nums,
)
output.decode_top_logprobs = LogitsProcessor.get_top_logprobs(
output.next_token_logprobs, logits_metadata
)[1]
return output return output