diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 584cebe6c..1d8bcaaf5 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -113,6 +113,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--quantization` | The quantization method. | None | | `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | None | | `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'fp8_e5m2' and 'fp8_e4m3' is supported for CUDA 11.8+. | auto | +| `--enable-fp32-lm-head` | If set, the LM head outputs (logits) are in FP32. | False | ## Memory and scheduling diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index e39727842..5f9651086 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -220,6 +220,7 @@ class LogitsProcessor(nn.Module): self.config = config self.logit_scale = logit_scale self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"] + self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"] if self.use_attn_tp_group: self.attn_tp_size = get_attention_tp_size() self.do_tensor_parallel_all_gather = ( @@ -461,7 +462,11 @@ class LogitsProcessor(nn.Module): dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) if hasattr(lm_head, "weight"): - if use_intel_amx_backend(lm_head): + if self.use_fp32_lm_head: + logits = torch.matmul( + hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T + ) + elif use_intel_amx_backend(lm_head): logits = torch.ops.sgl_kernel.weight_packed_linear( hidden_states.to(lm_head.weight.dtype), lm_head.weight, @@ -475,7 +480,15 @@ class LogitsProcessor(nn.Module): else: # GGUF models # TODO: use weight_packed_linear for GGUF models - logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias) + if self.use_fp32_lm_head: + with torch.cuda.amp.autocast(enabled=False): + logits = lm_head.quant_method.apply( + lm_head, hidden_states.to(torch.float32), embedding_bias + ) + else: + logits = lm_head.quant_method.apply( + lm_head, hidden_states, embedding_bias + ) if self.logit_scale is not None: logits.mul_(self.logit_scale) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ecad78ba5..01859b6ef 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -90,6 +90,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "disable_flashinfer_cutlass_moe_fp4_allgather", "disable_radix_cache", "enable_dp_lm_head", + "enable_fp32_lm_head", "flashinfer_mxfp4_moe_precision", "enable_flashinfer_allreduce_fusion", "moe_dense_tp_size", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a32547120..701415390 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -167,6 +167,7 @@ class ServerArgs: quantization: Optional[str] = None quantization_param_path: Optional[str] = None kv_cache_dtype: str = "auto" + enable_fp32_lm_head: bool = False # Memory and scheduling mem_fraction_static: Optional[float] = None @@ -1392,6 +1393,11 @@ class ServerArgs: choices=["auto", "fp8_e5m2", "fp8_e4m3"], help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.', ) + parser.add_argument( + "--enable-fp32-lm-head", + action="store_true", + help="If set, the LM head outputs (logits) are in FP32.", + ) # Memory and scheduling parser.add_argument( diff --git a/test/srt/rl/test_fp32_lm_head.py b/test/srt/rl/test_fp32_lm_head.py new file mode 100644 index 000000000..e892e3151 --- /dev/null +++ b/test/srt/rl/test_fp32_lm_head.py @@ -0,0 +1,106 @@ +import unittest +from types import SimpleNamespace +from unittest.mock import patch + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.managers.schedule_batch import global_server_args_dict + + +class LMHeadStub(nn.Module): + def __init__(self, vocab, hidden, dtype, device="cuda"): + super().__init__() + self.weight = nn.Parameter( + torch.randn(vocab, hidden, dtype=dtype, device=device) + ) + + +class DummyMeta: + gathered_buffer = None + next_token_logits_buffer = None + + def compute_dp_attention_metadata(self): ... + + +class TestLMHeadFP32(unittest.TestCase): + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("needs CUDA GPU") + + def _make_logprocessor(self, vocab_size, enable_fp32): + global_server_args_dict["enable_dp_lm_head"] = False + global_server_args_dict["enable_fp32_lm_head"] = enable_fp32 + cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None) + return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None) + + def _run_case( + self, + hidden_state_dtype, + enable_fp32, + weights_dtype, + expected_a_dtype, + expected_b_dtype, + ): + device = "cuda" + BATCH_SIZE, HIDDEN_SIZE, VOCAB_SIZE = 2, 64, 128 + hidden_state = torch.randn( + BATCH_SIZE, HIDDEN_SIZE, dtype=hidden_state_dtype, device=device + ) + head = LMHeadStub(VOCAB_SIZE, HIDDEN_SIZE, dtype=weights_dtype, device=device) + meta = DummyMeta() + logprocessor = self._make_logprocessor(VOCAB_SIZE, enable_fp32) + + original_matmul = torch.matmul + original_linear = F.linear + + state = { + "called": False, # Whether a matmul/linear call has been intercepted yet + "operation": None, # Which operation was captured ("matmul" or "linear") + "a": None, # The dtype of the first input tensor to the operation + "b": None, # The dtype of the second input tensor to the operation + } + + def probe_matmul(a, b, *args, **kw): + if not state["called"]: + state.update(called=True, operation="matmul", a=a.dtype, b=b.dtype) + return original_matmul(a, b, *args, **kw) + + def probe_linear(x, w, bias=None): + if not state["called"]: + state.update(called=True, ooperationp="linear", a=x.dtype, b=w.dtype) + return original_linear(x, w, bias) + + with patch("torch.matmul", new=probe_matmul), patch( + "torch.nn.functional.linear", new=probe_linear + ): + logits = logprocessor._get_logits(hidden_state, head, meta) + self.assertEqual(hidden_state.dtype, hidden_state_dtype) + self.assertTrue(state["called"], "no call lm head matlmul/linear") + self.assertEqual(state["a"], expected_a_dtype) + self.assertEqual(state["b"], expected_b_dtype) + + def test_flag_true_fp16_activations(self): + self._run_case(torch.float16, True, torch.float16, torch.float32, torch.float32) + + def test_flag_true_bf16_activations(self): + self._run_case( + torch.bfloat16, True, torch.bfloat16, torch.float32, torch.float32 + ) + + def test_flag_false_fp16_path(self): + self._run_case( + torch.float16, False, torch.float16, torch.float16, torch.float16 + ) + + def test_flag_false_bf16_path(self): + self._run_case( + torch.bfloat16, False, torch.bfloat16, torch.bfloat16, torch.bfloat16 + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 11837c172..a93b3f47c 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -59,6 +59,7 @@ suites = { TestFile("quant/test_int8_kernel.py", 8), TestFile("quant/test_triton_scaled_mm.py", 8), TestFile("quant/test_w8a8_quantization.py", 46), + TestFile("rl/test_fp32_lm_head.py", 30), TestFile("rl/test_update_weights_from_disk.py", 114), TestFile("rl/test_update_weights_from_tensor.py", 48), TestFile("test_abort.py", 51),