inference.ensembler.ped_weighted_ensemble

Weighted Ensemble for Pediatric Brain Tumor Segmentation

Provides functionalities related to running a weighted ensemble of predictions from SwinUNETR, nnUNet, and MedNeXt models.

  1"""
  2### Weighted Ensemble for Pediatric Brain Tumor Segmentation
  3
  4Provides functionalities related to running a weighted ensemble of predictions from SwinUNETR, nnUNet, and MedNeXt models.
  5"""
  6
  7
  8from pathlib import Path
  9import numpy as np
 10import nibabel as nib
 11import os
 12import subprocess
 13from tqdm import tqdm
 14from typing import List, Tuple, Optional, Union, Dict
 15
 16
 17def maybe_make_dir(path: Union[str, Path]) -> Path:
 18    """
 19    Creates a directory at the specified path if it does not exist.
 20
 21    Args:
 22        path (Union[str, Path]): Path to the directory to be created.
 23
 24    Returns:
 25        Path: The path to the created or existing directory.
 26    """
 27    os.makedirs(path, exist_ok=True)
 28    return Path(path)
 29
 30
 31def convert_npz_mednext(
 32    npz_path: Union[str, Path],
 33    pkl_path: Union[str, Path],
 34    save_nifti: bool = False,
 35    nifti_dir: Union[str, Path] = Path('./tmp_prob_nifti'),
 36    suffix: str = 'mednext'
 37) -> np.ndarray:
 38    """
 39    Converts MedNeXt .npz and .pkl files to a NumPy array and optionally saves as NIfTI files.
 40
 41    Args:
 42        npz_path (Union[str, Path]): Path to the .npz file.
 43        pkl_path (Union[str, Path]): Path to the .pkl file.
 44        save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False.
 45        nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'.
 46        suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'mednext'.
 47
 48    Returns:
 49        np.ndarray: Stacked probability array.
 50    """
 51    npz = np.load(npz_path, allow_pickle=True)
 52    pkl = np.load(pkl_path, allow_pickle=True)
 53    prob = npz[npz.files[0]].astype(np.float32)
 54    bbox = pkl['crop_bbox']
 55    shape_original_before_cropping = pkl['original_size_of_raw_data']
 56    out_list = []
 57
 58    for i in range(prob.shape[0]):
 59        if bbox is not None:
 60            seg_old_size = np.zeros(shape_original_before_cropping, dtype=np.float16)
 61            for c in range(3):
 62                bbox[c][1] = min(bbox[c][0] + prob[i].shape[c], shape_original_before_cropping[c])
 63            seg_old_size[
 64                bbox[0][0]:bbox[0][1],
 65                bbox[1][0]:bbox[1][1],
 66                bbox[2][0]:bbox[2][1]
 67            ] = prob[i]
 68        else:
 69            seg_old_size = prob[i]
 70        
 71        out = np.swapaxes(seg_old_size, 0, 2)
 72        out_list.append(out)
 73
 74        if save_nifti:
 75            nifti_dir_path = maybe_make_dir(nifti_dir)
 76            nifti_path = nifti_dir_path / f"{npz_path.name.split('.npz')[0]}_{i}_{suffix}.nii.gz"
 77            nib.save(
 78                nib.Nifti1Image(out.astype(np.float32), affine=np.eye(4)),
 79                nifti_path
 80            )
 81
 82    return np.stack(out_list, axis=0).astype(np.float32)
 83
 84
 85def convert_npz_nnunet(
 86    npz_path: Union[str, Path],
 87    save_nifti: bool = False,
 88    nifti_dir: Union[str, Path] = Path('./tmp_prob_nifti'),
 89    suffix: str = 'nnunet'
 90) -> np.ndarray:
 91    """
 92    Converts nnUNet .npz files to a NumPy array and optionally saves as NIfTI files.
 93
 94    Args:
 95        npz_path (Union[str, Path]): Path to the .npz file.
 96        save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False.
 97        nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'.
 98        suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'nnunet'.
 99
100    Returns:
101        np.ndarray: Stacked probability array.
102    """
103    npz = np.load(npz_path, allow_pickle=True)
104    prob = npz[npz.files[0]].astype(np.float32)
105    out_list = []
106
107    for i in range(prob.shape[0]):
108        out = np.swapaxes(prob[i], 0, 2)
109        out_list.append(out)
110
111        if save_nifti:
112            nifti_dir_path = maybe_make_dir(nifti_dir)
113            nifti_path = nifti_dir_path / f"{npz_path.name.split('.npz')[0]}_{i}_{suffix}.nii.gz"
114            nib.save(
115                nib.Nifti1Image(out.astype(np.float32), affine=np.eye(4)),
116                nifti_path
117            )
118
119    return np.stack(out_list, axis=0).astype(np.float32)
120
121
122def convert_npz_swinunetr(
123    npz_path: Union[str, Path],
124    save_nifti: bool = False,
125    nifti_dir: Union[str, Path] = Path('./tmp_prob_nifti'),
126    suffix: str = 'swinunetr'
127) -> np.ndarray:
128    """
129    Converts SwinUNETR .npz files to a NumPy array and optionally saves as NIfTI files.
130
131    Args:
132        npz_path (Union[str, Path]): Path to the .npz file.
133        save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False.
134        nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'.
135        suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'swinunetr'.
136
137    Returns:
138        np.ndarray: Probability array.
139    """
140    npz = np.load(npz_path, allow_pickle=True)
141    prob = npz[npz.files[0]].astype(np.float32)
142
143    for i in range(prob.shape[0]):
144        if save_nifti:
145            nifti_dir_path = maybe_make_dir(nifti_dir)
146            nifti_path = nifti_dir_path / f"{npz_path.name.split('.npz')[0]}_{i}_{suffix}.nii.gz"
147            nib.save(
148                nib.Nifti1Image(prob[i].astype(np.float32), affine=np.eye(4)),
149                nifti_path
150            )
151
152    return prob.astype(np.float32)
153
154
155def ped_ensemble(
156    swinunetr_npz_path_list: List[Union[str, Path]],
157    nnunet_npz_path_list: List[Union[str, Path]],
158    mednext_npz_path_list: List[Union[str, Path]],
159    mednext_pkl_path_list: List[Union[str, Path]],
160    ensembled_path: Union[str, Path],
161    input_img: Union[str, Path],
162    weights: List[float] = [1.0, 1.0, 1.0]
163) -> Union[Path, Tuple[Path, Path]]:
164    """
165    Performs ensemble of predictions from SwinUNETR, nnUNet, and MedNeXt models.
166
167    Args:
168        swinunetr_npz_path_list (List[Union[str, Path]]): List of paths to SwinUNETR .npz files.
169        nnunet_npz_path_list (List[Union[str, Path]]): List of paths to nnUNet .npz files.
170        mednext_npz_path_list (List[Union[str, Path]]): List of paths to MedNeXt .npz files.
171        mednext_pkl_path_list (List[Union[str, Path]]): List of paths to MedNeXt .pkl files.
172        ensembled_path (Union[str, Path]): Directory to save the ensembled predictions.
173        input_img (Union[str, Path]): Path to the original input image.
174        weights (List[float], optional): Weights for each model in the ensemble. Defaults to [1.0, 1.0, 1.0].
175
176    Returns:
177        Union[Path, Tuple[Path, Path]]: Path(s) to the saved ensembled NIfTI file(s).
178    """
179    ensembled_path = maybe_make_dir(ensembled_path)
180
181    case = nnunet_npz_path_list[0].name.split('.npz')[0]
182    print(f"Ensemble {case}")
183
184    ensembled_nifti = ensembled_path / f"{case}.nii.gz"
185    if ensembled_nifti.exists():
186        print(f"File {ensembled_nifti} already exists. Skipping.")
187        return ensembled_nifti
188
189    # SwinUNETR
190    prob_swinunetr = convert_npz_swinunetr(swinunetr_npz_path_list[0])
191    for swin_npz in swinunetr_npz_path_list[1:]:
192        prob_swinunetr += convert_npz_swinunetr(swin_npz)
193    prob_swinunetr /= len(swinunetr_npz_path_list)
194    print(f"Probabilities SwinUNETR: {prob_swinunetr.shape}")
195
196    # nnUNet
197    prob_nnunet = convert_npz_nnunet(nnunet_npz_path_list[0])
198    for nnunet_npz in nnunet_npz_path_list[1:]:
199        prob_nnunet += convert_npz_nnunet(nnunet_npz)
200    prob_nnunet /= len(nnunet_npz_path_list)
201    print(f"Probabilities nnUNet: {prob_nnunet.shape}")
202
203    # MedNeXt
204    prob_mednext = convert_npz_mednext(mednext_npz_path_list[0], mednext_pkl_path_list[0])
205    for mednext_npz, mednext_pkl in zip(mednext_npz_path_list[1:], mednext_pkl_path_list[1:]):
206        prob_mednext += convert_npz_mednext(mednext_npz, mednext_pkl)
207    prob_mednext /= len(mednext_npz_path_list)
208    print(f"Probabilities MedNeXt: {prob_mednext.shape}")
209
210    # Weighted Ensemble
211    prob = (
212        weights[0] * prob_swinunetr +
213        weights[1] * prob_nnunet +
214        weights[2] * prob_mednext
215    )
216    prob /= sum(weights)
217
218    # Generate segmentation by taking argmax
219    seg = np.argmax(prob, axis=0)
220    print(f"Segmentation shape: {seg.shape}")
221
222    # Save the ensembled segmentation
223    img = nib.load(input_img)
224    nib.save(nib.Nifti1Image(seg.astype(np.int8), img.affine), ensembled_nifti)
225    print(f"Saved ensembled segmentation to {ensembled_nifti}")
226
227    return ensembled_nifti
228
229
230def batch_ped_ensemble(
231    swinunetr_pred_dirs: List[Union[str, Path]],
232    nnunet_pred_dirs: List[Union[str, Path]],
233    mednext_pred_dirs: List[Union[str, Path]],
234    input_img_dir: Union[str, Path],
235    ensembled_dir: Union[str, Path],
236    weights: List[float] = [1.0, 1.0, 1.0],
237    cv: bool = False
238) -> None:
239    """
240    Performs ensemble of predictions for multiple cases either in cross-validation (cv) or validation mode.
241
242    Args:
243        swinunetr_pred_dirs (List[Union[str, Path]]): List of directories containing SwinUNETR predictions.
244        nnunet_pred_dirs (List[Union[str, Path]]): List of directories containing nnUNet predictions.
245        mednext_pred_dirs (List[Union[str, Path]]): List of directories containing MedNeXt predictions.
246        input_img_dir (Union[str, Path]): Directory containing the original input images.
247        ensembled_dir (Union[str, Path]): Directory to save the ensembled predictions.
248        weights (List[float], optional): Weights for each model in the ensemble. Defaults to [1.0, 1.0, 1.0].
249        cv (bool, optional): Flag indicating whether to perform cross-validation ensemble. Defaults to False.
250    """
251    ensembled_dir = maybe_make_dir(ensembled_dir)
252
253    if cv:
254        
255        # Get files inside each of the nnunet_pred_dirs items
256        cases = [[case_path.name[:-7] for case_path in pred_dir.iterdir() if str(case_path).endswith(".nii.gz")] for pred_dir in nnunet_pred_dirs]
257        
258        for f in range(len(cases)):
259            for case in cases[f]:
260                
261                # swinunetr_npz_path_list = [swinunetr_pred_dirs[f] / f"{case}-t1n.npz"]
262                swinunetr_npz_path_list = [swinunetr_pred_dirs[f] / f"{case}.npz"]
263                nnunet_npz_path_list = [nnunet_pred_dirs[f] / f"{case}.npz"]
264                mednext_npz_path_list = [mednext_pred_dirs[f] / f"{case}.npz"]
265                mednext_pkl_path_list = [mednext_pred_dirs[f] / f"{case}.pkl"]
266
267                saved_path = ped_ensemble(
268                    swinunetr_npz_path_list, 
269                    nnunet_npz_path_list, 
270                    mednext_npz_path_list, 
271                    mednext_pkl_path_list, 
272                    ensembled_dir, 
273                    input_img_dir / case / f"{case}-t1n.nii.gz", 
274                    weights=weights
275                )
276                print(f"Saved {saved_path}")
277        
278    else:
279    
280        cases = [case_path.name for case_path in input_img_dir.iterdir() if case_path.is_dir()]
281
282        for case in cases:
283            swinunetr_npz_path_list = [pred / f"{case}-t1n.npz" for pred in swinunetr_pred_dirs]
284            nnunet_npz_path_list = [pred / f"{case}.npz" for pred in nnunet_pred_dirs]
285            mednext_npz_path_list = [pred / f"{case}.npz" for pred in mednext_pred_dirs]
286            mednext_pkl_path_list = [pred / f"{case}.pkl" for pred in mednext_pred_dirs]
287
288            saved_path = ped_ensemble(
289                swinunetr_npz_path_list, 
290                nnunet_npz_path_list, 
291                mednext_npz_path_list, 
292                mednext_pkl_path_list, 
293                ensembled_dir, 
294                input_img_dir / case / f"{case}-t1n.nii.gz", 
295                weights=weights
296            )
297            print(f"Saved {saved_path}")
298
299
300def main_cv():
301    """
302    Executes ensemble predictions for cross-validation (cv) mode.
303    """
304    swinunetr_pred_path = Path("/home/v363/v363397/media/output_cv/ped_stratified")
305    swinunetr_pred_dirs = [
306        swinunetr_pred_path / f'swinunetr_e650_f{i}_b1p4' for i in [0, 1, 2]
307    ] + [
308        swinunetr_pred_path / f'swinunetr_e1000_f{i}_b1p4' for i in [3, 4]
309    ]
310
311    nnunet_pred_path = Path("/home/v363/v363397/media/nnUNet_results/Dataset021_BraTS2024-PED/nnUNetTrainer_200epochs__nnUNetPlans__3d_fullres")
312    nnunet_pred_dirs = [nnunet_pred_path / f'fold_{i}' / 'validation' for i in range(5)]
313
314    mednext_pred_path = Path("/home/v363/v363397/media/nnUNet_trained_models/nnUNet/3d_fullres/Task021_BraTS2024-PEDs/nnUNetTrainerV2_MedNeXt_M_kernel3_200epochs_StratifiedSplit__nnUNetPlansv2.1_trgSp_1x1x1")
315    mednext_pred_dirs = [mednext_pred_path / f'fold_{i}' / 'validation_raw' for i in range(5)]
316
317    input_img_dir = Path("/home/v363/v363397/stay/brats2024/data/MICCAI-BraTS2024-PED/BraTS-PEDs2024_Training")
318    ensembled_dir = Path("/home/v363/v363397/stay/brats2024/Task099a_postprocessed_cv/PEDs/ensembled_preds_cv")
319
320    weights = [0.330911177, 0.330839468, 0.338249355]
321
322    batch_ped_ensemble(
323        swinunetr_pred_dirs,
324        nnunet_pred_dirs,
325        mednext_pred_dirs,
326        input_img_dir,
327        ensembled_dir,
328        weights=weights,
329        cv=True
330    )
331
332
333def main_val():
334    """
335    Executes ensemble predictions for validation mode.
336    """
337    swinunetr_pred_path = Path("/home/v363/v363397/media/output_val/ped_stratified")
338    swinunetr_pred_dirs = [
339        swinunetr_pred_path / f'swinunetr_e650_f{i}_b1p4' for i in [0, 1, 2]
340    ] + [
341        swinunetr_pred_path / f'swinunetr_e1000_f{i}_b1p4' for i in [3, 4]
342    ]
343
344    nnunet_pred_path = Path("/home/v363/v363397/stay/brats2024/Task012a_BraTS2024-PED_nnUNet_SS/6_predict/predicted/folds_independent")
345    nnunet_pred_dirs = [nnunet_pred_path / f'fold_{i}' / "nnUNetTrainer_200epochs/3d_fullres" for i in range(5)]
346
347    mednext_pred_path = Path("/home/v363/v363397/stay/brats2024/Task012b_BraTS2024-PED_MedNeXt_SS/6_predict/predicted/folds_independent")
348    mednext_pred_dirs = [mednext_pred_path / f'fold_{i}' / "nnUNetTrainerV2_MedNeXt_M_kernel3_200epochs_StratifiedSplit/3d_fullres" for i in range(5)]
349
350    input_img_dir = Path("/home/v363/v363397/stay/brats2024/data/MICCAI-BraTS2024-PED/BraTS_Validation_Data_backup")
351    ensembled_dir = Path("/home/v363/v363397/stay/brats2024/Task099b_postprocessed_val/PEDs/ensembled_preds_val")
352
353    weights = [0.330911177, 0.330839468, 0.338249355]
354
355    batch_ped_ensemble(
356        swinunetr_pred_dirs,
357        nnunet_pred_dirs,
358        mednext_pred_dirs,
359        input_img_dir,
360        ensembled_dir,
361        weights=weights,
362    )
363
364
365if __name__ == '__main__':
366    # Uncomment the desired mode to run
367    # main_cv()
368    main_val()
def maybe_make_dir(path: Union[str, pathlib.Path]) -> pathlib.Path:
18def maybe_make_dir(path: Union[str, Path]) -> Path:
19    """
20    Creates a directory at the specified path if it does not exist.
21
22    Args:
23        path (Union[str, Path]): Path to the directory to be created.
24
25    Returns:
26        Path: The path to the created or existing directory.
27    """
28    os.makedirs(path, exist_ok=True)
29    return Path(path)

