Support loading of larger models with on-the-fly quantization (#3061)
This commit is contained in:
@@ -460,7 +460,12 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights_to_module(
|
||||
self,
|
||||
fqn: str,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
):
|
||||
"""Load weights onto submodule pointed by path `fqn`."""
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -469,7 +474,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
module = self.get_submodule(fqn)
|
||||
params_dict = dict(module.named_parameters(prefix=fqn, recurse=False))
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
@@ -486,7 +492,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
if name.endswith(".bias") or name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
@@ -494,12 +500,19 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
if name.endswith(".bias") or name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
):
|
||||
"""Load weights onto the full model."""
|
||||
self.load_weights_to_module("", weights)
|
||||
|
||||
|
||||
class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user