From 82dec1f70b4c559c0d320a6c5c75aaa5493b5bc9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 17 Mar 2025 05:57:35 -0700 Subject: [PATCH] Remove redundant type conversion (#4513) --- .github/workflows/pr-test-amd.yml | 7 +++++-- .../sglang/srt/layers/attention/flashinfer_backend.py | 2 +- python/sglang/srt/layers/attention/triton_backend.py | 10 +++++----- python/sglang/srt/layers/sampler.py | 2 +- python/sglang/srt/managers/tp_worker_overlap_thread.py | 2 +- test/srt/test_update_weights_from_tensor.py | 3 +++ 6 files changed, 16 insertions(+), 10 deletions(-) diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index 38b810938..0ba7994ff 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -21,7 +21,8 @@ concurrency: jobs: accuracy-test-1-gpu-amd: - if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && + github.event.pull_request.draft == false runs-on: linux-mi300-gpu-1 steps: - name: Checkout code @@ -60,7 +61,8 @@ jobs: docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py mla-test-1-gpu-amd: - if: github.event.pull_request.head.repo.fork == false && github.event.pull_request.draft == false + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && + github.event.pull_request.draft == false runs-on: linux-mi300-gpu-1 steps: - name: Checkout code @@ -97,6 +99,7 @@ jobs: docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py finish: + if: always() needs: [ accuracy-test-1-gpu-amd, mla-test-1-gpu-amd ] diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 791cbeec0..fba806010 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -1008,7 +1008,7 @@ class FlashInferMultiStepDraftBackend: global_override_indptr_cpu = None def init_forward_metadata(self, forward_batch: ForwardBatch): - kv_indices = torch.zeros( + kv_indices = torch.empty( ( self.speculative_num_steps, forward_batch.batch_size * self.topk * self.max_context_len, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index b942dee5c..f5cb29a0f 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -84,7 +84,7 @@ class TritonAttnBackend(AttentionBackend): if spec_info is None: kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.zeros( + kv_indices = torch.empty( forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device ) create_flashinfer_kv_indices_triton[(bs,)]( @@ -100,7 +100,7 @@ class TritonAttnBackend(AttentionBackend): kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices bs = kv_indptr.shape[0] - 1 - attn_logits = torch.zeros( + attn_logits = torch.empty( ( bs, self.num_head, @@ -127,7 +127,7 @@ class TritonAttnBackend(AttentionBackend): # Different with flashinfer kv_indptr and kv_indices construction kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.zeros( + kv_indices = torch.empty( kv_indptr[-1], dtype=torch.int32, device=self.device ) create_flashinfer_kv_indices_triton[(bs,)]( @@ -166,7 +166,7 @@ class TritonAttnBackend(AttentionBackend): forward_batch.extend_prefix_lens, dim=0 ) kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.zeros( + kv_indices = torch.empty( forward_batch.extend_prefix_lens.sum().item(), dtype=torch.int32, device=self.device, @@ -531,7 +531,7 @@ class TritonMultiStepDraftBackend: call_fn(i, forward_batch) def init_forward_metadata(self, forward_batch: ForwardBatch): - kv_indices = torch.zeros( + kv_indices = torch.empty( ( self.speculative_num_steps, forward_batch.batch_size * self.topk * self.max_context_len, diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 37f22ec21..fcf2af9ea 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -168,7 +168,7 @@ class Sampler(nn.Module): group=self.tp_sync_group, ) - return batch_next_token_ids.to(torch.int32) + return batch_next_token_ids def _apply_custom_logit_processor( self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index deb2fbe59..fb4fdc6d5 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -69,7 +69,7 @@ class TpModelWorkerClient: self.future_token_ids_ct = 0 self.future_token_ids_limit = self.max_running_requests * 3 self.future_token_ids_map = torch.empty( - (self.max_running_requests * 5,), dtype=torch.int32, device=self.device + (self.max_running_requests * 5,), dtype=torch.int64, device=self.device ) # Launch threads diff --git a/test/srt/test_update_weights_from_tensor.py b/test/srt/test_update_weights_from_tensor.py index 64780b78f..1e0134715 100644 --- a/test/srt/test_update_weights_from_tensor.py +++ b/test/srt/test_update_weights_from_tensor.py @@ -44,6 +44,9 @@ class TestUpdateWeightsFromTensor(unittest.TestCase): def test_update_weights_from_tensor(self): tp_sizes = [1, 2] for tp_size in tp_sizes: + if torch.cuda.device_count() < tp_size: + continue + with self.subTest(tp_size=tp_size): test_update_weights_from_tensor(tp_size)