Creates a directory at the specified path if it does not exist.

Args: path (Union[str, Path]): Path to the directory to be created.

Returns: Path: The path to the created or existing directory.

def convert_npz_mednext( npz_path: Union[str, pathlib.Path], pkl_path: Union[str, pathlib.Path], save_nifti: bool = False, nifti_dir: Union[str, pathlib.Path] = PosixPath('tmp_prob_nifti'), suffix: str = 'mednext') -> numpy.ndarray:
32def convert_npz_mednext(
33    npz_path: Union[str, Path],
34    pkl_path: Union[str, Path],
35    save_nifti: bool = False,
36    nifti_dir: Union[str, Path] = Path('./tmp_prob_nifti'),
37    suffix: str = 'mednext'
38) -> np.ndarray:
39    """
40    Converts MedNeXt .npz and .pkl files to a NumPy array and optionally saves as NIfTI files.
41
42    Args:
43        npz_path (Union[str, Path]): Path to the .npz file.
44        pkl_path (Union[str, Path]): Path to the .pkl file.
45        save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False.
46        nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'.
47        suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'mednext'.
48
49    Returns:
50        np.ndarray: Stacked probability array.
51    """
52    npz = np.load(npz_path, allow_pickle=True)
53    pkl = np.load(pkl_path, allow_pickle=True)
54    prob = npz[npz.files[0]].astype(np.float32)
55    bbox = pkl['crop_bbox']
56    shape_original_before_cropping = pkl['original_size_of_raw_data']
57    out_list = []
58
59    for i in range(prob.shape[0]):
60        if bbox is not None:
61            seg_old_size = np.zeros(shape_original_before_cropping, dtype=np.float16)
62            for c in range(3):
63                bbox[c][1] = min(bbox[c][0] + prob[i].shape[c], shape_original_before_cropping[c])
64            seg_old_size[
65                bbox[0][0]:bbox[0][1],
66                bbox[1][0]:bbox[1][1],
67                bbox[2][0]:bbox[2][1]
68            ] = prob[i]
69        else:
70            seg_old_size = prob[i]
71        
72        out = np.swapaxes(seg_old_size, 0, 2)
73        out_list.append(out)
74
75        if save_nifti:
76            nifti_dir_path = maybe_make_dir(nifti_dir)
77            nifti_path = nifti_dir_path / f"{npz_path.name.split('.npz')[0]}_{i}_{suffix}.nii.gz"
78            nib.save(
79                nib.Nifti1Image(out.astype(np.float32), affine=np.eye(4)),
80                nifti_path
81            )
82
83    return np.stack(out_list, axis=0).astype(np.float32)

