Fix qwen3 tbo/dp-lm-head (#6652)
This commit is contained in:
@@ -501,6 +501,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("lm_head", prefix),
|
prefix=add_prefix("lm_head", prefix),
|
||||||
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
|||||||
@@ -688,6 +688,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("lm_head", prefix),
|
prefix=add_prefix("lm_head", prefix),
|
||||||
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
|||||||
@@ -370,7 +370,7 @@ def model_forward_maybe_tbo(
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
**(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}),
|
zero_allocator=zero_allocator,
|
||||||
)
|
)
|
||||||
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
|
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
|
||||||
operations_strategy = OperationsStrategy.init_new_tbo(
|
operations_strategy = OperationsStrategy.init_new_tbo(
|
||||||
|
|||||||
Reference in New Issue
Block a user