diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 203fc0e82..67e72d465 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -501,6 +501,7 @@ class Qwen2MoeForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 27e4ae62c..af7a47651 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -688,6 +688,7 @@ class Qwen3MoeForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index bb527aaa6..78bc6b431 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -370,7 +370,7 @@ def model_forward_maybe_tbo( hidden_states=hidden_states, forward_batch=forward_batch, residual=residual, - **(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}), + zero_allocator=zero_allocator, ) layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode operations_strategy = OperationsStrategy.init_new_tbo(