From 62832bb2728e0e8ac5f97dc7687eaf263aaa927f Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Mon, 18 Nov 2024 08:29:20 +0800 Subject: [PATCH] Support cuda graph for DP attention (#2061) --- python/sglang/srt/managers/schedule_batch.py | 10 ++++ python/sglang/srt/managers/scheduler.py | 32 ++++++++---- python/sglang/srt/managers/tp_worker.py | 3 -- .../srt/managers/tp_worker_overlap_thread.py | 5 +- .../srt/model_executor/cuda_graph_runner.py | 50 ++++++++++++++++--- .../srt/model_executor/forward_batch_info.py | 2 + .../sglang/srt/model_executor/model_runner.py | 3 ++ python/sglang/srt/server_args.py | 5 +- scripts/playground/reference_hf.py | 4 +- 9 files changed, 88 insertions(+), 26 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 19bd07b88..054d1fcf8 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -455,6 +455,7 @@ class ScheduleBatch: # For DP attention global_num_tokens: Optional[List[int]] = None + can_run_dp_cuda_graph: bool = False # For processing logprobs return_logprob: bool = False @@ -891,6 +892,13 @@ class ScheduleBatch: self.seq_lens = torch.empty(0, dtype=torch.int32).to( self.device, non_blocking=True ) + self.out_cache_loc = torch.empty(0, dtype=torch.int32).to( + self.device, non_blocking=True + ) + self.req_pool_indices = torch.empty(0, dtype=torch.int32).to( + self.device, non_blocking=True + ) + self.seq_lens_sum = 0 self.extend_num_tokens = 0 def prepare_for_decode(self, enable_overlap: bool = False): @@ -1032,6 +1040,7 @@ class ScheduleBatch: return_logprob=self.return_logprob, top_logprobs_nums=self.top_logprobs_nums, global_num_tokens=self.global_num_tokens, + can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, extend_num_tokens=self.extend_num_tokens, extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, @@ -1093,6 +1102,7 @@ class ModelWorkerBatch: # For DP attention global_num_tokens: Optional[List[int]] + can_run_dp_cuda_graph: bool # For extend extend_num_tokens: Optional[int] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index bb97efe2e..e25af5583 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -337,7 +337,7 @@ class Scheduler: kill_parent_process() - @torch.inference_mode() + @torch.no_grad() def event_loop_normal(self): """A normal blocking scheduler loop.""" self.last_batch = None @@ -375,7 +375,7 @@ class Scheduler: self.last_batch = batch - @torch.inference_mode() + @torch.no_grad() def event_loop_overlap(self): """A scheduler loop that overlaps the CPU processing and GPU computation.""" result_queue = deque() @@ -411,16 +411,12 @@ class Scheduler: else: num_tokens = local_batch.extend_num_tokens - local_num_tokens = torch.tensor( - num_tokens, dtype=torch.int64, device=self.device - ) - global_num_tokens = torch.empty( - self.tp_size, dtype=torch.int64, device=self.device - ) + local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64) + global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64) torch.distributed.all_gather_into_tensor( global_num_tokens, local_num_tokens, - group=self.tp_worker.get_tp_device_group(), + group=self.tp_cpu_group, ) if local_batch is None and global_num_tokens.max().item() > 0: @@ -429,6 +425,24 @@ class Scheduler: if local_batch is not None: local_batch.global_num_tokens = global_num_tokens.tolist() + # Check forward mode for cuda graph + if not self.server_args.disable_cuda_graph: + forward_mode_state = torch.tensor( + ( + 1 + if local_batch.forward_mode.is_decode() + or local_batch.forward_mode.is_idle() + else 0 + ), + dtype=torch.int32, + ) + torch.distributed.all_reduce( + forward_mode_state, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_cpu_group, + ) + local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1 + return local_batch def get_idle_batch(self): diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 361febfac..4900575ee 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -128,9 +128,6 @@ class TpModelWorker: def get_tp_cpu_group(self): return self.model_runner.tp_group.cpu_group - def get_tp_device_group(self): - return self.model_runner.tp_group.device_group - def get_memory_pool(self): return ( self.model_runner.req_to_token_pool, diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 3ae1e37b3..412680f98 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -83,9 +83,6 @@ class TpModelWorkerClient: def get_tp_cpu_group(self): return self.worker.get_tp_cpu_group() - def get_tp_device_group(self): - return self.worker.get_tp_device_group() - def get_memory_pool(self): return ( self.worker.model_runner.req_to_token_pool, @@ -96,7 +93,7 @@ class TpModelWorkerClient: with torch.cuda.stream(self.forward_stream): self.forward_thread_func_() - @torch.inference_mode() + @torch.no_grad() def forward_thread_func_(self): while True: model_worker_batch, future_token_ids_ct = self.input_queue.get() diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 236a57f1a..db185599f 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -111,6 +111,8 @@ class CudaGraphRunner: self.use_torch_compile = model_runner.server_args.enable_torch_compile self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder + self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention + self.tp_size = self.model_runner.tp_size # Batch sizes to capture if model_runner.server_args.disable_cuda_graph_padding: @@ -165,6 +167,16 @@ class CudaGraphRunner: else: self.encoder_lens = None + if self.enable_dp_attention: + self.global_num_tokens = [0] * self.tp_size + self.gathered_buffer = torch.zeros( + ( + self.max_bs * self.tp_size, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) + # Capture try: with self.model_capture_mode(): @@ -190,11 +202,21 @@ class CudaGraphRunner: self.model_runner.model.capture_mode = False def can_run(self, forward_batch: ForwardBatch): - is_bs_supported = ( - forward_batch.batch_size in self.graphs - if self.disable_padding - else forward_batch.batch_size <= self.max_bs - ) + if self.enable_dp_attention: + min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max( + forward_batch.global_num_tokens + ) + is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( + (min_num_tokens == max_num_tokens and max_num_tokens in self.graphs) + if self.disable_padding + else max_num_tokens <= self.max_bs + ) + else: + is_bs_supported = ( + forward_batch.batch_size in self.graphs + if self.disable_padding + else forward_batch.batch_size <= self.max_bs + ) # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph @@ -239,6 +261,13 @@ class CudaGraphRunner: seq_lens_sum = seq_lens.sum().item() mrope_positions = self.mrope_positions[:, :bs] + if self.enable_dp_attention: + self.global_num_tokens[:] = [bs] * self.tp_size + gathered_buffer = self.gathered_buffer[: bs * self.tp_size] + else: + self.global_num_tokens = None + gathered_buffer = None + # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( bs, @@ -265,6 +294,8 @@ class CudaGraphRunner: top_logprobs_nums=[0] * bs, positions=clamp_position(seq_lens), mrope_positions=mrope_positions, + global_num_tokens=self.global_num_tokens, + gathered_buffer=gathered_buffer, ) logits_output = forward(input_ids, forward_batch.positions, forward_batch) return logits_output.next_token_logits @@ -295,7 +326,12 @@ class CudaGraphRunner: raw_bs = forward_batch.batch_size # Pad - index = bisect.bisect_left(self.capture_bs, raw_bs) + if self.enable_dp_attention: + index = bisect.bisect_left( + self.capture_bs, max(forward_batch.global_num_tokens) + ) + else: + index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] if bs != raw_bs: self.seq_lens.fill_(1) @@ -310,6 +346,8 @@ class CudaGraphRunner: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) if forward_batch.mrope_positions is not None: self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) + if self.enable_dp_attention: + self.global_num_tokens[:] = [bs] * self.tp_size # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index ea7c8d89a..c4af97957 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -138,6 +138,7 @@ class ForwardBatch: # For DP attention global_num_tokens: Optional[List[int]] = None gathered_buffer: Optional[torch.Tensor] = None + can_run_dp_cuda_graph: bool = False def compute_mrope_positions( self, model_runner: ModelRunner, batch: ModelWorkerBatch @@ -221,6 +222,7 @@ class ForwardBatch: return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, global_num_tokens=batch.global_num_tokens, + can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, lora_paths=batch.lora_paths, sampling_info=batch.sampling_info, ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 55bf9afd8..35d81050a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -592,6 +592,9 @@ class ModelRunner: ) def forward_idle(self, forward_batch: ForwardBatch): + if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): + return self.cuda_graph_runner.replay(forward_batch) + return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2a4b0d67e..8508d15d8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -191,11 +191,12 @@ class ServerArgs: if self.enable_dp_attention: self.dp_size = self.tp_size self.chunked_prefill_size = self.chunked_prefill_size // 2 - self.disable_cuda_graph = True + self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96) self.enable_overlap_schedule = False logger.warning( f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. " - "The CUDA graph is disabled. Data parallel size is adjust to be the same as tensor parallel size." + f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. " + "Data parallel size is adjusted to be the same as tensor parallel size." ) if self.enable_overlap_schedule: diff --git a/scripts/playground/reference_hf.py b/scripts/playground/reference_hf.py index bf56fc3c9..7901145c6 100644 --- a/scripts/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -31,7 +31,7 @@ from transformers import AutoModelForCausalLM from sglang.srt.hf_transformers_utils import get_tokenizer -@torch.inference_mode() +@torch.no_grad() def normal_text(args): t = get_tokenizer(args.model_path, trust_remote_code=True) m = AutoModelForCausalLM.from_pretrained( @@ -69,7 +69,7 @@ def normal_text(args): print(output_str) -@torch.inference_mode() +@torch.no_grad() def synthetic_tokens(args): m = AutoModelForCausalLM.from_pretrained( args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True