Add support for Qwen3-seq-cls (#9357)
This commit is contained in:
@@ -42,7 +42,13 @@ class Qwen3ForSequenceClassification(nn.Module):
|
||||
# Use normalize=True for qwen3 embedding based on official implementation
|
||||
# Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
|
||||
# Official code: output = F.normalize(output, p=2, dim=1)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
normalize = True
|
||||
|
||||
# We don't want to normalize the embedding if we have a classification head
|
||||
if config.id2label is not None or config.label2id is not None:
|
||||
normalize = False
|
||||
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=normalize)
|
||||
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user