[feature] support for roberta embedding models (#5730)

This commit is contained in:
DavidBao
2025-04-27 09:47:06 +08:00
committed by GitHub
parent c5e1026f47
commit d8fbc7c096
3 changed files with 186 additions and 2 deletions

View File

@@ -12,6 +12,7 @@ from sglang.srt.model_executor.model_runner import ForwardBatch
class PoolingType(IntEnum):
LAST = 0
CLS = 1
@dataclass
@@ -41,6 +42,11 @@ class Pooler(nn.Module):
if self.pooling_type == PoolingType.LAST:
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
pooled_data = hidden_states[last_token_indices]
elif self.pooling_type == PoolingType.CLS:
prompt_lens = forward_batch.extend_seq_lens
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
pooled_data = hidden_states[first_token_flat_indices]
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")