[feat]Ascend NPU Gemma-3-12b and Gemma-3-27b support (#8909)
This commit is contained in:
@@ -20,9 +20,11 @@ from sglang.srt.managers.schedule_batch import (
|
||||
)
|
||||
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import flatten_nested_list, print_warning_once
|
||||
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
|
||||
from sglang.utils import logger
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
|
||||
# to ensure consistent logging behavior across the codebase. This prevents issues with log
|
||||
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
||||
@@ -486,6 +488,8 @@ def get_embedding_and_mask(
|
||||
if embedding is None:
|
||||
return None, None
|
||||
# 2. Get mask
|
||||
if _is_npu:
|
||||
torch.npu.current_stream().synchronize()
|
||||
special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
|
||||
# 3. Adjust embedding length if needed
|
||||
embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)
|
||||
|
||||
Reference in New Issue
Block a user