Signed-off-by: ybyang <ybyang7@iflytek.com> Signed-off-by: huanglong <huanglong@linux.alibaba.com> Co-authored-by: Huang Long <121648372+LLLL114@users.noreply.github.com> Co-authored-by: huanglong <huanglong@linux.alibaba.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -94,6 +94,7 @@ from sglang.srt.managers.io_struct import (
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
LoRAUpdateResult,
|
||||
MultiTokenizerWarpper,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -131,6 +132,7 @@ from sglang.srt.utils import (
|
||||
dataclass_to_string_truncated,
|
||||
freeze_gc,
|
||||
get_bool_env_var,
|
||||
get_origin_rid,
|
||||
get_zmq_socket,
|
||||
kill_process_tree,
|
||||
)
|
||||
@@ -266,9 +268,15 @@ class TokenizerManager:
|
||||
self.recv_from_detokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
||||
)
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||
)
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
|
||||
)
|
||||
else:
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||
)
|
||||
|
||||
# Request states
|
||||
self.no_create_loop = False
|
||||
@@ -312,35 +320,7 @@ class TokenizerManager:
|
||||
self.lora_update_lock = asyncio.Lock()
|
||||
|
||||
# For PD disaggregtion
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.disaggregation_transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
# Start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
kv_bootstrap_server_class = get_kv_class(
|
||||
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
)
|
||||
self.bootstrap_server = kv_bootstrap_server_class(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
is_create_store = (
|
||||
self.server_args.node_rank == 0
|
||||
and self.server_args.disaggregation_transfer_backend == "ascend"
|
||||
)
|
||||
if is_create_store:
|
||||
try:
|
||||
from mf_adapter import create_config_store
|
||||
|
||||
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
||||
create_config_store(ascend_url)
|
||||
except Exception as e:
|
||||
error_message = f"Failed create mf store, invalid ascend_url."
|
||||
error_message += f" With exception {e}"
|
||||
raise error_message
|
||||
self.init_disaggregation()
|
||||
|
||||
# For load balancing
|
||||
self.current_load = 0
|
||||
@@ -488,6 +468,37 @@ class TokenizerManager:
|
||||
]
|
||||
)
|
||||
|
||||
def init_disaggregation(self):
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.disaggregation_transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
# Start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
kv_bootstrap_server_class = get_kv_class(
|
||||
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
)
|
||||
self.bootstrap_server = kv_bootstrap_server_class(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
is_create_store = (
|
||||
self.server_args.node_rank == 0
|
||||
and self.server_args.disaggregation_transfer_backend == "ascend"
|
||||
)
|
||||
if is_create_store:
|
||||
try:
|
||||
from mf_adapter import create_config_store
|
||||
|
||||
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
||||
create_config_store(ascend_url)
|
||||
except Exception as e:
|
||||
error_message = f"Failed create mf store, invalid ascend_url."
|
||||
error_message += f" With exception {e}"
|
||||
raise error_message
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -497,6 +508,15 @@ class TokenizerManager:
|
||||
self.auto_create_handle_loop()
|
||||
obj.normalize_batch_and_arguments()
|
||||
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
# Modify rid, add worker_id
|
||||
if isinstance(obj.rid, list):
|
||||
# If it's an array, add worker_id prefix to each element
|
||||
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
|
||||
else:
|
||||
# If it's a single value, add worker_id prefix
|
||||
obj.rid = f"{self.worker_id}_{obj.rid}"
|
||||
|
||||
if self.log_requests:
|
||||
max_length, skip_names, _ = self.log_request_metadata
|
||||
logger.info(
|
||||
@@ -1096,6 +1116,8 @@ class TokenizerManager:
|
||||
async def _wait_for_model_update_from_disk(
|
||||
self, obj: UpdateWeightFromDiskReqInput
|
||||
) -> Tuple[bool, str]:
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
if self.server_args.dp_size == 1:
|
||||
@@ -1315,6 +1337,8 @@ class TokenizerManager:
|
||||
elif obj.session_id in self.session_futures:
|
||||
return None
|
||||
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
self.session_futures[obj.session_id] = asyncio.Future()
|
||||
@@ -1590,7 +1614,6 @@ class TokenizerManager:
|
||||
|
||||
async def handle_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
self._result_dispatcher(recv_obj)
|
||||
@@ -1610,9 +1633,12 @@ class TokenizerManager:
|
||||
)
|
||||
continue
|
||||
|
||||
origin_rid = rid
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
origin_rid = get_origin_rid(rid)
|
||||
# Build meta_info and return value
|
||||
meta_info = {
|
||||
"id": rid,
|
||||
"id": origin_rid,
|
||||
"finish_reason": recv_obj.finished_reasons[i],
|
||||
"prompt_tokens": recv_obj.prompt_tokens[i],
|
||||
"weight_version": self.server_args.weight_version,
|
||||
@@ -1918,6 +1944,9 @@ class TokenizerManager:
|
||||
if is_health_check_generate_req(recv_obj):
|
||||
return
|
||||
state = self.rid_to_state[recv_obj.rid]
|
||||
origin_rid = recv_obj.rid
|
||||
if self.server_args.tokenizer_worker_num > 1:
|
||||
origin_rid = get_origin_rid(origin_rid)
|
||||
state.finished = True
|
||||
if recv_obj.finished_reason:
|
||||
out = {
|
||||
@@ -1930,7 +1959,7 @@ class TokenizerManager:
|
||||
out = {
|
||||
"text": "",
|
||||
"meta_info": {
|
||||
"id": recv_obj.rid,
|
||||
"id": origin_rid,
|
||||
"finish_reason": {
|
||||
"type": "abort",
|
||||
"message": "Abort before prefill",
|
||||
@@ -2116,6 +2145,8 @@ T = TypeVar("T")
|
||||
class _Communicator(Generic[T]):
|
||||
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
||||
|
||||
enable_multi_tokenizer = False
|
||||
|
||||
def __init__(self, sender, fan_out: int):
|
||||
self._sender = sender
|
||||
self._fan_out = fan_out
|
||||
@@ -2132,6 +2163,8 @@ class _Communicator(Generic[T]):
|
||||
assert self._result_values is None
|
||||
|
||||
if obj:
|
||||
if _Communicator.enable_multi_tokenizer:
|
||||
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
|
||||
self._sender.send_pyobj(obj)
|
||||
|
||||
self._result_event = asyncio.Event()
|
||||
|
||||
Reference in New Issue
Block a user