diff --git a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py index 57fbcfddf..ff6280dc7 100644 --- a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py +++ b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py @@ -235,7 +235,7 @@ class MiniMaxText01LightningAttention(nn.Module): "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype) ) output.append(qkv) - output = torch.concat(output, dim=-2) + output = torch.cat(output, dim=-2) # reshape output = rearrange(output, "b h n d -> b n (h d)") diff --git a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py index cd298487b..3bf9054bd 100644 --- a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py +++ b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py @@ -403,7 +403,7 @@ class MiniMaxText01LightningAttention(nn.Module): "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype) ) output.append(qkv) - output = torch.concat(output, dim=-2) + output = torch.cat(output, dim=-2) # reshape output = rearrange(output, "b h n d -> b n (h d)") # normalize diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 27f87d8a2..5cfd6d244 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1244,14 +1244,14 @@ class ScheduleBatch: self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens]) self.encoder_lens_cpu.extend(other.encoder_lens_cpu) - self.req_pool_indices = torch.concat( + self.req_pool_indices = torch.cat( [self.req_pool_indices, other.req_pool_indices] ) - self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) + self.seq_lens = torch.cat([self.seq_lens, other.seq_lens]) self.out_cache_loc = None self.seq_lens_sum += other.seq_lens_sum if self.output_ids is not None: - self.output_ids = torch.concat([self.output_ids, other.output_ids]) + self.output_ids = torch.cat([self.output_ids, other.output_ids]) if self.return_logprob and other.return_logprob: self.top_logprobs_nums.extend(other.top_logprobs_nums) self.token_ids_logprobs.extend(other.token_ids_logprobs) diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index e16e8e3ad..d2010a531 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -303,7 +303,7 @@ class HiRadixCache(RadixCache): value, last_node = self._match_prefix_helper(self.root_node, key) if value: - value = torch.concat(value) + value = torch.cat(value) else: value = torch.tensor([], dtype=torch.int32) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 126d03ab8..bdc3ae844 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -172,7 +172,7 @@ class TokenToKVPoolAllocator: return if self.is_not_in_free_group: - self.free_slots = torch.concat((self.free_slots, free_index)) + self.free_slots = torch.cat((self.free_slots, free_index)) else: self.free_group.append(free_index) @@ -183,7 +183,7 @@ class TokenToKVPoolAllocator: def free_group_end(self): self.is_not_in_free_group = True if self.free_group: - self.free(torch.concat(self.free_group)) + self.free(torch.cat(self.free_group)) def clear(self): # The padded slot 0 is used for writing dummy outputs from padded tokens. @@ -739,7 +739,7 @@ class HostKVCache(abc.ABC): @synchronized def free(self, indices: torch.Tensor) -> int: self.mem_state[indices] = MemoryStateInt.IDLE - self.free_slots = torch.concat([self.free_slots, indices]) + self.free_slots = torch.cat([self.free_slots, indices]) self.can_use_mem_size += len(indices) return len(indices) diff --git a/python/sglang/srt/mem_cache/paged_allocator.py b/python/sglang/srt/mem_cache/paged_allocator.py index 3b07aa2a3..76c74d069 100644 --- a/python/sglang/srt/mem_cache/paged_allocator.py +++ b/python/sglang/srt/mem_cache/paged_allocator.py @@ -272,7 +272,7 @@ class PagedTokenToKVPoolAllocator: def free_group_end(self): self.is_not_in_free_group = True if self.free_group: - self.free(torch.concat(self.free_group)) + self.free(torch.cat(self.free_group)) def clear(self): # The padded slot 0 is used for writing dummy outputs from padded tokens. diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 951f4d869..58ee432b9 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -152,7 +152,7 @@ class RadixCache(BasePrefixCache): value, last_node = self._match_prefix_helper(self.root_node, key) if value: - value = torch.concat(value) + value = torch.cat(value) else: value = torch.empty((0,), dtype=torch.int32, device=self.device) return value, last_node @@ -317,7 +317,7 @@ class RadixCache(BasePrefixCache): _dfs_helper(child) _dfs_helper(self.root_node) - return torch.concat(values) + return torch.cat(values) ##### Internal Helper Functions ##### diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index b732b033e..35667465a 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -383,7 +383,7 @@ class ForwardBatch: batch.image_inputs[i].mrope_position_delta = mrope_position_delta mrope_positions_list[i] = mrope_positions - self.mrope_positions = torch.concat( + self.mrope_positions = torch.cat( [torch.tensor(pos, device=device) for pos in mrope_positions_list], axis=1, ) @@ -449,7 +449,7 @@ def compute_position_kernel( def compute_position_torch( extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor ): - positions = torch.concat( + positions = torch.cat( [ torch.arange( prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device diff --git a/python/sglang/srt/models/deepseek_janus_pro.py b/python/sglang/srt/models/deepseek_janus_pro.py index 7b0dac16c..75a88b13f 100644 --- a/python/sglang/srt/models/deepseek_janus_pro.py +++ b/python/sglang/srt/models/deepseek_janus_pro.py @@ -1289,7 +1289,7 @@ class MlpProjector(nn.Module): high_x, low_x = x_or_tuple high_x = self.high_up_proj(high_x) low_x = self.low_up_proj(low_x) - x = torch.concat([high_x, low_x], dim=-1) + x = torch.cat([high_x, low_x], dim=-1) else: x = x_or_tuple diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py index 0e98a1392..00ae8fa01 100644 --- a/python/sglang/srt/models/minicpmv.py +++ b/python/sglang/srt/models/minicpmv.py @@ -828,7 +828,7 @@ class MiniCPMVBaseModel(nn.Module): ) if isinstance(image_embeds, list): - image_embeds = torch.concat(image_embeds) + image_embeds = torch.cat(image_embeds) return MiniCPMVImageEmbeddingInputs( image_bounds=image_bounds, diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 5942b8270..70a2443bf 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -306,7 +306,7 @@ class SamplingBatchInfo: ]: self_val = getattr(self, item, None) other_val = getattr(other, item, None) - setattr(self, item, torch.concat([self_val, other_val])) + setattr(self, item, torch.cat([self_val, other_val])) self.is_all_greedy |= other.is_all_greedy self.need_min_p_sampling |= other.need_min_p_sampling diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 086b532ac..fa8dc5f21 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -59,7 +59,7 @@ class EagleDraftInput: pt = 0 for i, extend_len in enumerate(batch.extend_lens): input_ids = batch.input_ids[pt : pt + extend_len] - batch.input_ids[pt : pt + extend_len] = torch.concat( + batch.input_ids[pt : pt + extend_len] = torch.cat( (input_ids[1:], self.verified_id[i].reshape(1)) ) pt += extend_len diff --git a/sgl-kernel/benchmark/bench_lightning_attention_decode.py b/sgl-kernel/benchmark/bench_lightning_attention_decode.py index 24872e61a..36bdccac0 100644 --- a/sgl-kernel/benchmark/bench_lightning_attention_decode.py +++ b/sgl-kernel/benchmark/bench_lightning_attention_decode.py @@ -148,7 +148,7 @@ def lightning_attention_decode_naive(q, k, v, past_kv, slope): kv.to(torch.float32), ) output.append(qkv) - output = torch.concat(output, dim=-2) + output = torch.cat(output, dim=-2) return output.to(original_dtype), kv diff --git a/sgl-kernel/tests/test_lightning_attention_decode.py b/sgl-kernel/tests/test_lightning_attention_decode.py index f2cace001..8f5d4bb77 100644 --- a/sgl-kernel/tests/test_lightning_attention_decode.py +++ b/sgl-kernel/tests/test_lightning_attention_decode.py @@ -24,7 +24,7 @@ def naive_lightning_attention_decode(q, k, v, past_kv, slope): kv.to(torch.float32), ) output.append(qkv) - output = torch.concat(output, dim=-2) + output = torch.cat(output, dim=-2) return output.to(original_dtype), kv