diff --git a/python/sglang/srt/models/qwen3_classification.py b/python/sglang/srt/models/qwen3_classification.py index 54802b558..a59d6769b 100644 --- a/python/sglang/srt/models/qwen3_classification.py +++ b/python/sglang/srt/models/qwen3_classification.py @@ -42,7 +42,13 @@ class Qwen3ForSequenceClassification(nn.Module): # Use normalize=True for qwen3 embedding based on official implementation # Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55 # Official code: output = F.normalize(output, p=2, dim=1) - self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + normalize = True + + # We don't want to normalize the embedding if we have a classification head + if config.id2label is not None or config.label2id is not None: + normalize = False + + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=normalize) self.eos_token_id = config.eos_token_id