Remove redundant type conversion (#4513)

This commit is contained in:
Lianmin Zheng
2025-03-17 05:57:35 -07:00
committed by GitHub
parent 5f9b2c62ff
commit 82dec1f70b
6 changed files with 16 additions and 10 deletions

View File

@@ -1008,7 +1008,7 @@ class FlashInferMultiStepDraftBackend:
global_override_indptr_cpu = None
def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros(
kv_indices = torch.empty(
(
self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len,

View File

@@ -84,7 +84,7 @@ class TritonAttnBackend(AttentionBackend):
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros(
kv_indices = torch.empty(
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
@@ -100,7 +100,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
attn_logits = torch.zeros(
attn_logits = torch.empty(
(
bs,
self.num_head,
@@ -127,7 +127,7 @@ class TritonAttnBackend(AttentionBackend):
# Different with flashinfer kv_indptr and kv_indices construction
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros(
kv_indices = torch.empty(
kv_indptr[-1], dtype=torch.int32, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
@@ -166,7 +166,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch.extend_prefix_lens, dim=0
)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros(
kv_indices = torch.empty(
forward_batch.extend_prefix_lens.sum().item(),
dtype=torch.int32,
device=self.device,
@@ -531,7 +531,7 @@ class TritonMultiStepDraftBackend:
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros(
kv_indices = torch.empty(
(
self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len,

View File

@@ -168,7 +168,7 @@ class Sampler(nn.Module):
group=self.tp_sync_group,
)
return batch_next_token_ids.to(torch.int32)
return batch_next_token_ids
def _apply_custom_logit_processor(
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo

View File

@@ -69,7 +69,7 @@ class TpModelWorkerClient:
self.future_token_ids_ct = 0
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
)
# Launch threads