Organize sampling batch info better (#1562)
This commit is contained in:
@@ -96,7 +96,9 @@ class Scheduler:
|
|||||||
|
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.scheduler_port}")
|
self.recv_from_tokenizer.bind(
|
||||||
|
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
|
||||||
|
)
|
||||||
|
|
||||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||||
self.send_to_detokenizer.connect(
|
self.send_to_detokenizer.connect(
|
||||||
@@ -141,9 +143,6 @@ class Scheduler:
|
|||||||
nccl_port=port_args.nccl_ports[0],
|
nccl_port=port_args.nccl_ports[0],
|
||||||
)
|
)
|
||||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||||
self.pad_input_ids_func = getattr(
|
|
||||||
self.tp_worker.model_runner.model, "pad_input_ids", None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get token and memory info from the tp worker
|
# Get token and memory info from the tp worker
|
||||||
(
|
(
|
||||||
@@ -154,6 +153,9 @@ class Scheduler:
|
|||||||
self.random_seed,
|
self.random_seed,
|
||||||
) = self.tp_worker.get_token_and_memory_info()
|
) = self.tp_worker.get_token_and_memory_info()
|
||||||
set_random_seed(self.random_seed)
|
set_random_seed(self.random_seed)
|
||||||
|
self.pad_input_ids_func = getattr(
|
||||||
|
self.tp_worker.model_runner.model, "pad_input_ids", None
|
||||||
|
)
|
||||||
|
|
||||||
# Print debug info
|
# Print debug info
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -87,7 +87,9 @@ class TokenizerManager:
|
|||||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||||
|
|
||||||
self.send_to_scheduler = context.socket(zmq.PUSH)
|
self.send_to_scheduler = context.socket(zmq.PUSH)
|
||||||
self.send_to_scheduler.connect(f"tcp://127.0.0.1:{port_args.scheduler_port}")
|
self.send_to_scheduler.connect(
|
||||||
|
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
|
||||||
|
)
|
||||||
|
|
||||||
# Read model args
|
# Read model args
|
||||||
self.model_path = server_args.model_path
|
self.model_path = server_args.model_path
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class ReqToTokenPool:
|
|||||||
|
|
||||||
def __init__(self, size: int, max_context_len: int, device: str):
|
def __init__(self, size: int, max_context_len: int, device: str):
|
||||||
self.size = size
|
self.size = size
|
||||||
|
self.max_context_len = max_context_len
|
||||||
self.free_slots = list(range(size))
|
self.free_slots = list(range(size))
|
||||||
self.req_to_token = torch.empty(
|
self.req_to_token = torch.empty(
|
||||||
(size, max_context_len), dtype=torch.int32, device=device
|
(size, max_context_len), dtype=torch.int32, device=device
|
||||||
@@ -54,7 +55,7 @@ class ReqToTokenPool:
|
|||||||
self.free_slots = list(range(self.size))
|
self.free_slots = list(range(self.size))
|
||||||
|
|
||||||
|
|
||||||
class BaseTokenToKVPool(ABC):
|
class BaseTokenToKVPool:
|
||||||
"""A memory pool that maps a token to its kv cache locations"""
|
"""A memory pool that maps a token to its kv cache locations"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -92,19 +93,15 @@ class BaseTokenToKVPool(ABC):
|
|||||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
self.free_slots = np.arange(1, self.size + 1)
|
self.free_slots = np.arange(1, self.size + 1)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def set_kv_buffer(
|
def set_kv_buffer(
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
|
|||||||
@@ -411,8 +411,8 @@ class ModelRunner:
|
|||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
self.req_to_token_pool = ReqToTokenPool(
|
self.req_to_token_pool = ReqToTokenPool(
|
||||||
max_num_reqs + 1,
|
size=max_num_reqs + 1,
|
||||||
self.model_config.context_len + 4,
|
max_context_len=self.model_config.context_len + 4,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -14,16 +14,17 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class SamplingBatchInfo:
|
class SamplingBatchInfo:
|
||||||
# Basic Info
|
|
||||||
vocab_size: int
|
|
||||||
|
|
||||||
# Batched sampling params
|
# Batched sampling params
|
||||||
temperatures: torch.Tensor = None
|
temperatures: torch.Tensor
|
||||||
top_ps: torch.Tensor = None
|
top_ps: torch.Tensor
|
||||||
top_ks: torch.Tensor = None
|
top_ks: torch.Tensor
|
||||||
min_ps: torch.Tensor = None
|
min_ps: torch.Tensor
|
||||||
|
|
||||||
|
# Dispatch in CUDA graph
|
||||||
|
need_min_p_sampling: bool
|
||||||
|
|
||||||
# Bias Tensors
|
# Bias Tensors
|
||||||
|
vocab_size: int
|
||||||
logit_bias: torch.Tensor = None
|
logit_bias: torch.Tensor = None
|
||||||
vocab_mask: torch.Tensor = None
|
vocab_mask: torch.Tensor = None
|
||||||
|
|
||||||
@@ -31,9 +32,6 @@ class SamplingBatchInfo:
|
|||||||
regex_fsms: List[RegexGuide] = None
|
regex_fsms: List[RegexGuide] = None
|
||||||
regex_fsm_states: List[int] = None
|
regex_fsm_states: List[int] = None
|
||||||
|
|
||||||
# Dispatch in CUDA graph
|
|
||||||
need_min_p_sampling: bool = False
|
|
||||||
|
|
||||||
# Penalizer
|
# Penalizer
|
||||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||||
linear_penalties: torch.Tensor = None
|
linear_penalties: torch.Tensor = None
|
||||||
@@ -42,25 +40,30 @@ class SamplingBatchInfo:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
ret = cls(vocab_size=vocab_size)
|
|
||||||
|
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
ret.temperatures = torch.tensor(
|
temperatures = torch.tensor(
|
||||||
[r.sampling_params.temperature for r in reqs],
|
[r.sampling_params.temperature for r in reqs],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
).view(-1, 1)
|
).view(-1, 1)
|
||||||
ret.top_ps = torch.tensor(
|
top_ps = torch.tensor(
|
||||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
||||||
)
|
)
|
||||||
ret.top_ks = torch.tensor(
|
top_ks = torch.tensor(
|
||||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int
|
||||||
)
|
)
|
||||||
ret.min_ps = torch.tensor(
|
min_ps = torch.tensor(
|
||||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ret = cls(
|
||||||
|
temperatures=temperatures,
|
||||||
|
top_ps=top_ps,
|
||||||
|
top_ks=top_ks,
|
||||||
|
min_ps=min_ps,
|
||||||
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
)
|
||||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||||
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
|
||||||
|
|
||||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ async def health_generate(request: Request) -> Response:
|
|||||||
|
|
||||||
@app.get("/get_model_info")
|
@app.get("/get_model_info")
|
||||||
async def get_model_info():
|
async def get_model_info():
|
||||||
|
"""Get the model information."""
|
||||||
result = {
|
result = {
|
||||||
"model_path": tokenizer_manager.model_path,
|
"model_path": tokenizer_manager.model_path,
|
||||||
"is_generation": tokenizer_manager.is_generation,
|
"is_generation": tokenizer_manager.is_generation,
|
||||||
@@ -127,11 +128,13 @@ async def get_model_info():
|
|||||||
|
|
||||||
@app.get("/get_server_args")
|
@app.get("/get_server_args")
|
||||||
async def get_server_args():
|
async def get_server_args():
|
||||||
|
"""Get the server arguments."""
|
||||||
return dataclasses.asdict(tokenizer_manager.server_args)
|
return dataclasses.asdict(tokenizer_manager.server_args)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/flush_cache")
|
@app.get("/flush_cache")
|
||||||
async def flush_cache():
|
async def flush_cache():
|
||||||
|
"""Flush the radix cache."""
|
||||||
tokenizer_manager.flush_cache()
|
tokenizer_manager.flush_cache()
|
||||||
return Response(
|
return Response(
|
||||||
content="Cache flushed.\nPlease check backend logs for more details. "
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
||||||
@@ -142,7 +145,7 @@ async def flush_cache():
|
|||||||
|
|
||||||
@app.post("/update_weights")
|
@app.post("/update_weights")
|
||||||
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
||||||
|
"""Update the weights inplace without re-launching the server."""
|
||||||
success, message = await tokenizer_manager.update_weights(obj, request)
|
success, message = await tokenizer_manager.update_weights(obj, request)
|
||||||
content = {"success": success, "message": message}
|
content = {"success": success, "message": message}
|
||||||
if success:
|
if success:
|
||||||
@@ -205,7 +208,7 @@ app.put("/encode")(encode_request)
|
|||||||
|
|
||||||
|
|
||||||
async def judge_request(obj: RewardReqInput, request: Request):
|
async def judge_request(obj: RewardReqInput, request: Request):
|
||||||
"""Handle an embedding request."""
|
"""Handle a reward model request."""
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
return ret
|
return ret
|
||||||
@@ -307,7 +310,7 @@ def launch_server(
|
|||||||
ports = server_args.additional_ports
|
ports = server_args.additional_ports
|
||||||
port_args = PortArgs(
|
port_args = PortArgs(
|
||||||
tokenizer_port=ports[0],
|
tokenizer_port=ports[0],
|
||||||
scheduler_port=ports[1],
|
scheduler_input_port=ports[1],
|
||||||
detokenizer_port=ports[2],
|
detokenizer_port=ports[2],
|
||||||
nccl_ports=ports[3:],
|
nccl_ports=ports[3:],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -627,10 +627,11 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
|||||||
class PortArgs:
|
class PortArgs:
|
||||||
# The port for tokenizer to receive inputs from detokenizer (zmq)
|
# The port for tokenizer to receive inputs from detokenizer (zmq)
|
||||||
tokenizer_port: int
|
tokenizer_port: int
|
||||||
# The port for scheduler to receive inputs from tokenizer (zmq)
|
# The port for scheduler (rank 0) to receive inputs from tokenizer (zmq)
|
||||||
scheduler_port: int
|
scheduler_input_port: int
|
||||||
# The port for detokenizer to receive inputs from scheduler (zmq)
|
# The port for detokenizer to receive inputs from scheduler (zmq)
|
||||||
detokenizer_port: int
|
detokenizer_port: int
|
||||||
|
|
||||||
# The port for nccl initialization for multiple TP groups (torch.dist)
|
# The port for nccl initialization for multiple TP groups (torch.dist)
|
||||||
nccl_ports: List[int]
|
nccl_ports: List[int]
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
|
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
|
||||||
|
|||||||
Reference in New Issue
Block a user