hail.harmonizer

  1import sys
  2import argparse
  3from pathlib import Path
  4from os import PathLike
  5
  6import nibabel as nib
  7import numpy as np
  8import torch
  9from torchvision.transforms import ToTensor
 10
 11from skimage.filters import threshold_otsu
 12from skimage.morphology import isotropic_closing
 13
 14from model import HAIL
 15from utils import crop, zero_pad, zero_pad2d, crop2d
 16from typing import List, Tuple
 17
 18"""
 19The harmonizer module is used to harmonize the images across different imaging locations 
 20"""
 21
 22def background_removal(image_vol: np.ndarray)-> np.ndarray:
 23    """remove background from the head using otsu threhold for MRI
 24
 25    Args:
 26        image_vol (np.ndarray): brain 3D image volume
 27
 28    Returns:
 29        np.ndarray: background removed 3D image volume
 30    """
 31    [n_row, n_col, n_slc] = image_vol.shape
 32    thresh = threshold_otsu(image_vol)
 33    mask = (image_vol >= thresh) * 1.0
 34    mask = zero_pad(mask, 256)
 35    mask = isotropic_closing(mask, radius=20)
 36    mask = crop(mask, n_row, n_col, n_slc)
 37    image_vol[mask < 1e-4] = 0.0
 38    return image_vol
 39
 40def background_removal2d(image_vol: np.ndarray) ->  np.ndarray:
 41    """remove background from the head using otsu threhold for MRI
 42
 43    Args:
 44        image_vol (np.ndarray): brain 2D slice
 45
 46    Returns:
 47        np.ndarray: background removed 2D slice
 48    """
 49    [n_row, n_col] = image_vol.shape
 50    thresh = threshold_otsu(image_vol)
 51    mask = (image_vol >= thresh) * 1.0
 52    mask = zero_pad2d(mask, 256)
 53    mask = isotropic_closing(mask, radius=20)
 54    mask = crop2d(mask, n_row, n_col)
 55    image_vol[mask < 1e-4] = 0.0
 56    return image_vol
 57
 58def obtain_single_image(image_path: PathLike, bg_removal: bool=True, a_max:float=5.0) -> Tuple[torch.Tensor, nib.Nifti1Header, Tuple[float, float]]:
 59    """get a single image from a NifTi file, and preprocess it for harmonization
 60
 61    Args:
 62        image_path (Pathlike): path to the NifTi file
 63        bg_removal (bool, optional): remove background from the image. Defaults to True.
 64        a_max (float, optional): maximum value for the image. Defaults to 5.0.
 65
 66    Returns:
 67        Tuple[torch.Tensor, nib.Nifti1Header, Tuple[float, float]]: preprocessed image, image header, and normalization values
 68    """
 69    image_obj = nib.Nifti1Image.from_filename(image_path)
 70    image_vol = np.array(image_obj.get_fdata().astype(np.float32))
 71    thresh = np.percentile(image_vol.flatten(), 95)
 72    max_thresh = image_vol.max()
 73    image_vol = image_vol / (thresh + 1e-5)
 74    image_vol = np.clip(image_vol, a_min=0.0, a_max=a_max)
 75    if bg_removal:
 76        image_vol = background_removal(image_vol)
 77
 78    n_row, n_col, n_slc = image_vol.shape
 79    # zero padding
 80    image_padded = np.zeros((224, 224, 224)).astype(np.float32)
 81    image_padded[112 - n_row // 2:112 + n_row // 2 + n_row % 2,
 82                 112 - n_col // 2:112 + n_col // 2 + n_col % 2,
 83                 112 - n_slc // 2:112 + n_slc // 2 + n_slc % 2] = image_vol
 84    return ToTensor()(image_padded), image_obj.header, (thresh, max_thresh)
 85
 86def load_source_images(image_paths:List[PathLike], bg_removal:bool =True) -> Tuple[List[torch.Tensor], nib.Nifti1Header]:
 87    """ Load all the source images from the list of paths
 88
 89    Args:
 90        image_paths (List[Pathlike]): list of paths to the images
 91        bg_removal (bool, optional): remove background from the image. Defaults to True.
 92
 93    Returns:
 94        Tuple[List[torch.Tensor], nib.Nifti1Header]: list of preprocessed images and the image header
 95    """
 96    
 97    source_images = []
 98    image_header = None
 99    for image_path in image_paths:
100        image_vol, image_header, _ = obtain_single_image(image_path, bg_removal)
101        source_images.append(image_vol.float().permute(2, 1, 0))
102    return source_images, image_header
103
104
105if __name__ == '__main__':
106    """
107    python harmonizer.py --in-path /path/to/image1.nii.gz /path/to/image2.nii.gz --target-image /path/to/target_image.nii.gz --out-path /path/to/output.nii.gz
108    
109    """
110    parser = argparse.ArgumentParser(description='Harmonization Across Imaging Location(HAIL)')
111    parser.add_argument('--in-path', type=Path, action='append', required=True)
112    parser.add_argument('--target-image', type=Path, action='append', default=[])
113    parser.add_argument('--out-path', type=Path, action='append', required=True)
114    parser.add_argument('--harmonization-model', type=Path, default=Path('/tmp/model_weights/harmonization.pt'))
115    parser.add_argument('--fusion-model', type=Path, default=Path('/tmp/model_weights/fusion.pt'))
116    parser.add_argument('--beta-dim', type=int, default=5)
117    parser.add_argument('--theta-dim', type=int, default=2)
118    parser.add_argument('--no-bg-removal', dest='bg_removal', action='store_false', default=True)
119    parser.add_argument('--gpu-id', type=int, default=0)
120    parser.add_argument('--num-batches', type=int, default=1)
121
122    args = parser.parse_args()
123    print(args)
124
125    text_div = '=' * 10
126    print(f'{text_div} BEGIN HARMONIZATION {text_div}')
127
128    # get all the absolute paths
129    for argname in ['in_path', 'target_image', 'out_path', 'harmonization_model',
130                    'fusion_model']:
131        if isinstance(getattr(args, argname), list):
132            setattr(args, argname, [path.resolve() for path in getattr(args, argname)])
133        else:
134            setattr(args, argname, getattr(args, argname).resolve())
135
136
137    print(args)
138    # initialize the HAIL model
139    hail = HAIL(beta_dim=args.beta_dim,
140                  theta_dim=args.theta_dim,
141                  eta_dim=2,
142                  pretrained=args.harmonization_model,
143                  gpu_id=args.gpu_id)
144
145    # load source images
146    source_images, image_header = load_source_images(args.in_path, args.bg_removal)
147
148    # load target template teams
149    target_images, norm_vals = [], []
150    for target_image_path, out_path in zip(args.target_image, args.out_path):
151        target_image_tmp, tmp_header, norm_val = obtain_single_image(target_image_path, args.bg_removal, a_max=6.0)
152        target_images.append(target_image_tmp.permute(2, 1, 0).permute(0, 2, 1).flip(1)[100:120, ...])
153        norm_vals.append(norm_val)
154
155
156    
157    # begin the harmonization process to generate the harmonized image in a axial fashion
158    hail.harmonize(
159        source_images=[image.permute(2, 0, 1) for image in source_images],
160        target_images=target_images,
161        target_theta=None,
162        target_eta=None,
163        out_paths=args.out_path,
164        header=image_header,
165        recon_orientation='axial',
166        norm_vals=norm_vals,
167        num_batches=args.num_batches,
168    )
169    print(f'{text_div} END HARMONIZATION {text_div}')
def background_removal(image_vol: numpy.ndarray) -> numpy.ndarray:
23def background_removal(image_vol: np.ndarray)-> np.ndarray:
24    """remove background from the head using otsu threhold for MRI
25
26    Args:
27        image_vol (np.ndarray): brain 3D image volume
28
29    Returns:
30        np.ndarray: background removed 3D image volume
31    """
32    [n_row, n_col, n_slc] = image_vol.shape
33    thresh = threshold_otsu(image_vol)
34    mask = (image_vol >= thresh) * 1.0
35    mask = zero_pad(mask, 256)
36    mask = isotropic_closing(mask, radius=20)
37    mask = crop(mask, n_row, n_col, n_slc)
38    image_vol[mask < 1e-4] = 0.0
39    return image_vol

remove background from the head using otsu threhold for MRI

Args: image_vol (np.ndarray): brain 3D image volume

Returns: np.ndarray: background removed 3D image volume

def background_removal2d(image_vol: numpy.ndarray) -> numpy.ndarray:
41def background_removal2d(image_vol: np.ndarray) ->  np.ndarray:
42    """remove background from the head using otsu threhold for MRI
43
44    Args:
45        image_vol (np.ndarray): brain 2D slice
46
47    Returns:
48        np.ndarray: background removed 2D slice
49    """
50    [n_row, n_col] = image_vol.shape
51    thresh = threshold_otsu(image_vol)
52    mask = (image_vol >= thresh) * 1.0
53    mask = zero_pad2d(mask, 256)
54    mask = isotropic_closing(mask, radius=20)
55    mask = crop2d(mask, n_row, n_col)
56    image_vol[mask < 1e-4] = 0.0
57    return image_vol

remove background from the head using otsu threhold for MRI

Args: image_vol (np.ndarray): brain 2D slice

Returns: np.ndarray: background removed 2D slice

def obtain_single_image( image_path: os.PathLike, bg_removal: bool = True, a_max: float = 5.0) -> Tuple[torch.Tensor, nibabel.nifti1.Nifti1Header, Tuple[float, float]]:
59def obtain_single_image(image_path: PathLike, bg_removal: bool=True, a_max:float=5.0) -> Tuple[torch.Tensor, nib.Nifti1Header, Tuple[float, float]]:
60    """get a single image from a NifTi file, and preprocess it for harmonization
61
62    Args:
63        image_path (Pathlike): path to the NifTi file
64        bg_removal (bool, optional): remove background from the image. Defaults to True.
65        a_max (float, optional): maximum value for the image. Defaults to 5.0.
66
67    Returns:
68        Tuple[torch.Tensor, nib.Nifti1Header, Tuple[float, float]]: preprocessed image, image header, and normalization values
69    """
70    image_obj = nib.Nifti1Image.from_filename(image_path)
71    image_vol = np.array(image_obj.get_fdata().astype(np.float32))
72    thresh = np.percentile(image_vol.flatten(), 95)
73    max_thresh = image_vol.max()
74    image_vol = image_vol / (thresh + 1e-5)
75    image_vol = np.clip(image_vol, a_min=0.0, a_max=a_max)
76    if bg_removal:
77        image_vol = background_removal(image_vol)
78
79    n_row, n_col, n_slc = image_vol.shape
80    # zero padding
81    image_padded = np.zeros((224, 224, 224)).astype(np.float32)
82    image_padded[112 - n_row // 2:112 + n_row // 2 + n_row % 2,
83                 112 - n_col // 2:112 + n_col // 2 + n_col % 2,
84                 112 - n_slc // 2:112 + n_slc // 2 + n_slc % 2] = image_vol
85    return ToTensor()(image_padded), image_obj.header, (thresh, max_thresh)

get a single image from a NifTi file, and preprocess it for harmonization

Args: image_path (Pathlike): path to the NifTi file bg_removal (bool, optional): remove background from the image. Defaults to True. a_max (float, optional): maximum value for the image. Defaults to 5.0.

Returns: Tuple[torch.Tensor, nib.Nifti1Header, Tuple[float, float]]: preprocessed image, image header, and normalization values

def load_source_images( image_paths: List[os.PathLike], bg_removal: bool = True) -> Tuple[List[torch.Tensor], nibabel.nifti1.Nifti1Header]:
 87def load_source_images(image_paths:List[PathLike], bg_removal:bool =True) -> Tuple[List[torch.Tensor], nib.Nifti1Header]:
 88    """ Load all the source images from the list of paths
 89
 90    Args:
 91        image_paths (List[Pathlike]): list of paths to the images
 92        bg_removal (bool, optional): remove background from the image. Defaults to True.
 93
 94    Returns:
 95        Tuple[List[torch.Tensor], nib.Nifti1Header]: list of preprocessed images and the image header
 96    """
 97    
 98    source_images = []
 99    image_header = None
100    for image_path in image_paths:
101        image_vol, image_header, _ = obtain_single_image(image_path, bg_removal)
102        source_images.append(image_vol.float().permute(2, 1, 0))
103    return source_images, image_header

Load all the source images from the list of paths

Args: image_paths (List[Pathlike]): list of paths to the images bg_removal (bool, optional): remove background from the image. Defaults to True.

Returns: Tuple[List[torch.Tensor], nib.Nifti1Header]: list of preprocessed images and the image header