Converts MedNeXt .npz and .pkl files to a NumPy array and optionally saves as NIfTI files.

Args: npz_path (Union[str, Path]): Path to the .npz file. pkl_path (Union[str, Path]): Path to the .pkl file. save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False. nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'. suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'mednext'.

Returns: np.ndarray: Stacked probability array.

def convert_npz_nnunet( npz_path: Union[str, pathlib.Path], save_nifti: bool = False, nifti_dir: Union[str, pathlib.Path] = PosixPath('tmp_prob_nifti'), suffix: str = 'nnunet') -> numpy.ndarray:
 86def convert_npz_nnunet(
 87    npz_path: Union[str, Path],
 88    save_nifti: bool = False,
 89    nifti_dir: Union[str, Path] = Path('./tmp_prob_nifti'),
 90    suffix: str = 'nnunet'
 91) -> np.ndarray:
 92    """
 93    Converts nnUNet .npz files to a NumPy array and optionally saves as NIfTI files.
 94
 95    Args:
 96        npz_path (Union[str, Path]): Path to the .npz file.
 97        save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False.
 98        nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'.
 99        suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'nnunet'.
100
101    Returns:
102        np.ndarray: Stacked probability array.
103    """
104    npz = np.load(npz_path, allow_pickle=True)
105    prob = npz[npz.files[0]].astype(np.float32)
106    out_list = []
107
108    for i in range(prob.shape[0]):
109        out = np.swapaxes(prob[i], 0, 2)
110        out_list.append(out)
111
112        if save_nifti:
113            nifti_dir_path = maybe_make_dir(nifti_dir)
114            nifti_path = nifti_dir_path / f"{npz_path.name.split('.npz')[0]}_{i}_{suffix}.nii.gz"
115            nib.save(
116                nib.Nifti1Image(out.astype(np.float32), affine=np.eye(4)),
117                nifti_path
118            )
119
120    return np.stack(out_list, axis=0).astype(np.float32)

