diff --git a/README.md b/README.md index b41389671..538185cee 100644 --- a/README.md +++ b/README.md @@ -621,7 +621,6 @@ Please cite our paper, [SGLang: Efficient Execution of Structured Language Model We also learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). -

Back To Top diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index f6d10170c..ae0ef6b7d 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -25,7 +25,11 @@ class AttentionBackend(ABC): raise NotImplementedError() def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index e2cd98ec2..c83fba814 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -144,7 +144,11 @@ class DoubleSparseAttnBackend(AttentionBackend): ) def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, ): self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index c2cfa5fb6..cd4aec859 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -127,6 +127,7 @@ class FlashInferAttnBackend(AttentionBackend): self.indices_updater_decode.update( forward_batch.req_pool_indices, forward_batch.seq_lens, + forward_batch.seq_lens_sum, ) self.forward_metadata = (self.decode_wrappers,) else: @@ -134,10 +135,7 @@ class FlashInferAttnBackend(AttentionBackend): # Some heuristics to check whether to use ragged forward use_ragged = False - if ( - torch.sum(forward_batch.seq_lens).item() >= 4096 - and self.num_wrappers == 1 - ): + if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1: use_ragged = True extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item() @@ -181,15 +179,25 @@ class FlashInferAttnBackend(AttentionBackend): ) ) - self.indices_updater_decode.update(req_pool_indices, seq_lens, decode_wrappers) + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_decode.update( + req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers + ) self.cuda_graph_metadata[bs] = decode_wrappers self.forward_metadata = (decode_wrappers,) def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, ): self.indices_updater_decode.update( - req_pool_indices[:bs], seq_lens[:bs], self.cuda_graph_metadata[bs] + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_sum, + self.cuda_graph_metadata[bs], ) def get_cuda_graph_seq_len_fill_value(self): @@ -305,13 +313,30 @@ class FlashInferIndicesUpdaterDecode: assert attn_backend.num_wrappers == 1 self.update = self.update_single_wrapper - def update_single_wrapper(self, req_pool_indices, seq_lens, decode_wrappers=None): + def update_single_wrapper( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + decode_wrappers=None, + ): decode_wrappers = decode_wrappers or self.decode_wrappers self.call_begin_forward( - decode_wrappers[0], req_pool_indices, seq_lens, self.kv_indptr[0], None + decode_wrappers[0], + req_pool_indices, + seq_lens, + seq_lens_sum, + self.kv_indptr[0], + None, ) - def update_sliding_window(self, req_pool_indices, seq_lens, decode_wrappers=None): + def update_sliding_window( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + decode_wrappers=None, + ): decode_wrappers = decode_wrappers or self.decode_wrappers for wrapper_id in range(2): @@ -331,6 +356,7 @@ class FlashInferIndicesUpdaterDecode: decode_wrappers[wrapper_id], req_pool_indices, paged_kernel_lens, + seq_lens_sum, self.kv_indptr[wrapper_id], kv_start_idx, ) @@ -339,13 +365,18 @@ class FlashInferIndicesUpdaterDecode: raise NotImplementedError() def call_begin_forward( - self, wrapper, req_pool_indices, paged_kernel_lens, kv_indptr, kv_start_idx + self, + wrapper, + req_pool_indices, + paged_kernel_lens, + seq_lens_sum, + kv_indptr, + kv_start_idx, ): bs = len(req_pool_indices) kv_indptr = kv_indptr[: bs + 1] - # TODO: optimize the blocking call on kv_indptr[-1] kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + kv_indices = torch.empty(seq_lens_sum, dtype=torch.int32, device="cuda") create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index e1f5bf371..fb3805cfe 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -91,7 +91,11 @@ class TritonAttnBackend(AttentionBackend): ) def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, ): self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 7fd153e80..b0ab2dfe5 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -416,7 +416,6 @@ class ScheduleBatch: req_to_token_pool: ReqToTokenPool = None token_to_kv_pool: BaseTokenToKVPool = None tree_cache: BasePrefixCache = None - forward_mode: ForwardMode = None sampling_info: SamplingBatchInfo = None @@ -424,9 +423,13 @@ class ScheduleBatch: input_ids: torch.Tensor = None req_pool_indices: torch.Tensor = None seq_lens: torch.Tensor = None + # The output locations of the KV cache out_cache_loc: torch.Tensor = None output_ids: torch.Tensor = None + # The sum of all sequence lengths + seq_lens_sum: int = None + # For processing logprobs return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None @@ -435,7 +438,6 @@ class ScheduleBatch: prefix_lens: List[int] = None extend_lens: List[int] = None extend_num_tokens: int = None - running_bs: int = None decoding_reqs: List[Req] = None # Stream @@ -549,10 +551,12 @@ class ScheduleBatch: self.device, non_blocking=True ) - self.extend_num_tokens = extend_num_tokens self.out_cache_loc = out_cache_loc + + self.seq_lens_sum = sum(seq_lens) if self.return_logprob: self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + self.extend_num_tokens = extend_num_tokens self.prefix_lens = [len(r.prefix_indices) for r in reqs] self.extend_lens = [r.extend_input_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] @@ -571,12 +575,11 @@ class ScheduleBatch: input_ids = torch.cat([self.input_ids, running_batch.input_ids]) out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) - extend_num_tokens = self.extend_num_tokens + running_bs self.merge_batch(running_batch) self.input_ids = input_ids self.out_cache_loc = out_cache_loc - self.extend_num_tokens = extend_num_tokens + self.extend_num_tokens += running_bs # NOTE: prefix_indices is what has been cached, but we don't cache each decode step self.prefix_lens.extend( @@ -775,6 +778,7 @@ class ScheduleBatch: (self.req_pool_indices, self.seq_lens), self.out_cache_loc ) self.seq_lens.add_(1) + self.seq_lens_sum += bs def filter_batch( self, @@ -805,6 +809,7 @@ class ScheduleBatch: self.req_pool_indices = self.req_pool_indices[new_indices] self.seq_lens = self.seq_lens[new_indices] self.out_cache_loc = None + self.seq_lens_sum = self.seq_lens.sum().item() self.output_ids = self.output_ids[new_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) if self.return_logprob: @@ -828,6 +833,7 @@ class ScheduleBatch: ) self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) self.out_cache_loc = None + self.seq_lens_sum += other.seq_lens_sum if self.output_ids is not None: self.output_ids = torch.concat([self.output_ids, other.output_ids]) if self.return_logprob and other.return_logprob: @@ -873,9 +879,11 @@ class ScheduleBatch: req_pool_indices=self.req_pool_indices, seq_lens=self.seq_lens, out_cache_loc=self.out_cache_loc, + seq_lens_sum=self.seq_lens_sum, req_to_token_pool_records=self.req_to_token_pool.get_write_records(), return_logprob=self.return_logprob, top_logprobs_nums=self.top_logprobs_nums, + extend_num_tokens=self.extend_num_tokens, extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, extend_logprob_start_lens=extend_logprob_start_lens, @@ -917,6 +925,9 @@ class ModelWorkerBatch: # The indices of output tokens in the token_to_kv_pool out_cache_loc: torch.Tensor + # The sum of all sequence lengths + seq_lens_sum: int + # The memory pool operation records req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]] @@ -925,6 +936,7 @@ class ModelWorkerBatch: top_logprobs_nums: Optional[List[int]] # For extend + extend_num_tokens: Optional[int] extend_seq_lens: Optional[List[int]] extend_prefix_lens: Optional[List[int]] extend_logprob_start_lens: Optional[List[int]] diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index d3ff3cd1d..37e3c8429 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -188,6 +188,7 @@ class CudaGraphRunner: req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] out_cache_loc = self.out_cache_loc[:bs] + seq_lens_sum = seq_lens.sum().item() # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( @@ -206,6 +207,7 @@ class CudaGraphRunner: token_to_kv_pool=self.model_runner.token_to_kv_pool, attn_backend=self.model_runner.attn_backend, out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens_sum, return_logprob=False, top_logprobs_nums=[0] * bs, positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), @@ -252,7 +254,10 @@ class CudaGraphRunner: # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( - bs, self.req_pool_indices, self.seq_lens + bs, + self.req_pool_indices, + self.seq_lens, + forward_batch.seq_lens_sum, ) # Replay diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 49ef754a2..f4e117b76 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -87,6 +87,9 @@ class ForwardBatch: # The indices of output tokens in the token_to_kv_pool out_cache_loc: torch.Tensor + # The sum of all sequence lengths + seq_lens_sum: int + # For logprob return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None @@ -95,6 +98,7 @@ class ForwardBatch: positions: torch.Tensor = None # For extend + extend_num_tokens: Optional[int] = None extend_seq_lens: Optional[torch.Tensor] = None extend_prefix_lens: Optional[torch.Tensor] = None extend_start_loc: Optional[torch.Tensor] = None @@ -175,21 +179,6 @@ class ForwardBatch: ) self.mrope_positions = self.mrope_positions.to(torch.int64) - def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch): - device = model_runner.device - if self.forward_mode.is_decode(): - self.positions = (self.seq_lens - 1).to(torch.int64) - else: - self.positions = torch.concat( - [ - torch.arange(prefix_len, prefix_len + extend_len, device=device) - for prefix_len, extend_len in zip( - batch.extend_prefix_lens, batch.extend_seq_lens - ) - ], - axis=0, - ) - @classmethod def init_new( cls, @@ -205,6 +194,7 @@ class ForwardBatch: req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, out_cache_loc=batch.out_cache_loc, + seq_lens_sum=batch.seq_lens_sum, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, lora_paths=batch.lora_paths, @@ -213,7 +203,17 @@ class ForwardBatch: # Init position information if not ret.forward_mode.is_decode(): + ret.positions = torch.concat( + [ + torch.arange(prefix_len, prefix_len + extend_len, device=device) + for prefix_len, extend_len in zip( + batch.extend_prefix_lens, batch.extend_seq_lens + ) + ], + axis=0, + ) ret.image_inputs = batch.image_inputs + ret.extend_num_tokens = batch.extend_num_tokens ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32 ).to(device, non_blocking=True) @@ -225,12 +225,8 @@ class ForwardBatch: ret.extend_seq_lens_cpu = batch.extend_seq_lens ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens - # Init position information - is_mrope = model_runner.model_is_mrope - if is_mrope: + if model_runner.model_is_mrope: ret.compute_mrope_positions(model_runner, batch) - else: - ret.compute_positions(model_runner, batch) # Init attention information ret.req_to_token_pool = model_runner.req_to_token_pool