diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index f3faa75d9..13ca29c54 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -85,8 +85,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) "No data_token_pairs provided, RadixAttention might be influenced." ) return input_ids - start_token_ids = [s for s, _e in data_token_pairs] - end_tokens_ids = [e for _s, e in data_token_pairs] + start_token_ids = {s for s, _e in data_token_pairs} + end_tokens_ids = {e for _s, e in data_token_pairs} padded_ids = [] last_idx = 0 @@ -135,7 +135,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa if not input_ids or not mm_inputs.mm_items: return input_ids - input_ids_tensor = torch.tensor(input_ids) + input_ids_tensor = torch.as_tensor(input_ids) # Create mapping of token_ids to pad_values for each modality token_to_pad_mapping = {} @@ -211,7 +211,7 @@ def get_embedding_chunk( end_index += extend_end_index - start + 1 elif extend_end_index > end: end_index += end - start + 1 - # some models embedding is 3-dim, reshape it to 2-dim + # some models' embedding is 3-dim, reshape it to 2-dim embedding = embedding.reshape(-1, embedding.shape[-1]) embedding_chunk = embedding[start_index:end_index] return embedding_chunk, start_index, end_index @@ -428,7 +428,7 @@ def embed_mm_inputs( modality_id = modality.name.lower() embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None) if len(items) != 0 and embedder is not None: - placeholder_tensor = torch.tensor( + placeholder_tensor = torch.as_tensor( [item.pad_value for item in items], device=input_ids.device, ) @@ -473,11 +473,9 @@ def embed_mm_inputs( for embedding, mask in zip(embeddings, masks): if embedding is None or mask is None: continue - mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device) - inputs_embeds = inputs_embeds.masked_scatter( - mask, - embedding.to(inputs_embeds.device, inputs_embeds.dtype), - ) + # in-place update + indices = torch.where(mask.squeeze(dim=-1))[0] + inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype) return inputs_embeds @@ -561,34 +559,36 @@ def get_multimodal_data_bounds( [bounds_count, 2] """ # All the multimodal data in the batch should share the same special bound token ids. - start_tokens = [s for s, _e in token_pairs] - end_tokens = [e for _s, e in token_pairs] + start_tokens = {s for s, _e in token_pairs} + end_tokens = {e for _s, e in token_pairs} assert all(isinstance(t, int) for t in start_tokens) assert all(isinstance(t, int) for t in end_tokens) start_cond = torch.isin( - input_ids, torch.tensor(start_tokens, device=input_ids.device) + input_ids, torch.as_tensor(start_tokens, device=input_ids.device) + ) + end_cond = torch.isin( + input_ids, torch.as_tensor(end_tokens, device=input_ids.device) ) - end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device)) (data_start_tokens,) = torch.where(start_cond) (data_end_tokens,) = torch.where(end_cond) + data_start_tokens_cpu = data_start_tokens.cpu().tolist() + data_end_tokens_cpu = data_end_tokens.cpu().tolist() + # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data - if len(data_start_tokens) != len(data_end_tokens): + if len(data_start_tokens_cpu) != len(data_end_tokens_cpu): if ( - len(data_start_tokens) + 1 == len(data_end_tokens) - and input_ids[0] in pad_values - and data_end_tokens[0] < data_start_tokens[0] + len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu) + and input_ids[0].item() in pad_values + and data_end_tokens_cpu + and data_start_tokens_cpu + and data_end_tokens_cpu[0] < data_start_tokens_cpu[0] ): - data_start_tokens = torch.cat( - [ - torch.tensor([0], device=data_start_tokens.device), - data_start_tokens, - ] - ) - valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens)) + data_start_tokens_cpu.insert(0, 0) + valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu)) if valid_mm_data_nums == 0: return torch.zeros((0, 2), device=input_ids.device) @@ -596,8 +596,8 @@ def get_multimodal_data_bounds( # Filter out pairs where start_token >= end_token valid_pairs = [] for i in range(valid_mm_data_nums): - start_token = data_start_tokens[i] - end_token = data_end_tokens[i] + start_token = data_start_tokens_cpu[i] + end_token = data_end_tokens_cpu[i] if start_token < end_token: valid_pairs.append((start_token + 1, end_token - 1)) @@ -605,7 +605,7 @@ def get_multimodal_data_bounds( return torch.zeros((0, 2), device=input_ids.device) # Convert valid pairs to tensor - valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device) + valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device) return valid_pairs_tensor @@ -634,11 +634,7 @@ def tensor_hash(tensor_list) -> int: tensor = tensor.float() assert isinstance(tensor, torch.Tensor) - if tensor.is_cuda: - # TODO: improve this - tensor_cpu = tensor.cpu() - else: - tensor_cpu = tensor + tensor_cpu = tensor.cpu() mv = memoryview(tensor_cpu.numpy()) return data_hash(mv.tobytes())