diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index b7f87a9a9..96eaf8566 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -321,9 +321,12 @@ class FusedMoE(torch.nn.Module): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 - loaded_weight = loaded_weight.narrow( - shard_dim, shard_size * tp_rank, shard_size - ) + + if not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) + # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": @@ -347,9 +350,12 @@ class FusedMoE(torch.nn.Module): # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] - loaded_weight = loaded_weight.narrow( - shard_dim, shard_size * tp_rank, shard_size - ) + + if not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) + # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) @@ -389,7 +395,9 @@ class FusedMoE(torch.nn.Module): weight_name: str, shard_id: str, expert_id: int, + use_presharded_weights: bool = False, ) -> None: + self.use_presharded_weights = use_presharded_weights # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index cb6a72a3f..e55a99465 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -16,13 +16,16 @@ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Grok1 model.""" -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.layers.activation import GeluAndMul @@ -42,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -347,6 +351,16 @@ class Grok1ForCausalLM(nn.Module): self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + # Monkey patch _prepare_weights to load pre-sharded weights + if ( + self.config.num_local_experts > 0 + and get_tensor_model_parallel_world_size() > 1 + ): + self.use_presharded_weights = True + setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) + else: + self.use_presharded_weights = False + def forward( self, input_ids: torch.Tensor, @@ -359,7 +373,15 @@ class Grok1ForCausalLM(nn.Module): input_ids, hidden_states, self.lm_head, forward_batch ) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + use_presharded_weights: bool | None = None, + ): + if use_presharded_weights is None: + use_presharded_weights = self.use_presharded_weights + num_experts = self.config.num_local_experts + stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -375,10 +397,23 @@ class Grok1ForCausalLM(nn.Module): ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts, + num_experts=num_experts, ) params_dict = dict(self.named_parameters()) + all_names = set(params_dict.keys()) + hit_names = set() + + def load_weight_wrapper(name, loaded_weight, *args, **kwargs): + if name not in params_dict: + return + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, *args, **kwargs) + + hit_names.add(name) + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -391,9 +426,7 @@ class Grok1ForCausalLM(nn.Module): if name.endswith(".bias") and name not in params_dict: continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + load_weight_wrapper(name, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: @@ -402,38 +435,76 @@ class Grok1ForCausalLM(nn.Module): continue name = name.replace(weight_name, param_name) - if ( - name.endswith(".bias") or name.endswith("_bias") - ) and name not in params_dict: - continue + if use_presharded_weights: + extra_kwargs = { + "use_presharded_weights": use_presharded_weights + } + else: + extra_kwargs = {} - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, + load_weight_wrapper( + name, loaded_weight, name, shard_id=shard_id, expert_id=expert_id, + **extra_kwargs, ) break else: # Skip loading extra bias for GPTQ models. - if ( - name.endswith(".bias") or name.endswith("_bias") - ) and name not in params_dict: - continue - # Skip loading kv_scale from ckpts towards new design. - if name.endswith(".kv_scale") and name not in params_dict: + if name.endswith(".bias") and name not in params_dict: continue if name is None: continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) + load_weight_wrapper(name=name, loaded_weight=loaded_weight) + + +old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights") + + +def _prepare_presharded_weights( + self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool +) -> Tuple[str, List[str], bool]: + import glob + import os + + if get_tensor_model_parallel_world_size() == 1: + return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt) + + if not os.path.isdir(model_name_or_path): + from sglang.srt.model_loader.weight_utils import download_weights_from_hf + + allow_patterns = ["*.safetensors", "*.bin"] + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + tp_rank = get_tensor_model_parallel_rank() + + # The old format + allow_patterns = [f"*-{tp_rank:03d}.bin"] + + # The new format + allow_patterns += [f"*-TP-{tp_rank:03d}.safetensors", "*-TP-common.safetensors"] + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + + if hf_weights_files[0].endswith("safetensors"): + use_safetensors = True + else: + use_safetensors = False + + return hf_folder, hf_weights_files, use_safetensors class Grok1ModelForCausalLM(Grok1ForCausalLM):