[1/2] Refactor multi-tokenizer manager (#10074)
This commit is contained in:
@@ -54,19 +54,14 @@ from fastapi import BackgroundTasks
|
||||
|
||||
from sglang.srt.aio_rwlock import RWLock
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
KVClassType,
|
||||
TransferBackend,
|
||||
get_kv_class,
|
||||
)
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_processor,
|
||||
get_tokenizer,
|
||||
get_tokenizer_from_processor,
|
||||
)
|
||||
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
@@ -321,8 +316,10 @@ class TokenizerManager:
|
||||
# LoRA updates and inference to overlap.
|
||||
self.lora_update_lock = asyncio.Lock()
|
||||
|
||||
# For PD disaggregtion
|
||||
self.init_disaggregation()
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.bootstrap_server = start_disagg_service(self.server_args)
|
||||
|
||||
# For load balancing
|
||||
self.current_load = 0
|
||||
@@ -471,38 +468,6 @@ class TokenizerManager:
|
||||
]
|
||||
)
|
||||
|
||||
def init_disaggregation(self):
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.disaggregation_transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
# Start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
|
||||
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
)
|
||||
self.bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
|
||||
host=self.server_args.host,
|
||||
port=self.server_args.disaggregation_bootstrap_port,
|
||||
)
|
||||
is_create_store = (
|
||||
self.server_args.node_rank == 0
|
||||
and self.server_args.disaggregation_transfer_backend == "ascend"
|
||||
)
|
||||
if is_create_store:
|
||||
try:
|
||||
from mf_adapter import create_config_store
|
||||
|
||||
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
||||
create_config_store(ascend_url)
|
||||
except Exception as e:
|
||||
error_message = f"Failed create mf store, invalid ascend_url."
|
||||
error_message += f" With exception {e}"
|
||||
raise error_message
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
|
||||
Reference in New Issue
Block a user