Converts nnUNet .npz files to a NumPy array and optionally saves as NIfTI files.

Args: npz_path (Union[str, Path]): Path to the .npz file. save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False. nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'. suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'nnunet'.

Returns: np.ndarray: Stacked probability array.

def convert_npz_swinunetr( npz_path: Union[str, pathlib.Path], save_nifti: bool = False, nifti_dir: Union[str, pathlib.Path] = PosixPath('tmp_prob_nifti'), suffix: str = 'swinunetr') -> numpy.ndarray:
123def convert_npz_swinunetr(
124    npz_path: Union[str, Path],
125    save_nifti: bool = False,
126    nifti_dir: Union[str, Path] = Path('./tmp_prob_nifti'),
127    suffix: str = 'swinunetr'
128) -> np.ndarray:
129    """
130    Converts SwinUNETR .npz files to a NumPy array and optionally saves as NIfTI files.
131
132    Args:
133        npz_path (Union[str, Path]): Path to the .npz file.
134        save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False.
135        nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'.
136        suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'swinunetr'.
137
138    Returns:
139        np.ndarray: Probability array.
140    """
141    npz = np.load(npz_path, allow_pickle=True)
142    prob = npz[npz.files[0]].astype(np.float32)
143
144    for i in range(prob.shape[0]):
145        if save_nifti:
146            nifti_dir_path = maybe_make_dir(nifti_dir)
147            nifti_path = nifti_dir_path / f"{npz_path.name.split('.npz')[0]}_{i}_{suffix}.nii.gz"
148            nib.save(
149                nib.Nifti1Image(prob[i].astype(np.float32), affine=np.eye(4)),
150                nifti_path
151            )
152
153    return prob.astype(np.float32)

