Remove fused_moe_grok (#2223)
This commit is contained in:
@@ -16,22 +16,17 @@
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
||||
"""Inference-only Grok1 model."""
|
||||
|
||||
import warnings
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.fused_moe_grok import FusedMoE
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
@@ -41,10 +36,12 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
@@ -293,17 +290,11 @@ class Grok1ForCausalLM(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.torchao_config = global_server_args_dict["torchao_config"]
|
||||
self.model = Grok1Model(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||
|
||||
self.use_presharded_weights = True
|
||||
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -357,28 +348,23 @@ class Grok1ForCausalLM(nn.Module):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if self.use_presharded_weights:
|
||||
extra_kwargs = {
|
||||
"use_presharded_weights": self.use_presharded_weights
|
||||
}
|
||||
else:
|
||||
extra_kwargs = {}
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
**extra_kwargs,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip loading kv_scale from ckpts towards new design.
|
||||
if name.endswith(".kv_scale") and name not in params_dict:
|
||||
continue
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
@@ -388,30 +374,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
||||
|
||||
|
||||
def _prepare_presharded_weights(
|
||||
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
import glob
|
||||
import os
|
||||
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt)
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
allow_patterns = [f"*-{tp_rank:03d}.bin"]
|
||||
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_weights_files: List[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
use_safetensors = False
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||
|
||||
|
||||
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
||||
|
||||
Reference in New Issue
Block a user