Rename InputMetadata -> ForwardBatch (#1543)

This commit is contained in:
Lianmin Zheng
2024-09-30 02:41:11 -07:00
committed by GitHub
parent 3f0fe08d37
commit 36d5acfca5
44 changed files with 435 additions and 433 deletions

View File

@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
@@ -51,13 +51,13 @@ class LlamaForSequenceClassification(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> EmbeddingPoolerOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
scores = self.score(hidden_states)
return self.pooler(scores, input_metadata)
return self.pooler(scores, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
@@ -102,19 +102,19 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
logits = self.score(hidden_states)
weights = self.weights(hidden_states)
pooled_logits = self.pooler(logits, input_metadata).embeddings
pooled_weights = self.pooler(weights, input_metadata).embeddings
pooled_logits = self.pooler(logits, forward_batch).embeddings
pooled_weights = self.pooler(weights, forward_batch).embeddings
rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view(
-1, self.num_labels // 2