Converts SwinUNETR .npz files to a NumPy array and optionally saves as NIfTI files.

Args: npz_path (Union[str, Path]): Path to the .npz file. save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False. nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'. suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'swinunetr'.

Returns: np.ndarray: Probability array.

def ped_ensemble( swinunetr_npz_path_list: List[Union[str, pathlib.Path]], nnunet_npz_path_list: List[Union[str, pathlib.Path]], mednext_npz_path_list: List[Union[str, pathlib.Path]], mednext_pkl_path_list: List[Union[str, pathlib.Path]], ensembled_path: Union[str, pathlib.Path], input_img: Union[str, pathlib.Path], weights: List[float] = [1.0, 1.0, 1.0]) -> Union[pathlib.Path, Tuple[pathlib.Path, pathlib.Path]]:
156def ped_ensemble(
157    swinunetr_npz_path_list: List[Union[str, Path]],
158    nnunet_npz_path_list: List[Union[str, Path]],
159    mednext_npz_path_list: List[Union[str, Path]],
160    mednext_pkl_path_list: List[Union[str, Path]],
161    ensembled_path: Union[str, Path],
162    input_img: Union[str, Path],
163    weights: List[float] = [1.0, 1.0, 1.0]
164) -> Union[Path, Tuple[Path, Path]]:
165    """
166    Performs ensemble of predictions from SwinUNETR, nnUNet, and MedNeXt models.
167
168    Args:
169        swinunetr_npz_path_list (List[Union[str, Path]]): List of paths to SwinUNETR .npz files.
170        nnunet_npz_path_list (List[Union[str, Path]]): List of paths to nnUNet .npz files.
171        mednext_npz_path_list (List[Union[str, Path]]): List of paths to MedNeXt .npz files.
172        mednext_pkl_path_list (List[Union[str, Path]]): List of paths to MedNeXt .pkl files.
173        ensembled_path (Union[str, Path]): Directory to save the ensembled predictions.
174        input_img (Union[str, Path]): Path to the original input image.
175        weights (List[float], optional): Weights for each model in the ensemble. Defaults to [1.0, 1.0, 1.0].
176
177    Returns:
178        Union[Path, Tuple[Path, Path]]: Path(s) to the saved ensembled NIfTI file(s).
179    """
180    ensembled_path = maybe_make_dir(ensembled_path)
181
182    case = nnunet_npz_path_list[0].name.split('.npz')[0]
183    print(f"Ensemble {case}")
184
185    ensembled_nifti = ensembled_path / f"{case}.nii.gz"
186    if ensembled_nifti.exists():
187        print(f"File {ensembled_nifti} already exists. Skipping.")
188        return ensembled_nifti
189
190    # SwinUNETR
191    prob_swinunetr = convert_npz_swinunetr(swinunetr_npz_path_list[0])
192    for swin_npz in swinunetr_npz_path_list[1:]:
193        prob_swinunetr += convert_npz_swinunetr(swin_npz)
194    prob_swinunetr /= len(swinunetr_npz_path_list)
195    print(f"Probabilities SwinUNETR: {prob_swinunetr.shape}")
196
197    # nnUNet
198    prob_nnunet = convert_npz_nnunet(nnunet_npz_path_list[0])
199    for nnunet_npz in nnunet_npz_path_list[1:]:
200        prob_nnunet += convert_npz_nnunet(nnunet_npz)
201    prob_nnunet /= len(nnunet_npz_path_list)
202    print(f"Probabilities nnUNet: {prob_nnunet.shape}")
203
204    # MedNeXt
205    prob_mednext = convert_npz_mednext(mednext_npz_path_list[0], mednext_pkl_path_list[0])
206    for mednext_npz, mednext_pkl in zip(mednext_npz_path_list[1:], mednext_pkl_path_list[1:]):
207        prob_mednext += convert_npz_mednext(mednext_npz, mednext_pkl)
208    prob_mednext /= len(mednext_npz_path_list)
209    print(f"Probabilities MedNeXt: {prob_mednext.shape}")
210
211    # Weighted Ensemble
212    prob = (
213        weights[0] * prob_swinunetr +
214        weights[1] * prob_nnunet +
215        weights[2] * prob_mednext
216    )
217    prob /= sum(weights)
218
219    # Generate segmentation by taking argmax
220    seg = np.argmax(prob, axis=0)
221    print(f"Segmentation shape: {seg.shape}")
222
223    # Save the ensembled segmentation
224    img = nib.load(input_img)
225    nib.save(nib.Nifti1Image(seg.astype(np.int8), img.affine), ensembled_nifti)
226    print(f"Saved ensembled segmentation to {ensembled_nifti}")
227
228    return ensembled_nifti

