[Bugfix] fix MTP support for lmhead_tensor_parallel_size (#3921)
### What this PR does / why we need it? Fix the issue of MTP being enabled and setting Imhead_tensor_parallel_size=16 causing the inference to hang. Signed-off-by: wyh145 <1987244901@qq.com>
This commit is contained in:
@@ -51,7 +51,7 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
if lmhead_tp_enable() and prefix.find("lm_head") != -1:
|
||||
if lmhead_tp_enable() and prefix.find("head") != -1:
|
||||
self.comm_group = get_lmhead_tp_group()
|
||||
else:
|
||||
self.comm_group = get_tp_group()
|
||||
|
||||
Reference in New Issue
Block a user