| | import os |
| | from pathlib import Path |
| | import cv2 |
| | from PIL import Image |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from sklearn.decomposition import PCA |
| | import imagehash |
| | from typing import Callable |
| | from datetime import datetime as dt |
| | from abc import ABC, abstractmethod |
| |
|
| | _DATASET_AVG_MEAN = 129.38489987766278 |
| | _DATASET_AVG_STD = 54.084109207654805 |
| |
|
| |
|
| | def save_to_file(location: str = './extracted_paths.txt') -> Callable: |
| | def outer_wrapper(fn: Callable) -> Callable: |
| | def inner_wrapper(*args, **kwargs): |
| | paths: list[str] = fn(*args, **kwargs) |
| | if kwargs.get('to_file'): |
| | with open(location, 'a') as file: |
| | file.write('\nFiles to remove [TIMESTAMP {}]:\n'.format(dt.now().strftime('%Y%m%d%H%M%S'))) |
| | for p in paths: |
| | file.write(f'{p}\n') |
| | return paths |
| | return inner_wrapper |
| | return outer_wrapper |
| |
|
| |
|
| | def visualize(show_limit: int = -1) -> Callable: |
| | def outer_wrapper(fn: Callable) -> Callable: |
| | def inner_wrapper(*args, **kwargs): |
| | paths: list[str] = fn(*args, **kwargs) |
| | if kwargs.get('visualize_'): |
| | if show_limit != -1: |
| | paths = paths[:show_limit] |
| |
|
| | num_cols = 8 |
| | num_rows = len(paths) // num_cols + 1 |
| |
|
| | fig = plt.figure(figsize=(8, 8)) |
| | for i, path in enumerate(paths, start=1): |
| | plt.subplot(num_rows, num_cols, i) |
| | plt.imshow(Image.open(path), cmap='gray') |
| | plt.title(f'{Path(path).parent.name}', fontsize=7) |
| | plt.axis('off') |
| | fig.tight_layout() |
| | plt.tight_layout() |
| | fig.subplots_adjust(hspace=0.6, top=0.97) |
| | plt.show() |
| | return paths |
| | return inner_wrapper |
| | return outer_wrapper |
| |
|
| |
|
| | class DataFilter(ABC): |
| | def __init__(self): |
| | self.paths = [] |
| |
|
| | @abstractmethod |
| | def extract(self, data_dir: str | Path, visualize_: bool, to_file: bool) -> list[str]: |
| | pass |
| |
|
| | @abstractmethod |
| | def clear(self) -> None: |
| | pass |
| |
|
| | @abstractmethod |
| | def filter(self) -> bool: |
| | pass |
| |
|
| | @staticmethod |
| | def _load_data(dir_: str) -> tuple[list[np.ndarray], list[str], list[str]]: |
| | images = [] |
| | class_names = [] |
| | paths = [] |
| |
|
| | for path in Path(dir_).glob('**/*.jpg'): |
| | label = path.parent.name |
| | image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE) |
| | if image is not None and label is not None: |
| | images.append(np.array(image)) |
| | class_names.append(label) |
| | paths.append(str(path)) |
| |
|
| | return images, class_names, paths |
| |
|
| |
|
| | class DataFilterCompose(DataFilter): |
| | def __init__(self, components: list[DataFilter]): |
| | super().__init__() |
| | self.components = components |
| |
|
| | @staticmethod |
| | def build(components: list[DataFilter]) -> DataFilter: |
| | return DataFilterCompose(components) |
| |
|
| | def extract(self, data_dir: str | Path, visualize_: bool, to_file: bool) -> list[str]: |
| | extracted_paths = [] |
| | for component in self.components: |
| | cur_extracted_paths = component.extract(data_dir, |
| | visualize_=visualize_, |
| | to_file=to_file) |
| | extracted_paths += cur_extracted_paths |
| | self.paths += extracted_paths |
| | return extracted_paths |
| |
|
| | def clear(self) -> None: |
| | for component in self.components: |
| | component.clear() |
| |
|
| | def filter(self): |
| | for component in self.components: |
| | component.filter() |
| |
|
| | def add_component(self, component: DataFilter, position: int) -> None: |
| | self.components.insert(position, component) |
| |
|
| | def rm_component(self, position: int) -> None: |
| | self.components.pop(position) |
| |
|
| |
|
| | class StatsDataFilter(DataFilter): |
| | _OPTIM_MEAN_THRESH = 107 |
| | _OPTIM_STD_THRESH = 51 |
| |
|
| | def __init__(self, data_avg_mean: float = None, data_avg_std: float = None, console_output: bool = False): |
| | super().__init__() |
| | self.data_avg_mean = data_avg_mean |
| | self.data_avg_std = data_avg_std |
| | self.console_output = console_output |
| |
|
| | @visualize() |
| | @save_to_file() |
| | def extract(self, data_dir: str | Path, visualize_: bool, to_file: bool) -> list[str]: |
| | if self.data_avg_mean is None or self.data_avg_std is None: |
| | stats = self._compute_dataset_stats(data_dir) |
| | self.data_avg_mean = stats['avg_mean'] |
| | self.data_avg_std = stats['avg_std'] |
| |
|
| | extracted_paths = self._extract_outliers_by_stats( |
| | data_dir, |
| | self.data_avg_mean, |
| | self.data_avg_std, |
| | StatsDataFilter._OPTIM_MEAN_THRESH, |
| | StatsDataFilter._OPTIM_STD_THRESH, |
| | self.console_output) |
| |
|
| | self.paths += extracted_paths |
| | return extracted_paths |
| |
|
| | def clear(self) -> None: |
| | self.paths.clear() |
| | if self.console_output: |
| | print(f'[{self.__class__.__name__}]: Paths memory cleared.') |
| |
|
| | def filter(self) -> bool: |
| | has_error = False |
| | for path in self.paths: |
| | if not Path(path).exists(): |
| | has_error = True |
| | continue |
| | os.remove(path) |
| | if self.console_output: |
| | print(f'[{self.__class__.__name__}]: Removed {path}') |
| | return has_error |
| |
|
| | @classmethod |
| | def _extract_outliers_by_stats(cls, |
| | data_root: str | Path, |
| | dataset_avg_mean: float, |
| | dataset_avg_std: float, |
| | mean_thresh: float, |
| | std_thresh: float, |
| | console_output: bool = False) -> list[str]: |
| | outlier_paths = [] |
| | count = 0 |
| | _, _, paths = StatsDataFilter._load_data(data_root) |
| | total_len = len(paths) |
| | for path in iter(paths): |
| | img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) |
| | if abs(dataset_avg_mean - np.mean(img)) > mean_thresh or abs( |
| | dataset_avg_std - np.std(img)) > std_thresh: |
| | outlier_paths.append(path) |
| | if console_output: |
| | count += 1 |
| | print(f'[{cls.__name__}]: Computed {count}/{total_len} images ({count / total_len * 100:.2f}%)') |
| | return outlier_paths |
| |
|
| | @staticmethod |
| | def _compute_dataset_stats(data_dir: str) -> dict[str, float]: |
| | img_paths = list(Path(data_dir).glob('**/*.jpg')) |
| | num_images = len(img_paths) |
| | mean_sum = 0 |
| | std_sum = 0 |
| |
|
| | for img_path in img_paths: |
| | img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE) |
| | img_mean = np.mean(img) |
| | img_std = np.std(img) |
| | mean_sum += img_mean |
| | std_sum += img_std |
| |
|
| | avg_mean = mean_sum / num_images |
| | avg_std = std_sum / num_images |
| | stats_dict = { |
| | 'avg_mean': avg_mean, |
| | 'avg_std': avg_std, |
| | } |
| | return stats_dict |
| |
|
| |
|
| | class PcaDataFilter(DataFilter): |
| | _OPTIM_NUM_COMPONENTS = 4 |
| | _OPTIM_ERROR_THRESH = 87 |
| |
|
| | def __init__(self, console_output: bool = False): |
| | super().__init__() |
| | self.console_output = console_output |
| |
|
| | @visualize() |
| | @save_to_file() |
| | def extract(self, data_dir: str | Path, visualize_: bool, to_file: bool) -> list[str]: |
| | extracted_paths = self._extract_outliers_with_pca(data_dir) |
| | self.paths += extracted_paths |
| | return extracted_paths |
| |
|
| | def clear(self) -> None: |
| | self.paths.clear() |
| | if self.console_output: |
| | print(f'[{self.__class__.__name__}]: Paths memory cleared.') |
| |
|
| | def filter(self) -> bool: |
| | has_error = False |
| | for path in self.paths: |
| | if not Path(path).exists(): |
| | has_error = True |
| | continue |
| | os.remove(path) |
| | if self.console_output: |
| | print(f'[{self.__class__.__name__}]: Removed {path}') |
| | return has_error |
| |
|
| | @staticmethod |
| | def _extract_outliers_with_pca(dir_: str | Path) -> list[str]: |
| | x, _, img_paths = PcaDataFilter._load_data(dir_) |
| | x = np.array(x) |
| | num_samples, height, width = x.shape |
| | X_flattened = x.reshape(num_samples, height * width) |
| |
|
| | outlier_indices = PcaDataFilter._detect_outliers_with_pca(X_flattened, |
| | PcaDataFilter._OPTIM_NUM_COMPONENTS, |
| | PcaDataFilter._OPTIM_ERROR_THRESH) |
| | img_paths_to_remove = [img_paths[i] for i in outlier_indices.tolist()] |
| | return img_paths_to_remove |
| |
|
| | @staticmethod |
| | def _detect_outliers_with_pca(orig_data: np.ndarray, |
| | num_components: int, |
| | error_thresh: float) -> np.ndarray: |
| | pca = PCA(n_components=num_components) |
| | X_reduced = pca.fit_transform(orig_data) |
| |
|
| | X_reconstructed = pca.inverse_transform(X_reduced) |
| | reconstruction_errors = np.sqrt(np.mean((orig_data - X_reconstructed) ** 2, axis=1)) |
| |
|
| | outlier_indices = np.where(reconstruction_errors > error_thresh)[0] |
| | return outlier_indices |
| |
|
| |
|
| | class DHashDuplicateFilter(DataFilter): |
| | def __init__(self, hash_size: int = 8, console_output: bool = False): |
| | super().__init__() |
| | self.hash_size = hash_size |
| | self.console_output = console_output |
| |
|
| | @visualize(60) |
| | @save_to_file() |
| | def extract(self, data_dir: str | Path, visualize_: bool, to_file: bool) -> list[str]: |
| | _, _, paths = self._load_data(data_dir) |
| | hashes = set() |
| | duplicates = [] |
| |
|
| | for path in paths: |
| | hash_ = imagehash.dhash(Image.open(path), self.hash_size) |
| | if hash_ in hashes: |
| | duplicates.append(path) |
| | if self.console_output: |
| | print(f'[{self.__class__.__name__}]: Duplicate found at {path}') |
| | else: |
| | hashes.add(hash_) |
| |
|
| | self.paths += duplicates |
| | return duplicates |
| |
|
| | def clear(self) -> None: |
| | self.paths.clear() |
| | if self.console_output: |
| | print(f'[{self.__class__.__name__}]: Paths memory cleared.') |
| |
|
| | def filter(self) -> bool: |
| | has_error = False |
| | for path in self.paths: |
| | if not Path(path).exists(): |
| | has_error = True |
| | continue |
| | os.remove(path) |
| | if self.console_output: |
| | print(f'[{self.__class__.__name__}]: Removed {path}') |
| | return has_error |
| |
|
| |
|
| | if __name__ == '__main__': |
| | dataset_dir = Path('./dataset') |
| |
|
| | stats_filter = StatsDataFilter(_DATASET_AVG_MEAN, _DATASET_AVG_STD, True) |
| | pca_filter = PcaDataFilter(console_output=True) |
| | duplicate_filter = DHashDuplicateFilter(console_output=True) |
| |
|
| | compose = DataFilterCompose.build([ |
| | stats_filter, |
| | pca_filter, |
| | duplicate_filter |
| | ]) |
| |
|
| | |
| | |
| | stats_filter.extract(dataset_dir, visualize_=False, to_file=False) |
| |
|
| | |
| | |
| |
|