add LogitsMetadata (#604)

This commit is contained in:
Liangsheng Yin
2024-07-08 17:46:55 -07:00
committed by GitHub
parent f4e885b7c3
commit f25b76c02a
7 changed files with 66 additions and 40 deletions

View File

@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio):
) )
for i in redirect_indices: for i in redirect_indices:
target_idx = np.random.choice(min(i * 2 + 100, num_lines)) target_idx = np.random.choice(min(i * 2 + 100, num_lines))
lines[i] = ( lines[
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." i
) ] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
redirects[i] = target_idx redirects[i] = target_idx
# Build links and find sources # Build links and find sources

View File

@@ -1,7 +1,7 @@
"""Logits processing.""" """Logits processing."""
import dataclasses import dataclasses
from typing import List from typing import List, Union
import torch import torch
from torch import nn from torch import nn
@@ -31,6 +31,27 @@ class LogitProcessorOutput:
decode_top_logprobs: List decode_top_logprobs: List
@dataclasses.dataclass
class LogitsMetadata:
forward_mode: ForwardMode
extend_seq_lens: torch.Tensor
extend_start_loc: torch.Tensor
# For logprobs
return_logprob: bool
top_logprobs_nums: List[int]
@classmethod
def from_input_metadata(cls, input_metadata: InputMetadata):
return cls(
forward_mode=input_metadata.forward_mode,
extend_seq_lens=input_metadata.extend_seq_lens,
extend_start_loc=input_metadata.extend_start_loc,
return_logprob=input_metadata.return_logprob,
top_logprobs_nums=input_metadata.top_logprobs_nums,
)
class LogitsProcessor(nn.Module): class LogitsProcessor(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@@ -38,14 +59,14 @@ class LogitsProcessor(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
def _get_normalized_prompt_logprobs( def _get_normalized_prompt_logprobs(
self, prefill_token_logprobs, input_metadata: InputMetadata self, prefill_token_logprobs, logits_metadata: LogitsMetadata
): ):
logprobs_cumsum = torch.cumsum( logprobs_cumsum = torch.cumsum(
prefill_token_logprobs, dim=0, dtype=torch.float32 prefill_token_logprobs, dim=0, dtype=torch.float32
) )
start = input_metadata.extend_start_loc.clone() start = logits_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2 end = start + logits_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
sum_logp = ( sum_logp = (
@@ -54,17 +75,17 @@ class LogitsProcessor(nn.Module):
+ prefill_token_logprobs[start] + prefill_token_logprobs[start]
) )
normalized_prompt_logprobs = sum_logp / ( normalized_prompt_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1) (logits_metadata.extend_seq_lens - 1).clamp(min=1)
) )
return normalized_prompt_logprobs return normalized_prompt_logprobs
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata): def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below # TODO: vectorize the code below
if input_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = [] decode_top_logprobs = []
for i in range(all_logprobs.shape[0]): for i in range(all_logprobs.shape[0]):
k = input_metadata.top_logprobs_nums[i] k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k) t = all_logprobs[i].topk(k)
v_cpu = t.values.tolist() v_cpu = t.values.tolist()
p_cpu = t.indices.tolist() p_cpu = t.indices.tolist()
@@ -73,13 +94,13 @@ class LogitsProcessor(nn.Module):
else: else:
prefill_top_logprobs, decode_top_logprobs = [], [] prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0 pt = 0
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist() extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu): for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0: if extend_seq_len == 0:
prefill_top_logprobs.append([]) prefill_top_logprobs.append([])
decode_top_logprobs.append([]) decode_top_logprobs.append([])
continue continue
k = input_metadata.top_logprobs_nums[i] k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_len].topk(k) t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist() vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist() ps_cpu = t.indices.tolist()
@@ -91,14 +112,24 @@ class LogitsProcessor(nn.Module):
return prefill_top_logprobs, decode_top_logprobs return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata): def forward(
self,
input_ids,
hidden_states,
weight,
logits_metadata: Union[LogitsMetadata, InputMetadata],
):
if isinstance(logits_metadata, InputMetadata):
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
assert isinstance(logits_metadata, LogitsMetadata)
# Get the last hidden states and last logits for the next token prediction # Get the last hidden states and last logits for the next token prediction
if input_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
last_index = None last_index = None
last_hidden = hidden_states last_hidden = hidden_states
else: else:
last_index = ( last_index = (
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long) torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
- 1 - 1
) )
last_hidden = hidden_states[last_index] last_hidden = hidden_states[last_index]
@@ -114,7 +145,7 @@ class LogitsProcessor(nn.Module):
last_logits *= self.config.final_logit_softcapping last_logits *= self.config.final_logit_softcapping
# Return only last_logits if logprob is not requested # Return only last_logits if logprob is not requested
if not input_metadata.return_logprob: if not logits_metadata.return_logprob:
return LogitProcessorOutput( return LogitProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=None, next_token_logprobs=None,
@@ -125,7 +156,7 @@ class LogitsProcessor(nn.Module):
) )
else: else:
# When logprob is requested, compute the logits for all tokens. # When logprob is requested, compute the logits for all tokens.
if input_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
all_logits = last_logits all_logits = last_logits
else: else:
all_logits = torch.matmul(hidden_states, weight.T) all_logits = torch.matmul(hidden_states, weight.T)
@@ -138,15 +169,15 @@ class LogitsProcessor(nn.Module):
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
# Get the logprob of top-k tokens # Get the logprob of top-k tokens
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
if return_top_logprob: if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata all_logprobs, logits_metadata
) )
else: else:
prefill_top_logprobs = decode_top_logprobs = None prefill_top_logprobs = decode_top_logprobs = None
if input_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
return LogitProcessorOutput( return LogitProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=all_logprobs, next_token_logprobs=all_logprobs,
@@ -166,7 +197,7 @@ class LogitsProcessor(nn.Module):
] ]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
prefill_token_logprobs, input_metadata prefill_token_logprobs, logits_metadata
) )
return LogitProcessorOutput( return LogitProcessorOutput(

View File

@@ -2,9 +2,8 @@
import numpy as np import numpy as np
import torch import torch
from torch import nn
from flashinfer.cascade import merge_state from flashinfer.cascade import merge_state
from torch import nn
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd

View File

@@ -334,15 +334,15 @@ class TokenizerManager:
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
) )
if top_logprobs_num > 0: if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = ( ret["meta_info"][
self.detokenize_top_logprobs_tokens( "prefill_top_logprobs"
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs ] = self.detokenize_top_logprobs_tokens(
) ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
) )
ret["meta_info"]["decode_top_logprobs"] = ( ret["meta_info"][
self.detokenize_top_logprobs_tokens( "decode_top_logprobs"
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs ] = self.detokenize_top_logprobs_tokens(
) ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
) )
return ret return ret

