| | """Image processor class for Palmyra_vision""" |
| | from typing import List, Optional, Union, Mapping |
| |
|
| | import numpy as np |
| | import einops |
| | import torch |
| | import torchvision.transforms |
| | from torchvision.transforms import InterpolationMode |
| | from torchvision.transforms.functional import convert_image_dtype |
| |
|
| | from transformers.image_utils import ( |
| | OPENAI_CLIP_MEAN, |
| | OPENAI_CLIP_STD, |
| | ImageInput, |
| | is_valid_image, |
| | ) |
| | from transformers.processing_utils import ImagesKwargs |
| | from transformers.image_processing_utils import BaseImageProcessor |
| | from transformers.utils import logging |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def pad_to_bounding_box( |
| | image, offset_height, offset_width, target_height, |
| | target_width, value=0 |
| | ): |
| | height, width = image.shape[:2] |
| | after_padding_width = target_width - offset_width - width |
| | after_padding_height = target_height - offset_height - height |
| | return np.pad(image, [ |
| | [offset_height, after_padding_height], |
| | [offset_width, after_padding_width], |
| | [0, 0] |
| | ], constant_values=value) |
| |
|
| |
|
| | def normalize_image(image, offset, scale): |
| | image -= np.array(offset, dtype=np.float32)[None, None, :] |
| | image /= np.array(scale, dtype=np.float32)[None, None, :] |
| | return image |
| |
|
| |
|
| | def resize_and_pad( |
| | image, |
| | desired_output_size, |
| | resize_method="torch-bilinear", |
| | pad_value=0, |
| | normalize=True, |
| | image_mean=OPENAI_CLIP_MEAN, |
| | image_std=OPENAI_CLIP_STD, |
| | ): |
| | desired_height, desired_width = desired_output_size |
| | height, width = image.shape[:2] |
| |
|
| | |
| | |
| | image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32) |
| | image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32) |
| | image_scale = min(image_scale_x, image_scale_y) |
| | scaled_height = int(np.array(height, np.float32) * image_scale) |
| | scaled_width = int(np.array(width, np.float32) * image_scale) |
| |
|
| | if resize_method == "tensorflow": |
| | |
| | |
| | import tensorflow as tf |
| | image = tf.image.convert_image_dtype(tf.constant(image), dtype=tf.float32) |
| | image = tf.image.resize( |
| | image, |
| | [scaled_height, scaled_width], |
| | method=tf.image.ResizeMethod.BILINEAR, |
| | antialias=True, |
| | ) |
| | image = tf.clip_by_value(image, 0.0, 1.0) |
| | image = image.numpy() |
| | elif resize_method == "torch-bilinear": |
| | image = torch.permute(torch.from_numpy(image), [2, 0, 1]) |
| | image = convert_image_dtype(image) |
| | image = torchvision.transforms.Resize( |
| | [scaled_height, scaled_width], InterpolationMode.BILINEAR, antialias=True |
| | )(image) |
| | image = torch.clip(image, 0.0, 1.0) |
| | image = torch.permute(image, [1, 2, 0]).numpy() |
| | else: |
| | raise NotImplementedError(resize_method) |
| |
|
| | top_pad = (desired_height - scaled_height) // 2 |
| | left_pad = (desired_width - scaled_width) // 2 |
| | padding = [ |
| | [top_pad, desired_height - scaled_height - top_pad], |
| | [left_pad, desired_width - scaled_width - left_pad], |
| | [0, 0] |
| | ] |
| | image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2]) |
| | image = np.pad(image, padding, constant_values=pad_value) |
| | if normalize: |
| | image = normalize_image(image, offset=image_mean, scale=image_std) |
| | return image, image_mask |
| |
|
| |
|
| | def select_tiling(h, w, patch_size, max_num_patches): |
| | """Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size""" |
| | original_size = np.stack([h, w]) |
| | original_res = h * w |
| | tilings = [] |
| | for i in range(1, max_num_patches+1): |
| | for j in range(1, max_num_patches+1): |
| | if i*j <= max_num_patches: |
| | tilings.append((i, j)) |
| | |
| | tilings.sort(key=lambda x: (x[0]*x[1], x[0])) |
| | candidate_tilings = np.array(tilings, dtype=np.int32) |
| | candidate_resolutions = candidate_tilings * patch_size |
| |
|
| | |
| | original_size = np.stack([h, w], dtype=np.float32) |
| | required_scale_d = candidate_resolutions.astype(np.float32) / original_size |
| | required_scale = np.min(required_scale_d, axis=-1, keepdims=True) |
| | if np.all(required_scale < 1): |
| | |
| | ix = np.argmax(required_scale) |
| | else: |
| | |
| | required_scale = np.where(required_scale < 1.0, 10e9, required_scale) |
| | ix = np.argmin(required_scale) |
| | return candidate_tilings[ix] |
| |
|
| |
|
| | class Palmyra_visionImagesKwargs(ImagesKwargs, total=False): |
| | max_crops: Optional[int] |
| | overlap_margins: Optional[List[int]] |
| | base_image_input_size: Optional[List[int]] |
| | image_token_length_w: Optional[int] |
| | image_token_length_h: Optional[int] |
| | image_patch_size: Optional[int] |
| | image_padding_mask: Optional[bool] |
| |
|
| |
|
| | class Palmyra_visionImageProcessor(BaseImageProcessor): |
| | """Preprocess images and multi-model inputs""" |
| |
|
| | def __init__( |
| | self, |
| | max_crops: int = 12, |
| | overlap_margins: List[int] = (4, 4), |
| | base_image_input_size: List[int] = (336, 336), |
| | image_token_length_w: int = 12, |
| | image_token_length_h: int = 12, |
| | image_patch_size: int = 14, |
| | image_padding_mask: bool = True, |
| | do_normalize: bool = True, |
| | image_mean: Optional[Union[float, List[float]]] = None, |
| | image_std: Optional[Union[float, List[float]]] = None, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.max_crops = max_crops |
| | self.overlap_margins = overlap_margins |
| | self.base_image_input_size = base_image_input_size |
| | self.image_token_length_w = image_token_length_w |
| | self.image_token_length_h = image_token_length_h |
| | self.image_patch_size = image_patch_size |
| | self.image_padding_mask = image_padding_mask |
| | self.do_normalize = do_normalize |
| | self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN |
| | self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD |
| |
|
| | def image_to_patches_and_tokens( |
| | self, |
| | image: ImageInput, |
| | image_patch_token_id: int, |
| | image_col_token_id: int, |
| | image_start_token_id: int, |
| | image_end_token_id: int, |
| | max_crops: Optional[int] = None, |
| | overlap_margins: Optional[List[int]] = None, |
| | base_image_input_size: Optional[Union[int, List[int]]] = None, |
| | image_token_length_w: Optional[int] = None, |
| | image_token_length_h: Optional[int] = None, |
| | image_patch_size: Optional[int] = None, |
| | ): |
| | if isinstance(base_image_input_size, int): |
| | base_image_input_size = (base_image_input_size, base_image_input_size) |
| |
|
| | base_image_input_d = image_patch_size |
| | tokens_per_image = image_token_length_w * image_token_length_h |
| | image_base_patch_w = base_image_input_size[1] // base_image_input_d |
| | image_base_patch_h = base_image_input_size[0] // base_image_input_d |
| |
|
| | original_image_h, original_image_w = image.shape[:2] |
| | crop_size = base_image_input_size[0] |
| |
|
| | |
| | left_margin, right_margin = overlap_margins |
| | |
| | assert left_margin % 2 == 0 |
| | total_margin_pixels = base_image_input_d*(right_margin + left_margin) |
| | crop_patches = base_image_input_size[0] // base_image_input_d |
| | crop_window_patches = crop_patches - (right_margin + left_margin) |
| | crop_window_size = crop_window_patches * base_image_input_d |
| | tiling = select_tiling( |
| | original_image_h - total_margin_pixels, |
| | original_image_w - total_margin_pixels, |
| | crop_window_size, |
| | max_crops |
| | ) |
| | src, img_mask = resize_and_pad( |
| | image, |
| | [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels] |
| | ) |
| |
|
| | |
| | |
| | n_crops = tiling[0] * tiling[1] |
| | patches_arr = [] |
| | mask_arr = [] |
| | patch_ordering_arr = [] |
| |
|
| | |
| | |
| | assert (crop_patches+1)//2 == image_token_length_h |
| | assert (crop_patches+1)//2 == image_token_length_w |
| | on = 0 |
| | on_patch = 0 |
| | for i in range(tiling[0]): |
| | y0 = i*crop_window_size |
| | if i == 0: |
| | crop_y0 = 0 |
| | else: |
| | crop_y0 = left_margin // 2 |
| |
|
| | crop_h = image_base_patch_h - (right_margin + left_margin) |
| | if i == 0: |
| | crop_h += left_margin |
| | if i == (tiling[0]-1): |
| | crop_h += right_margin |
| | for j in range(tiling[1]): |
| | x0 = j*crop_window_size |
| | if j == 0: |
| | crop_x0 = 0 |
| | else: |
| | crop_x0 = left_margin // 2 |
| |
|
| | crop_w = image_base_patch_w - (right_margin + left_margin) |
| | if j == 0: |
| | crop_w += left_margin |
| | if j == (tiling[1]-1): |
| | crop_w += right_margin |
| |
|
| | pooled_w = (crop_w + 1) // 2 |
| | pooled_h = (crop_h + 1) // 2 |
| | patch_ordering_arr.append( |
| | pad_to_bounding_box( |
| | np.reshape(np.arange(on, on+pooled_h*pooled_w, dtype=np.int32), (pooled_h, pooled_w, 1)), |
| | crop_y0, crop_x0, image_token_length_h, image_token_length_w, value=-1 |
| | )[:, :, 0] |
| | ) |
| | patches_arr.append(src[y0:y0+crop_size, x0:x0+crop_size]) |
| | mask_arr.append(img_mask[y0:y0+crop_size, x0:x0+crop_size]) |
| |
|
| | on += pooled_h*pooled_w |
| | on_patch += 1 |
| | patches = np.stack(patches_arr) |
| | patch_ordering = np.stack(patch_ordering_arr) |
| | img_mask = np.stack(mask_arr) |
| |
|
| | |
| | image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1] |
| | patches = einops.rearrange( |
| | patches, 'p (h dh) (w dw) c -> p (h w) (dh dw c)', |
| | dh=base_image_input_d, |
| | dw=base_image_input_d, |
| | h=image_base_patch_h, |
| | w=image_base_patch_w |
| | ) |
| | img_mask = einops.rearrange( |
| | img_mask, 'p (h dh) (w dw) -> p (h w) (dh dw)', |
| | dh=base_image_input_d, |
| | dw=base_image_input_d, |
| | h=image_base_patch_h, |
| | w=image_base_patch_w |
| | ) |
| |
|
| | img_mask = img_mask.astype(np.float32).mean(axis=-1) |
| | patch_ordering = np.reshape(patch_ordering, [-1]) |
| | valid = patch_ordering >= 0 |
| |
|
| | |
| | patch_ordering_rh = np.reshape( |
| | patch_ordering, |
| | [tiling[0], tiling[1], image_token_length_h, image_token_length_w] |
| | ) |
| | patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3]) |
| | patch_ordering_rh = np.reshape(patch_ordering_rh, [-1]) |
| |
|
| | |
| | |
| | patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0] |
| |
|
| | |
| | h = tiling[0] * crop_window_patches + (right_margin+left_margin) |
| | w = tiling[1] * crop_window_patches + (right_margin+left_margin) |
| | per_row = np.full( |
| | ((w+1)//2,), |
| | image_patch_token_id, |
| | ) |
| | per_row = np.concatenate([per_row, [image_col_token_id]], 0) |
| |
|
| | joint = np.tile(per_row, [(h+1)//2]) |
| | joint = [ |
| | [image_start_token_id], |
| | joint, |
| | [image_end_token_id] |
| | ] |
| |
|
| | |
| | resized, _ = resize_and_pad(image, base_image_input_size) |
| | resized = einops.rearrange( |
| | resized, '(h dh) (w dw) c -> (h w) (dh dw c)', |
| | dh=base_image_input_d, |
| | dw=base_image_input_d, |
| | h=image_base_patch_h, |
| | w=image_base_patch_w |
| | ) |
| | patches = np.concatenate([np.expand_dims(resized, 0), patches], 0) |
| |
|
| | |
| | patch_ordering = np.where( |
| | patch_ordering >= 0, |
| | patch_ordering + tokens_per_image, |
| | -1 |
| | ) |
| | patch_ordering = np.concatenate([np.arange(0, tokens_per_image), patch_ordering], 0) |
| | per_row = np.full( |
| | (image_token_length_w,), |
| | image_patch_token_id, |
| | ) |
| | per_row = np.concatenate([per_row, [image_col_token_id]], 0) |
| | extra_tokens = np.tile(per_row, [image_token_length_h]) |
| | joint = [ |
| | [image_start_token_id], |
| | extra_tokens, |
| | [image_end_token_id], |
| | ] + joint |
| |
|
| | joint = np.concatenate(joint, 0) |
| | img_mask = np.pad(img_mask, [[0, 1], [0, 0]], constant_values=-1) |
| | return patches, joint, patch_ordering, img_mask |
| |
|
| | def build_image_input_idx( |
| | self, |
| | image_tokens: np.ndarray, |
| | patch_order: np.ndarray, |
| | image_patch_token_id: int, |
| | no_image: Optional[bool] = None, |
| | image_token_length_w: Optional[int] = None, |
| | image_token_length_h: Optional[int] = None, |
| | ): |
| | """Converts `patch_order` into a mapping of token_id -> patch_id""" |
| |
|
| | tokens_per_image = image_token_length_w * image_token_length_h |
| | if no_image is not None and no_image: |
| | return np.zeros((0, tokens_per_image), np.int32) |
| |
|
| | |
| | image_input_idx = image_tokens == image_patch_token_id |
| | image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32) |
| |
|
| | if patch_order is not None: |
| | n_tokens = image_input_idx.shape[0] |
| | patch_order = np.reshape(patch_order, [-1]) |
| | n_patches = patch_order.shape[0] |
| |
|
| | valid = patch_order >= 0 |
| | n_valid_patches = valid.sum() |
| | assert len(image_input_idx) == n_valid_patches |
| |
|
| | sorted_patch_ixs = np.zeros([n_tokens], np.int32) |
| | sorted_patch_ixs[patch_order[valid]] = np.arange(n_valid_patches, dtype=np.int32) |
| |
|
| | |
| | sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1) |
| | sorted_patch_ixs_ex[valid] = sorted_patch_ixs |
| |
|
| | |
| | valid = (sorted_patch_ixs_ex >= 0).astype(np.int32) |
| | image_input_idx = image_input_idx[sorted_patch_ixs_ex*valid] |
| | image_input_idx = image_input_idx*valid - 100*(1 - valid) |
| | image_input_idx = np.reshape(image_input_idx, [-1, tokens_per_image]) |
| | return image_input_idx |
| |
|
| | def preprocess( |
| | self, |
| | image: np.ndarray, |
| | image_patch_token_id: int, |
| | image_col_token_id: int, |
| | image_start_token_id: int, |
| | image_end_token_id: int, |
| | max_crops: Optional[int] = None, |
| | overlap_margins: Optional[List[int]] = None, |
| | base_image_input_size: Optional[Union[int, List[int]]] = None, |
| | image_token_length_w: Optional[int] = None, |
| | image_token_length_h: Optional[int] = None, |
| | image_patch_size: Optional[int] = None, |
| | **kwargs, |
| | ): |
| | """Preprocesses an image |
| | |
| | Returns: |
| | crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might |
| | change between images but the other dimension are fixed |
| | tokens: (n_tokens,) int32 tokens, pad tokens indicate where to insert the |
| | patch features, might include other special tokens as well |
| | image_idx: (n_crops, n_patches) index in `tokens` to put the patch features from the |
| | crops after pooling, negative values indicates patches features to exclude |
| | padding_mask: (n_crops, n_patches) what percent of each crop is padding, can be None |
| | if the image mask is not being used. |
| | """ |
| |
|
| | max_crops = max_crops or self.max_crops |
| | overlap_margins = overlap_margins or self.overlap_margins |
| | base_image_input_size = base_image_input_size or self.base_image_input_size |
| | image_token_length_w = image_token_length_w or self.image_token_length_w |
| | image_token_length_h = image_token_length_h or self.image_token_length_h |
| | image_patch_size = image_patch_size or self.image_patch_size |
| |
|
| | crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens( |
| | image, |
| | image_patch_token_id, |
| | image_col_token_id, |
| | image_start_token_id, |
| | image_end_token_id, |
| | max_crops, |
| | overlap_margins, |
| | base_image_input_size, |
| | image_token_length_w, |
| | image_token_length_h, |
| | image_patch_size, |
| | ) |
| | patch_idx = self.build_image_input_idx( |
| | image_tokens, |
| | patch_ordering, |
| | image_patch_token_id, |
| | image_token_length_w=image_token_length_w, |
| | image_token_length_h=image_token_length_h, |
| | ) |
| | return crops, image_tokens, patch_idx, img_mask |
| |
|
| | def multimodal_preprocess( |
| | self, |
| | images: np.ndarray, |
| | tokens: List[int], |
| | image_idx: np.ndarray, |
| | sequence_length: int, |
| | image_patch_token_id: int, |
| | image_col_token_id: int, |
| | image_start_token_id: int, |
| | image_end_token_id: int, |
| | **kwargs, |
| | ): |
| | """Merge images and text tokens into multi-modal features for the model |
| | |
| | :param images: images to use as input |
| | :param tokens: input text tokens |
| | :param image_idx: where to insert the images into `tokens` |
| | :params image_patch_token_id: id to use of tokens that will contain image features |
| | :params image_col_token_id: token id for image column special tokens |
| | :params image_start_token_id: token id for image start special tokens |
| | :params image_end_token_id: token id for image end special tokens |
| | :params kwargs: override preprocessor default args |
| | """ |
| | max_total_crops = kwargs.get("max_crops") or self.max_crops |
| | image_token_length_w = kwargs.get("image_token_length_w") or self.image_token_length_w |
| | image_token_length_h = kwargs.get("image_token_length_h") or self.image_token_length_h |
| | image_patch_size = kwargs.get("image_patch_size") or self.image_patch_size |
| | base_image_input_size = kwargs.get("base_image_input_size") or self.base_image_input_size |
| | image_num_patch = ( |
| | base_image_input_size[0] // image_patch_size, |
| | base_image_input_size[1] // image_patch_size, |
| | ) |
| | image_padding_mask = kwargs.get("image_padding_mask") or self.image_padding_mask |
| |
|
| | tokens_per_image = image_token_length_w * image_token_length_h |
| | n_pixels = image_patch_size * image_patch_size * 3 |
| | n_patches = image_num_patch[0] * image_num_patch[1] |
| |
|
| | if images is None: |
| | return { |
| | "input_ids": tokens, |
| | } |
| | else: |
| | n = len(images) |
| | all_crops = [] |
| | all_image_idx = [] |
| | out_tokens = [] |
| | all_crop_masks = [] |
| |
|
| | for ix in range(n): |
| | token_ix = image_idx[ix] |
| | crops, image_tokens, patch_idx, img_mask = self.preprocess( |
| | images[ix], |
| | image_patch_token_id, |
| | image_col_token_id, |
| | image_start_token_id, |
| | image_end_token_id, |
| | **kwargs, |
| | ) |
| |
|
| | if token_ix == -1: |
| | start = 0 |
| | token_ix = 0 |
| | end = 0 |
| | else: |
| | start = 0 if ix == 0 else image_idx[ix-1] + 1 |
| | end = token_ix + 1 |
| |
|
| | all_image_idx.append(patch_idx + token_ix) |
| | all_crops.append(crops) |
| | out_tokens.append(tokens[start:token_ix]) |
| | out_tokens.append(image_tokens) |
| | if ix == (n - 1): |
| | out_tokens.append(tokens[end:]) |
| | if image_padding_mask: |
| | all_crop_masks.append(img_mask) |
| |
|
| | input_ids = np.concatenate(out_tokens, 0) |
| | images = np.concatenate(all_crops, 0) |
| | image_input_idx = np.concatenate(all_image_idx, 0) |
| | if image_padding_mask: |
| | image_masks = np.concatenate(all_crop_masks, 0) |
| | else: |
| | image_masks = None |
| |
|
| | out = { |
| | "input_ids": input_ids, |
| | "images": images, |
| | "image_input_idx": image_input_idx |
| | } |
| | if image_masks is not None: |
| | out["image_masks"] = image_masks |
| | return out |
| |
|
| |
|
| | Palmyra_visionImageProcessor.register_for_auto_class() |