diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9679bb0ae..988844004 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -96,7 +96,9 @@ class Scheduler: if self.tp_rank == 0: self.recv_from_tokenizer = context.socket(zmq.PULL) - self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.scheduler_port}") + self.recv_from_tokenizer.bind( + f"tcp://127.0.0.1:{port_args.scheduler_input_port}" + ) self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer.connect( @@ -141,9 +143,6 @@ class Scheduler: nccl_port=port_args.nccl_ports[0], ) self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group - self.pad_input_ids_func = getattr( - self.tp_worker.model_runner.model, "pad_input_ids", None - ) # Get token and memory info from the tp worker ( @@ -154,6 +153,9 @@ class Scheduler: self.random_seed, ) = self.tp_worker.get_token_and_memory_info() set_random_seed(self.random_seed) + self.pad_input_ids_func = getattr( + self.tp_worker.model_runner.model, "pad_input_ids", None + ) # Print debug info logger.info( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 00c75fcda..c0a0ff34c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -87,7 +87,9 @@ class TokenizerManager: self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.send_to_scheduler = context.socket(zmq.PUSH) - self.send_to_scheduler.connect(f"tcp://127.0.0.1:{port_args.scheduler_port}") + self.send_to_scheduler.connect( + f"tcp://127.0.0.1:{port_args.scheduler_input_port}" + ) # Read model args self.model_path = server_args.model_path diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 7dea057d1..7ce80b57c 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -30,6 +30,7 @@ class ReqToTokenPool: def __init__(self, size: int, max_context_len: int, device: str): self.size = size + self.max_context_len = max_context_len self.free_slots = list(range(size)) self.req_to_token = torch.empty( (size, max_context_len), dtype=torch.int32, device=device @@ -54,7 +55,7 @@ class ReqToTokenPool: self.free_slots = list(range(self.size)) -class BaseTokenToKVPool(ABC): +class BaseTokenToKVPool: """A memory pool that maps a token to its kv cache locations""" def __init__( @@ -92,19 +93,15 @@ class BaseTokenToKVPool(ABC): # The padded slot 0 is used for writing dummy outputs from padded tokens. self.free_slots = np.arange(1, self.size + 1) - @abstractmethod def get_key_buffer(self, layer_id: int) -> torch.Tensor: raise NotImplementedError() - @abstractmethod def get_value_buffer(self, layer_id: int) -> torch.Tensor: raise NotImplementedError() - @abstractmethod def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError() - @abstractmethod def set_kv_buffer( self, layer_id: int, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f596acbe9..83273bc43 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -411,8 +411,8 @@ class ModelRunner: device = "cuda" self.req_to_token_pool = ReqToTokenPool( - max_num_reqs + 1, - self.model_config.context_len + 4, + size=max_num_reqs + 1, + max_context_len=self.model_config.context_len + 4, device=device, ) if ( diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index fb0dfc53c..247c15d8e 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -14,16 +14,17 @@ if TYPE_CHECKING: @dataclasses.dataclass class SamplingBatchInfo: - # Basic Info - vocab_size: int - # Batched sampling params - temperatures: torch.Tensor = None - top_ps: torch.Tensor = None - top_ks: torch.Tensor = None - min_ps: torch.Tensor = None + temperatures: torch.Tensor + top_ps: torch.Tensor + top_ks: torch.Tensor + min_ps: torch.Tensor + + # Dispatch in CUDA graph + need_min_p_sampling: bool # Bias Tensors + vocab_size: int logit_bias: torch.Tensor = None vocab_mask: torch.Tensor = None @@ -31,9 +32,6 @@ class SamplingBatchInfo: regex_fsms: List[RegexGuide] = None regex_fsm_states: List[int] = None - # Dispatch in CUDA graph - need_min_p_sampling: bool = False - # Penalizer penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None linear_penalties: torch.Tensor = None @@ -42,25 +40,30 @@ class SamplingBatchInfo: @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): reqs = batch.reqs - ret = cls(vocab_size=vocab_size) - with torch.device("cuda"): - ret.temperatures = torch.tensor( + temperatures = torch.tensor( [r.sampling_params.temperature for r in reqs], dtype=torch.float, ).view(-1, 1) - ret.top_ps = torch.tensor( + top_ps = torch.tensor( [r.sampling_params.top_p for r in reqs], dtype=torch.float ) - ret.top_ks = torch.tensor( + top_ks = torch.tensor( [r.sampling_params.top_k for r in reqs], dtype=torch.int ) - ret.min_ps = torch.tensor( + min_ps = torch.tensor( [r.sampling_params.min_p for r in reqs], dtype=torch.float ) + ret = cls( + temperatures=temperatures, + top_ps=top_ps, + top_ks=top_ks, + min_ps=min_ps, + need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), + vocab_size=vocab_size, + ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. - ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs) # Each penalizers will do nothing if they evaluate themselves as not required by looking at # the sampling_params of the requests (See {_is_required()} of each penalizers). So this diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 258ddc303..583e60989 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -118,6 +118,7 @@ async def health_generate(request: Request) -> Response: @app.get("/get_model_info") async def get_model_info(): + """Get the model information.""" result = { "model_path": tokenizer_manager.model_path, "is_generation": tokenizer_manager.is_generation, @@ -127,11 +128,13 @@ async def get_model_info(): @app.get("/get_server_args") async def get_server_args(): + """Get the server arguments.""" return dataclasses.asdict(tokenizer_manager.server_args) @app.get("/flush_cache") async def flush_cache(): + """Flush the radix cache.""" tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " @@ -142,7 +145,7 @@ async def flush_cache(): @app.post("/update_weights") async def update_weights(obj: UpdateWeightReqInput, request: Request): - + """Update the weights inplace without re-launching the server.""" success, message = await tokenizer_manager.update_weights(obj, request) content = {"success": success, "message": message} if success: @@ -205,7 +208,7 @@ app.put("/encode")(encode_request) async def judge_request(obj: RewardReqInput, request: Request): - """Handle an embedding request.""" + """Handle a reward model request.""" try: ret = await tokenizer_manager.generate_request(obj, request).__anext__() return ret @@ -307,7 +310,7 @@ def launch_server( ports = server_args.additional_ports port_args = PortArgs( tokenizer_port=ports[0], - scheduler_port=ports[1], + scheduler_input_port=ports[1], detokenizer_port=ports[2], nccl_ports=ports[3:], ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index cf54df1b1..ceacd9364 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -627,10 +627,11 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: class PortArgs: # The port for tokenizer to receive inputs from detokenizer (zmq) tokenizer_port: int - # The port for scheduler to receive inputs from tokenizer (zmq) - scheduler_port: int + # The port for scheduler (rank 0) to receive inputs from tokenizer (zmq) + scheduler_input_port: int # The port for detokenizer to receive inputs from scheduler (zmq) detokenizer_port: int + # The port for nccl initialization for multiple TP groups (torch.dist) nccl_ports: List[int] diff --git a/test/killall_sglang.sh b/test/killall_sglang.sh index c536548d4..b26d86b6f 100644 --- a/test/killall_sglang.sh +++ b/test/killall_sglang.sh @@ -1 +1 @@ -kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') +kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')