Source code for resomapper.processing.masking

import cv2

# import re
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from scipy.ndimage import rotate

from resomapper.core.misc import auto_innited_logger as lggr
import resomapper.core.utils as ut


import warnings

warnings.filterwarnings("ignore")


[docs] def no_mask(nifti_file_path, mask_nii_output_path): """ Create a binary mask from a NIfTI file and save it to a specified output path. This function loads a NIfTI file, processes its data to create a mask of ones, and saves the resulting mask as a new NIfTI file. If the input data is 4D, only the first volume is used for the mask. Args: nifti_file_path (str): The file path to the input NIfTI file. mask_nii_output_path (str): The file path where the output mask NIfTI file will be saved. Returns: None Raises: FileNotFoundError: If the input NIfTI file does not exist. nib.filebasedimages.ImageFileError: If the input file is not a valid NIfTI file. """ study_nii = nib.load(nifti_file_path) nii_data = study_nii.get_fdata() if len(np.shape(nii_data)) == 4: nii_data = nii_data[:, :, :, 0] ones_mask = np.ones_like(nii_data) nii_ima = nib.Nifti1Image( # TODO: CHECK WHICH TYPE IS BETTER # ones_mask, study_nii.affine, study_nii.header ones_mask.astype(np.float32), study_nii.affine, study_nii.header ) nib.save(nii_ima, mask_nii_output_path)
[docs] def hist_strip_mask(): # TODO pass
[docs] def manual_mask(nifti_file_path, mask_nii_output_path, ask_if_repeat=False): """ Create a manual mask for a NIfTI file and save it to a specified output path. This function allows the user to interactively create a mask for each slice of a 3D or 4D NIfTI image. The user can draw the mask outline using mouse clicks, and the function will save the resulting mask as a new NIfTI file. Args: nifti_file_path (str): The file path to the input NIfTI file. mask_nii_output_path (str): The file path where the output mask NIfTI file will be saved. ask_if_repeat (bool): If True, prompts the user to confirm the mask before saving. Defaults to False. Returns: None Raises: FileNotFoundError: If the input NIfTI file does not exist. nib.filebasedimages.ImageFileError: If the input file is not a valid NIfTI file. """ print( f"\n{lggr.ask}Please create the mask for this study in the pop-up window\n" "- Left click: create lines between clicks to draw the mask outline\n" "- Rigth click: close the outline joining the first and last points and skip to next slice\n" ) study_nii = nib.load(nifti_file_path) nii_data = study_nii.get_fdata() if len(np.shape(nii_data)) == 4: nii_data = nii_data[:, :, :, 0] # TODO: what happens with 1-slice images # TODO: check if this needs to be done always or only in DTI nii_data = min_max_normalization(nii_data) * 255 x_dim, y_dim = np.shape(nii_data)[:2] # get real dims images = prepare_vol(nii_data) # list of lists (one list per slice) for storing masks vertexes refPT = [[] for _ in range(len(images))] global counter counter = 0 for ima in images: refPT = itera(ima, refPT) # TODO: extract to separate function # shows user their selection ans ask if it is ok if ask_if_repeat: n_slc = np.shape(images)[0] rows = 2 cols = int(np.ceil(n_slc / rows)) fig, ax = plt.subplots(rows, cols, figsize=(10, 7)) ax = ax.flatten() for i in range(n_slc): poly = np.array((refPT[i]), np.int32) img_copy = np.copy(images[i]) img_poly = cv2.polylines( img_copy, [poly], True, (255, 255, 255), thickness=3 ) ax[i].imshow(img_poly, cmap="gray") ax[i].set_title(f"Slice {i+1}") plt.tight_layout() for i in range(len(ax)): ax[i].axis("off") plt.show(block=False) correct_selection = ut.ask_user( "Is the created mask ok? If not, you can repeat it." ) plt.close() else: correct_selection = True if correct_selection: # creates niimask file n_slc = np.shape(images)[0] masks = [] for i in range(n_slc): poly = np.array((refPT[i]), np.int32) background = np.zeros(images[i].shape) mask = cv2.fillPoly(background, [poly], 1) mask = cv2.resize(mask, (x_dim, y_dim), interpolation=cv2.INTER_NEAREST) mask = mask.astype(np.int32) masks.append(mask) # cv2.destroyAllWindows() masks = np.asarray(masks) masks = masks.transpose(2, 1, 0) nii_ima = nib.Nifti1Image( masks.astype(np.float32), study_nii.affine, study_nii.header ) nib.save(nii_ima, mask_nii_output_path) # TODO??? # return mask_nii_output_path else: manual_mask(nifti_file_path, mask_nii_output_path, ask_if_repeat=ask_if_repeat)
[docs] def check_mask_shape(img, mask): """ Verify that the shape of the mask matches the shape of the image. This function checks if the dimensions of the provided mask are compatible with the dimensions of the input image. It raises an error if the shapes do not match, ensuring that the mask can be correctly applied to the image. If the input arguments are file paths, the 'load_nifti' function is used to load the respective NIfTI files. Args: img (numpy.ndarray or str): Input image array or path to the image file. mask (numpy.ndarray or str): The mask array or path to be checked against the image. Returns: bool: True if mask and image match dimensions, False if not. """ try: if not isinstance(img, np.ndarray): img_data = nib.load(img) img = img_data.get_fdata() if not isinstance(mask, np.ndarray): mask_data = nib.load(mask) mask = mask_data.get_fdata() except nib.filebasedimages.ImageFileError: print(f"\n{lggr.error}The file you have provided is not a NiFTI image.") return False if img.shape[:3] != mask.shape[:3]: print( f"\n{lggr.error}Mask and image have different shapes. " "Please check that you have selected a suitable mask for this study.\n\n" "More info:\n" f"- Image: size {img.shape[0]}x{img.shape[1]}, " f"{img.shape[2]} slices.\n" f"- Mask: size {mask.shape[0]}x{mask.shape[1]}, " f"{mask.shape[2]} slices." ) return False else: return True
#### Functions for manual masking mode ####
[docs] def prepare_vol(vol_3d): """Some modifications on the volume: 270 degrees rotation and image flip. Args: vol_3d (ndarray): Input image. Returns: list: Transformed image ready for visualization. """ n_slc = vol_3d.shape[2] # numer of slices vol_prepared = [] rot_degrees = 270 for j in range(n_slc): ima = vol_3d[:, :, j] ima = rotate(ima, rot_degrees) ima = np.flip(ima, axis=1) ima = ima.astype(np.uint8) # change only for better visualization purposes scale_percent = 440 width = int(ima.shape[1] * scale_percent / 100) height = int(ima.shape[0] * scale_percent / 100) dim = (width, height) ima = cv2.resize(ima, dim, interpolation=cv2.INTER_AREA) vol_prepared.append(ima) return vol_prepared
[docs] def min_max_normalization(img): """Apply min-max normalization to the input image. Creates a copy of the input image and computes the minimum and maximum values. The image is normalized using the formula (img - min_val) / (max_val - min_val). Args: img (numpy.ndarray): Input image array. Returns: numpy.ndarray: Normalized image array. """ new_img = img.copy() new_img = new_img.astype(np.float32) min_val = np.min(new_img) max_val = np.max(new_img) new_img = (np.asarray(new_img).astype(np.float32) - min_val) / (max_val - min_val) return new_img
[docs] def click(event, x, y, flags, param): """Event handler function for mouse clicks. Args: event: The type of mouse event (left button down, right button down, etc.). x: The x-coordinate of the mouse click position. y: The y-coordinate of the mouse click position. flags: Additional flags associated with the mouse event. param: Additional parameters associated with the mouse event. The function handles mouse click events and updates the global variables 'status' and 'counter'. If the event is a left button down click, the function appends the coordinates of the click position to the list specified by 'param[counter]'. If the event is a right button down click, the function performs the same action as the left click and also sets the 'status' variable to 0, indicating that the click operation is finished. Note: - The global variables 'status' and 'counter' are used and updated within this function. - The 'param' argument is expected to be a list or an array-like object. Example: mouse_params = [[] for _ in range(5)] # Create list to store click positions cv2.setMouseCallback("window", click, param=mouse_params) """ global status global counter if event == cv2.EVENT_LBUTTONDOWN: # left click click_pos = [(x, y)] param[counter].append(click_pos) elif event == cv2.EVENT_RBUTTONDOWN: # right click click_pos = [(x, y)] param[counter].append(click_pos) status = 0 # finish
[docs] def itera(ima, refPT): """Iteratively display slices for masking. Left click adds a line and right click closes the polygon. Next slice will be showed after right click. Args: ima (numpy.ndarray): Input image array. refPT (list): List to store the masked vertices for each slice. The 'click' event handler is used to handle mouse events and update the 'refPT' list with the coordinates of the drawn lines. The function continues to display and process slices until all slices have been processed or until the 'c' key or a right-click event is detected. At that point, the function returns the updated 'refPT' list. Note: - The global variables 'counter' and 'status' are used and updated within this function. - The 'click' event handler is set using 'cv2.setMouseCallback' with the 'refPT' argument. Returns: list: Updated 'refPT' list with the masked vertices for each slice. Example: image = np.zeros((256, 256, 3), dtype=np.uint8) # Create a blank image ref_points = [[] for _ in range(10)] # Create list to store masked vertices masker = Mask(study_path) masked_vertices = masker.itera(image, ref_points) """ global counter global status status = 1 cv2.namedWindow("Mask_drawing") # creates a new window cv2.setMouseCallback("Mask_drawing", click, refPT) while True: if refPT[counter] == []: # shows umodified image first while your vertice list is empty cv2.imshow("Mask_drawing", ima) # cv2.waitKey(1) key = cv2.waitKey(1) & 0xFF try: if len(refPT[counter]) > 1: # after two clicks ver = len(refPT[counter]) # saves a point line = refPT[counter][ver - 2 : ver] # creates a line ima = cv2.line( ima, line[0][0], line[1][0], (255, 255, 255), thickness=2 ) cv2.imshow("Mask_drawing", ima) cv2.waitKey(1) if key == ord("c") or status == 0: # if 'c' key or right click cv2.destroyAllWindows() status = 1 # restore to 1 counter += 1 # pass to the next slice break except IndexError: cv2.destroyAllWindows() break return refPT