Rename InputMetadata -> ForwardBatch (#1543)
This commit is contained in:
@@ -42,7 +42,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
@@ -193,7 +193,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
@@ -230,7 +230,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
|
||||
-1, self.num_local_heads * 128
|
||||
)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, 128)[
|
||||
..., : self.v_head_dim
|
||||
].reshape(-1, self.num_local_heads * self.v_head_dim)
|
||||
@@ -341,7 +341,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
q_len = hidden_states.shape[0]
|
||||
q_input = hidden_states.new_empty(
|
||||
@@ -383,7 +383,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
q_input[..., self.kv_lora_rank :] = q_pe
|
||||
k_input[..., self.kv_lora_rank :] = k_pe
|
||||
|
||||
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
|
||||
attn_output = self.attn(q_input, k_input, v_input, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
if self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
@@ -472,7 +472,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -481,7 +481,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
hidden_states = residual + hidden_states * (
|
||||
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
|
||||
@@ -528,7 +528,7 @@ class MiniCPM3Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -542,7 +542,7 @@ class MiniCPM3Model(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@@ -581,19 +581,19 @@ class MiniCPM3ForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is not None:
|
||||
input_embeds = input_embeds * self.config.scale_emb
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
hidden_states = hidden_states / self.scale_width
|
||||
if self.config.tie_word_embeddings:
|
||||
lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
lm_head_weight = self.lm_head.weight
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, lm_head_weight, input_metadata
|
||||
input_ids, hidden_states, lm_head_weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
Reference in New Issue
Block a user