Rename InputMetadata -> ForwardBatch (#1543)

This commit is contained in:
Lianmin Zheng
2024-09-30 02:41:11 -07:00
committed by GitHub
parent 3f0fe08d37
commit 36d5acfca5
44 changed files with 435 additions and 433 deletions

View File

@@ -63,7 +63,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import set_weight_attrs
@@ -220,14 +220,14 @@ class CohereAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -255,7 +255,7 @@ class CohereDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -264,7 +264,7 @@ class CohereDecoderLayer(nn.Module):
hidden_states_attention = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states_mlp = self.mlp(hidden_states)
# Add everything together
@@ -299,7 +299,7 @@ class CohereModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
@@ -308,7 +308,7 @@ class CohereModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@@ -333,15 +333,15 @@ class CohereForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
positions,
input_metadata,
forward_batch,
)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):