From ad4125d1a9c4796cdbc6c6a5cdb69b09e60e5509 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 22 Oct 2024 23:20:43 -0700 Subject: [PATCH] Fuse more ops & Simplify token mapping (#1758) --- docs/en/benchmark_and_profiling.md | 6 +- .../layers/attention/flashinfer_backend.py | 10 +-- python/sglang/srt/layers/sampler.py | 81 ++++++++++--------- .../srt/managers/tp_worker_overlap_thread.py | 36 +++++---- python/sglang/srt/mem_cache/memory_pool.py | 27 ++++--- .../srt/model_executor/cuda_graph_runner.py | 8 +- python/sglang/test/run_eval.py | 2 + test/srt/test_eval_accuracy_mini.py | 1 + test/srt/test_pytorch_sampling_backend.py | 3 +- 9 files changed, 99 insertions(+), 75 deletions(-) diff --git a/docs/en/benchmark_and_profiling.md b/docs/en/benchmark_and_profiling.md index 77fbbfc1b..c0f54957d 100644 --- a/docs/en/benchmark_and_profiling.md +++ b/docs/en/benchmark_and_profiling.md @@ -46,4 +46,8 @@ pip install nvtx import nvtx with nvtx.annotate("description", color="color"): # some critical code -``` \ No newline at end of file +``` + +## Other tips + +1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder. diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index e5e7ca29c..c6b5393ee 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -337,7 +337,7 @@ class FlashInferIndicesUpdaterDecode: def update( self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens ): - # Keep the signature for type checking, will be initialized during runtime + # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def update_single_wrapper( @@ -432,8 +432,8 @@ class FlashInferIndicesUpdaterDecode: kv_start_idx, ): bs = len(req_pool_indices) + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) kv_indices = torch.empty( paged_kernel_lens_sum, dtype=torch.int32, device="cuda" ) @@ -497,7 +497,7 @@ class FlashInferIndicesUpdaterPrefill: self.update = self.update_single_wrapper def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens): - # Keep the signature for type checking, will be initialized during runtime + # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def update_single_wrapper( @@ -589,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill: use_ragged, ): bs = len(req_pool_indices) + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, @@ -602,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill: self.max_context_len, ) + qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] - qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) # extend part if use_ragged: diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 454078d59..54fc47b73 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -33,56 +33,61 @@ class Sampler(nn.Module): if isinstance(logits, LogitsProcessorOutput): logits = logits.next_token_logits - # Post process logits logits = logits.contiguous() - logits.div_(sampling_info.temperatures) - probs = torch.softmax(logits, dim=-1) - logits = None - del logits - if self.use_nan_detectioin and torch.any(torch.isnan(probs)): - logger.warning("Detected errors during sampling! NaN in the probability.") - probs = torch.where( - torch.isnan(probs), torch.full_like(probs, 1e-10), probs + if self.use_nan_detectioin and torch.any(torch.isnan(logits)): + logger.warning("Detected errors during sampling! NaN in the logits.") + logits = torch.where( + torch.isnan(logits), torch.full_like(logits, -1e5), logits ) if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling - batch_next_token_ids = torch.argmax(probs, -1) - elif global_server_args_dict["sampling_backend"] == "flashinfer": - max_top_k_round, batch_size = 32, probs.shape[0] - uniform_samples = torch.rand( - (max_top_k_round, batch_size), device=probs.device - ) - if sampling_info.need_min_p_sampling: - probs = top_k_renorm_prob(probs, sampling_info.top_ks) - probs = top_p_renorm_prob(probs, sampling_info.top_ps) - batch_next_token_ids, success = min_p_sampling_from_probs( - probs, uniform_samples, sampling_info.min_ps + batch_next_token_ids = torch.argmax(logits, -1) + else: + # Post process logits + logits.div_(sampling_info.temperatures) + probs = torch.softmax(logits, dim=-1) + logits = None + del logits + + if global_server_args_dict["sampling_backend"] == "flashinfer": + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand( + (max_top_k_round, batch_size), device=probs.device ) - else: - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + if sampling_info.need_min_p_sampling: + probs = top_k_renorm_prob(probs, sampling_info.top_ks) + probs = top_p_renorm_prob(probs, sampling_info.top_ps) + batch_next_token_ids, success = min_p_sampling_from_probs( + probs, uniform_samples, sampling_info.min_ps + ) + else: + batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + probs, + uniform_samples, + sampling_info.top_ks, + sampling_info.top_ps, + filter_apply_order="joint", + ) + + if not torch.all(success): + logger.warning("Detected errors during sampling!") + batch_next_token_ids = torch.zeros_like(batch_next_token_ids) + elif global_server_args_dict["sampling_backend"] == "pytorch": + # A slower fallback implementation with torch native operations. + batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( probs, - uniform_samples, sampling_info.top_ks, sampling_info.top_ps, - filter_apply_order="joint", + sampling_info.min_ps, + ) + else: + raise ValueError( + f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" ) - if not torch.all(success): - logger.warning("Detected errors during sampling!") - batch_next_token_ids = torch.zeros_like(batch_next_token_ids) - elif global_server_args_dict["sampling_backend"] == "pytorch": - # Here we provide a slower fallback implementation. - batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( - probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps - ) - else: - raise ValueError( - f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" - ) - - return batch_next_token_ids + return batch_next_token_ids.to(torch.int32) def top_k_top_p_min_p_sampling_from_probs_torch( diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 8b27d2a69..8032915e7 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -32,6 +32,15 @@ from sglang.srt.server_args import ServerArgs logger = logging.getLogger(__name__) +@torch.compile(dynamic=True) +def resolve_future_token_ids(input_ids, future_token_ids_map): + input_ids[:] = torch.where( + input_ids < 0, + future_token_ids_map[torch.clamp(-input_ids, min=0)], + input_ids, + ) + + class TpModelWorkerClient: """A tensor parallel model worker.""" @@ -99,33 +108,25 @@ class TpModelWorkerClient: # Resolve future tokens in the input input_ids = model_worker_batch.input_ids - input_ids[:] = torch.where( - input_ids < 0, - self.future_token_ids_map[torch.clamp(-input_ids, min=0)], - input_ids, - ) + resolve_future_token_ids(input_ids, self.future_token_ids_map) # Run forward logits_output, next_token_ids = self.worker.forward_batch_generation( model_worker_batch ) - self.launch_event.set() # Update the future token ids map bs = len(model_worker_batch.seq_lens) - future_next_token_ids = torch.arange( - -(future_token_ids_ct + bs), - -(future_token_ids_ct), - dtype=torch.int32, - device=self.device, - ) - self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to( - torch.int32 - ) + self.future_token_ids_map[ + future_token_ids_ct + 1 : future_token_ids_ct + bs + 1 + ] = next_token_ids + # Copy results to the CPU next_token_ids = next_token_ids.to("cpu", non_blocking=True) copy_event = torch.cuda.Event(blocking=True) copy_event.record() + + self.launch_event.set() self.copy_queue.put((copy_event, next_token_ids)) def copy_thread_func(self): @@ -149,8 +150,9 @@ class TpModelWorkerClient: # Allocate output future objects bs = len(model_worker_batch.seq_lens) future_next_token_ids = torch.arange( - -(self.future_token_ids_ct + bs), - -(self.future_token_ids_ct), + -(self.future_token_ids_ct + 1), + -(self.future_token_ids_ct + 1 + bs), + -1, dtype=torch.int32, device=self.device, ) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 4277862a7..181ac7eef 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -51,7 +51,7 @@ class ReqToTokenPool: self.write = self.write_without_records def write(self, indices, values): - # Keep the signature for type checking, will be initialized during runtime + # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def available_size(self): @@ -221,16 +221,21 @@ class MHATokenToKVPool(BaseTokenToKVPool): cache_v: torch.Tensor, ): layer_id = layer.layer_id - if cache_k.dtype != self.dtype: - cache_k = cache_k.to(self.dtype) - if cache_v.dtype != self.dtype: - cache_v = cache_v.to(self.dtype) - if self.store_dtype != self.dtype: - self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) - self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) - else: - self.k_buffer[layer_id][loc] = cache_k - self.v_buffer[layer_id][loc] = cache_v + copy_two_array( + loc, + self.k_buffer[layer_id], + cache_k, + self.v_buffer[layer_id], + cache_v, + self.dtype, + self.store_dtype, + ) + + +@torch.compile(dynamic=True) +def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): + dst_1[loc] = src_1.to(dtype).view(store_dtype) + dst_2[loc] = src_2.to(dtype).view(store_dtype) class MLATokenToKVPool(BaseTokenToKVPool): diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index ffa77ec4c..b859df358 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -92,6 +92,11 @@ def set_torch_compile_config(): torch._dynamo.config.accumulated_cache_size_limit = 1024 +@torch.compile(dynamic=True) +def clamp_position(seq_lens): + return torch.clamp((seq_lens - 1), min=0).to(torch.int64) + + class CudaGraphRunner: """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" @@ -112,7 +117,6 @@ class CudaGraphRunner: self.capture_bs = list(range(1, 32)) + [64, 128] else: self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] - self.capture_bs = [ bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size ] @@ -253,7 +257,7 @@ class CudaGraphRunner: encoder_lens=encoder_lens, return_logprob=False, top_logprobs_nums=[0] * bs, - positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), + positions=clamp_position(seq_lens), ) return forward(input_ids, forward_batch.positions, forward_batch) diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index 51b32ca01..fe88171ce 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -67,6 +67,7 @@ def run_eval(args): model=args.model, max_tokens=2048, base_url=base_url, + temperature=getattr(args, "temperature", 0.0), ) # Run eval @@ -119,6 +120,7 @@ if __name__ == "__main__": parser.add_argument("--eval-name", type=str, default="mmlu") parser.add_argument("--num-examples", type=int) parser.add_argument("--num-threads", type=int, default=512) + parser.add_argument("--temperature", type=float, default=0.0) args = parser.parse_args() run_eval(args) diff --git a/test/srt/test_eval_accuracy_mini.py b/test/srt/test_eval_accuracy_mini.py index 6ddd97d94..ee977a636 100644 --- a/test/srt/test_eval_accuracy_mini.py +++ b/test/srt/test_eval_accuracy_mini.py @@ -31,6 +31,7 @@ class TestEvalAccuracyMini(unittest.TestCase): eval_name="mmlu", num_examples=64, num_threads=32, + temperature=0.1, ) metrics = run_eval(args) diff --git a/test/srt/test_pytorch_sampling_backend.py b/test/srt/test_pytorch_sampling_backend.py index 5507182a7..ee06de8fa 100644 --- a/test/srt/test_pytorch_sampling_backend.py +++ b/test/srt/test_pytorch_sampling_backend.py @@ -23,7 +23,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--sampling-backend", "pytorch"], + other_args=["--sampling-backend", "pytorch", "--disable-radix-cache"], ) @classmethod @@ -37,6 +37,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase): eval_name="mmlu", num_examples=64, num_threads=32, + temperature=0.1, ) metrics = run_eval(args)