[feature] support for roberta embedding models (#5730)
This commit is contained in:
@@ -12,6 +12,7 @@ from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
LAST = 0
|
||||
CLS = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -41,6 +42,11 @@ class Pooler(nn.Module):
|
||||
if self.pooling_type == PoolingType.LAST:
|
||||
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
|
||||
pooled_data = hidden_states[last_token_indices]
|
||||
elif self.pooling_type == PoolingType.CLS:
|
||||
prompt_lens = forward_batch.extend_seq_lens
|
||||
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
||||
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
|
||||
pooled_data = hidden_states[first_token_flat_indices]
|
||||
else:
|
||||
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user