Source code for cleanvision.utils.viz_manager

from typing import List, Tuple, Optional

import math
import matplotlib.axes
import matplotlib.pyplot as plt
from PIL import Image


[docs]class VizManager:
[docs] @staticmethod def individual_images( images: List[Image.Image], titles: List[str], ncols: int, cell_size: Tuple[int, int], additional_info: Optional[List[str]] = None, ) -> None: """Plots a list of images in a grid.""" plot_image_grid(images, titles, ncols, cell_size, additional_info)
[docs] @staticmethod def image_sets( image_sets: List[List[Image.Image]], title_sets: List[List[str]], ncols: int, cell_size: Tuple[int, int], ) -> None: for i, s in enumerate(image_sets): print(f"Set: {i}") plot_image_grid(s, title_sets[i], ncols, cell_size)
[docs]def set_image_on_axes(image: Image.Image, ax: matplotlib.axes.Axes, title: str) -> None: cmap = "gray" if image.mode == "L" else None ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ax.set_title(title, fontsize=7) ax.imshow(image, cmap=cmap, vmin=0, vmax=255)
[docs]def plot_image_grid( images: List[Image.Image], titles: List[str], ncols: int, cell_size: Tuple[int, int], additional_info: Optional[List[str]] = None, ) -> None: nrows = math.ceil(len(images) / ncols) ncols = min(ncols, len(images)) fig, axes = plt.subplots( nrows, ncols, figsize=(cell_size[0] * ncols, cell_size[1] * nrows) ) """Converts font size of 7 into inches""" CHARACTER_SIZE_INCHES = 7 * (1 / 72) chars_allowed = math.ceil(cell_size[0] / CHARACTER_SIZE_INCHES) - 4 if len(images) > 1: k1 = 1 while k1 <= chars_allowed and titles[0][:k1] == titles[1][:k1]: k1 += 1 k2 = 1 while ( k2 <= chars_allowed and titles[0][(len(titles[0]) - k2) :] == titles[1][(len(titles[1]) - k2) :] ): k2 += 1 if k1 > k2: truncate_from_front = True else: truncate_from_front = False for i in range(len(images)): title_width = len(titles[i]) * CHARACTER_SIZE_INCHES if title_width >= cell_size[0]: titles[i] = ( ("..." + titles[i][len(titles[i]) - chars_allowed :]) if truncate_from_front else (titles[i][:chars_allowed] + "...") ) if additional_info is not None: for i in range(len(images)): titles[i] = f"{titles[i]}\n{additional_info[i]}" if nrows > 1: idx = 0 for i in range(nrows): for j in range(ncols): idx = i * ncols + j if idx >= len(images): axes[i, j].axis("off") continue set_image_on_axes(images[idx], axes[i, j], titles[idx]) if idx >= len(images): break elif ncols > 1: for i in range(min(ncols, len(images))): set_image_on_axes(images[i], axes[i], titles[i]) else: set_image_on_axes(images[0], axes, titles[0]) plt.show()