From 3815b23ccb3d3a54cad705123da2f89aafdde0d2 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Dec 2024 00:45:57 -0800 Subject: [PATCH] Clean up wrapper in flashinfer backend (#2638) --- python/sglang/bench_offline_throughput.py | 1 + python/sglang/srt/configs/model_config.py | 4 +- .../sglang/srt/layers/attention/__init__.py | 1 - .../layers/attention/flashinfer_backend.py | 95 +++++++++++-------- python/sglang/srt/layers/logits_processor.py | 32 ++++++- python/sglang/srt/managers/schedule_batch.py | 4 +- python/sglang/srt/managers/scheduler.py | 28 ++++-- .../srt/model_executor/forward_batch_info.py | 45 ++++++++- python/sglang/srt/models/llama.py | 11 +++ python/sglang/srt/server.py | 4 +- python/sglang/srt/server_args.py | 47 ++++----- python/sglang/srt/speculative/spec_info.py | 19 ++++ 12 files changed, 197 insertions(+), 94 deletions(-) create mode 100644 python/sglang/srt/speculative/spec_info.py diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index f840ee878..f32063b41 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -331,6 +331,7 @@ def throughput_test( extra_request_body=extra_request_body, profile=bench_args.profile, ) + backend.shutdown() if bench_args.result_filename: with open(bench_args.result_filename, "a") as fout: diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index c3f6ba993..a2f9b8284 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -131,10 +131,8 @@ class ModelConfig: # Veirfy quantization self._verify_quantization() - # Text attrs + # Cache attributes self.hf_eos_token_id = self.get_hf_eos_token_id() - - # Multimodel attrs self.image_token_id = getattr(self.hf_config, "image_token_id", None) # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index a70e9537b..1486987dc 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from typing import Optional import torch -from torch import nn from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import ForwardBatch diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 926874027..db1e17b5e 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -8,8 +8,9 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an """ import os +from dataclasses import dataclass from enum import Enum, auto -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Union import torch import triton @@ -38,12 +39,25 @@ class WrapperDispatch(Enum): CROSS_ATTENTION = auto() +@dataclass +class DecodeMetadata: + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper] + + +@dataclass +class PrefillMetadata: + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper] + use_ragged: bool + extend_no_prefix: bool + + class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" def __init__(self, model_runner: ModelRunner): super().__init__() + # Parse constants self.decode_use_tensor_cores = should_use_tensor_core( kv_cache_dtype=model_runner.kv_cache_dtype, num_attention_heads=model_runner.model_config.num_attention_heads @@ -52,7 +66,6 @@ class FlashInferAttnBackend(AttentionBackend): model_runner.tp_size ), ) - self.max_context_len = model_runner.model_config.context_len assert not ( @@ -120,8 +133,8 @@ class FlashInferAttnBackend(AttentionBackend): ) # Other metadata - self.forward_metadata = None - self.cuda_graph_metadata = {} + self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None + self.decode_cuda_graph_metadata = {} def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode(): @@ -129,10 +142,10 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch.req_pool_indices, forward_batch.seq_lens, forward_batch.seq_lens_sum, - decode_wrappers=None, + decode_wrappers=self.decode_wrappers, encoder_lens=forward_batch.encoder_lens, ) - self.forward_metadata = (self.decode_wrappers,) + self.forward_metadata = DecodeMetadata(self.decode_wrappers) else: prefix_lens = forward_batch.extend_prefix_lens @@ -149,11 +162,13 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch.seq_lens, forward_batch.seq_lens_sum, prefix_lens, + prefill_wrappers=self.prefill_wrappers_paged, use_ragged=use_ragged, encoder_lens=forward_batch.encoder_lens, ) - - self.forward_metadata = (use_ragged, extend_no_prefix) + self.forward_metadata = PrefillMetadata( + self.prefill_wrappers_paged, use_ragged, extend_no_prefix + ) def init_cuda_graph_state(self, max_bs: int): cuda_graph_kv_indices = torch.zeros( @@ -194,8 +209,8 @@ class FlashInferAttnBackend(AttentionBackend): decode_wrappers=decode_wrappers, encoder_lens=encoder_lens, ) - self.cuda_graph_metadata[bs] = decode_wrappers - self.forward_metadata = (decode_wrappers,) + self.decode_cuda_graph_metadata[bs] = decode_wrappers + self.forward_metadata = DecodeMetadata(decode_wrappers) def init_forward_metadata_replay_cuda_graph( self, @@ -209,7 +224,7 @@ class FlashInferAttnBackend(AttentionBackend): req_pool_indices[:bs], seq_lens[:bs], seq_lens_sum, - decode_wrappers=self.cuda_graph_metadata[bs], + decode_wrappers=self.decode_cuda_graph_metadata[bs], encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, ) @@ -225,18 +240,16 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): - prefill_wrapper_paged = self.prefill_wrappers_paged[ + prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ self._get_wrapper_idx(layer) ] - - use_ragged, extend_no_prefix = self.forward_metadata cache_loc = ( forward_batch.out_cache_loc if not layer.is_cross_attention else forward_batch.encoder_out_cache_loc ) - if not use_ragged: + if not self.forward_metadata.use_ragged: if k is not None: assert v is not None if save_kv_cache: @@ -260,7 +273,7 @@ class FlashInferAttnBackend(AttentionBackend): logits_soft_cap=layer.logit_cap, ) - if extend_no_prefix: + if self.forward_metadata.extend_no_prefix: o = o1 else: o2, s2 = prefill_wrapper_paged.forward_return_lse( @@ -287,7 +300,9 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): - decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)] + decode_wrapper = self.forward_metadata.decode_wrappers[ + self._get_wrapper_idx(layer) + ] cache_loc = ( forward_batch.out_cache_loc if not layer.is_cross_attention @@ -322,7 +337,7 @@ class FlashInferAttnBackend(AttentionBackend): class FlashInferIndicesUpdaterDecode: def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): - # Constants + # Parse Constants self.num_qo_heads = ( model_runner.model_config.num_attention_heads // model_runner.tp_size ) @@ -340,9 +355,8 @@ class FlashInferIndicesUpdaterDecode: self.kv_indptr = attn_backend.kv_indptr self.kv_last_page_len = attn_backend.kv_last_page_len self.req_to_token = model_runner.req_to_token_pool.req_to_token - self.decode_wrappers = attn_backend.decode_wrappers - # Dispatch + # Dispatch the update function if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: self.update = self.update_sliding_window elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: @@ -356,7 +370,7 @@ class FlashInferIndicesUpdaterDecode: req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - decode_wrappers: List, + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: torch.Tensor, ): # Keep the signature for type checking. It will be assigned during runtime. @@ -367,7 +381,7 @@ class FlashInferIndicesUpdaterDecode: req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - decode_wrappers: List, + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: torch.Tensor, ): decode_wrappers = decode_wrappers or self.decode_wrappers @@ -385,11 +399,9 @@ class FlashInferIndicesUpdaterDecode: req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - decode_wrappers: List, + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: torch.Tensor, ): - decode_wrappers = decode_wrappers or self.decode_wrappers - for wrapper_id in range(2): if wrapper_id == 0: # Sliding window attention @@ -419,11 +431,9 @@ class FlashInferIndicesUpdaterDecode: req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - decode_wrappers: List, + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: torch.Tensor, ): - decode_wrappers = decode_wrappers or self.decode_wrappers - for wrapper_id in range(2): if wrapper_id == 0: # Normal attention @@ -446,7 +456,7 @@ class FlashInferIndicesUpdaterDecode: def call_begin_forward( self, - wrapper, + wrapper: BatchDecodeWithPagedKVCacheWrapper, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, @@ -486,7 +496,7 @@ class FlashInferIndicesUpdaterDecode: class FlashInferIndicesUpdaterPrefill: def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): - # Constants + # Parse Constants self.num_qo_heads = ( model_runner.model_config.num_attention_heads // model_runner.tp_size ) @@ -505,10 +515,9 @@ class FlashInferIndicesUpdaterPrefill: self.kv_last_page_len = attn_backend.kv_last_page_len self.qo_indptr = attn_backend.qo_indptr self.req_to_token = model_runner.req_to_token_pool.req_to_token - self.wrapper_ragged = attn_backend.prefill_wrapper_ragged - self.wrappers_paged = attn_backend.prefill_wrappers_paged + self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged - # Dispatch + # Dispatch the update function if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: self.update = self.update_sliding_window elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: @@ -523,6 +532,7 @@ class FlashInferIndicesUpdaterPrefill: seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: torch.Tensor, ): @@ -535,6 +545,7 @@ class FlashInferIndicesUpdaterPrefill: seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: torch.Tensor, ): @@ -546,8 +557,8 @@ class FlashInferIndicesUpdaterPrefill: paged_kernel_lens_sum = seq_lens_sum self.call_begin_forward( - self.wrapper_ragged, - self.wrappers_paged[0], + self.prefill_wrapper_ragged, + prefill_wrappers[0], req_pool_indices, paged_kernel_lens, paged_kernel_lens_sum, @@ -565,6 +576,7 @@ class FlashInferIndicesUpdaterPrefill: seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: torch.Tensor, ): @@ -584,8 +596,8 @@ class FlashInferIndicesUpdaterPrefill: kv_start_idx = seq_lens - paged_kernel_lens self.call_begin_forward( - self.wrapper_ragged, - self.wrappers_paged[wrapper_id], + self.prefill_wrapper_ragged, + prefill_wrappers[wrapper_id], req_pool_indices, paged_kernel_lens, paged_kernel_lens_sum, @@ -603,6 +615,7 @@ class FlashInferIndicesUpdaterPrefill: seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: torch.Tensor, ): @@ -619,8 +632,8 @@ class FlashInferIndicesUpdaterPrefill: paged_kernel_lens_sum = paged_kernel_lens.sum().item() self.call_begin_forward( - self.wrapper_ragged, - self.wrappers_paged[wrapper_id], + self.prefill_wrapper_ragged, + prefill_wrappers[wrapper_id], req_pool_indices, paged_kernel_lens, paged_kernel_lens_sum, @@ -634,8 +647,8 @@ class FlashInferIndicesUpdaterPrefill: def call_begin_forward( self, - wrapper_ragged, - wrapper_paged, + wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper, + wrapper_paged: BatchPrefillWithPagedKVCacheWrapper, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 5bb52f5bb..31820d37a 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -24,7 +24,11 @@ from vllm.distributed import ( ) from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) @dataclasses.dataclass @@ -46,6 +50,10 @@ class LogitsProcessorOutput: output_top_logprobs_val: List = None output_top_logprobs_idx: List = None + # Used by speculative decoding (EAGLE) + # The output of transformer layers + hidden_states: Optional[torch.Tensor] = None + @dataclasses.dataclass class LogitsMetadata: @@ -61,6 +69,8 @@ class LogitsMetadata: extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_pruned_lens_cpu: Optional[List[int]] = None + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL + @classmethod def from_forward_batch(cls, forward_batch: ForwardBatch): extend_logprob_pruned_lens_cpu = None @@ -78,6 +88,11 @@ class LogitsMetadata: else: return_top_logprob = False + if forward_batch.spec_info: + capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode + else: + capture_hidden_mode = CaptureHiddenMode.NULL + return cls( forward_mode=forward_batch.forward_mode, top_logprobs_nums=forward_batch.top_logprobs_nums, @@ -87,6 +102,7 @@ class LogitsMetadata: extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, + capture_hidden_mode=capture_hidden_mode, ) @@ -116,7 +132,10 @@ class LogitsProcessor(nn.Module): assert isinstance(logits_metadata, LogitsMetadata) # Get the last hidden states and last logits for the next token prediction - if logits_metadata.forward_mode.is_decode(): + if ( + logits_metadata.forward_mode.is_decode() + or logits_metadata.forward_mode.is_target_verify() + ): last_index = None last_hidden = hidden_states else: @@ -137,6 +156,15 @@ class LogitsProcessor(nn.Module): if not logits_metadata.return_logprob: return LogitsProcessorOutput( next_token_logits=last_logits, + hidden_states=( + hidden_states + if logits_metadata.capture_hidden_mode.is_full() + else ( + last_hidden + if logits_metadata.capture_hidden_mode.is_last() + else None + ) + ), ) else: last_logprobs = self.compute_temp_top_p_normalized_logprobs( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ee2884df8..b78d205f2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -843,8 +843,8 @@ class ScheduleBatch: # TODO (lianmin): Revisit this. It should be seq_len - 1 self.extend_logprob_start_lens.extend([0] * running_bs) - def check_decode_mem(self): - bs = len(self.reqs) + def check_decode_mem(self, buf_multiplier=1): + bs = len(self.reqs) * buf_multiplier if self.token_to_kv_pool.available_size() >= bs: return True diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c70c61e4c..fe7bf0198 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -90,7 +90,7 @@ from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) -# Test retract decode +# Test retract decode for debugging purposes test_retract = get_bool_env_var("SGLANG_TEST_RETRACT") @@ -129,12 +129,12 @@ class Scheduler: ) if server_args.skip_tokenizer_init: - # Directly send to the tokenizer/api + # Directly send to the TokenizerManager self.send_to_detokenizer = get_zmq_socket( context, zmq.PUSH, port_args.tokenizer_ipc_name ) else: - # Send to the detokenizer + # Send to the DetokenizerManager self.send_to_detokenizer = get_zmq_socket( context, zmq.PUSH, port_args.detokenizer_ipc_name ) @@ -385,7 +385,8 @@ class Scheduler: self.process_input_requests(recv_reqs) batch = self.get_next_batch_to_run() - if self.server_args.enable_dp_attention: + + if self.server_args.enable_dp_attention: # TODO: simplify this batch = self.prepare_dp_attn_batch(batch) self.cur_batch = batch @@ -394,7 +395,7 @@ class Scheduler: result = self.run_batch(batch) self.process_batch_result(batch, result) else: - # Self-check and re-init some states when the server is idle + # When the server is idle, so self-check and re-init some states self.check_memory() self.new_token_ratio = self.init_new_token_ratio @@ -411,12 +412,13 @@ class Scheduler: batch = self.get_next_batch_to_run() self.cur_batch = batch + if batch: result = self.run_batch(batch) result_queue.append((batch.copy(), result)) if self.last_batch is None: - # A dummy first batch to start the pipeline for overlap scheduler. + # Create a dummy first batch to start the pipeline for overlap scheduler. # It is now used for triggering the sampling_info_done event. tmp_batch = ScheduleBatch( reqs=None, @@ -426,19 +428,21 @@ class Scheduler: self.process_batch_result(tmp_batch, None) if self.last_batch: + # Process the results of the last batch tmp_batch, tmp_result = result_queue.popleft() tmp_batch.next_batch_sampling_info = ( self.tp_worker.cur_sampling_info if batch else None ) self.process_batch_result(tmp_batch, tmp_result) elif batch is None: - # Self-check and re-init some states when the server is idle + # When the server is idle, so self-check and re-init some states self.check_memory() self.new_token_ratio = self.init_new_token_ratio self.last_batch = batch - def recv_requests(self): + def recv_requests(self) -> List[Req]: + """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" if self.tp_rank == 0 or self.server_args.enable_dp_attention: recv_reqs = [] @@ -812,6 +816,8 @@ class Scheduler: if res == AddReqResult.NO_TOKEN: self.batch_is_full = True break + if self.server_args.prefill_only_one_req: + break # Update waiting queue can_run_list = adder.can_run_list @@ -1528,18 +1534,20 @@ def run_scheduler_process( if dp_rank is None and "SGLANG_DP_RANK" in os.environ: dp_rank = int(os.environ["SGLANG_DP_RANK"]) + # Configue the logger if dp_rank is None: configure_logger(server_args, prefix=f" TP{tp_rank}") else: configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") + suppress_other_loggers() - # set cpu affinity to this gpu process + # Set cpu affinity to this gpu process if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) - suppress_other_loggers() parent_process = psutil.Process().parent() + # Create a scheduler and run the event loop try: scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) pipe_writer.send( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 3a5519956..4f77c8f79 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -45,6 +45,7 @@ if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo + from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm class ForwardMode(IntEnum): @@ -59,6 +60,11 @@ class ForwardMode(IntEnum): # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated. IDLE = auto() + # Used in speculative decoding: verify a batch in the target model. + TARGET_VERIFY = auto() + # Used in speculative decoding: extend a batch in the draft model. + DRAFT_EXTEND = auto() + # A dummy first batch to start the pipeline for overlap scheduler. # It is now used for triggering the sampling_info_done event for the first prefill batch. DUMMY_FIRST = auto() @@ -67,7 +73,12 @@ class ForwardMode(IntEnum): return self == ForwardMode.PREFILL def is_extend(self): - return self == ForwardMode.EXTEND or self == ForwardMode.MIXED + return ( + self == ForwardMode.EXTEND + or self == ForwardMode.MIXED + or self == ForwardMode.DRAFT_EXTEND + or self == self.TARGET_VERIFY + ) def is_decode(self): return self == ForwardMode.DECODE @@ -78,6 +89,15 @@ class ForwardMode(IntEnum): def is_idle(self): return self == ForwardMode.IDLE + def is_target_verify(self): + return self == ForwardMode.TARGET_VERIFY + + def is_draft_extend(self): + return self == ForwardMode.DRAFT_EXTEND + + def is_cuda_graph(self): + return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY) + def is_dummy_first(self): return self == ForwardMode.DUMMY_FIRST @@ -141,14 +161,18 @@ class ForwardBatch: token_to_kv_pool: BaseTokenToKVPool = None attn_backend: AttentionBackend = None - # For Qwen2-VL - mrope_positions: torch.Tensor = None + # Speculative decoding + spec_info: SpecInfo = None + spec_algorithm: SpeculativeAlgorithm = None # For DP attention global_num_tokens: Optional[List[int]] = None gathered_buffer: Optional[torch.Tensor] = None can_run_dp_cuda_graph: bool = False + # For Qwen2-VL + mrope_positions: torch.Tensor = None + def compute_mrope_positions( self, model_runner: ModelRunner, batch: ModelWorkerBatch ): @@ -351,3 +375,18 @@ def compute_position_torch( extend_start_loc = torch.zeros_like(extend_seq_lens) extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0) return positions.to(torch.int64), extend_start_loc + + +class CaptureHiddenMode(IntEnum): + NULL = auto() + FULL = auto() + LAST = auto() + + def need_capture(self): + return self != CaptureHiddenMode.NULL + + def is_full(self): + return self == CaptureHiddenMode.FULL + + def is_last(self): + return self == CaptureHiddenMode.LAST diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 0cf0b344e..c06637962 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -516,6 +516,17 @@ class LlamaForCausalLM(nn.Module): ) return None + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 92df9b8bf..a0d07ca44 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -503,7 +503,7 @@ def launch_engine( ) scheduler_infos.append(data) - # Assume all schedulers have same max_total_num_tokens + # Assume all schedulers have same scheduler_info scheduler_info = scheduler_infos[0] @@ -890,7 +890,7 @@ class Runtime: using the commond line interface. It is mainly used for the frontend language. - You should use the Engine class if you want to do normal offline processing. + You should use the Engine class above if you want to do normal offline processing. """ def __init__( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4c751809a..23beb3eb8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -55,7 +55,7 @@ class ServerArgs: is_embedding: bool = False revision: Optional[str] = None - # Port + # Port for the HTTP server host: str = "127.0.0.1" port: int = 30000 @@ -68,6 +68,7 @@ class ServerArgs: schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 cpu_offload_gb: int = 0 + prefill_only_one_req: bool = False # Other runtime options tp_size: int = 1 @@ -94,6 +95,7 @@ class ServerArgs: # Data parallelism dp_size: int = 1 load_balance_method: str = "round_robin" + # Expert parallelism ep_size: int = 1 @@ -217,6 +219,13 @@ class ServerArgs: ) self.disable_cuda_graph = True + # Expert parallelism + if self.enable_ep_moe: + self.ep_size = self.tp_size + logger.info( + f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + ) + # Others if self.enable_dp_attention: self.dp_size = self.tp_size @@ -229,12 +238,6 @@ class ServerArgs: "Data parallel size is adjusted to be the same as tensor parallel size. " "Overlap scheduler is disabled." ) - # Expert parallelism - if self.enable_ep_moe: - self.ep_size = self.tp_size - logger.info( - f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." - ) # GGUF if ( @@ -430,13 +433,18 @@ class ServerArgs: default=ServerArgs.schedule_conservativeness, help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) - parser.add_argument( "--cpu-offload-gb", type=int, default=ServerArgs.cpu_offload_gb, help="How many GBs of RAM to reserve for CPU offloading", ) + parser.add_argument( + "--prefill-only-one-req", + type=bool, + help="If true, we only prefill one request at one prefill batch", + default=ServerArgs.prefill_only_one_req, + ) # Other runtime options parser.add_argument( @@ -555,6 +563,7 @@ class ServerArgs: "shortest_queue", ], ) + # Expert parallelism parser.add_argument( "--expert-parallel-size", @@ -777,28 +786,6 @@ class ServerArgs: help="Delete the model checkpoint after loading the model.", ) - # Deprecated arguments - parser.add_argument( - "--enable-overlap-schedule", - action=DeprecatedAction, - help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.", - ) - parser.add_argument( - "--disable-flashinfer", - action=DeprecatedAction, - help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.", - ) - parser.add_argument( - "--disable-flashinfer-sampling", - action=DeprecatedAction, - help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.", - ) - parser.add_argument( - "--disable-disk-cache", - action=DeprecatedAction, - help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.", - ) - @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py new file mode 100644 index 000000000..3b306e985 --- /dev/null +++ b/python/sglang/srt/speculative/spec_info.py @@ -0,0 +1,19 @@ +from enum import IntEnum, auto + + +class SpeculativeAlgorithm(IntEnum): + EAGLE = auto() + + def is_eagle(self): + return self == SpeculativeAlgorithm.EAGLE + + @staticmethod + def from_string(name: str): + name_map = { + "EAGLE": SpeculativeAlgorithm.EAGLE, + } + return name_map[name] + + +class SpecInfo: + pass