[Bug fix] Fix Gemma 2 and fix Gemma 3 multimodal with bs > 1 on NPU (#9871)
Co-authored-by: Maksim <makcum888e@mail.ru>
This commit is contained in:
@@ -288,16 +288,11 @@ class GemmaRMSNorm(CustomOp):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor] = None,
|
residual: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
orig_dtype = x.dtype
|
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
x = x + residual
|
x = x + residual
|
||||||
residual = x
|
residual = x
|
||||||
|
|
||||||
x = x.float()
|
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
|
||||||
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
|
|
||||||
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
|
|
||||||
x = x * (1.0 + self.weight.float())
|
|
||||||
x = x.to(orig_dtype)
|
|
||||||
return x if residual is None else (x, residual)
|
return x if residual is None else (x, residual)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -46,10 +46,12 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import dump_to_file, use_intel_amx_backend
|
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_is_npu = is_npu()
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class LogitsProcessorOutput:
|
class LogitsProcessorOutput:
|
||||||
@@ -517,7 +519,12 @@ class LogitsProcessor(nn.Module):
|
|||||||
logits = logits[:, : self.config.vocab_size].float()
|
logits = logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
if self.final_logit_softcapping:
|
if self.final_logit_softcapping:
|
||||||
fused_softcap(logits, self.final_logit_softcapping)
|
if not _is_npu:
|
||||||
|
fused_softcap(logits, self.final_logit_softcapping)
|
||||||
|
else:
|
||||||
|
logits = self.final_logit_softcapping * torch.tanh(
|
||||||
|
logits / self.final_logit_softcapping
|
||||||
|
)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|||||||
@@ -550,7 +550,7 @@ class PrefillAdder:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Make sure at least one page is available
|
# Make sure at least one page is available
|
||||||
trunc_len = self.rem_chunk_tokens - self.page_size + 1
|
trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size
|
||||||
if trunc_len <= 0:
|
if trunc_len <= 0:
|
||||||
return AddReqResult.OTHER
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user