fix: fix MLA for ShardedModelLoader/RemoteModelLoader (#6287)

Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
wangyu
2025-08-29 07:10:09 +08:00
committed by GitHub
parent a38c149758
commit 9f81d741a2
8 changed files with 37 additions and 35 deletions

View File

@@ -20,7 +20,7 @@ class ConnectorType(str, enum.Enum):
KV = "KV"
def create_remote_connector(url, device="cpu") -> BaseConnector:
def create_remote_connector(url, **kwargs) -> BaseConnector:
connector_type = parse_connector_type(url)
if connector_type == "redis":
return RedisConnector(url)

View File

@@ -20,9 +20,8 @@ class BaseConnector(ABC):
<connector_type://<host>:<port>/<model_name>/files/<filename>
"""
def __init__(self, url: str, device: torch.device = "cpu"):
def __init__(self, url: str):
self.url = url
self.device = device
self.closed = False
self.local_dir = tempfile.mkdtemp()
for sig in (signal.SIGINT, signal.SIGTERM):

View File

@@ -15,10 +15,10 @@ logger = logging.getLogger(__name__)
class RedisConnector(BaseKVConnector):
def __init__(self, url: str, device: torch.device = "cpu"):
def __init__(self, url: str):
import redis
super().__init__(url, device)
super().__init__(url)
parsed_url = urlparse(url)
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
self.model_name = parsed_url.path.lstrip("/")

View File

@@ -15,7 +15,7 @@ def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
if serde_type == "safe":
s = SafeSerializer()
d = SafeDeserializer(torch.uint8)
d = SafeDeserializer()
else:
raise ValueError(f"Unknown serde type: {serde_type}")

View File

@@ -19,11 +19,12 @@ class SafeSerializer(Serializer):
class SafeDeserializer(Deserializer):
def __init__(self, dtype):
super().__init__(dtype)
def __init__(self):
# TODO: dtype options
super().__init__(torch.float32)
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
return load(bytes(b))["tensor_bytes"]
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return self.from_bytes_normal(b)

View File

@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_loader.utils import (
get_model_architecture,
post_load_weights,
set_default_torch_dtype,
)
from sglang.srt.model_loader.weight_utils import (
@@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader):
# random values to the weights.
initialize_dummy_weights(model)
# Model weight loading consists of two stages:
# 1. Initial weight loading.
# 2. Post-processing of weights, including assigning specific member variables.
# For `dummy_init`, only the second stage is required.
if hasattr(model, "post_load_weights"):
if (
model_config.hf_config.architectures[0]
== "DeepseekV3ForCausalLMNextN"
):
model.post_load_weights(is_nextn=True)
else:
model.post_load_weights()
post_load_weights(model, model_config)
return model.eval()
@@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader):
state_dict.pop(key)
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
post_load_weights(model, model_config)
return model.eval()
@staticmethod
@@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader):
# ignore hidden files
if file_name.startswith("."):
continue
if os.path.splitext(file_name)[1] not in (
".bin",
".pt",
".safetensors",
):
if os.path.splitext(file_name)[1] in (".json", ".py"):
file_path = os.path.join(root, file_name)
with open(file_path, encoding="utf-8") as file:
file_content = file.read()
f_key = f"{model_name}/files/{file_name}"
client.setstr(f_key, file_content)
def _load_model_from_remote_kv(self, model: nn.Module, client):
def _load_model_from_remote_kv(
self, model: nn.Module, model_config: ModelConfig, client
):
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
@@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader):
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
post_load_weights(model, model_config)
def _load_model_from_remote_fs(
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module:
@@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
with create_remote_connector(model_weights, device_config.device) as client:
with create_remote_connector(
model_weights, device=device_config.device
) as client:
connector_type = get_connector_type(client)
if connector_type == ConnectorType.KV:
self._load_model_from_remote_kv(model, client)
self._load_model_from_remote_kv(model, model_config, client)
elif connector_type == ConnectorType.FS:
self._load_model_from_remote_fs(
model, client, model_config, device_config

View File

@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]
def post_load_weights(model: nn.Module, model_config: ModelConfig):
# Model weight loading consists of two stages:
# 1. Initial weight loading.
# 2. Post-processing of weights, including assigning specific member variables.
# For `dummy_init`, only the second stage is required.
if hasattr(model, "post_load_weights"):
if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN":
model.post_load_weights(is_nextn=True)
else:
model.post_load_weights()