Improve linear.py to load sharded weights & remove the dependency of Parameters from vllm (#2784)

Co-authored-by: SangBin Cho rkooo567@gmail.com
This commit is contained in:
Lianmin Zheng
2025-01-07 23:29:10 -08:00
committed by GitHub
parent 694e41925e
commit 8a6906127a
15 changed files with 655 additions and 88 deletions

View File

@@ -57,6 +57,7 @@ class Grok1MLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
reduce_results=True,
use_presharded_weights: bool = False,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@@ -65,6 +66,7 @@ class Grok1MLP(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
use_presharded_weights=use_presharded_weights,
)
self.down_proj = RowParallelLinear(
intermediate_size,
@@ -73,6 +75,7 @@ class Grok1MLP(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
reduce_results=reduce_results,
use_presharded_weights=use_presharded_weights,
)
self.act_fn = GeluAndMul(approximate="tanh")
@@ -103,6 +106,7 @@ class Grok1MoE(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
reduce_results=True,
use_presharded_weights: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
@@ -129,6 +133,7 @@ class Grok1MoE(nn.Module):
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
use_presharded_weights=use_presharded_weights,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -156,6 +161,7 @@ class Grok1Attention(nn.Module):
max_position: int = 4096 * 32,
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.config = config
@@ -194,6 +200,7 @@ class Grok1Attention(nn.Module):
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
)
self.rotary_emb = get_rope(
self.head_dim,
@@ -234,10 +241,12 @@ class Grok1DecoderLayer(nn.Module):
config: PretrainedConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
use_presharded_weights: bool = False,
) -> None:
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.layer_id = layer_id
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = Grok1Attention(
@@ -262,6 +271,7 @@ class Grok1DecoderLayer(nn.Module):
),
quant_config=quant_config,
reduce_results=True,
use_presharded_weights=use_presharded_weights,
)
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -299,6 +309,7 @@ class Grok1Model(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
use_presharded_weights: bool = False,
) -> None:
super().__init__()
self.config = config
@@ -311,7 +322,12 @@ class Grok1Model(nn.Module):
)
self.layers = nn.ModuleList(
[
Grok1DecoderLayer(config, i, quant_config=quant_config)
Grok1DecoderLayer(
config,
i,
quant_config=quant_config,
use_presharded_weights=use_presharded_weights,
)
for i in range(config.num_hidden_layers)
]
)
@@ -347,11 +363,7 @@ class Grok1ForCausalLM(nn.Module):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Grok1Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
# Monkey patch _prepare_weights to load pre-sharded weights
if (
self.config.num_local_experts > 0
and get_tensor_model_parallel_world_size() > 1
@@ -361,6 +373,14 @@ class Grok1ForCausalLM(nn.Module):
else:
self.use_presharded_weights = False
self.model = Grok1Model(
config,
quant_config=quant_config,
use_presharded_weights=self.use_presharded_weights,
)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
def forward(
self,
input_ids: torch.Tensor,
@@ -376,10 +396,7 @@ class Grok1ForCausalLM(nn.Module):
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
use_presharded_weights: Optional[bool] = None,
):
if use_presharded_weights is None:
use_presharded_weights = self.use_presharded_weights
num_experts = self.config.num_local_experts
stacked_params_mapping = [
@@ -435,20 +452,12 @@ class Grok1ForCausalLM(nn.Module):
continue
name = name.replace(weight_name, param_name)
if use_presharded_weights:
extra_kwargs = {
"use_presharded_weights": use_presharded_weights
}
else:
extra_kwargs = {}
load_weight_wrapper(
name,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
**extra_kwargs,
)
break
else: