Fix the deadlock in multi-node tp (#1122)

This commit is contained in:
Lianmin Zheng
2024-08-16 01:39:24 -07:00
committed by GitHub
parent 6aa8ad14f8
commit 5a261bd055
7 changed files with 54 additions and 16 deletions

View File

@@ -295,8 +295,9 @@ class Grok1ModelForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = Grok1Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)