add LogitsMetadata (#604)
This commit is contained in:
@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio):
|
|||||||
)
|
)
|
||||||
for i in redirect_indices:
|
for i in redirect_indices:
|
||||||
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
|
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
|
||||||
lines[i] = (
|
lines[
|
||||||
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
|
i
|
||||||
)
|
] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
|
||||||
redirects[i] = target_idx
|
redirects[i] = target_idx
|
||||||
|
|
||||||
# Build links and find sources
|
# Build links and find sources
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Logits processing."""
|
"""Logits processing."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -31,6 +31,27 @@ class LogitProcessorOutput:
|
|||||||
decode_top_logprobs: List
|
decode_top_logprobs: List
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class LogitsMetadata:
|
||||||
|
forward_mode: ForwardMode
|
||||||
|
extend_seq_lens: torch.Tensor
|
||||||
|
extend_start_loc: torch.Tensor
|
||||||
|
|
||||||
|
# For logprobs
|
||||||
|
return_logprob: bool
|
||||||
|
top_logprobs_nums: List[int]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_input_metadata(cls, input_metadata: InputMetadata):
|
||||||
|
return cls(
|
||||||
|
forward_mode=input_metadata.forward_mode,
|
||||||
|
extend_seq_lens=input_metadata.extend_seq_lens,
|
||||||
|
extend_start_loc=input_metadata.extend_start_loc,
|
||||||
|
return_logprob=input_metadata.return_logprob,
|
||||||
|
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessor(nn.Module):
|
class LogitsProcessor(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -38,14 +59,14 @@ class LogitsProcessor(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
def _get_normalized_prompt_logprobs(
|
def _get_normalized_prompt_logprobs(
|
||||||
self, prefill_token_logprobs, input_metadata: InputMetadata
|
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
|
||||||
):
|
):
|
||||||
logprobs_cumsum = torch.cumsum(
|
logprobs_cumsum = torch.cumsum(
|
||||||
prefill_token_logprobs, dim=0, dtype=torch.float32
|
prefill_token_logprobs, dim=0, dtype=torch.float32
|
||||||
)
|
)
|
||||||
|
|
||||||
start = input_metadata.extend_start_loc.clone()
|
start = logits_metadata.extend_start_loc.clone()
|
||||||
end = start + input_metadata.extend_seq_lens - 2
|
end = start + logits_metadata.extend_seq_lens - 2
|
||||||
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
||||||
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
||||||
sum_logp = (
|
sum_logp = (
|
||||||
@@ -54,17 +75,17 @@ class LogitsProcessor(nn.Module):
|
|||||||
+ prefill_token_logprobs[start]
|
+ prefill_token_logprobs[start]
|
||||||
)
|
)
|
||||||
normalized_prompt_logprobs = sum_logp / (
|
normalized_prompt_logprobs = sum_logp / (
|
||||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||||
)
|
)
|
||||||
|
|
||||||
return normalized_prompt_logprobs
|
return normalized_prompt_logprobs
|
||||||
|
|
||||||
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
|
def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
|
||||||
# TODO: vectorize the code below
|
# TODO: vectorize the code below
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
decode_top_logprobs = []
|
decode_top_logprobs = []
|
||||||
for i in range(all_logprobs.shape[0]):
|
for i in range(all_logprobs.shape[0]):
|
||||||
k = input_metadata.top_logprobs_nums[i]
|
k = logits_metadata.top_logprobs_nums[i]
|
||||||
t = all_logprobs[i].topk(k)
|
t = all_logprobs[i].topk(k)
|
||||||
v_cpu = t.values.tolist()
|
v_cpu = t.values.tolist()
|
||||||
p_cpu = t.indices.tolist()
|
p_cpu = t.indices.tolist()
|
||||||
@@ -73,13 +94,13 @@ class LogitsProcessor(nn.Module):
|
|||||||
else:
|
else:
|
||||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||||
pt = 0
|
pt = 0
|
||||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
|
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
||||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||||
if extend_seq_len == 0:
|
if extend_seq_len == 0:
|
||||||
prefill_top_logprobs.append([])
|
prefill_top_logprobs.append([])
|
||||||
decode_top_logprobs.append([])
|
decode_top_logprobs.append([])
|
||||||
continue
|
continue
|
||||||
k = input_metadata.top_logprobs_nums[i]
|
k = logits_metadata.top_logprobs_nums[i]
|
||||||
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
||||||
vs_cpu = t.values.tolist()
|
vs_cpu = t.values.tolist()
|
||||||
ps_cpu = t.indices.tolist()
|
ps_cpu = t.indices.tolist()
|
||||||
@@ -91,14 +112,24 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
return prefill_top_logprobs, decode_top_logprobs
|
return prefill_top_logprobs, decode_top_logprobs
|
||||||
|
|
||||||
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
hidden_states,
|
||||||
|
weight,
|
||||||
|
logits_metadata: Union[LogitsMetadata, InputMetadata],
|
||||||
|
):
|
||||||
|
if isinstance(logits_metadata, InputMetadata):
|
||||||
|
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
|
||||||
|
assert isinstance(logits_metadata, LogitsMetadata)
|
||||||
|
|
||||||
# Get the last hidden states and last logits for the next token prediction
|
# Get the last hidden states and last logits for the next token prediction
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
last_index = None
|
last_index = None
|
||||||
last_hidden = hidden_states
|
last_hidden = hidden_states
|
||||||
else:
|
else:
|
||||||
last_index = (
|
last_index = (
|
||||||
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
last_hidden = hidden_states[last_index]
|
last_hidden = hidden_states[last_index]
|
||||||
@@ -114,7 +145,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
last_logits *= self.config.final_logit_softcapping
|
last_logits *= self.config.final_logit_softcapping
|
||||||
|
|
||||||
# Return only last_logits if logprob is not requested
|
# Return only last_logits if logprob is not requested
|
||||||
if not input_metadata.return_logprob:
|
if not logits_metadata.return_logprob:
|
||||||
return LogitProcessorOutput(
|
return LogitProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
next_token_logprobs=None,
|
next_token_logprobs=None,
|
||||||
@@ -125,7 +156,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# When logprob is requested, compute the logits for all tokens.
|
# When logprob is requested, compute the logits for all tokens.
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
all_logits = last_logits
|
all_logits = last_logits
|
||||||
else:
|
else:
|
||||||
all_logits = torch.matmul(hidden_states, weight.T)
|
all_logits = torch.matmul(hidden_states, weight.T)
|
||||||
@@ -138,15 +169,15 @@ class LogitsProcessor(nn.Module):
|
|||||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||||
|
|
||||||
# Get the logprob of top-k tokens
|
# Get the logprob of top-k tokens
|
||||||
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
|
||||||
if return_top_logprob:
|
if return_top_logprob:
|
||||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||||
all_logprobs, input_metadata
|
all_logprobs, logits_metadata
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_top_logprobs = decode_top_logprobs = None
|
prefill_top_logprobs = decode_top_logprobs = None
|
||||||
|
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
return LogitProcessorOutput(
|
return LogitProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
next_token_logprobs=all_logprobs,
|
next_token_logprobs=all_logprobs,
|
||||||
@@ -166,7 +197,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||||
prefill_token_logprobs, input_metadata
|
prefill_token_logprobs, logits_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
return LogitProcessorOutput(
|
return LogitProcessorOutput(
|
||||||
|
|||||||
@@ -2,9 +2,8 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from flashinfer.cascade import merge_state
|
from flashinfer.cascade import merge_state
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||||
|
|||||||
@@ -334,15 +334,15 @@ class TokenizerManager:
|
|||||||
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
||||||
)
|
)
|
||||||
if top_logprobs_num > 0:
|
if top_logprobs_num > 0:
|
||||||
ret["meta_info"]["prefill_top_logprobs"] = (
|
ret["meta_info"][
|
||||||
self.detokenize_top_logprobs_tokens(
|
"prefill_top_logprobs"
|
||||||
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
] = self.detokenize_top_logprobs_tokens(
|
||||||
)
|
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
||||||
)
|
)
|
||||||
ret["meta_info"]["decode_top_logprobs"] = (
|
ret["meta_info"][
|
||||||
self.detokenize_top_logprobs_tokens(
|
"decode_top_logprobs"
|
||||||
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
] = self.detokenize_top_logprobs_tokens(
|
||||||
)
|
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|||||||
@@ -81,7 +81,6 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
|||||||
|
|
||||||
|
|
||||||
class GemmaRotaryEmbedding(RotaryEmbedding):
|
class GemmaRotaryEmbedding(RotaryEmbedding):
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
|
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
|
||||||
inv_freq = 1.0 / (
|
inv_freq = 1.0 / (
|
||||||
@@ -95,7 +94,6 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
|
|||||||
|
|
||||||
|
|
||||||
class Gemma2MLP(nn.Module):
|
class Gemma2MLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
@@ -127,7 +125,6 @@ class Gemma2MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Gemma2Attention(nn.Module):
|
class Gemma2Attention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
@@ -218,7 +215,6 @@ class Gemma2Attention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Gemma2DecoderLayer(nn.Module):
|
class Gemma2DecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
@@ -287,7 +283,6 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Gemma2Model(nn.Module):
|
class Gemma2Model(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
|||||||
@@ -163,9 +163,9 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
if rope_scaling is not None and getattr(
|
if rope_scaling is not None and getattr(
|
||||||
config, "original_max_position_embeddings", None
|
config, "original_max_position_embeddings", None
|
||||||
):
|
):
|
||||||
rope_scaling["original_max_position_embeddings"] = (
|
rope_scaling[
|
||||||
config.original_max_position_embeddings
|
"original_max_position_embeddings"
|
||||||
)
|
] = config.original_max_position_embeddings
|
||||||
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||||
self.self_attn = LlamaAttention(
|
self.self_attn = LlamaAttention(
|
||||||
|
|||||||
@@ -459,6 +459,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
||||||
|
|
||||||
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user