|
|
|
|
@ -57,7 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
|
|
|
|
is_pp_missing_parameter,
|
|
|
|
|
make_empty_intermediate_tensors_factory, make_layers,
|
|
|
|
|
maybe_prefix, merge_multimodal_embeddings)
|
|
|
|
|
from .vision import select_patch_features
|
|
|
|
|
from .vision import scatter_patch_features, select_patch_features
|
|
|
|
|
|
|
|
|
|
# TODO: hard-coded for now. Consider making it configurable.
|
|
|
|
|
VIT_LAYERS = [-2, -9]
|
|
|
|
|
@ -71,13 +71,13 @@ POOLING_SIZE = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MolmoImageInputs(TypedDict):
|
|
|
|
|
images: Union[torch.Tensor, List[torch.Tensor]]
|
|
|
|
|
images: Union[torch.Tensor, list[torch.Tensor]]
|
|
|
|
|
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
|
|
|
|
|
|
|
|
|
|
image_masks: Optional[Union[torch.Tensor, List[torch.Tensor]]]
|
|
|
|
|
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
|
|
|
|
|
"""Shape: `(batch_size, num_crops, num_patch)`"""
|
|
|
|
|
|
|
|
|
|
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
|
|
|
|
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
|
|
|
|
"""
|
|
|
|
|
A boolean mask indicating which image features correspond
|
|
|
|
|
to patch tokens.
|
|
|
|
|
@ -85,7 +85,7 @@ class MolmoImageInputs(TypedDict):
|
|
|
|
|
Shape: `(batch_size, num_crops, num_patch)`
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
|
|
|
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
|
|
|
|
"""
|
|
|
|
|
A boolean mask indicating which image embeddings correspond
|
|
|
|
|
to patch tokens.
|
|
|
|
|
@ -93,7 +93,7 @@ class MolmoImageInputs(TypedDict):
|
|
|
|
|
Shape: `(batch_size, num_embeds)`
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
num_crops: torch.Tensor
|
|
|
|
|
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
|
|
|
|
"""Shape: `(batch_size, num_images)`"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1144,13 +1144,7 @@ class MolmoProcessorWrapper:
|
|
|
|
|
|
|
|
|
|
image_input_idx = outputs.pop("image_input_idx", None)
|
|
|
|
|
if image_input_idx is not None:
|
|
|
|
|
input_is_patch = input_ids == self.image_patch_id
|
|
|
|
|
image_input_idx_flat: torch.Tensor = image_input_idx.view(-1)
|
|
|
|
|
image_valid_flat = image_input_idx_flat >= 0
|
|
|
|
|
feat_is_patch_flat = image_valid_flat.clone()
|
|
|
|
|
feat_is_patch_flat[image_valid_flat] = (
|
|
|
|
|
input_is_patch[image_input_idx_flat[image_valid_flat]])
|
|
|
|
|
feat_is_patch = feat_is_patch_flat.view(*image_input_idx.shape)
|
|
|
|
|
feat_is_patch = image_input_idx >= 0
|
|
|
|
|
|
|
|
|
|
input_is_embed = torch.isin(
|
|
|
|
|
input_ids,
|
|
|
|
|
@ -1165,6 +1159,17 @@ class MolmoProcessorWrapper:
|
|
|
|
|
embed_is_patch = embed_ids == self.image_patch_id
|
|
|
|
|
assert embed_is_patch.sum() == feat_is_patch.sum()
|
|
|
|
|
|
|
|
|
|
# image_tokens = extra_joint + joint
|
|
|
|
|
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
|
|
|
|
|
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
|
|
|
|
|
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
|
|
|
|
|
assert len(embed_start) == len(embed_end) == len(images)
|
|
|
|
|
|
|
|
|
|
embed_is_patch = [
|
|
|
|
|
embed_is_patch[start:end + 1]
|
|
|
|
|
for start, end in zip(embed_start, embed_end)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
tilings = [
|
|
|
|
|
self.select_tiling(
|
|
|
|
|
image_width=image.size[0],
|
|
|
|
|
@ -1180,7 +1185,7 @@ class MolmoProcessorWrapper:
|
|
|
|
|
outputs["num_crops"] = num_crops
|
|
|
|
|
outputs["img_patch_id"] = self.image_patch_id
|
|
|
|
|
|
|
|
|
|
return BatchFeature(outputs, tensor_type=return_tensors)
|
|
|
|
|
return BatchFeature(outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MolmoProcessingInfo(BaseProcessingInfo):
|
|
|
|
|
@ -1190,9 +1195,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
|
|
|
|
return MolmoProcessorWrapper(processor)
|
|
|
|
|
|
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
|
|
|
# TODO: Investigate different `embed_is_patch` between cache/no-cache
|
|
|
|
|
# in multi-image case
|
|
|
|
|
return {"image": 1}
|
|
|
|
|
return {"image": None}
|
|
|
|
|
|
|
|
|
|
def get_mm_max_tokens_per_item(
|
|
|
|
|
self,
|
|
|
|
|
@ -1325,7 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
|
|
|
|
"image", num_crops),
|
|
|
|
|
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
|
|
|
|
"image", num_crops),
|
|
|
|
|
embed_is_patch=MultiModalFieldConfig.shared("image", num_images),
|
|
|
|
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
|
|
|
|
num_crops=MultiModalFieldConfig.batched("image"),
|
|
|
|
|
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
|
|
|
|
)
|
|
|
|
|
@ -1499,7 +1502,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|
|
|
|
def _process_image_input(
|
|
|
|
|
self,
|
|
|
|
|
image_input: MolmoImageInputs,
|
|
|
|
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
|
|
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
|
|
|
|
if isinstance(image_input["images"], list):
|
|
|
|
|
# Call the vision backbone on the whole batch at once
|
|
|
|
|
images_flat = flatten_bn(image_input["images"], concat=True)
|
|
|
|
|
@ -1530,7 +1533,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|
|
|
|
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
|
|
|
|
|
num_crops: torch.Tensor, # Shape: (num_images,)
|
|
|
|
|
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
|
|
|
|
|
) -> list[torch.Tensor]:
|
|
|
|
|
) -> tuple[torch.Tensor, ...]:
|
|
|
|
|
"""
|
|
|
|
|
Scatter the patch features into a contiguous tensor that corresponds
|
|
|
|
|
to the embedding tokens defined by the multimodal processor.
|
|
|
|
|
@ -1565,16 +1568,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|
|
|
|
feats_per_image = features.split(num_crops_per_image)
|
|
|
|
|
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
|
|
|
|
|
|
|
|
|
|
_, _, embed_dim = features.shape
|
|
|
|
|
(num_embeds, ) = embed_is_patch.shape
|
|
|
|
|
features = torch.cat([
|
|
|
|
|
feats[f_is_patch]
|
|
|
|
|
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image)
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
embeds_in_batch = list[torch.Tensor]()
|
|
|
|
|
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
|
|
|
|
|
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
|
|
|
|
|
embeds[embed_is_patch] = feats[f_is_patch]
|
|
|
|
|
embeds_in_batch.append(embeds)
|
|
|
|
|
|
|
|
|
|
return embeds_in_batch
|
|
|
|
|
return scatter_patch_features(features, embed_is_patch)
|
|
|
|
|
|
|
|
|
|
def get_multimodal_embeddings(
|
|
|
|
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
|
|
|
|
|