Remove redundant type conversion (#4513)
This commit is contained in:
7
.github/workflows/pr-test-amd.yml
vendored
7
.github/workflows/pr-test-amd.yml
vendored
@@ -21,7 +21,8 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
accuracy-test-1-gpu-amd:
|
accuracy-test-1-gpu-amd:
|
||||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false
|
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||||
|
github.event.pull_request.draft == false
|
||||||
runs-on: linux-mi300-gpu-1
|
runs-on: linux-mi300-gpu-1
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
@@ -60,7 +61,8 @@ jobs:
|
|||||||
docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py
|
docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py
|
||||||
|
|
||||||
mla-test-1-gpu-amd:
|
mla-test-1-gpu-amd:
|
||||||
if: github.event.pull_request.head.repo.fork == false && github.event.pull_request.draft == false
|
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||||
|
github.event.pull_request.draft == false
|
||||||
runs-on: linux-mi300-gpu-1
|
runs-on: linux-mi300-gpu-1
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
@@ -97,6 +99,7 @@ jobs:
|
|||||||
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py
|
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py
|
||||||
|
|
||||||
finish:
|
finish:
|
||||||
|
if: always()
|
||||||
needs: [
|
needs: [
|
||||||
accuracy-test-1-gpu-amd, mla-test-1-gpu-amd
|
accuracy-test-1-gpu-amd, mla-test-1-gpu-amd
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1008,7 +1008,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
global_override_indptr_cpu = None
|
global_override_indptr_cpu = None
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
kv_indices = torch.zeros(
|
kv_indices = torch.empty(
|
||||||
(
|
(
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
forward_batch.batch_size * self.topk * self.max_context_len,
|
forward_batch.batch_size * self.topk * self.max_context_len,
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indices = torch.zeros(
|
kv_indices = torch.empty(
|
||||||
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
@@ -100,7 +100,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
bs = kv_indptr.shape[0] - 1
|
bs = kv_indptr.shape[0] - 1
|
||||||
|
|
||||||
attn_logits = torch.zeros(
|
attn_logits = torch.empty(
|
||||||
(
|
(
|
||||||
bs,
|
bs,
|
||||||
self.num_head,
|
self.num_head,
|
||||||
@@ -127,7 +127,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
# Different with flashinfer kv_indptr and kv_indices construction
|
# Different with flashinfer kv_indptr and kv_indices construction
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indices = torch.zeros(
|
kv_indices = torch.empty(
|
||||||
kv_indptr[-1], dtype=torch.int32, device=self.device
|
kv_indptr[-1], dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
@@ -166,7 +166,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
forward_batch.extend_prefix_lens, dim=0
|
forward_batch.extend_prefix_lens, dim=0
|
||||||
)
|
)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indices = torch.zeros(
|
kv_indices = torch.empty(
|
||||||
forward_batch.extend_prefix_lens.sum().item(),
|
forward_batch.extend_prefix_lens.sum().item(),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
@@ -531,7 +531,7 @@ class TritonMultiStepDraftBackend:
|
|||||||
call_fn(i, forward_batch)
|
call_fn(i, forward_batch)
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
kv_indices = torch.zeros(
|
kv_indices = torch.empty(
|
||||||
(
|
(
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
forward_batch.batch_size * self.topk * self.max_context_len,
|
forward_batch.batch_size * self.topk * self.max_context_len,
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ class Sampler(nn.Module):
|
|||||||
group=self.tp_sync_group,
|
group=self.tp_sync_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
return batch_next_token_ids.to(torch.int32)
|
return batch_next_token_ids
|
||||||
|
|
||||||
def _apply_custom_logit_processor(
|
def _apply_custom_logit_processor(
|
||||||
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class TpModelWorkerClient:
|
|||||||
self.future_token_ids_ct = 0
|
self.future_token_ids_ct = 0
|
||||||
self.future_token_ids_limit = self.max_running_requests * 3
|
self.future_token_ids_limit = self.max_running_requests * 3
|
||||||
self.future_token_ids_map = torch.empty(
|
self.future_token_ids_map = torch.empty(
|
||||||
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
|
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Launch threads
|
# Launch threads
|
||||||
|
|||||||
@@ -44,6 +44,9 @@ class TestUpdateWeightsFromTensor(unittest.TestCase):
|
|||||||
def test_update_weights_from_tensor(self):
|
def test_update_weights_from_tensor(self):
|
||||||
tp_sizes = [1, 2]
|
tp_sizes = [1, 2]
|
||||||
for tp_size in tp_sizes:
|
for tp_size in tp_sizes:
|
||||||
|
if torch.cuda.device_count() < tp_size:
|
||||||
|
continue
|
||||||
|
|
||||||
with self.subTest(tp_size=tp_size):
|
with self.subTest(tp_size=tp_size):
|
||||||
test_update_weights_from_tensor(tp_size)
|
test_update_weights_from_tensor(tp_size)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user