From 9f81d741a28667f05037d14c83491a740fb2251a Mon Sep 17 00:00:00 2001 From: wangyu Date: Fri, 29 Aug 2025 07:10:09 +0800 Subject: [PATCH] fix: fix MLA for ShardedModelLoader/RemoteModelLoader (#6287) Signed-off-by: wangyu --- examples/runtime/engine/save_remote_state.py | 3 +- python/sglang/srt/connector/__init__.py | 2 +- python/sglang/srt/connector/base_connector.py | 3 +- python/sglang/srt/connector/redis.py | 4 +- python/sglang/srt/connector/serde/__init__.py | 2 +- .../sglang/srt/connector/serde/safe_serde.py | 7 ++-- python/sglang/srt/model_loader/loader.py | 39 +++++++------------ python/sglang/srt/model_loader/utils.py | 12 ++++++ 8 files changed, 37 insertions(+), 35 deletions(-) diff --git a/examples/runtime/engine/save_remote_state.py b/examples/runtime/engine/save_remote_state.py index 89afa5949..a428195ca 100644 --- a/examples/runtime/engine/save_remote_state.py +++ b/examples/runtime/engine/save_remote_state.py @@ -14,8 +14,7 @@ python save_remote_state.py \ Then, the model can be loaded with llm = Engine( - model_path="/path/to/save", - --remote-model-url [protocol]://[host]:[port]/[model_name], + model_path="[protocol]://[host]:[port]/[model_name]", tensor_parallel_size=8, ) """ diff --git a/python/sglang/srt/connector/__init__.py b/python/sglang/srt/connector/__init__.py index 829644c91..38e1d5eab 100644 --- a/python/sglang/srt/connector/__init__.py +++ b/python/sglang/srt/connector/__init__.py @@ -20,7 +20,7 @@ class ConnectorType(str, enum.Enum): KV = "KV" -def create_remote_connector(url, device="cpu") -> BaseConnector: +def create_remote_connector(url, **kwargs) -> BaseConnector: connector_type = parse_connector_type(url) if connector_type == "redis": return RedisConnector(url) diff --git a/python/sglang/srt/connector/base_connector.py b/python/sglang/srt/connector/base_connector.py index a9c00d0c9..c9a1c36e2 100644 --- a/python/sglang/srt/connector/base_connector.py +++ b/python/sglang/srt/connector/base_connector.py @@ -20,9 +20,8 @@ class BaseConnector(ABC): ://files/ """ - def __init__(self, url: str, device: torch.device = "cpu"): + def __init__(self, url: str): self.url = url - self.device = device self.closed = False self.local_dir = tempfile.mkdtemp() for sig in (signal.SIGINT, signal.SIGTERM): diff --git a/python/sglang/srt/connector/redis.py b/python/sglang/srt/connector/redis.py index 761594f78..cb1db3f7c 100644 --- a/python/sglang/srt/connector/redis.py +++ b/python/sglang/srt/connector/redis.py @@ -15,10 +15,10 @@ logger = logging.getLogger(__name__) class RedisConnector(BaseKVConnector): - def __init__(self, url: str, device: torch.device = "cpu"): + def __init__(self, url: str): import redis - super().__init__(url, device) + super().__init__(url) parsed_url = urlparse(url) self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port) self.model_name = parsed_url.path.lstrip("/") diff --git a/python/sglang/srt/connector/serde/__init__.py b/python/sglang/srt/connector/serde/__init__.py index 394dba0a6..c05b20afa 100644 --- a/python/sglang/srt/connector/serde/__init__.py +++ b/python/sglang/srt/connector/serde/__init__.py @@ -15,7 +15,7 @@ def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]: if serde_type == "safe": s = SafeSerializer() - d = SafeDeserializer(torch.uint8) + d = SafeDeserializer() else: raise ValueError(f"Unknown serde type: {serde_type}") diff --git a/python/sglang/srt/connector/serde/safe_serde.py b/python/sglang/srt/connector/serde/safe_serde.py index 0163af9f5..3e75f9bfc 100644 --- a/python/sglang/srt/connector/serde/safe_serde.py +++ b/python/sglang/srt/connector/serde/safe_serde.py @@ -19,11 +19,12 @@ class SafeSerializer(Serializer): class SafeDeserializer(Deserializer): - def __init__(self, dtype): - super().__init__(dtype) + def __init__(self): + # TODO: dtype options + super().__init__(torch.float32) def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor: - return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype) + return load(bytes(b))["tensor_bytes"] def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor: return self.from_bytes_normal(b) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 23d70be44..1abfee2f4 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -42,6 +42,7 @@ from sglang.srt.distributed import ( from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_loader.utils import ( get_model_architecture, + post_load_weights, set_default_torch_dtype, ) from sglang.srt.model_loader.weight_utils import ( @@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader): # random values to the weights. initialize_dummy_weights(model) - # Model weight loading consists of two stages: - # 1. Initial weight loading. - # 2. Post-processing of weights, including assigning specific member variables. - # For `dummy_init`, only the second stage is required. - if hasattr(model, "post_load_weights"): - if ( - model_config.hf_config.architectures[0] - == "DeepseekV3ForCausalLMNextN" - ): - model.post_load_weights(is_nextn=True) - else: - model.post_load_weights() + post_load_weights(model, model_config) return model.eval() @@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader): state_dict.pop(key) if state_dict: raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!") + + post_load_weights(model, model_config) + return model.eval() @staticmethod @@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader): # ignore hidden files if file_name.startswith("."): continue - if os.path.splitext(file_name)[1] not in ( - ".bin", - ".pt", - ".safetensors", - ): + if os.path.splitext(file_name)[1] in (".json", ".py"): file_path = os.path.join(root, file_name) with open(file_path, encoding="utf-8") as file: file_content = file.read() f_key = f"{model_name}/files/{file_name}" client.setstr(f_key, file_content) - def _load_model_from_remote_kv(self, model: nn.Module, client): + def _load_model_from_remote_kv( + self, model: nn.Module, model_config: ModelConfig, client + ): for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: @@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader): if state_dict: raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!") + post_load_weights(model, model_config) + def _load_model_from_remote_fs( self, model, client, model_config: ModelConfig, device_config: DeviceConfig ) -> nn.Module: @@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader): with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) - with create_remote_connector(model_weights, device_config.device) as client: + with create_remote_connector( + model_weights, device=device_config.device + ) as client: connector_type = get_connector_type(client) if connector_type == ConnectorType.KV: - self._load_model_from_remote_kv(model, client) + self._load_model_from_remote_kv(model, model_config, client) elif connector_type == ConnectorType.FS: self._load_model_from_remote_fs( model, client, model_config, device_config diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py index dfbbd154d..f6ad79010 100644 --- a/python/sglang/srt/model_loader/utils.py +++ b/python/sglang/srt/model_loader/utils.py @@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], def get_architecture_class_name(model_config: ModelConfig) -> str: return get_model_architecture(model_config)[1] + + +def post_load_weights(model: nn.Module, model_config: ModelConfig): + # Model weight loading consists of two stages: + # 1. Initial weight loading. + # 2. Post-processing of weights, including assigning specific member variables. + # For `dummy_init`, only the second stage is required. + if hasattr(model, "post_load_weights"): + if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN": + model.post_load_weights(is_nextn=True) + else: + model.post_load_weights()