Files
sglang/test/srt/rl/test_fp32_lm_head.py
narutolhy d17986f8c6 Enable optional FP32 compute for LM Head (#10729)
Thanks to MiniMax Team and Chenyang Zhao's support.
2025-09-29 20:45:17 -07:00

107 lines
3.6 KiB
Python

import unittest
from types import SimpleNamespace
from unittest.mock import patch
import torch
import torch.nn as nn
import torch.nn.functional as F
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.managers.schedule_batch import global_server_args_dict
class LMHeadStub(nn.Module):
def __init__(self, vocab, hidden, dtype, device="cuda"):
super().__init__()
self.weight = nn.Parameter(
torch.randn(vocab, hidden, dtype=dtype, device=device)
)
class DummyMeta:
gathered_buffer = None
next_token_logits_buffer = None
def compute_dp_attention_metadata(self): ...
class TestLMHeadFP32(unittest.TestCase):
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("needs CUDA GPU")
def _make_logprocessor(self, vocab_size, enable_fp32):
global_server_args_dict["enable_dp_lm_head"] = False
global_server_args_dict["enable_fp32_lm_head"] = enable_fp32
cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None)
return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None)
def _run_case(
self,
hidden_state_dtype,
enable_fp32,
weights_dtype,
expected_a_dtype,
expected_b_dtype,
):
device = "cuda"
BATCH_SIZE, HIDDEN_SIZE, VOCAB_SIZE = 2, 64, 128
hidden_state = torch.randn(
BATCH_SIZE, HIDDEN_SIZE, dtype=hidden_state_dtype, device=device
)
head = LMHeadStub(VOCAB_SIZE, HIDDEN_SIZE, dtype=weights_dtype, device=device)
meta = DummyMeta()
logprocessor = self._make_logprocessor(VOCAB_SIZE, enable_fp32)
original_matmul = torch.matmul
original_linear = F.linear
state = {
"called": False, # Whether a matmul/linear call has been intercepted yet
"operation": None, # Which operation was captured ("matmul" or "linear")
"a": None, # The dtype of the first input tensor to the operation
"b": None, # The dtype of the second input tensor to the operation
}
def probe_matmul(a, b, *args, **kw):
if not state["called"]:
state.update(called=True, operation="matmul", a=a.dtype, b=b.dtype)
return original_matmul(a, b, *args, **kw)
def probe_linear(x, w, bias=None):
if not state["called"]:
state.update(called=True, ooperationp="linear", a=x.dtype, b=w.dtype)
return original_linear(x, w, bias)
with patch("torch.matmul", new=probe_matmul), patch(
"torch.nn.functional.linear", new=probe_linear
):
logits = logprocessor._get_logits(hidden_state, head, meta)
self.assertEqual(hidden_state.dtype, hidden_state_dtype)
self.assertTrue(state["called"], "no call lm head matlmul/linear")
self.assertEqual(state["a"], expected_a_dtype)
self.assertEqual(state["b"], expected_b_dtype)
def test_flag_true_fp16_activations(self):
self._run_case(torch.float16, True, torch.float16, torch.float32, torch.float32)
def test_flag_true_bf16_activations(self):
self._run_case(
torch.bfloat16, True, torch.bfloat16, torch.float32, torch.float32
)
def test_flag_false_fp16_path(self):
self._run_case(
torch.float16, False, torch.float16, torch.float16, torch.float16
)
def test_flag_false_bf16_path(self):
self._run_case(
torch.bfloat16, False, torch.bfloat16, torch.bfloat16, torch.bfloat16
)
if __name__ == "__main__":
unittest.main(verbosity=2)