Rename InputMetadata -> ForwardBatch (#1543)
This commit is contained in:
@@ -35,7 +35,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
@@ -90,7 +90,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.split(
|
||||
@@ -101,7 +101,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@@ -160,12 +160,12 @@ class GPTBigCodeBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states, input_metadata=input_metadata
|
||||
hidden_states=hidden_states, forward_batch=forward_batch
|
||||
)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
@@ -214,7 +214,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
@@ -222,7 +222,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
|
||||
for i in range(len(self.h)):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, input_metadata)
|
||||
hidden_states = layer(hidden_states, forward_batch)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@@ -267,11 +267,11 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
Reference in New Issue
Block a user