fix: fix MLA for ShardedModelLoader/RemoteModelLoader (#6287)
Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
@@ -14,8 +14,7 @@ python save_remote_state.py \
|
|||||||
Then, the model can be loaded with
|
Then, the model can be loaded with
|
||||||
|
|
||||||
llm = Engine(
|
llm = Engine(
|
||||||
model_path="/path/to/save",
|
model_path="[protocol]://[host]:[port]/[model_name]",
|
||||||
--remote-model-url [protocol]://[host]:[port]/[model_name],
|
|
||||||
tensor_parallel_size=8,
|
tensor_parallel_size=8,
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class ConnectorType(str, enum.Enum):
|
|||||||
KV = "KV"
|
KV = "KV"
|
||||||
|
|
||||||
|
|
||||||
def create_remote_connector(url, device="cpu") -> BaseConnector:
|
def create_remote_connector(url, **kwargs) -> BaseConnector:
|
||||||
connector_type = parse_connector_type(url)
|
connector_type = parse_connector_type(url)
|
||||||
if connector_type == "redis":
|
if connector_type == "redis":
|
||||||
return RedisConnector(url)
|
return RedisConnector(url)
|
||||||
|
|||||||
@@ -20,9 +20,8 @@ class BaseConnector(ABC):
|
|||||||
<connector_type://<host>:<port>/<model_name>/files/<filename>
|
<connector_type://<host>:<port>/<model_name>/files/<filename>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, url: str, device: torch.device = "cpu"):
|
def __init__(self, url: str):
|
||||||
self.url = url
|
self.url = url
|
||||||
self.device = device
|
|
||||||
self.closed = False
|
self.closed = False
|
||||||
self.local_dir = tempfile.mkdtemp()
|
self.local_dir = tempfile.mkdtemp()
|
||||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
|||||||
@@ -15,10 +15,10 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class RedisConnector(BaseKVConnector):
|
class RedisConnector(BaseKVConnector):
|
||||||
|
|
||||||
def __init__(self, url: str, device: torch.device = "cpu"):
|
def __init__(self, url: str):
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
super().__init__(url, device)
|
super().__init__(url)
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
|
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
|
||||||
self.model_name = parsed_url.path.lstrip("/")
|
self.model_name = parsed_url.path.lstrip("/")
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
|
|||||||
|
|
||||||
if serde_type == "safe":
|
if serde_type == "safe":
|
||||||
s = SafeSerializer()
|
s = SafeSerializer()
|
||||||
d = SafeDeserializer(torch.uint8)
|
d = SafeDeserializer()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown serde type: {serde_type}")
|
raise ValueError(f"Unknown serde type: {serde_type}")
|
||||||
|
|
||||||
|
|||||||
@@ -19,11 +19,12 @@ class SafeSerializer(Serializer):
|
|||||||
|
|
||||||
class SafeDeserializer(Deserializer):
|
class SafeDeserializer(Deserializer):
|
||||||
|
|
||||||
def __init__(self, dtype):
|
def __init__(self):
|
||||||
super().__init__(dtype)
|
# TODO: dtype options
|
||||||
|
super().__init__(torch.float32)
|
||||||
|
|
||||||
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
||||||
return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
|
return load(bytes(b))["tensor_bytes"]
|
||||||
|
|
||||||
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
||||||
return self.from_bytes_normal(b)
|
return self.from_bytes_normal(b)
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
|
|||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.model_loader.utils import (
|
from sglang.srt.model_loader.utils import (
|
||||||
get_model_architecture,
|
get_model_architecture,
|
||||||
|
post_load_weights,
|
||||||
set_default_torch_dtype,
|
set_default_torch_dtype,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_loader.weight_utils import (
|
from sglang.srt.model_loader.weight_utils import (
|
||||||
@@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader):
|
|||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
initialize_dummy_weights(model)
|
initialize_dummy_weights(model)
|
||||||
|
|
||||||
# Model weight loading consists of two stages:
|
post_load_weights(model, model_config)
|
||||||
# 1. Initial weight loading.
|
|
||||||
# 2. Post-processing of weights, including assigning specific member variables.
|
|
||||||
# For `dummy_init`, only the second stage is required.
|
|
||||||
if hasattr(model, "post_load_weights"):
|
|
||||||
if (
|
|
||||||
model_config.hf_config.architectures[0]
|
|
||||||
== "DeepseekV3ForCausalLMNextN"
|
|
||||||
):
|
|
||||||
model.post_load_weights(is_nextn=True)
|
|
||||||
else:
|
|
||||||
model.post_load_weights()
|
|
||||||
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
@@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
state_dict.pop(key)
|
state_dict.pop(key)
|
||||||
if state_dict:
|
if state_dict:
|
||||||
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||||
|
|
||||||
|
post_load_weights(model, model_config)
|
||||||
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader):
|
|||||||
# ignore hidden files
|
# ignore hidden files
|
||||||
if file_name.startswith("."):
|
if file_name.startswith("."):
|
||||||
continue
|
continue
|
||||||
if os.path.splitext(file_name)[1] not in (
|
if os.path.splitext(file_name)[1] in (".json", ".py"):
|
||||||
".bin",
|
|
||||||
".pt",
|
|
||||||
".safetensors",
|
|
||||||
):
|
|
||||||
file_path = os.path.join(root, file_name)
|
file_path = os.path.join(root, file_name)
|
||||||
with open(file_path, encoding="utf-8") as file:
|
with open(file_path, encoding="utf-8") as file:
|
||||||
file_content = file.read()
|
file_content = file.read()
|
||||||
f_key = f"{model_name}/files/{file_name}"
|
f_key = f"{model_name}/files/{file_name}"
|
||||||
client.setstr(f_key, file_content)
|
client.setstr(f_key, file_content)
|
||||||
|
|
||||||
def _load_model_from_remote_kv(self, model: nn.Module, client):
|
def _load_model_from_remote_kv(
|
||||||
|
self, model: nn.Module, model_config: ModelConfig, client
|
||||||
|
):
|
||||||
for _, module in model.named_modules():
|
for _, module in model.named_modules():
|
||||||
quant_method = getattr(module, "quant_method", None)
|
quant_method = getattr(module, "quant_method", None)
|
||||||
if quant_method is not None:
|
if quant_method is not None:
|
||||||
@@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader):
|
|||||||
if state_dict:
|
if state_dict:
|
||||||
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||||
|
|
||||||
|
post_load_weights(model, model_config)
|
||||||
|
|
||||||
def _load_model_from_remote_fs(
|
def _load_model_from_remote_fs(
|
||||||
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
@@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader):
|
|||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
model = _initialize_model(model_config, self.load_config)
|
model = _initialize_model(model_config, self.load_config)
|
||||||
for _, module in model.named_modules():
|
|
||||||
quant_method = getattr(module, "quant_method", None)
|
|
||||||
if quant_method is not None:
|
|
||||||
quant_method.process_weights_after_loading(module)
|
|
||||||
|
|
||||||
with create_remote_connector(model_weights, device_config.device) as client:
|
with create_remote_connector(
|
||||||
|
model_weights, device=device_config.device
|
||||||
|
) as client:
|
||||||
connector_type = get_connector_type(client)
|
connector_type = get_connector_type(client)
|
||||||
if connector_type == ConnectorType.KV:
|
if connector_type == ConnectorType.KV:
|
||||||
self._load_model_from_remote_kv(model, client)
|
self._load_model_from_remote_kv(model, model_config, client)
|
||||||
elif connector_type == ConnectorType.FS:
|
elif connector_type == ConnectorType.FS:
|
||||||
self._load_model_from_remote_fs(
|
self._load_model_from_remote_fs(
|
||||||
model, client, model_config, device_config
|
model, client, model_config, device_config
|
||||||
|
|||||||
@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
|
|||||||
|
|
||||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||||
return get_model_architecture(model_config)[1]
|
return get_model_architecture(model_config)[1]
|
||||||
|
|
||||||
|
|
||||||
|
def post_load_weights(model: nn.Module, model_config: ModelConfig):
|
||||||
|
# Model weight loading consists of two stages:
|
||||||
|
# 1. Initial weight loading.
|
||||||
|
# 2. Post-processing of weights, including assigning specific member variables.
|
||||||
|
# For `dummy_init`, only the second stage is required.
|
||||||
|
if hasattr(model, "post_load_weights"):
|
||||||
|
if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN":
|
||||||
|
model.post_load_weights(is_nextn=True)
|
||||||
|
else:
|
||||||
|
model.post_load_weights()
|
||||||
|
|||||||
Reference in New Issue
Block a user