Rename InputMetadata -> ForwardBatch (#1543)
This commit is contained in:
@@ -7,7 +7,7 @@ from enum import IntEnum
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
@@ -36,10 +36,10 @@ class Pooler(nn.Module):
|
||||
self.normalize = normalize
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, input_metadata: InputMetadata
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> EmbeddingPoolerOutput:
|
||||
if self.pooling_type == PoolingType.LAST:
|
||||
last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1
|
||||
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
|
||||
pooled_data = hidden_states[last_token_indices]
|
||||
else:
|
||||
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
||||
|
||||
Reference in New Issue
Block a user