Performs ensemble of predictions from SwinUNETR, nnUNet, and MedNeXt models.

Args: swinunetr_npz_path_list (List[Union[str, Path]]): List of paths to SwinUNETR .npz files. nnunet_npz_path_list (List[Union[str, Path]]): List of paths to nnUNet .npz files. mednext_npz_path_list (List[Union[str, Path]]): List of paths to MedNeXt .npz files. mednext_pkl_path_list (List[Union[str, Path]]): List of paths to MedNeXt .pkl files. ensembled_path (Union[str, Path]): Directory to save the ensembled predictions. input_img (Union[str, Path]): Path to the original input image. weights (List[float], optional): Weights for each model in the ensemble. Defaults to [1.0, 1.0, 1.0].

Returns: Union[Path, Tuple[Path, Path]]: Path(s) to the saved ensembled NIfTI file(s).

def batch_ped_ensemble( swinunetr_pred_dirs: List[Union[str, pathlib.Path]], nnunet_pred_dirs: List[Union[str, pathlib.Path]], mednext_pred_dirs: List[Union[str, pathlib.Path]], input_img_dir: Union[str, pathlib.Path], ensembled_dir: Union[str, pathlib.Path], weights: List[float] = [1.0, 1.0, 1.0], cv: bool = False) -> None:
231def batch_ped_ensemble(
232    swinunetr_pred_dirs: List[Union[str, Path]],
233    nnunet_pred_dirs: List[Union[str, Path]],
234    mednext_pred_dirs: List[Union[str, Path]],
235    input_img_dir: Union[str, Path],
236    ensembled_dir: Union[str, Path],
237    weights: List[float] = [1.0, 1.0, 1.0],
238    cv: bool = False
239) -> None:
240    """
241    Performs ensemble of predictions for multiple cases either in cross-validation (cv) or validation mode.
242
243    Args:
244        swinunetr_pred_dirs (List[Union[str, Path]]): List of directories containing SwinUNETR predictions.
245        nnunet_pred_dirs (List[Union[str, Path]]): List of directories containing nnUNet predictions.
246        mednext_pred_dirs (List[Union[str, Path]]): List of directories containing MedNeXt predictions.
247        input_img_dir (Union[str, Path]): Directory containing the original input images.
248        ensembled_dir (Union[str, Path]): Directory to save the ensembled predictions.
249        weights (List[float], optional): Weights for each model in the ensemble. Defaults to [1.0, 1.0, 1.0].
250        cv (bool, optional): Flag indicating whether to perform cross-validation ensemble. Defaults to False.
251    """
252    ensembled_dir = maybe_make_dir(ensembled_dir)
253
254    if cv:
255        
256        # Get files inside each of the nnunet_pred_dirs items
257        cases = [[case_path.name[:-7] for case_path in pred_dir.iterdir() if str(case_path).endswith(".nii.gz")] for pred_dir in nnunet_pred_dirs]
258        
259        for f in range(len(cases)):
260            for case in cases[f]:
261                
262                # swinunetr_npz_path_list = [swinunetr_pred_dirs[f] / f"{case}-t1n.npz"]
263                swinunetr_npz_path_list = [swinunetr_pred_dirs[f] / f"{case}.npz"]
264                nnunet_npz_path_list = [nnunet_pred_dirs[f] / f"{case}.npz"]
265                mednext_npz_path_list = [mednext_pred_dirs[f] / f"{case}.npz"]
266                mednext_pkl_path_list = [mednext_pred_dirs[f] / f"{case}.pkl"]
267
268                saved_path = ped_ensemble(
269                    swinunetr_npz_path_list, 
270                    nnunet_npz_path_list, 
271                    mednext_npz_path_list, 
272                    mednext_pkl_path_list, 
273                    ensembled_dir, 
274                    input_img_dir / case / f"{case}-t1n.nii.gz", 
275                    weights=weights
276                )
277                print(f"Saved {saved_path}")
278        
279    else:
280    
281        cases = [case_path.name for case_path in input_img_dir.iterdir() if case_path.is_dir()]
282
283        for case in cases:
284            swinunetr_npz_path_list = [pred / f"{case}-t1n.npz" for pred in swinunetr_pred_dirs]
285            nnunet_npz_path_list = [pred / f"{case}.npz" for pred in nnunet_pred_dirs]
286            mednext_npz_path_list = [pred / f"{case}.npz" for pred in mednext_pred_dirs]
287            mednext_pkl_path_list = [pred / f"{case}.pkl" for pred in mednext_pred_dirs]
288
289            saved_path = ped_ensemble(
290                swinunetr_npz_path_list, 
291                nnunet_npz_path_list, 
292                mednext_npz_path_list, 
293                mednext_pkl_path_list, 
294                ensembled_dir, 
295                input_img_dir / case / f"{case}-t1n.nii.gz", 
296                weights=weights
297            )
298            print(f"Saved {saved_path}")