View File

@@ -81,7 +81,6 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
class GemmaRotaryEmbedding(RotaryEmbedding): class GemmaRotaryEmbedding(RotaryEmbedding):
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq = 1.0 / ( inv_freq = 1.0 / (
@@ -95,7 +94,6 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
class Gemma2MLP(nn.Module): class Gemma2MLP(nn.Module):
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
@@ -127,7 +125,6 @@ class Gemma2MLP(nn.Module):
class Gemma2Attention(nn.Module): class Gemma2Attention(nn.Module):
def __init__( def __init__(
self, self,
layer_idx: int, layer_idx: int,
@@ -218,7 +215,6 @@ class Gemma2Attention(nn.Module):
class Gemma2DecoderLayer(nn.Module): class Gemma2DecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
layer_idx: int, layer_idx: int,
@@ -287,7 +283,6 @@ class Gemma2DecoderLayer(nn.Module):
class Gemma2Model(nn.Module): class Gemma2Model(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,

View File

@@ -163,9 +163,9 @@ class LlamaDecoderLayer(nn.Module):
if rope_scaling is not None and getattr( if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None config, "original_max_position_embeddings", None
): ):
rope_scaling["original_max_position_embeddings"] = ( rope_scaling[
config.original_max_position_embeddings "original_max_position_embeddings"
) ] = config.original_max_position_embeddings
rope_is_neox_style = getattr(config, "rope_is_neox_style", True) rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(

View File

@@ -459,6 +459,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
""" """
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)