Fix qwen3 tbo/dp-lm-head (#6652)

This commit is contained in:
Yi Zhang
2025-05-27 15:38:27 +08:00
committed by GitHub
parent ce9d690ef4
commit b18416fbf8
3 changed files with 3 additions and 1 deletions

View File

@@ -501,6 +501,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)

View File

@@ -688,6 +688,7 @@ class Qwen3MoeForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)

View File

@@ -370,7 +370,7 @@ def model_forward_maybe_tbo(
hidden_states=hidden_states,
forward_batch=forward_batch,
residual=residual,
**(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}),
zero_allocator=zero_allocator,
)
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
operations_strategy = OperationsStrategy.init_new_tbo(