Performs ensemble of predictions for multiple cases either in cross-validation (cv) or validation mode.

Args: swinunetr_pred_dirs (List[Union[str, Path]]): List of directories containing SwinUNETR predictions. nnunet_pred_dirs (List[Union[str, Path]]): List of directories containing nnUNet predictions. mednext_pred_dirs (List[Union[str, Path]]): List of directories containing MedNeXt predictions. input_img_dir (Union[str, Path]): Directory containing the original input images. ensembled_dir (Union[str, Path]): Directory to save the ensembled predictions. weights (List[float], optional): Weights for each model in the ensemble. Defaults to [1.0, 1.0, 1.0]. cv (bool, optional): Flag indicating whether to perform cross-validation ensemble. Defaults to False.

def main_cv():
301def main_cv():
302    """
303    Executes ensemble predictions for cross-validation (cv) mode.
304    """
305    swinunetr_pred_path = Path("/home/v363/v363397/media/output_cv/ped_stratified")
306    swinunetr_pred_dirs = [
307        swinunetr_pred_path / f'swinunetr_e650_f{i}_b1p4' for i in [0, 1, 2]
308    ] + [
309        swinunetr_pred_path / f'swinunetr_e1000_f{i}_b1p4' for i in [3, 4]
310    ]
311
312    nnunet_pred_path = Path("/home/v363/v363397/media/nnUNet_results/Dataset021_BraTS2024-PED/nnUNetTrainer_200epochs__nnUNetPlans__3d_fullres")
313    nnunet_pred_dirs = [nnunet_pred_path / f'fold_{i}' / 'validation' for i in range(5)]
314
315    mednext_pred_path = Path("/home/v363/v363397/media/nnUNet_trained_models/nnUNet/3d_fullres/Task021_BraTS2024-PEDs/nnUNetTrainerV2_MedNeXt_M_kernel3_200epochs_StratifiedSplit__nnUNetPlansv2.1_trgSp_1x1x1")
316    mednext_pred_dirs = [mednext_pred_path / f'fold_{i}' / 'validation_raw' for i in range(5)]
317
318    input_img_dir = Path("/home/v363/v363397/stay/brats2024/data/MICCAI-BraTS2024-PED/BraTS-PEDs2024_Training")
319    ensembled_dir = Path("/home/v363/v363397/stay/brats2024/Task099a_postprocessed_cv/PEDs/ensembled_preds_cv")
320
321    weights = [0.330911177, 0.330839468, 0.338249355]
322
323    batch_ped_ensemble(
324        swinunetr_pred_dirs,
325        nnunet_pred_dirs,
326        mednext_pred_dirs,
327        input_img_dir,
328        ensembled_dir,
329        weights=weights,
330        cv=True
331    )

