Fix BumpAllocator error when no input_ids (#5564)

This commit is contained in:
fzyzcjy
2025-04-20 17:20:53 +08:00
committed by GitHub
parent 80ac527d22
commit 0a0dd34e6a
2 changed files with 6 additions and 2 deletions

View File

@@ -94,7 +94,9 @@ class DeepseekModelNextN(nn.Module):
zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
device=input_ids.device,
device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
)
if input_embeds is None:

View File

@@ -1374,7 +1374,9 @@ class DeepseekV2Model(nn.Module):
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size=len(self.layers) * 2,
dtype=torch.float32,
device=input_ids.device,
device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
)
if input_embeds is None: