Improve linear.py to load sharded weights & remove the dependency of Parameters from vllm (#2784)
Co-authored-by: SangBin Cho rkooo567@gmail.com
This commit is contained in:
@@ -57,6 +57,7 @@ class Grok1MLP(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
reduce_results=True,
|
||||
use_presharded_weights: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -65,6 +66,7 @@ class Grok1MLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
@@ -73,6 +75,7 @@ class Grok1MLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
reduce_results=reduce_results,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.act_fn = GeluAndMul(approximate="tanh")
|
||||
|
||||
@@ -103,6 +106,7 @@ class Grok1MoE(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
reduce_results=True,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -129,6 +133,7 @@ class Grok1MoE(nn.Module):
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -156,6 +161,7 @@ class Grok1Attention(nn.Module):
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -194,6 +200,7 @@ class Grok1Attention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -234,10 +241,12 @@ class Grok1DecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_id = layer_id
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = Grok1Attention(
|
||||
@@ -262,6 +271,7 @@ class Grok1DecoderLayer(nn.Module):
|
||||
),
|
||||
quant_config=quant_config,
|
||||
reduce_results=True,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -299,6 +309,7 @@ class Grok1Model(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -311,7 +322,12 @@ class Grok1Model(nn.Module):
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Grok1DecoderLayer(config, i, quant_config=quant_config)
|
||||
Grok1DecoderLayer(
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -347,11 +363,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Grok1Model(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||
if (
|
||||
self.config.num_local_experts > 0
|
||||
and get_tensor_model_parallel_world_size() > 1
|
||||
@@ -361,6 +373,14 @@ class Grok1ForCausalLM(nn.Module):
|
||||
else:
|
||||
self.use_presharded_weights = False
|
||||
|
||||
self.model = Grok1Model(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -376,10 +396,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
use_presharded_weights: Optional[bool] = None,
|
||||
):
|
||||
if use_presharded_weights is None:
|
||||
use_presharded_weights = self.use_presharded_weights
|
||||
num_experts = self.config.num_local_experts
|
||||
|
||||
stacked_params_mapping = [
|
||||
@@ -435,20 +452,12 @@ class Grok1ForCausalLM(nn.Module):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if use_presharded_weights:
|
||||
extra_kwargs = {
|
||||
"use_presharded_weights": use_presharded_weights
|
||||
}
|
||||
else:
|
||||
extra_kwargs = {}
|
||||
|
||||
load_weight_wrapper(
|
||||
name,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
**extra_kwargs,
|
||||
)
|
||||
break
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user