Executes ensemble predictions for cross-validation (cv) mode.

def main_val():
334def main_val():
335    """
336    Executes ensemble predictions for validation mode.
337    """
338    swinunetr_pred_path = Path("/home/v363/v363397/media/output_val/ped_stratified")
339    swinunetr_pred_dirs = [
340        swinunetr_pred_path / f'swinunetr_e650_f{i}_b1p4' for i in [0, 1, 2]
341    ] + [
342        swinunetr_pred_path / f'swinunetr_e1000_f{i}_b1p4' for i in [3, 4]
343    ]
344
345    nnunet_pred_path = Path("/home/v363/v363397/stay/brats2024/Task012a_BraTS2024-PED_nnUNet_SS/6_predict/predicted/folds_independent")
346    nnunet_pred_dirs = [nnunet_pred_path / f'fold_{i}' / "nnUNetTrainer_200epochs/3d_fullres" for i in range(5)]
347
348    mednext_pred_path = Path("/home/v363/v363397/stay/brats2024/Task012b_BraTS2024-PED_MedNeXt_SS/6_predict/predicted/folds_independent")
349    mednext_pred_dirs = [mednext_pred_path / f'fold_{i}' / "nnUNetTrainerV2_MedNeXt_M_kernel3_200epochs_StratifiedSplit/3d_fullres" for i in range(5)]
350
351    input_img_dir = Path("/home/v363/v363397/stay/brats2024/data/MICCAI-BraTS2024-PED/BraTS_Validation_Data_backup")
352    ensembled_dir = Path("/home/v363/v363397/stay/brats2024/Task099b_postprocessed_val/PEDs/ensembled_preds_val")
353
354    weights = [0.330911177, 0.330839468, 0.338249355]
355
356    batch_ped_ensemble(
357        swinunetr_pred_dirs,
358        nnunet_pred_dirs,
359        mednext_pred_dirs,
360        input_img_dir,
361        ensembled_dir,
362        weights=weights,
363    )

Executes ensemble predictions for validation mode.