Rename InputMetadata -> ForwardBatch (#1543)

This commit is contained in:
Lianmin Zheng
2024-09-30 02:41:11 -07:00
committed by GitHub
parent 3f0fe08d37
commit 36d5acfca5
44 changed files with 435 additions and 433 deletions

View File

@@ -17,7 +17,7 @@ limitations under the License.
from torch import nn
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class RadixAttention(nn.Module):
@@ -48,11 +48,11 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
def forward(self, q, k, v, input_metadata: InputMetadata):
def forward(self, q, k, v, forward_batch: ForwardBatch):
if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)