Sync from v0.13
This commit is contained in:
85
vllm/distributed/ec_transfer/ec_connector/factory.py
Normal file
85
vllm/distributed/ec_transfer/ec_connector/factory.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorBase,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ECTransferConfig, VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ECConnectorFactory:
|
||||
_registry: dict[str, Callable[[], type[ECConnectorBase]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
|
||||
"""Register a connector with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[ECConnectorBase]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_connector(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
role: ECConnectorRole,
|
||||
) -> ECConnectorBase:
|
||||
ec_transfer_config = config.ec_transfer_config
|
||||
if ec_transfer_config is None:
|
||||
raise ValueError("ec_transfer_config must be set to create a connector")
|
||||
connector_cls = cls.get_connector_class(ec_transfer_config)
|
||||
logger.info(
|
||||
"Creating connector with name: %s and engine_id: %s",
|
||||
connector_cls.__name__,
|
||||
ec_transfer_config.engine_id,
|
||||
)
|
||||
# Connector is explicitly separated into two roles.
|
||||
# Scheduler connector:
|
||||
# - Co-locate with scheduler process
|
||||
# - Should only be used inside the Scheduler class
|
||||
# Worker connector:
|
||||
# - Co-locate with worker process
|
||||
return connector_cls(config, role)
|
||||
|
||||
@classmethod
|
||||
def get_connector_class(
|
||||
cls, ec_transfer_config: "ECTransferConfig"
|
||||
) -> type[ECConnectorBase]:
|
||||
"""Get the connector class by name."""
|
||||
connector_name = ec_transfer_config.ec_connector
|
||||
if connector_name is None:
|
||||
raise ValueError("EC connect must not be None")
|
||||
elif connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
connector_module_path = ec_transfer_config.ec_connector_module_path
|
||||
if connector_module_path is None:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
return connector_cls
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
# only load the files corresponding to the current connector.
|
||||
|
||||
ECConnectorFactory.register_connector(
|
||||
"ECExampleConnector",
|
||||
"vllm.distributed.ec_transfer.ec_connector.example_connector",
|
||||
"ECExampleConnector",
|
||||
)
|
||||
Reference in New Issue
Block a user