diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index 5320da72b..a3793f921 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -56,6 +56,7 @@ class BenchArgs: gen_output_len: int = 256 disable_ignore_eos: bool = False seed: int = 1 + do_not_exit: bool = False @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -143,6 +144,11 @@ class BenchArgs: help="Disable ignore EOS token", ) parser.add_argument("--seed", type=int, default=1, help="The random seed.") + parser.add_argument( + "--do-not-exit", + action="store_true", + help="Do not exit the program. This is useful for nsys profile with --duration and --delay.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -309,3 +315,6 @@ if __name__ == "__main__": ) throughput_test(server_args, bench_args) + + while bench_args.do_not_exit: + pass diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index c6b5393ee..5b3ae30c3 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -314,7 +314,6 @@ class FlashInferIndicesUpdaterDecode: self.head_dim = model_runner.model_config.head_dim self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype - self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1) self.sliding_window_size = model_runner.sliding_window_size self.attn_backend = attn_backend @@ -445,7 +444,7 @@ class FlashInferIndicesUpdaterDecode: kv_indptr, kv_start_idx, kv_indices, - self.max_context_len, + self.req_to_token.shape[1], ) wrapper.end_forward() @@ -474,7 +473,6 @@ class FlashInferIndicesUpdaterPrefill: self.head_dim = model_runner.model_config.head_dim self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype - self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1) self.sliding_window_size = model_runner.sliding_window_size self.attn_backend = attn_backend @@ -599,7 +597,7 @@ class FlashInferIndicesUpdaterPrefill: kv_indptr, kv_start_idx, kv_indices, - self.max_context_len, + self.req_to_token.shape[1], ) qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) @@ -638,10 +636,11 @@ def create_flashinfer_kv_indices_triton( kv_indptr, kv_start_idx, kv_indices_ptr, - max_context_len: tl.constexpr, + req_to_token_ptr_stride: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 512 pid = tl.program_id(axis=0) + req_pool_index = tl.load(req_pool_indices_ptr + pid) kv_indices_offset = tl.load(kv_indptr + pid) @@ -652,15 +651,15 @@ def create_flashinfer_kv_indices_triton( kv_end = kv_start kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) - req_to_token_ptr += req_pool_index * max_context_len - kv_indices_ptr += kv_indices_offset - - ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) - st_offset = tl.arange(0, BLOCK_SIZE) num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) - for _ in range(num_loop): - mask = ld_offset < kv_end - data = tl.load(req_to_token_ptr + ld_offset, mask=mask) - tl.store(kv_indices_ptr + st_offset, data, mask=mask) - ld_offset += BLOCK_SIZE - st_offset += BLOCK_SIZE + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < kv_end - kv_start + data = tl.load( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + kv_start + + offset, + mask=mask, + ) + tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index eda2c7738..7bdff8f55 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -62,21 +62,21 @@ class LogitsMetadata: @classmethod def from_forward_batch(cls, forward_batch: ForwardBatch): + extend_logprob_pruned_lens_cpu = None + if forward_batch.return_logprob: return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) + if forward_batch.forward_mode.is_extend(): + extend_logprob_pruned_lens_cpu = [ + extend_len - start_len + for extend_len, start_len in zip( + forward_batch.extend_seq_lens_cpu, + forward_batch.extend_logprob_start_lens_cpu, + ) + ] else: return_top_logprob = False - if forward_batch.forward_mode.is_extend(): - extend_logprob_pruned_lens_cpu = [ - extend_len - start_len - for extend_len, start_len in zip( - forward_batch.extend_seq_lens, - forward_batch.extend_logprob_start_lens_cpu, - ) - ] - else: - extend_logprob_pruned_lens_cpu = None return cls( forward_mode=forward_batch.forward_mode, top_logprobs_nums=forward_batch.top_logprobs_nums, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 109f3bf6f..19bd07b88 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -34,6 +34,8 @@ import logging from typing import List, Optional, Tuple, Union import torch +import triton +import triton.language as tl from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig @@ -615,12 +617,12 @@ class ScheduleBatch: input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [] + pre_lens = [] # Allocate memory req_pool_indices = self.alloc_req_slots(bs) out_cache_loc = self.alloc_token_slots(extend_num_tokens) - pt = 0 for i, req in enumerate(reqs): already_computed = ( req.extend_logprob_start_len + 1 + req.cached_tokens @@ -638,10 +640,6 @@ class ScheduleBatch: self.req_to_token_pool.write( (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices ) - self.req_to_token_pool.write( - (req.req_pool_idx, slice(pre_len, seq_len)), - out_cache_loc[pt : pt + req.extend_input_len], - ) # Compute the relative logprob_start_len in an extend batch if req.logprob_start_len >= pre_len: @@ -652,8 +650,8 @@ class ScheduleBatch: extend_logprob_start_len = req.extend_input_len - 1 req.extend_logprob_start_len = extend_logprob_start_len - pt += req.extend_input_len req.is_retracted = False + pre_lens.append(pre_len) # Set fields self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( @@ -665,7 +663,6 @@ class ScheduleBatch: self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( self.device, non_blocking=True ) - self.out_cache_loc = out_cache_loc self.seq_lens_sum = sum(seq_lens) @@ -676,9 +673,33 @@ class ScheduleBatch: self.extend_lens = [r.extend_input_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] + # Write to req_to_token_pool + pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to( + self.device, non_blocking=True + ) + extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to( + self.device, non_blocking=True + ) + write_req_to_token_pool_triton[(bs,)]( + self.req_to_token_pool.req_to_token, + self.req_pool_indices, + pre_lens, + self.seq_lens, + extend_lens, + self.out_cache_loc, + self.req_to_token_pool.req_to_token.shape[1], + ) + # The triton kernel is equivalent to the following python code. + # self.req_to_token_pool.write( + # (req.req_pool_idx, slice(pre_len, seq_len)), + # out_cache_loc[pt : pt + req.extend_input_len], + # ) + # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) + if self.model_config.is_encoder_decoder: self.prepare_encoder_info_extend(input_ids, seq_lens) + # Build sampling info self.sampling_info = SamplingBatchInfo.from_schedule_batch( self, self.model_config.vocab_size, @@ -1025,6 +1046,9 @@ class ScheduleBatch: ) def copy(self): + # We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors. + _ = self.seq_lens[0].item() + # Only contain fields that will be used by process_batch_result return ScheduleBatch( reqs=self.reqs, @@ -1104,3 +1128,40 @@ class ModelWorkerBatch: for x, y in self.req_to_token_pool_records ] self.sampling_info.to(device) + + +@triton.jit +def write_req_to_token_pool_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + req_to_token_ptr_stride: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(0) + + req_pool_index = tl.load(req_pool_indices + pid) + pre_len = tl.load(pre_lens + pid) + seq_len = tl.load(seq_lens + pid) + + # TODO: optimize this? + cumsum_start = 0 + for i in range(pid): + cumsum_start += tl.load(extend_lens + i) + + num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < (seq_len - pre_len) + value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask) + tl.store( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + offset + + pre_len, + value, + mask=mask, + ) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 21264f1a9..8c924c442 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -56,6 +56,7 @@ class TpModelWorkerClient: self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port) self.max_running_requests = self.worker.max_running_requests self.device = self.worker.device + self.gpu_id = gpu_id # Init future mappings self.future_token_ids_ct = 0 @@ -73,12 +74,6 @@ class TpModelWorkerClient: ) self.forward_thread.start() - self.copy_queue = Queue() - self.copy_thread = threading.Thread( - target=self.copy_thread_func, - ) - self.copy_thread.start() - def get_worker_info(self): return self.worker.get_worker_info() @@ -104,12 +99,11 @@ class TpModelWorkerClient: @torch.inference_mode() def forward_thread_func_(self): while True: - self.has_inflight_batch = False model_worker_batch, future_token_ids_ct = self.input_queue.get() if not model_worker_batch: break - self.has_inflight_batch = True self.launch_event = threading.Event() + copy_event = torch.cuda.Event() # Resolve future tokens in the input input_ids = model_worker_batch.input_ids @@ -142,39 +136,29 @@ class TpModelWorkerClient: ) ) next_token_ids = next_token_ids.to("cpu", non_blocking=True) - copy_event = torch.cuda.Event(blocking=True) copy_event.record() self.launch_event.set() - self.copy_queue.put((copy_event, logits_output, next_token_ids)) - - def copy_thread_func(self): - while True: - copy_event, logits_output, next_token_ids = self.copy_queue.get() - if not copy_event: - break - while not copy_event.query(): - time.sleep(1e-5) - - if logits_output.next_token_logprobs is not None: - logits_output.next_token_logprobs = ( - logits_output.next_token_logprobs.tolist() - ) - if logits_output.input_token_logprobs is not None: - logits_output.input_token_logprobs = ( - logits_output.input_token_logprobs.tolist() - ) - logits_output.normalized_prompt_logprobs = ( - logits_output.normalized_prompt_logprobs.tolist() - ) - - self.output_queue.put((logits_output, next_token_ids.tolist())) + self.output_queue.put((copy_event, logits_output, next_token_ids)) def resulve_batch_result(self, bid: int): - logits_output, next_token_ids = self.output_queue.get() - if self.has_inflight_batch: - # Wait until the batch is launched - self.launch_event.wait() + copy_event, logits_output, next_token_ids = self.output_queue.get() + while not copy_event.query(): + time.sleep(1e-5) + self.launch_event.wait() + + if logits_output.next_token_logprobs is not None: + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs.tolist() + ) + if logits_output.input_token_logprobs is not None: + logits_output.input_token_logprobs = ( + logits_output.input_token_logprobs.tolist() + ) + logits_output.normalized_prompt_logprobs = ( + logits_output.normalized_prompt_logprobs.tolist() + ) + next_token_ids = next_token_ids.tolist() return logits_output, next_token_ids def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 3381c9211..ea7c8d89a 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -36,6 +36,8 @@ from enum import IntEnum, auto from typing import TYPE_CHECKING, List, Optional import torch +import triton +import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding @@ -236,25 +238,16 @@ class ForwardBatch: # Init position information if not ret.forward_mode.is_decode(): - ret.positions = torch.concat( - [ - torch.arange(prefix_len, prefix_len + extend_len, device=device) - for prefix_len, extend_len in zip( - batch.extend_prefix_lens, batch.extend_seq_lens - ) - ], - axis=0, - ) - ret.extend_num_tokens = batch.extend_num_tokens ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32 ).to(device, non_blocking=True) - ret.extend_prefix_lens = torch.tensor( batch.extend_prefix_lens, dtype=torch.int32 ).to(device, non_blocking=True) - ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens) - ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0) + ret.extend_num_tokens = batch.extend_num_tokens + ret.positions, ret.extend_start_loc = compute_position_triton( + ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens + ) ret.extend_seq_lens_cpu = batch.extend_seq_lens ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens @@ -271,3 +264,72 @@ class ForwardBatch: model_runner.lora_manager.prepare_lora_batch(ret) return ret + + +def compute_position_triton( + extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum +): + """Compute positions. It is a fused version of `compute_position_torch`.""" + batch_size = extend_seq_lens.shape[0] + positions = torch.empty( + extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device + ) + extend_start_loc = torch.empty( + batch_size, dtype=torch.int32, device=extend_seq_lens.device + ) + + # Launch kernel + compute_position_kernel[(batch_size,)]( + positions, + extend_start_loc, + extend_prefix_lens, + extend_seq_lens, + ) + + return positions, extend_start_loc + + +@triton.jit +def compute_position_kernel( + positions, + extend_start_loc, + extend_prefix_lens, + extend_seq_lens, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(0) + + prefix_len = tl.load(extend_prefix_lens + pid) + seq_len = tl.load(extend_seq_lens + pid) + + # TODO: optimize this? + cumsum_start = 0 + for i in range(pid): + cumsum_start += tl.load(extend_seq_lens + i) + + num_loop = tl.cdiv(seq_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + tl.store( + positions + cumsum_start + offset, + prefix_len + offset, + mask=offset < seq_len, + ) + tl.store(extend_start_loc + pid, cumsum_start) + + +def compute_position_torch( + extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor +): + positions = torch.concat( + [ + torch.arange( + prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device + ) + for prefix_len, extend_len in zip(extend_prefix_lens, extend_seq_lens) + ], + axis=0, + ) + extend_start_loc = torch.zeros_like(extend_seq_lens) + extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0) + return positions.to(torch.int64), extend_start_loc diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 17369d31a..a341c2b17 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -73,7 +73,7 @@ class SamplingBatchInfo: top_ks=top_ks, min_ps=min_ps, need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), - is_all_greedy=top_ks.max().item() <= 1, + is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), vocab_size=vocab_size, device=device, )