Cleanup readme, llava examples, usage examples and nccl init (#1194)

This commit is contained in:
Lianmin Zheng
2024-08-24 08:02:23 -07:00
committed by GitHub
parent c9064e6fd9
commit f6af3a6561
65 changed files with 174 additions and 317 deletions

View File

@@ -295,12 +295,14 @@ class Grok1ModelForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = Grok1Model(config, quant_config=quant_config)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
self.use_presharded_weights = True
warnings.filterwarnings("ignore", category=FutureWarning)
def forward(
@@ -356,6 +358,13 @@ class Grok1ModelForCausalLM(nn.Module):
continue
name = name.replace(weight_name, param_name)
if self.use_presharded_weights:
extra_kwargs = {
"use_presharded_weights": self.use_presharded_weights
}
else:
extra_kwargs = {}
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
@@ -364,7 +373,7 @@ class Grok1ModelForCausalLM(nn.Module):
weight_name,
shard_id=shard_id,
expert_id=expert_id,
pre_sharded=get_tensor_model_parallel_world_size() > 1,
**extra_kwargs,
)
break
else: