Fix qwen3 tbo/dp-lm-head (#6652)
This commit is contained in:
@@ -501,6 +501,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -688,6 +688,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -370,7 +370,7 @@ def model_forward_maybe_tbo(
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
residual=residual,
|
||||
**(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}),
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
|
||||
operations_strategy = OperationsStrategy.init_new_tbo(
|
||||
|
||||
Reference in New Issue
Block a user