Fix the deadlock in multi-node tp (#1122)
This commit is contained in:
@@ -295,8 +295,9 @@ class Grok1ModelForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Grok1Model(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||
|
||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||
|
||||
Reference in New Issue
Block a user