Rename InputMetadata -> ForwardBatch (#1543)

This commit is contained in:
Lianmin Zheng
2024-09-30 02:41:11 -07:00
committed by GitHub
parent 3f0fe08d37
commit 36d5acfca5
44 changed files with 435 additions and 433 deletions

View File

@@ -44,7 +44,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import set_weight_attrs
@@ -249,14 +249,14 @@ class DbrxAttention(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
hidden_states, _ = self.out_proj(attn_output)
return hidden_states
@@ -278,14 +278,14 @@ class DbrxFusedNormAttention(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = residual + x
residual = hidden_states
@@ -310,12 +310,12 @@ class DbrxBlock(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states, residual = self.norm_attn_norm(
position_ids=position_ids,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = self.ffn(hidden_states)
hidden_states = hidden_states + residual
@@ -349,7 +349,7 @@ class DbrxModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -358,7 +358,7 @@ class DbrxModel(nn.Module):
hidden_states = input_embeds
for i in range(len(self.blocks)):
block = self.blocks[i]
hidden_states = block(position_ids, hidden_states, input_metadata)
hidden_states = block(position_ids, hidden_states, forward_batch)
hidden_states = self.norm_f(hidden_states)
return hidden_states
@@ -388,11 +388,11 @@ class DbrxForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):