diff --git a/python/pyproject.toml b/python/pyproject.toml index 6966d1452..d60477f82 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ [project.optional-dependencies] srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", - "zmq", "vllm>=0.4.2", "interegular", "pydantic", "pillow", "outlines>=0.0.27", "flashinfer>=0.0.4", "packaging"] + "zmq", "vllm>=0.4.2", "interegular", "pydantic", "pillow", "outlines>=0.0.27", "packaging"] openai = ["openai>=1.0", "numpy", "tiktoken"] anthropic = ["anthropic>=0.20.0", "numpy"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 44b9d0210..19434911d 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -113,7 +113,8 @@ class ModelRpcServer: f"max_prefill_num_token={self.max_prefill_num_token}, " f"context_len={self.model_config.context_len}, " ) - logger.info(f"server_args: {server_args.print_mode_args()}") + if self.tp_rank == 0: + logger.info(f"server_args: {server_args.print_mode_args()}") # Init cache self.tree_cache = RadixCache(disable=server_args.disable_radix_cache) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index d24db31c7..565e1beac 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -110,12 +110,12 @@ class InputMetadata: self.kv_last_page_len = torch.ones( (self.batch_size,), dtype=torch.int32, device="cuda" ) - req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() - seq_lens_cpu = self.seq_lens.cpu().numpy() + req_pool_indices_cpu = self.req_pool_indices.cpu().tolist() + seq_lens_cpu = self.seq_lens.tolist() self.kv_indices = torch.cat( [ self.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i]: seq_lens_cpu[i] + req_pool_indices_cpu[i], : seq_lens_cpu[i] ] for i in range(self.batch_size) ],