Add support for Qwen3-seq-cls (#9357)
This commit is contained in:
@@ -42,7 +42,13 @@ class Qwen3ForSequenceClassification(nn.Module):
|
|||||||
# Use normalize=True for qwen3 embedding based on official implementation
|
# Use normalize=True for qwen3 embedding based on official implementation
|
||||||
# Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
|
# Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
|
||||||
# Official code: output = F.normalize(output, p=2, dim=1)
|
# Official code: output = F.normalize(output, p=2, dim=1)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
normalize = True
|
||||||
|
|
||||||
|
# We don't want to normalize the embedding if we have a classification head
|
||||||
|
if config.id2label is not None or config.label2id is not None:
|
||||||
|
normalize = False
|
||||||
|
|
||||||
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=normalize)
|
||||||
|
|
||||||
self.eos_token_id = config.eos_token_id
|
self.eos_token_id = config.eos_token_id
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user