hail.harmonizer
1import sys 2import argparse 3from pathlib import Path 4from os import PathLike 5 6import nibabel as nib 7import numpy as np 8import torch 9from torchvision.transforms import ToTensor 10 11from skimage.filters import threshold_otsu 12from skimage.morphology import isotropic_closing 13 14from model import HAIL 15from utils import crop, zero_pad, zero_pad2d, crop2d 16from typing import List, Tuple 17 18""" 19The harmonizer module is used to harmonize the images across different imaging locations 20""" 21 22def background_removal(image_vol: np.ndarray)-> np.ndarray: 23 """remove background from the head using otsu threhold for MRI 24 25 Args: 26 image_vol (np.ndarray): brain 3D image volume 27 28 Returns: 29 np.ndarray: background removed 3D image volume 30 """ 31 [n_row, n_col, n_slc] = image_vol.shape 32 thresh = threshold_otsu(image_vol) 33 mask = (image_vol >= thresh) * 1.0 34 mask = zero_pad(mask, 256) 35 mask = isotropic_closing(mask, radius=20) 36 mask = crop(mask, n_row, n_col, n_slc) 37 image_vol[mask < 1e-4] = 0.0 38 return image_vol 39 40def background_removal2d(image_vol: np.ndarray) -> np.ndarray: 41 """remove background from the head using otsu threhold for MRI 42 43 Args: 44 image_vol (np.ndarray): brain 2D slice 45 46 Returns: 47 np.ndarray: background removed 2D slice 48 """ 49 [n_row, n_col] = image_vol.shape 50 thresh = threshold_otsu(image_vol) 51 mask = (image_vol >= thresh) * 1.0 52 mask = zero_pad2d(mask, 256) 53 mask = isotropic_closing(mask, radius=20) 54 mask = crop2d(mask, n_row, n_col) 55 image_vol[mask < 1e-4] = 0.0 56 return image_vol 57 58def obtain_single_image(image_path: PathLike, bg_removal: bool=True, a_max:float=5.0) -> Tuple[torch.Tensor, nib.Nifti1Header, Tuple[float, float]]: 59 """get a single image from a NifTi file, and preprocess it for harmonization 60 61 Args: 62 image_path (Pathlike): path to the NifTi file 63 bg_removal (bool, optional): remove background from the image. Defaults to True. 64 a_max (float, optional): maximum value for the image. Defaults to 5.0. 65 66 Returns: 67 Tuple[torch.Tensor, nib.Nifti1Header, Tuple[float, float]]: preprocessed image, image header, and normalization values 68 """ 69 image_obj = nib.Nifti1Image.from_filename(image_path) 70 image_vol = np.array(image_obj.get_fdata().astype(np.float32)) 71 thresh = np.percentile(image_vol.flatten(), 95) 72 max_thresh = image_vol.max() 73 image_vol = image_vol / (thresh + 1e-5) 74 image_vol = np.clip(image_vol, a_min=0.0, a_max=a_max) 75 if bg_removal: 76 image_vol = background_removal(image_vol) 77 78 n_row, n_col, n_slc = image_vol.shape 79 # zero padding 80 image_padded = np.zeros((224, 224, 224)).astype(np.float32) 81 image_padded[112 - n_row // 2:112 + n_row // 2 + n_row % 2, 82 112 - n_col // 2:112 + n_col // 2 + n_col % 2, 83 112 - n_slc // 2:112 + n_slc // 2 + n_slc % 2] = image_vol 84 return ToTensor()(image_padded), image_obj.header, (thresh, max_thresh) 85 86def load_source_images(image_paths:List[PathLike], bg_removal:bool =True) -> Tuple[List[torch.Tensor], nib.Nifti1Header]: 87 """ Load all the source images from the list of paths 88 89 Args: 90 image_paths (List[Pathlike]): list of paths to the images 91 bg_removal (bool, optional): remove background from the image. Defaults to True. 92 93 Returns: 94 Tuple[List[torch.Tensor], nib.Nifti1Header]: list of preprocessed images and the image header 95 """ 96 97 source_images = [] 98 image_header = None 99 for image_path in image_paths: 100 image_vol, image_header, _ = obtain_single_image(image_path, bg_removal) 101 source_images.append(image_vol.float().permute(2, 1, 0)) 102 return source_images, image_header 103 104 105if __name__ == '__main__': 106 """ 107 python harmonizer.py --in-path /path/to/image1.nii.gz /path/to/image2.nii.gz --target-image /path/to/target_image.nii.gz --out-path /path/to/output.nii.gz 108 109 """ 110 parser = argparse.ArgumentParser(description='Harmonization Across Imaging Location(HAIL)') 111 parser.add_argument('--in-path', type=Path, action='append', required=True) 112 parser.add_argument('--target-image', type=Path, action='append', default=[]) 113 parser.add_argument('--out-path', type=Path, action='append', required=True) 114 parser.add_argument('--harmonization-model', type=Path, default=Path('/tmp/model_weights/harmonization.pt')) 115 parser.add_argument('--fusion-model', type=Path, default=Path('/tmp/model_weights/fusion.pt')) 116 parser.add_argument('--beta-dim', type=int, default=5) 117 parser.add_argument('--theta-dim', type=int, default=2) 118 parser.add_argument('--no-bg-removal', dest='bg_removal', action='store_false', default=True) 119 parser.add_argument('--gpu-id', type=int, default=0) 120 parser.add_argument('--num-batches', type=int, default=1) 121 122 args = parser.parse_args() 123 print(args) 124 125 text_div = '=' * 10 126 print(f'{text_div} BEGIN HARMONIZATION {text_div}') 127 128 # get all the absolute paths 129 for argname in ['in_path', 'target_image', 'out_path', 'harmonization_model', 130 'fusion_model']: 131 if isinstance(getattr(args, argname), list): 132 setattr(args, argname, [path.resolve() for path in getattr(args, argname)]) 133 else: 134 setattr(args, argname, getattr(args, argname).resolve()) 135 136 137 print(args) 138 # initialize the HAIL model 139 hail = HAIL(beta_dim=args.beta_dim, 140 theta_dim=args.theta_dim, 141 eta_dim=2, 142 pretrained=args.harmonization_model, 143 gpu_id=args.gpu_id) 144 145 # load source images 146 source_images, image_header = load_source_images(args.in_path, args.bg_removal) 147 148 # load target template teams 149 target_images, norm_vals = [], [] 150 for target_image_path, out_path in zip(args.target_image, args.out_path): 151 target_image_tmp, tmp_header, norm_val = obtain_single_image(target_image_path, args.bg_removal, a_max=6.0) 152 target_images.append(target_image_tmp.permute(2, 1, 0).permute(0, 2, 1).flip(1)[100:120, ...]) 153 norm_vals.append(norm_val) 154 155 156 157 # begin the harmonization process to generate the harmonized image in a axial fashion 158 hail.harmonize( 159 source_images=[image.permute(2, 0, 1) for image in source_images], 160 target_images=target_images, 161 target_theta=None, 162 target_eta=None, 163 out_paths=args.out_path, 164 header=image_header, 165 recon_orientation='axial', 166 norm_vals=norm_vals, 167 num_batches=args.num_batches, 168 ) 169 print(f'{text_div} END HARMONIZATION {text_div}')
23def background_removal(image_vol: np.ndarray)-> np.ndarray: 24 """remove background from the head using otsu threhold for MRI 25 26 Args: 27 image_vol (np.ndarray): brain 3D image volume 28 29 Returns: 30 np.ndarray: background removed 3D image volume 31 """ 32 [n_row, n_col, n_slc] = image_vol.shape 33 thresh = threshold_otsu(image_vol) 34 mask = (image_vol >= thresh) * 1.0 35 mask = zero_pad(mask, 256) 36 mask = isotropic_closing(mask, radius=20) 37 mask = crop(mask, n_row, n_col, n_slc) 38 image_vol[mask < 1e-4] = 0.0 39 return image_vol
remove background from the head using otsu threhold for MRI
Args: image_vol (np.ndarray): brain 3D image volume
Returns: np.ndarray: background removed 3D image volume
41def background_removal2d(image_vol: np.ndarray) -> np.ndarray: 42 """remove background from the head using otsu threhold for MRI 43 44 Args: 45 image_vol (np.ndarray): brain 2D slice 46 47 Returns: 48 np.ndarray: background removed 2D slice 49 """ 50 [n_row, n_col] = image_vol.shape 51 thresh = threshold_otsu(image_vol) 52 mask = (image_vol >= thresh) * 1.0 53 mask = zero_pad2d(mask, 256) 54 mask = isotropic_closing(mask, radius=20) 55 mask = crop2d(mask, n_row, n_col) 56 image_vol[mask < 1e-4] = 0.0 57 return image_vol
remove background from the head using otsu threhold for MRI
Args: image_vol (np.ndarray): brain 2D slice
Returns: np.ndarray: background removed 2D slice
59def obtain_single_image(image_path: PathLike, bg_removal: bool=True, a_max:float=5.0) -> Tuple[torch.Tensor, nib.Nifti1Header, Tuple[float, float]]: 60 """get a single image from a NifTi file, and preprocess it for harmonization 61 62 Args: 63 image_path (Pathlike): path to the NifTi file 64 bg_removal (bool, optional): remove background from the image. Defaults to True. 65 a_max (float, optional): maximum value for the image. Defaults to 5.0. 66 67 Returns: 68 Tuple[torch.Tensor, nib.Nifti1Header, Tuple[float, float]]: preprocessed image, image header, and normalization values 69 """ 70 image_obj = nib.Nifti1Image.from_filename(image_path) 71 image_vol = np.array(image_obj.get_fdata().astype(np.float32)) 72 thresh = np.percentile(image_vol.flatten(), 95) 73 max_thresh = image_vol.max() 74 image_vol = image_vol / (thresh + 1e-5) 75 image_vol = np.clip(image_vol, a_min=0.0, a_max=a_max) 76 if bg_removal: 77 image_vol = background_removal(image_vol) 78 79 n_row, n_col, n_slc = image_vol.shape 80 # zero padding 81 image_padded = np.zeros((224, 224, 224)).astype(np.float32) 82 image_padded[112 - n_row // 2:112 + n_row // 2 + n_row % 2, 83 112 - n_col // 2:112 + n_col // 2 + n_col % 2, 84 112 - n_slc // 2:112 + n_slc // 2 + n_slc % 2] = image_vol 85 return ToTensor()(image_padded), image_obj.header, (thresh, max_thresh)
get a single image from a NifTi file, and preprocess it for harmonization
Args: image_path (Pathlike): path to the NifTi file bg_removal (bool, optional): remove background from the image. Defaults to True. a_max (float, optional): maximum value for the image. Defaults to 5.0.
Returns: Tuple[torch.Tensor, nib.Nifti1Header, Tuple[float, float]]: preprocessed image, image header, and normalization values
87def load_source_images(image_paths:List[PathLike], bg_removal:bool =True) -> Tuple[List[torch.Tensor], nib.Nifti1Header]: 88 """ Load all the source images from the list of paths 89 90 Args: 91 image_paths (List[Pathlike]): list of paths to the images 92 bg_removal (bool, optional): remove background from the image. Defaults to True. 93 94 Returns: 95 Tuple[List[torch.Tensor], nib.Nifti1Header]: list of preprocessed images and the image header 96 """ 97 98 source_images = [] 99 image_header = None 100 for image_path in image_paths: 101 image_vol, image_header, _ = obtain_single_image(image_path, bg_removal) 102 source_images.append(image_vol.float().permute(2, 1, 0)) 103 return source_images, image_header
Load all the source images from the list of paths
Args: image_paths (List[Pathlike]): list of paths to the images bg_removal (bool, optional): remove background from the image. Defaults to True.
Returns: Tuple[List[torch.Tensor], nib.Nifti1Header]: list of preprocessed images and the image header