Add support for Qwen3-seq-cls (#9357)

This commit is contained in:
nathan
2025-08-21 01:51:56 +02:00
committed by GitHub
parent ef3004d90a
commit 8f5b9910c1

View File

@@ -42,7 +42,13 @@ class Qwen3ForSequenceClassification(nn.Module):
# Use normalize=True for qwen3 embedding based on official implementation
# Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
# Official code: output = F.normalize(output, p=2, dim=1)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
normalize = True
# We don't want to normalize the embedding if we have a classification head
if config.id2label is not None or config.label2id is not None:
normalize = False
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=normalize)
self.eos_token_id = config.eos_token_id