From bdc1acf6cdadf6bf08f7d2d895c8099023253d36 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 7 Jan 2025 02:52:53 -0800 Subject: [PATCH] Misc fix for min_p_sampling, --cuda-graph-bs (#2761) --- .github/workflows/pr-test.yml | 4 +- .../tuning_fused_moe_triton.py | 1 + python/pyproject.toml | 12 +++-- python/sglang/bench_serving.py | 5 +- python/sglang/srt/layers/logits_processor.py | 5 ++ .../layers/moe/fused_moe_triton/fused_moe.py | 21 ++++++-- .../srt/layers/quantization/__init__.py | 3 +- .../srt/managers/data_parallel_controller.py | 2 + python/sglang/srt/managers/scheduler.py | 5 +- python/sglang/srt/metrics/collector.py | 52 ++++++++----------- .../srt/model_executor/cuda_graph_runner.py | 9 +++- .../sglang/srt/model_executor/model_runner.py | 9 +++- .../srt/sampling/sampling_batch_info.py | 1 + python/sglang/srt/server.py | 10 ++-- python/sglang/srt/server_args.py | 7 +++ python/sglang/srt/utils.py | 11 ++-- python/sglang/test/test_utils.py | 41 ++++++++++++--- 17 files changed, 135 insertions(+), 63 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index c9ea1105b..f1c7871de 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -66,12 +66,14 @@ jobs: - name: Run test timeout-minutes: 25 run: | - cd test/srt RANGE=${{ matrix.range }} range_begin=${RANGE%-*} range_end=${RANGE#*-} + + cd test/srt python3 run_suite.py --suite per-commit --range-begin ${range_begin} --range-end ${range_end} + unit-test-backend-2-gpu: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 2-gpu-runner diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 06f64813d..72715fb50 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -228,6 +228,7 @@ class BenchmarkWorker: hidden_size, topk, dtype_str, + False, ) else: config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] diff --git a/python/pyproject.toml b/python/pyproject.toml index a824d47f3..d536f8832 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -16,14 +16,20 @@ classifiers = [ dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"] [project.optional-dependencies] -runtime_common = ["aiohttp", "decord", "fastapi", +runtime_common = [ + "aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "modelscope", "orjson", "outlines>=0.0.44,<0.1.0", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", - "xgrammar>=0.1.6"] -srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6", "sgl-kernel>=0.0.2.post11"] + "xgrammar>=0.1.6" +] +srt = [ + "sglang[runtime_common]", "cuda-python", + "sgl-kernel>=0.0.2.post11", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", + "flashinfer==0.1.6" +] # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 6067a7444..4744ad338 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -563,7 +563,7 @@ def sample_sharegpt_requests( raise ValueError("output_len too small") # Download sharegpt if necessary - if not os.path.isfile(dataset_path): + if not os.path.isfile(dataset_path) and dataset_path == "": dataset_path = download_and_cache_file(SHAREGPT_URL) # Load the dataset. @@ -1064,8 +1064,11 @@ async def benchmark( "total_output_tokens_retokenized": metrics.total_output_retokenized, "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + "mean_ttft_ms": metrics.mean_ttft_ms, "median_ttft_ms": metrics.median_ttft_ms, + "mean_itl_ms": metrics.mean_itl_ms, "median_itl_ms": metrics.median_itl_ms, + "input_throughput": metrics.input_throughput, "output_throughput": metrics.output_throughput, "sharegpt_output_len": args.sharegpt_output_len, "random_input_len": args.random_input_len, diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 51e73d072..7ca1d51a7 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -117,6 +117,11 @@ class LogitsProcessor(nn.Module): self.final_logit_softcapping = getattr( self.config, "final_logit_softcapping", None ) + if ( + self.final_logit_softcapping is not None + and self.final_logit_softcapping < 0 + ): + self.final_logit_softcapping = None def forward( self, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 1c8700783..ed132555b 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -1011,11 +1011,22 @@ def fused_experts_impl( out_hidden_states[begin_chunk_idx:end_chunk_idx], ) else: - torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=out_hidden_states[begin_chunk_idx:end_chunk_idx], - ) + if topk_ids.shape[1] == 1: + out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_( + intermediate_cache3[:, 0] + ) + elif topk_ids.shape[1] == 2: + torch.add( + intermediate_cache3[:, 0], + intermediate_cache3[:, 1], + out=out_hidden_states[begin_chunk_idx:end_chunk_idx], + ).squeeze(dim=1) + elif topk_ids.shape[1] > 2: + torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) return out_hidden_states diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 2ff570ba1..df20a7a4b 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -1,8 +1,7 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py -from typing import Callable, Dict, Optional, Type +from typing import Dict, Type -import torch from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 8edb79417..7ae6689ee 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -20,6 +20,7 @@ import threading from enum import Enum, auto import psutil +import setproctitle import zmq from sglang.srt.managers.io_struct import ( @@ -230,6 +231,7 @@ def run_data_parallel_controller_process( port_args: PortArgs, pipe_writer, ): + setproctitle.setproctitle("sglang::data_parallel_controller") configure_logger(server_args) parent_process = psutil.Process().parent() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 180b4d96f..6022a2567 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1516,8 +1516,9 @@ class Scheduler: return success, message def update_weights_from_distributed( - self, recv_req: UpdateWeightsFromDistributedReqInput - ): + self, + recv_req: UpdateWeightsFromDistributedReqInput, + ) -> Tuple[bool, str]: """Update the online model parameter.""" success, message = self.tp_worker.update_weights_from_distributed(recv_req) if success: diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index d5ae98834..9505f012f 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -114,26 +114,20 @@ class TokenizerMetricsCollector: documentation="Histogram of time to first token in seconds.", labelnames=labels.keys(), buckets=[ - 0.001, - 0.005, - 0.01, - 0.02, - 0.04, - 0.06, - 0.08, 0.1, 0.25, 0.5, 0.75, - 1.0, - 2.5, - 5.0, - 7.5, - 10.0, - 15.0, - 20.0, - 25.0, - 30.0, + 1, + 2, + 5, + 10, + 20, + 40, + 60, + 80, + 120, + 160, ], ) @@ -168,21 +162,19 @@ class TokenizerMetricsCollector: documentation="Histogram of End-to-end request latency in seconds", labelnames=labels.keys(), buckets=[ - 0.3, + 0.1, + 0.25, 0.5, - 0.8, - 1.0, - 1.5, - 2.0, - 2.5, - 5.0, - 10.0, - 15.0, - 20.0, - 30.0, - 40.0, - 50.0, - 60.0, + 1, + 2, + 5, + 10, + 20, + 40, + 60, + 80, + 120, + 160, ], ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 99e72a3d0..deaea3312 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -124,6 +124,13 @@ class CudaGraphRunner: self.tp_size = self.model_runner.tp_size # Batch sizes to capture + self.capture_bs = self.model_runner.server_args.cuda_graph_bs + if self.capture_bs is None: + if model_runner.server_args.disable_cuda_graph_padding: + self.capture_bs = list(range(1, 33)) + [64, 128] + else: + self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] + if model_runner.server_args.disable_cuda_graph_padding: self.capture_bs = list(range(1, 33)) + [64, 128] else: @@ -340,8 +347,8 @@ class CudaGraphRunner: top_logprobs_nums=[0] * bs, positions=positions, global_num_tokens=global_num_tokens, - mrope_positions=mrope_positions, gathered_buffer=gathered_buffer, + mrope_positions=mrope_positions, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, capture_hidden_mode=( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 41905e272..7cd9e759a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -89,6 +89,7 @@ class ModelRunner: self.is_draft_worker = is_draft_worker self.is_generation = model_config.is_generation self.is_multimodal = model_config.is_multimodal + self.should_log = tp_rank == 0 self.spec_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) @@ -117,15 +118,21 @@ class ModelRunner: if self.is_multimodal: self.mem_fraction_static *= 0.95 + logger.info( + f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} " + f"because this is a multimodal model." + ) + if self.model_config.hf_config.architectures == [ "MllamaForConditionalGeneration" ]: logger.info("Automatically turn off --chunked-prefill-size for mllama.") server_args.chunked_prefill_size = -1 - # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically + if self.model_config.hf_config.architectures == [ "Qwen2VLForConditionalGeneration" ]: + # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically logger.info( "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl." ) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 3a46b2209..9497e53d3 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -232,6 +232,7 @@ class SamplingBatchInfo: self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device ) + self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling def apply_logits_bias(self, logits: torch.Tensor): # Apply logit_bias diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index b58a3b032..f60af5d73 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -127,14 +127,12 @@ async def health() -> Response: async def health_generate(request: Request) -> Response: """Check the health of the inference server by generating one token.""" + sampling_params = {"max_new_tokens": 1, "temperature": 0.7} + if tokenizer_manager.is_generation: - gri = GenerateReqInput( - input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7} - ) + gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params) else: - gri = EmbeddingReqInput( - input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7} - ) + gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params) try: async for _ in tokenizer_manager.generate_request(gri, request): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5ed78dc5e..ef4df60a5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -148,6 +148,7 @@ class ServerArgs: enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None + cuda_graph_bs: Optional[List[int]] = None torchao_config: str = "" enable_nan_detection: bool = False enable_p2p_check: bool = False @@ -803,6 +804,12 @@ class ServerArgs: default=ServerArgs.cuda_graph_max_bs, help="Set the maximum batch size for cuda graph.", ) + parser.add_argument( + "--cuda-graph-bs", + type=int, + nargs="+", + help="Set the list of batch sizes for cuda graph.", + ) parser.add_argument( "--torchao-config", type=str, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 8ee9d205c..6f3144ca6 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -709,13 +709,14 @@ def broadcast_pyobj( data: List[Any], rank: int, dist_group: Optional[torch.distributed.ProcessGroup] = None, + src: int = 0, ): """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" if rank == 0: if len(data) == 0: tensor_size = torch.tensor([0], dtype=torch.long) - dist.broadcast(tensor_size, src=0, group=dist_group) + dist.broadcast(tensor_size, src=src, group=dist_group) else: serialized_data = pickle.dumps(data) size = len(serialized_data) @@ -724,19 +725,19 @@ def broadcast_pyobj( ) tensor_size = torch.tensor([size], dtype=torch.long) - dist.broadcast(tensor_size, src=0, group=dist_group) - dist.broadcast(tensor_data, src=0, group=dist_group) + dist.broadcast(tensor_size, src=src, group=dist_group) + dist.broadcast(tensor_data, src=src, group=dist_group) return data else: tensor_size = torch.tensor([0], dtype=torch.long) - dist.broadcast(tensor_size, src=0, group=dist_group) + dist.broadcast(tensor_size, src=src, group=dist_group) size = tensor_size.item() if size == 0: return [] tensor_data = torch.empty(size, dtype=torch.uint8) - dist.broadcast(tensor_data, src=0, group=dist_group) + dist.broadcast(tensor_data, src=src, group=dist_group) serialized_data = bytes(tensor_data.cpu().numpy()) data = pickle.loads(serialized_data) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 32c6e08b6..cd21c896a 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -532,6 +532,8 @@ def run_bench_serving( request_rate, other_server_args, dataset_name="random", + dataset_path="", + tokenizer=None, random_input_len=4096, random_output_len=2048, disable_stream=False, @@ -553,9 +555,9 @@ def run_bench_serving( host=None, port=None, dataset_name=dataset_name, - dataset_path="", + dataset_path=dataset_path, model=None, - tokenizer=None, + tokenizer=tokenizer, num_prompts=num_prompts, sharegpt_output_len=None, random_input_len=random_input_len, @@ -657,16 +659,16 @@ STDERR_FILENAME = "stderr.txt" STDOUT_FILENAME = "stdout.txt" -def read_output(output_lines): +def read_output(output_lines: List[str], filename: str = STDERR_FILENAME): """Print the output in real time with another thread.""" - while not os.path.exists(STDERR_FILENAME): + while not os.path.exists(filename): time.sleep(1) pt = 0 while pt >= 0: - if pt > 0 and not os.path.exists(STDERR_FILENAME): + if pt > 0 and not os.path.exists(filename): break - lines = open(STDERR_FILENAME).readlines() + lines = open(filename).readlines() for line in lines[pt:]: print(line, end="", flush=True) output_lines.append(line) @@ -747,6 +749,33 @@ def run_and_check_memory_leak( assert has_abort +def run_command_and_capture_output(command, env: Optional[dict] = None): + stdout = open(STDOUT_FILENAME, "w") + stderr = open(STDERR_FILENAME, "w") + process = subprocess.Popen( + command, stdout=stdout, stderr=stderr, env=env, text=True + ) + + # Launch a thread to stream the output + output_lines = [] + t = threading.Thread(target=read_output, args=(output_lines, STDOUT_FILENAME)) + t.start() + + # Join the process + process.wait() + + stdout.close() + stderr.close() + if os.path.exists(STDOUT_FILENAME): + os.remove(STDOUT_FILENAME) + if os.path.exists(STDERR_FILENAME): + os.remove(STDERR_FILENAME) + kill_process_tree(process.pid) + t.join() + + return output_lines + + def run_mmlu_test( disable_radix_cache=False, enable_mixed_chunk=False,