From f25b76c02abbc2971b5e5532c0c49e960e662e23 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 8 Jul 2024 17:46:55 -0700 Subject: [PATCH] add `LogitsMetadata` (#604) --- benchmark/line_retrieval/gen_data.py | 6 +- python/sglang/srt/layers/logits_processor.py | 69 ++++++++++++++----- python/sglang/srt/layers/radix_attention.py | 3 +- .../sglang/srt/managers/tokenizer_manager.py | 16 ++--- python/sglang/srt/models/gemma2.py | 5 -- python/sglang/srt/models/llama2.py | 6 +- python/sglang/srt/utils.py | 1 + 7 files changed, 66 insertions(+), 40 deletions(-) diff --git a/benchmark/line_retrieval/gen_data.py b/benchmark/line_retrieval/gen_data.py index c88ecba49..5763e6615 100644 --- a/benchmark/line_retrieval/gen_data.py +++ b/benchmark/line_retrieval/gen_data.py @@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio): ) for i in redirect_indices: target_idx = np.random.choice(min(i * 2 + 100, num_lines)) - lines[i] = ( - f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." - ) + lines[ + i + ] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." redirects[i] = target_idx # Build links and find sources diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index ce9415337..1ed7b8f7d 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -1,7 +1,7 @@ """Logits processing.""" import dataclasses -from typing import List +from typing import List, Union import torch from torch import nn @@ -31,6 +31,27 @@ class LogitProcessorOutput: decode_top_logprobs: List +@dataclasses.dataclass +class LogitsMetadata: + forward_mode: ForwardMode + extend_seq_lens: torch.Tensor + extend_start_loc: torch.Tensor + + # For logprobs + return_logprob: bool + top_logprobs_nums: List[int] + + @classmethod + def from_input_metadata(cls, input_metadata: InputMetadata): + return cls( + forward_mode=input_metadata.forward_mode, + extend_seq_lens=input_metadata.extend_seq_lens, + extend_start_loc=input_metadata.extend_start_loc, + return_logprob=input_metadata.return_logprob, + top_logprobs_nums=input_metadata.top_logprobs_nums, + ) + + class LogitsProcessor(nn.Module): def __init__(self, config): super().__init__() @@ -38,14 +59,14 @@ class LogitsProcessor(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() def _get_normalized_prompt_logprobs( - self, prefill_token_logprobs, input_metadata: InputMetadata + self, prefill_token_logprobs, logits_metadata: LogitsMetadata ): logprobs_cumsum = torch.cumsum( prefill_token_logprobs, dim=0, dtype=torch.float32 ) - start = input_metadata.extend_start_loc.clone() - end = start + input_metadata.extend_seq_lens - 2 + start = logits_metadata.extend_start_loc.clone() + end = start + logits_metadata.extend_seq_lens - 2 start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) sum_logp = ( @@ -54,17 +75,17 @@ class LogitsProcessor(nn.Module): + prefill_token_logprobs[start] ) normalized_prompt_logprobs = sum_logp / ( - (input_metadata.extend_seq_lens - 1).clamp(min=1) + (logits_metadata.extend_seq_lens - 1).clamp(min=1) ) return normalized_prompt_logprobs - def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata): + def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata): # TODO: vectorize the code below - if input_metadata.forward_mode == ForwardMode.DECODE: + if logits_metadata.forward_mode == ForwardMode.DECODE: decode_top_logprobs = [] for i in range(all_logprobs.shape[0]): - k = input_metadata.top_logprobs_nums[i] + k = logits_metadata.top_logprobs_nums[i] t = all_logprobs[i].topk(k) v_cpu = t.values.tolist() p_cpu = t.indices.tolist() @@ -73,13 +94,13 @@ class LogitsProcessor(nn.Module): else: prefill_top_logprobs, decode_top_logprobs = [], [] pt = 0 - extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist() + extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist() for i, extend_seq_len in enumerate(extend_seq_lens_cpu): if extend_seq_len == 0: prefill_top_logprobs.append([]) decode_top_logprobs.append([]) continue - k = input_metadata.top_logprobs_nums[i] + k = logits_metadata.top_logprobs_nums[i] t = all_logprobs[pt : pt + extend_seq_len].topk(k) vs_cpu = t.values.tolist() ps_cpu = t.indices.tolist() @@ -91,14 +112,24 @@ class LogitsProcessor(nn.Module): return prefill_top_logprobs, decode_top_logprobs - def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata): + def forward( + self, + input_ids, + hidden_states, + weight, + logits_metadata: Union[LogitsMetadata, InputMetadata], + ): + if isinstance(logits_metadata, InputMetadata): + logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata) + assert isinstance(logits_metadata, LogitsMetadata) + # Get the last hidden states and last logits for the next token prediction - if input_metadata.forward_mode == ForwardMode.DECODE: + if logits_metadata.forward_mode == ForwardMode.DECODE: last_index = None last_hidden = hidden_states else: last_index = ( - torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long) + torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long) - 1 ) last_hidden = hidden_states[last_index] @@ -114,7 +145,7 @@ class LogitsProcessor(nn.Module): last_logits *= self.config.final_logit_softcapping # Return only last_logits if logprob is not requested - if not input_metadata.return_logprob: + if not logits_metadata.return_logprob: return LogitProcessorOutput( next_token_logits=last_logits, next_token_logprobs=None, @@ -125,7 +156,7 @@ class LogitsProcessor(nn.Module): ) else: # When logprob is requested, compute the logits for all tokens. - if input_metadata.forward_mode == ForwardMode.DECODE: + if logits_metadata.forward_mode == ForwardMode.DECODE: all_logits = last_logits else: all_logits = torch.matmul(hidden_states, weight.T) @@ -138,15 +169,15 @@ class LogitsProcessor(nn.Module): all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) # Get the logprob of top-k tokens - return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) + return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums) if return_top_logprob: prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( - all_logprobs, input_metadata + all_logprobs, logits_metadata ) else: prefill_top_logprobs = decode_top_logprobs = None - if input_metadata.forward_mode == ForwardMode.DECODE: + if logits_metadata.forward_mode == ForwardMode.DECODE: return LogitProcessorOutput( next_token_logits=last_logits, next_token_logprobs=all_logprobs, @@ -166,7 +197,7 @@ class LogitsProcessor(nn.Module): ] normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( - prefill_token_logprobs, input_metadata + prefill_token_logprobs, logits_metadata ) return LogitProcessorOutput( diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 6ee4a31a1..a2d96e9d2 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -2,9 +2,8 @@ import numpy as np import torch -from torch import nn - from flashinfer.cascade import merge_state +from torch import nn from sglang.global_config import global_config from sglang.srt.layers.extend_attention import extend_attention_fwd diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3f3c848e0..63cecdca3 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -334,15 +334,15 @@ class TokenizerManager: ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs ) if top_logprobs_num > 0: - ret["meta_info"]["prefill_top_logprobs"] = ( - self.detokenize_top_logprobs_tokens( - ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs - ) + ret["meta_info"][ + "prefill_top_logprobs" + ] = self.detokenize_top_logprobs_tokens( + ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs ) - ret["meta_info"]["decode_top_logprobs"] = ( - self.detokenize_top_logprobs_tokens( - ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs - ) + ret["meta_info"][ + "decode_top_logprobs" + ] = self.detokenize_top_logprobs_tokens( + ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs ) return ret diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 4593a5731..c6c409dee 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -81,7 +81,6 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding class GemmaRotaryEmbedding(RotaryEmbedding): - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 inv_freq = 1.0 / ( @@ -95,7 +94,6 @@ class GemmaRotaryEmbedding(RotaryEmbedding): class Gemma2MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -127,7 +125,6 @@ class Gemma2MLP(nn.Module): class Gemma2Attention(nn.Module): - def __init__( self, layer_idx: int, @@ -218,7 +215,6 @@ class Gemma2Attention(nn.Module): class Gemma2DecoderLayer(nn.Module): - def __init__( self, layer_idx: int, @@ -287,7 +283,6 @@ class Gemma2DecoderLayer(nn.Module): class Gemma2Model(nn.Module): - def __init__( self, config: PretrainedConfig, diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 7bdab0f5d..95ba71ee9 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -163,9 +163,9 @@ class LlamaDecoderLayer(nn.Module): if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): - rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings - ) + rope_scaling[ + "original_max_position_embeddings" + ] = config.original_max_position_embeddings rope_is_neox_style = getattr(config, "rope_is_neox_style", True) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = LlamaAttention( diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e6b7b7663..732a2c71e 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -459,6 +459,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int): """ import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt + setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)