Cleanup readme, llava examples, usage examples and nccl init (#1194)
This commit is contained in:
@@ -295,12 +295,14 @@ class Grok1ModelForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Grok1Model(config, quant_config=quant_config)
|
||||
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||
|
||||
self.use_presharded_weights = True
|
||||
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
def forward(
|
||||
@@ -356,6 +358,13 @@ class Grok1ModelForCausalLM(nn.Module):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if self.use_presharded_weights:
|
||||
extra_kwargs = {
|
||||
"use_presharded_weights": self.use_presharded_weights
|
||||
}
|
||||
else:
|
||||
extra_kwargs = {}
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
@@ -364,7 +373,7 @@ class Grok1ModelForCausalLM(nn.Module):
|
||||
weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
pre_sharded=get_tensor_model_parallel_world_size() > 1,
|
||||
**extra_kwargs,
|
||||
)
|
||||
break
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user