Rename InputMetadata -> ForwardBatch (#1543)

This commit is contained in:
Lianmin Zheng
2024-09-30 02:41:11 -07:00
committed by GitHub
parent 3f0fe08d37
commit 36d5acfca5
44 changed files with 435 additions and 433 deletions

View File

@@ -42,7 +42,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip
# ROCm: flashinfer available later
@@ -193,7 +193,7 @@ class MiniCPM3Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
@@ -230,7 +230,7 @@ class MiniCPM3Attention(nn.Module):
v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
-1, self.num_local_heads * 128
)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, 128)[
..., : self.v_head_dim
].reshape(-1, self.num_local_heads * self.v_head_dim)
@@ -341,7 +341,7 @@ class MiniCPM3AttentionMLA(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
@@ -383,7 +383,7 @@ class MiniCPM3AttentionMLA(nn.Module):
q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
attn_output = self.attn(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fn:
@@ -472,7 +472,7 @@ class MiniCPM3DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -481,7 +481,7 @@ class MiniCPM3DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = residual + hidden_states * (
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
@@ -528,7 +528,7 @@ class MiniCPM3Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -542,7 +542,7 @@ class MiniCPM3Model(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states = self.norm(hidden_states)
@@ -581,19 +581,19 @@ class MiniCPM3ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is not None:
input_embeds = input_embeds * self.config.scale_emb
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, input_metadata
input_ids, hidden_states, lm_head_weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):