Rename InputMetadata -> ForwardBatch (#1543)

This commit is contained in:
Lianmin Zheng
2024-09-30 02:41:11 -07:00
committed by GitHub
parent 3f0fe08d37
commit 36d5acfca5
44 changed files with 435 additions and 433 deletions

View File

@@ -7,7 +7,7 @@ from enum import IntEnum
import torch
import torch.nn as nn
from sglang.srt.model_executor.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import ForwardBatch
class PoolingType(IntEnum):
@@ -36,10 +36,10 @@ class Pooler(nn.Module):
self.normalize = normalize
def forward(
self, hidden_states: torch.Tensor, input_metadata: InputMetadata
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> EmbeddingPoolerOutput:
if self.pooling_type == PoolingType.LAST:
last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
pooled_data = hidden_states[last_token_indices]
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")