inference.ensembler.ped_weighted_ensemble
Weighted Ensemble for Pediatric Brain Tumor Segmentation
Provides functionalities related to running a weighted ensemble of predictions from SwinUNETR, nnUNet, and MedNeXt models.
1""" 2### Weighted Ensemble for Pediatric Brain Tumor Segmentation 3 4Provides functionalities related to running a weighted ensemble of predictions from SwinUNETR, nnUNet, and MedNeXt models. 5""" 6 7 8from pathlib import Path 9import numpy as np 10import nibabel as nib 11import os 12import subprocess 13from tqdm import tqdm 14from typing import List, Tuple, Optional, Union, Dict 15 16 17def maybe_make_dir(path: Union[str, Path]) -> Path: 18 """ 19 Creates a directory at the specified path if it does not exist. 20 21 Args: 22 path (Union[str, Path]): Path to the directory to be created. 23 24 Returns: 25 Path: The path to the created or existing directory. 26 """ 27 os.makedirs(path, exist_ok=True) 28 return Path(path) 29 30 31def convert_npz_mednext( 32 npz_path: Union[str, Path], 33 pkl_path: Union[str, Path], 34 save_nifti: bool = False, 35 nifti_dir: Union[str, Path] = Path('./tmp_prob_nifti'), 36 suffix: str = 'mednext' 37) -> np.ndarray: 38 """ 39 Converts MedNeXt .npz and .pkl files to a NumPy array and optionally saves as NIfTI files. 40 41 Args: 42 npz_path (Union[str, Path]): Path to the .npz file. 43 pkl_path (Union[str, Path]): Path to the .pkl file. 44 save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False. 45 nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'. 46 suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'mednext'. 47 48 Returns: 49 np.ndarray: Stacked probability array. 50 """ 51 npz = np.load(npz_path, allow_pickle=True) 52 pkl = np.load(pkl_path, allow_pickle=True) 53 prob = npz[npz.files[0]].astype(np.float32) 54 bbox = pkl['crop_bbox'] 55 shape_original_before_cropping = pkl['original_size_of_raw_data'] 56 out_list = [] 57 58 for i in range(prob.shape[0]): 59 if bbox is not None: 60 seg_old_size = np.zeros(shape_original_before_cropping, dtype=np.float16) 61 for c in range(3): 62 bbox[c][1] = min(bbox[c][0] + prob[i].shape[c], shape_original_before_cropping[c]) 63 seg_old_size[ 64 bbox[0][0]:bbox[0][1], 65 bbox[1][0]:bbox[1][1], 66 bbox[2][0]:bbox[2][1] 67 ] = prob[i] 68 else: 69 seg_old_size = prob[i] 70 71 out = np.swapaxes(seg_old_size, 0, 2) 72 out_list.append(out) 73 74 if save_nifti: 75 nifti_dir_path = maybe_make_dir(nifti_dir) 76 nifti_path = nifti_dir_path / f"{npz_path.name.split('.npz')[0]}_{i}_{suffix}.nii.gz" 77 nib.save( 78 nib.Nifti1Image(out.astype(np.float32), affine=np.eye(4)), 79 nifti_path 80 ) 81 82 return np.stack(out_list, axis=0).astype(np.float32) 83 84 85def convert_npz_nnunet( 86 npz_path: Union[str, Path], 87 save_nifti: bool = False, 88 nifti_dir: Union[str, Path] = Path('./tmp_prob_nifti'), 89 suffix: str = 'nnunet' 90) -> np.ndarray: 91 """ 92 Converts nnUNet .npz files to a NumPy array and optionally saves as NIfTI files. 93 94 Args: 95 npz_path (Union[str, Path]): Path to the .npz file. 96 save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False. 97 nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'. 98 suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'nnunet'. 99 100 Returns: 101 np.ndarray: Stacked probability array. 102 """ 103 npz = np.load(npz_path, allow_pickle=True) 104 prob = npz[npz.files[0]].astype(np.float32) 105 out_list = [] 106 107 for i in range(prob.shape[0]): 108 out = np.swapaxes(prob[i], 0, 2) 109 out_list.append(out) 110 111 if save_nifti: 112 nifti_dir_path = maybe_make_dir(nifti_dir) 113 nifti_path = nifti_dir_path / f"{npz_path.name.split('.npz')[0]}_{i}_{suffix}.nii.gz" 114 nib.save( 115 nib.Nifti1Image(out.astype(np.float32), affine=np.eye(4)), 116 nifti_path 117 ) 118 119 return np.stack(out_list, axis=0).astype(np.float32) 120 121 122def convert_npz_swinunetr( 123 npz_path: Union[str, Path], 124 save_nifti: bool = False, 125 nifti_dir: Union[str, Path] = Path('./tmp_prob_nifti'), 126 suffix: str = 'swinunetr' 127) -> np.ndarray: 128 """ 129 Converts SwinUNETR .npz files to a NumPy array and optionally saves as NIfTI files. 130 131 Args: 132 npz_path (Union[str, Path]): Path to the .npz file. 133 save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False. 134 nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'. 135 suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'swinunetr'. 136 137 Returns: 138 np.ndarray: Probability array. 139 """ 140 npz = np.load(npz_path, allow_pickle=True) 141 prob = npz[npz.files[0]].astype(np.float32) 142 143 for i in range(prob.shape[0]): 144 if save_nifti: 145 nifti_dir_path = maybe_make_dir(nifti_dir) 146 nifti_path = nifti_dir_path / f"{npz_path.name.split('.npz')[0]}_{i}_{suffix}.nii.gz" 147 nib.save( 148 nib.Nifti1Image(prob[i].astype(np.float32), affine=np.eye(4)), 149 nifti_path 150 ) 151 152 return prob.astype(np.float32) 153 154 155def ped_ensemble( 156 swinunetr_npz_path_list: List[Union[str, Path]], 157 nnunet_npz_path_list: List[Union[str, Path]], 158 mednext_npz_path_list: List[Union[str, Path]], 159 mednext_pkl_path_list: List[Union[str, Path]], 160 ensembled_path: Union[str, Path], 161 input_img: Union[str, Path], 162 weights: List[float] = [1.0, 1.0, 1.0] 163) -> Union[Path, Tuple[Path, Path]]: 164 """ 165 Performs ensemble of predictions from SwinUNETR, nnUNet, and MedNeXt models. 166 167 Args: 168 swinunetr_npz_path_list (List[Union[str, Path]]): List of paths to SwinUNETR .npz files. 169 nnunet_npz_path_list (List[Union[str, Path]]): List of paths to nnUNet .npz files. 170 mednext_npz_path_list (List[Union[str, Path]]): List of paths to MedNeXt .npz files. 171 mednext_pkl_path_list (List[Union[str, Path]]): List of paths to MedNeXt .pkl files. 172 ensembled_path (Union[str, Path]): Directory to save the ensembled predictions. 173 input_img (Union[str, Path]): Path to the original input image. 174 weights (List[float], optional): Weights for each model in the ensemble. Defaults to [1.0, 1.0, 1.0]. 175 176 Returns: 177 Union[Path, Tuple[Path, Path]]: Path(s) to the saved ensembled NIfTI file(s). 178 """ 179 ensembled_path = maybe_make_dir(ensembled_path) 180 181 case = nnunet_npz_path_list[0].name.split('.npz')[0] 182 print(f"Ensemble {case}") 183 184 ensembled_nifti = ensembled_path / f"{case}.nii.gz" 185 if ensembled_nifti.exists(): 186 print(f"File {ensembled_nifti} already exists. Skipping.") 187 return ensembled_nifti 188 189 # SwinUNETR 190 prob_swinunetr = convert_npz_swinunetr(swinunetr_npz_path_list[0]) 191 for swin_npz in swinunetr_npz_path_list[1:]: 192 prob_swinunetr += convert_npz_swinunetr(swin_npz) 193 prob_swinunetr /= len(swinunetr_npz_path_list) 194 print(f"Probabilities SwinUNETR: {prob_swinunetr.shape}") 195 196 # nnUNet 197 prob_nnunet = convert_npz_nnunet(nnunet_npz_path_list[0]) 198 for nnunet_npz in nnunet_npz_path_list[1:]: 199 prob_nnunet += convert_npz_nnunet(nnunet_npz) 200 prob_nnunet /= len(nnunet_npz_path_list) 201 print(f"Probabilities nnUNet: {prob_nnunet.shape}") 202 203 # MedNeXt 204 prob_mednext = convert_npz_mednext(mednext_npz_path_list[0], mednext_pkl_path_list[0]) 205 for mednext_npz, mednext_pkl in zip(mednext_npz_path_list[1:], mednext_pkl_path_list[1:]): 206 prob_mednext += convert_npz_mednext(mednext_npz, mednext_pkl) 207 prob_mednext /= len(mednext_npz_path_list) 208 print(f"Probabilities MedNeXt: {prob_mednext.shape}") 209 210 # Weighted Ensemble 211 prob = ( 212 weights[0] * prob_swinunetr + 213 weights[1] * prob_nnunet + 214 weights[2] * prob_mednext 215 ) 216 prob /= sum(weights) 217 218 # Generate segmentation by taking argmax 219 seg = np.argmax(prob, axis=0) 220 print(f"Segmentation shape: {seg.shape}") 221 222 # Save the ensembled segmentation 223 img = nib.load(input_img) 224 nib.save(nib.Nifti1Image(seg.astype(np.int8), img.affine), ensembled_nifti) 225 print(f"Saved ensembled segmentation to {ensembled_nifti}") 226 227 return ensembled_nifti 228 229 230def batch_ped_ensemble( 231 swinunetr_pred_dirs: List[Union[str, Path]], 232 nnunet_pred_dirs: List[Union[str, Path]], 233 mednext_pred_dirs: List[Union[str, Path]], 234 input_img_dir: Union[str, Path], 235 ensembled_dir: Union[str, Path], 236 weights: List[float] = [1.0, 1.0, 1.0], 237 cv: bool = False 238) -> None: 239 """ 240 Performs ensemble of predictions for multiple cases either in cross-validation (cv) or validation mode. 241 242 Args: 243 swinunetr_pred_dirs (List[Union[str, Path]]): List of directories containing SwinUNETR predictions. 244 nnunet_pred_dirs (List[Union[str, Path]]): List of directories containing nnUNet predictions. 245 mednext_pred_dirs (List[Union[str, Path]]): List of directories containing MedNeXt predictions. 246 input_img_dir (Union[str, Path]): Directory containing the original input images. 247 ensembled_dir (Union[str, Path]): Directory to save the ensembled predictions. 248 weights (List[float], optional): Weights for each model in the ensemble. Defaults to [1.0, 1.0, 1.0]. 249 cv (bool, optional): Flag indicating whether to perform cross-validation ensemble. Defaults to False. 250 """ 251 ensembled_dir = maybe_make_dir(ensembled_dir) 252 253 if cv: 254 255 # Get files inside each of the nnunet_pred_dirs items 256 cases = [[case_path.name[:-7] for case_path in pred_dir.iterdir() if str(case_path).endswith(".nii.gz")] for pred_dir in nnunet_pred_dirs] 257 258 for f in range(len(cases)): 259 for case in cases[f]: 260 261 # swinunetr_npz_path_list = [swinunetr_pred_dirs[f] / f"{case}-t1n.npz"] 262 swinunetr_npz_path_list = [swinunetr_pred_dirs[f] / f"{case}.npz"] 263 nnunet_npz_path_list = [nnunet_pred_dirs[f] / f"{case}.npz"] 264 mednext_npz_path_list = [mednext_pred_dirs[f] / f"{case}.npz"] 265 mednext_pkl_path_list = [mednext_pred_dirs[f] / f"{case}.pkl"] 266 267 saved_path = ped_ensemble( 268 swinunetr_npz_path_list, 269 nnunet_npz_path_list, 270 mednext_npz_path_list, 271 mednext_pkl_path_list, 272 ensembled_dir, 273 input_img_dir / case / f"{case}-t1n.nii.gz", 274 weights=weights 275 ) 276 print(f"Saved {saved_path}") 277 278 else: 279 280 cases = [case_path.name for case_path in input_img_dir.iterdir() if case_path.is_dir()] 281 282 for case in cases: 283 swinunetr_npz_path_list = [pred / f"{case}-t1n.npz" for pred in swinunetr_pred_dirs] 284 nnunet_npz_path_list = [pred / f"{case}.npz" for pred in nnunet_pred_dirs] 285 mednext_npz_path_list = [pred / f"{case}.npz" for pred in mednext_pred_dirs] 286 mednext_pkl_path_list = [pred / f"{case}.pkl" for pred in mednext_pred_dirs] 287 288 saved_path = ped_ensemble( 289 swinunetr_npz_path_list, 290 nnunet_npz_path_list, 291 mednext_npz_path_list, 292 mednext_pkl_path_list, 293 ensembled_dir, 294 input_img_dir / case / f"{case}-t1n.nii.gz", 295 weights=weights 296 ) 297 print(f"Saved {saved_path}") 298 299 300def main_cv(): 301 """ 302 Executes ensemble predictions for cross-validation (cv) mode. 303 """ 304 swinunetr_pred_path = Path("/home/v363/v363397/media/output_cv/ped_stratified") 305 swinunetr_pred_dirs = [ 306 swinunetr_pred_path / f'swinunetr_e650_f{i}_b1p4' for i in [0, 1, 2] 307 ] + [ 308 swinunetr_pred_path / f'swinunetr_e1000_f{i}_b1p4' for i in [3, 4] 309 ] 310 311 nnunet_pred_path = Path("/home/v363/v363397/media/nnUNet_results/Dataset021_BraTS2024-PED/nnUNetTrainer_200epochs__nnUNetPlans__3d_fullres") 312 nnunet_pred_dirs = [nnunet_pred_path / f'fold_{i}' / 'validation' for i in range(5)] 313 314 mednext_pred_path = Path("/home/v363/v363397/media/nnUNet_trained_models/nnUNet/3d_fullres/Task021_BraTS2024-PEDs/nnUNetTrainerV2_MedNeXt_M_kernel3_200epochs_StratifiedSplit__nnUNetPlansv2.1_trgSp_1x1x1") 315 mednext_pred_dirs = [mednext_pred_path / f'fold_{i}' / 'validation_raw' for i in range(5)] 316 317 input_img_dir = Path("/home/v363/v363397/stay/brats2024/data/MICCAI-BraTS2024-PED/BraTS-PEDs2024_Training") 318 ensembled_dir = Path("/home/v363/v363397/stay/brats2024/Task099a_postprocessed_cv/PEDs/ensembled_preds_cv") 319 320 weights = [0.330911177, 0.330839468, 0.338249355] 321 322 batch_ped_ensemble( 323 swinunetr_pred_dirs, 324 nnunet_pred_dirs, 325 mednext_pred_dirs, 326 input_img_dir, 327 ensembled_dir, 328 weights=weights, 329 cv=True 330 ) 331 332 333def main_val(): 334 """ 335 Executes ensemble predictions for validation mode. 336 """ 337 swinunetr_pred_path = Path("/home/v363/v363397/media/output_val/ped_stratified") 338 swinunetr_pred_dirs = [ 339 swinunetr_pred_path / f'swinunetr_e650_f{i}_b1p4' for i in [0, 1, 2] 340 ] + [ 341 swinunetr_pred_path / f'swinunetr_e1000_f{i}_b1p4' for i in [3, 4] 342 ] 343 344 nnunet_pred_path = Path("/home/v363/v363397/stay/brats2024/Task012a_BraTS2024-PED_nnUNet_SS/6_predict/predicted/folds_independent") 345 nnunet_pred_dirs = [nnunet_pred_path / f'fold_{i}' / "nnUNetTrainer_200epochs/3d_fullres" for i in range(5)] 346 347 mednext_pred_path = Path("/home/v363/v363397/stay/brats2024/Task012b_BraTS2024-PED_MedNeXt_SS/6_predict/predicted/folds_independent") 348 mednext_pred_dirs = [mednext_pred_path / f'fold_{i}' / "nnUNetTrainerV2_MedNeXt_M_kernel3_200epochs_StratifiedSplit/3d_fullres" for i in range(5)] 349 350 input_img_dir = Path("/home/v363/v363397/stay/brats2024/data/MICCAI-BraTS2024-PED/BraTS_Validation_Data_backup") 351 ensembled_dir = Path("/home/v363/v363397/stay/brats2024/Task099b_postprocessed_val/PEDs/ensembled_preds_val") 352 353 weights = [0.330911177, 0.330839468, 0.338249355] 354 355 batch_ped_ensemble( 356 swinunetr_pred_dirs, 357 nnunet_pred_dirs, 358 mednext_pred_dirs, 359 input_img_dir, 360 ensembled_dir, 361 weights=weights, 362 ) 363 364 365if __name__ == '__main__': 366 # Uncomment the desired mode to run 367 # main_cv() 368 main_val()
18def maybe_make_dir(path: Union[str, Path]) -> Path: 19 """ 20 Creates a directory at the specified path if it does not exist. 21 22 Args: 23 path (Union[str, Path]): Path to the directory to be created. 24 25 Returns: 26 Path: The path to the created or existing directory. 27 """ 28 os.makedirs(path, exist_ok=True) 29 return Path(path)
Creates a directory at the specified path if it does not exist.
Args: path (Union[str, Path]): Path to the directory to be created.
Returns: Path: The path to the created or existing directory.
32def convert_npz_mednext( 33 npz_path: Union[str, Path], 34 pkl_path: Union[str, Path], 35 save_nifti: bool = False, 36 nifti_dir: Union[str, Path] = Path('./tmp_prob_nifti'), 37 suffix: str = 'mednext' 38) -> np.ndarray: 39 """ 40 Converts MedNeXt .npz and .pkl files to a NumPy array and optionally saves as NIfTI files. 41 42 Args: 43 npz_path (Union[str, Path]): Path to the .npz file. 44 pkl_path (Union[str, Path]): Path to the .pkl file. 45 save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False. 46 nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'. 47 suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'mednext'. 48 49 Returns: 50 np.ndarray: Stacked probability array. 51 """ 52 npz = np.load(npz_path, allow_pickle=True) 53 pkl = np.load(pkl_path, allow_pickle=True) 54 prob = npz[npz.files[0]].astype(np.float32) 55 bbox = pkl['crop_bbox'] 56 shape_original_before_cropping = pkl['original_size_of_raw_data'] 57 out_list = [] 58 59 for i in range(prob.shape[0]): 60 if bbox is not None: 61 seg_old_size = np.zeros(shape_original_before_cropping, dtype=np.float16) 62 for c in range(3): 63 bbox[c][1] = min(bbox[c][0] + prob[i].shape[c], shape_original_before_cropping[c]) 64 seg_old_size[ 65 bbox[0][0]:bbox[0][1], 66 bbox[1][0]:bbox[1][1], 67 bbox[2][0]:bbox[2][1] 68 ] = prob[i] 69 else: 70 seg_old_size = prob[i] 71 72 out = np.swapaxes(seg_old_size, 0, 2) 73 out_list.append(out) 74 75 if save_nifti: 76 nifti_dir_path = maybe_make_dir(nifti_dir) 77 nifti_path = nifti_dir_path / f"{npz_path.name.split('.npz')[0]}_{i}_{suffix}.nii.gz" 78 nib.save( 79 nib.Nifti1Image(out.astype(np.float32), affine=np.eye(4)), 80 nifti_path 81 ) 82 83 return np.stack(out_list, axis=0).astype(np.float32)
Converts MedNeXt .npz and .pkl files to a NumPy array and optionally saves as NIfTI files.
Args: npz_path (Union[str, Path]): Path to the .npz file. pkl_path (Union[str, Path]): Path to the .pkl file. save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False. nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'. suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'mednext'.
Returns: np.ndarray: Stacked probability array.
86def convert_npz_nnunet( 87 npz_path: Union[str, Path], 88 save_nifti: bool = False, 89 nifti_dir: Union[str, Path] = Path('./tmp_prob_nifti'), 90 suffix: str = 'nnunet' 91) -> np.ndarray: 92 """ 93 Converts nnUNet .npz files to a NumPy array and optionally saves as NIfTI files. 94 95 Args: 96 npz_path (Union[str, Path]): Path to the .npz file. 97 save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False. 98 nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'. 99 suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'nnunet'. 100 101 Returns: 102 np.ndarray: Stacked probability array. 103 """ 104 npz = np.load(npz_path, allow_pickle=True) 105 prob = npz[npz.files[0]].astype(np.float32) 106 out_list = [] 107 108 for i in range(prob.shape[0]): 109 out = np.swapaxes(prob[i], 0, 2) 110 out_list.append(out) 111 112 if save_nifti: 113 nifti_dir_path = maybe_make_dir(nifti_dir) 114 nifti_path = nifti_dir_path / f"{npz_path.name.split('.npz')[0]}_{i}_{suffix}.nii.gz" 115 nib.save( 116 nib.Nifti1Image(out.astype(np.float32), affine=np.eye(4)), 117 nifti_path 118 ) 119 120 return np.stack(out_list, axis=0).astype(np.float32)
Converts nnUNet .npz files to a NumPy array and optionally saves as NIfTI files.
Args: npz_path (Union[str, Path]): Path to the .npz file. save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False. nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'. suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'nnunet'.
Returns: np.ndarray: Stacked probability array.
123def convert_npz_swinunetr( 124 npz_path: Union[str, Path], 125 save_nifti: bool = False, 126 nifti_dir: Union[str, Path] = Path('./tmp_prob_nifti'), 127 suffix: str = 'swinunetr' 128) -> np.ndarray: 129 """ 130 Converts SwinUNETR .npz files to a NumPy array and optionally saves as NIfTI files. 131 132 Args: 133 npz_path (Union[str, Path]): Path to the .npz file. 134 save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False. 135 nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'. 136 suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'swinunetr'. 137 138 Returns: 139 np.ndarray: Probability array. 140 """ 141 npz = np.load(npz_path, allow_pickle=True) 142 prob = npz[npz.files[0]].astype(np.float32) 143 144 for i in range(prob.shape[0]): 145 if save_nifti: 146 nifti_dir_path = maybe_make_dir(nifti_dir) 147 nifti_path = nifti_dir_path / f"{npz_path.name.split('.npz')[0]}_{i}_{suffix}.nii.gz" 148 nib.save( 149 nib.Nifti1Image(prob[i].astype(np.float32), affine=np.eye(4)), 150 nifti_path 151 ) 152 153 return prob.astype(np.float32)
Converts SwinUNETR .npz files to a NumPy array and optionally saves as NIfTI files.
Args: npz_path (Union[str, Path]): Path to the .npz file. save_nifti (bool, optional): Whether to save the probabilities as NIfTI files. Defaults to False. nifti_dir (Union[str, Path], optional): Directory to save NIfTI files. Defaults to './tmp_prob_nifti'. suffix (str, optional): Suffix for the NIfTI file names. Defaults to 'swinunetr'.
Returns: np.ndarray: Probability array.
156def ped_ensemble( 157 swinunetr_npz_path_list: List[Union[str, Path]], 158 nnunet_npz_path_list: List[Union[str, Path]], 159 mednext_npz_path_list: List[Union[str, Path]], 160 mednext_pkl_path_list: List[Union[str, Path]], 161 ensembled_path: Union[str, Path], 162 input_img: Union[str, Path], 163 weights: List[float] = [1.0, 1.0, 1.0] 164) -> Union[Path, Tuple[Path, Path]]: 165 """ 166 Performs ensemble of predictions from SwinUNETR, nnUNet, and MedNeXt models. 167 168 Args: 169 swinunetr_npz_path_list (List[Union[str, Path]]): List of paths to SwinUNETR .npz files. 170 nnunet_npz_path_list (List[Union[str, Path]]): List of paths to nnUNet .npz files. 171 mednext_npz_path_list (List[Union[str, Path]]): List of paths to MedNeXt .npz files. 172 mednext_pkl_path_list (List[Union[str, Path]]): List of paths to MedNeXt .pkl files. 173 ensembled_path (Union[str, Path]): Directory to save the ensembled predictions. 174 input_img (Union[str, Path]): Path to the original input image. 175 weights (List[float], optional): Weights for each model in the ensemble. Defaults to [1.0, 1.0, 1.0]. 176 177 Returns: 178 Union[Path, Tuple[Path, Path]]: Path(s) to the saved ensembled NIfTI file(s). 179 """ 180 ensembled_path = maybe_make_dir(ensembled_path) 181 182 case = nnunet_npz_path_list[0].name.split('.npz')[0] 183 print(f"Ensemble {case}") 184 185 ensembled_nifti = ensembled_path / f"{case}.nii.gz" 186 if ensembled_nifti.exists(): 187 print(f"File {ensembled_nifti} already exists. Skipping.") 188 return ensembled_nifti 189 190 # SwinUNETR 191 prob_swinunetr = convert_npz_swinunetr(swinunetr_npz_path_list[0]) 192 for swin_npz in swinunetr_npz_path_list[1:]: 193 prob_swinunetr += convert_npz_swinunetr(swin_npz) 194 prob_swinunetr /= len(swinunetr_npz_path_list) 195 print(f"Probabilities SwinUNETR: {prob_swinunetr.shape}") 196 197 # nnUNet 198 prob_nnunet = convert_npz_nnunet(nnunet_npz_path_list[0]) 199 for nnunet_npz in nnunet_npz_path_list[1:]: 200 prob_nnunet += convert_npz_nnunet(nnunet_npz) 201 prob_nnunet /= len(nnunet_npz_path_list) 202 print(f"Probabilities nnUNet: {prob_nnunet.shape}") 203 204 # MedNeXt 205 prob_mednext = convert_npz_mednext(mednext_npz_path_list[0], mednext_pkl_path_list[0]) 206 for mednext_npz, mednext_pkl in zip(mednext_npz_path_list[1:], mednext_pkl_path_list[1:]): 207 prob_mednext += convert_npz_mednext(mednext_npz, mednext_pkl) 208 prob_mednext /= len(mednext_npz_path_list) 209 print(f"Probabilities MedNeXt: {prob_mednext.shape}") 210 211 # Weighted Ensemble 212 prob = ( 213 weights[0] * prob_swinunetr + 214 weights[1] * prob_nnunet + 215 weights[2] * prob_mednext 216 ) 217 prob /= sum(weights) 218 219 # Generate segmentation by taking argmax 220 seg = np.argmax(prob, axis=0) 221 print(f"Segmentation shape: {seg.shape}") 222 223 # Save the ensembled segmentation 224 img = nib.load(input_img) 225 nib.save(nib.Nifti1Image(seg.astype(np.int8), img.affine), ensembled_nifti) 226 print(f"Saved ensembled segmentation to {ensembled_nifti}") 227 228 return ensembled_nifti
Performs ensemble of predictions from SwinUNETR, nnUNet, and MedNeXt models.
Args: swinunetr_npz_path_list (List[Union[str, Path]]): List of paths to SwinUNETR .npz files. nnunet_npz_path_list (List[Union[str, Path]]): List of paths to nnUNet .npz files. mednext_npz_path_list (List[Union[str, Path]]): List of paths to MedNeXt .npz files. mednext_pkl_path_list (List[Union[str, Path]]): List of paths to MedNeXt .pkl files. ensembled_path (Union[str, Path]): Directory to save the ensembled predictions. input_img (Union[str, Path]): Path to the original input image. weights (List[float], optional): Weights for each model in the ensemble. Defaults to [1.0, 1.0, 1.0].
Returns: Union[Path, Tuple[Path, Path]]: Path(s) to the saved ensembled NIfTI file(s).
231def batch_ped_ensemble( 232 swinunetr_pred_dirs: List[Union[str, Path]], 233 nnunet_pred_dirs: List[Union[str, Path]], 234 mednext_pred_dirs: List[Union[str, Path]], 235 input_img_dir: Union[str, Path], 236 ensembled_dir: Union[str, Path], 237 weights: List[float] = [1.0, 1.0, 1.0], 238 cv: bool = False 239) -> None: 240 """ 241 Performs ensemble of predictions for multiple cases either in cross-validation (cv) or validation mode. 242 243 Args: 244 swinunetr_pred_dirs (List[Union[str, Path]]): List of directories containing SwinUNETR predictions. 245 nnunet_pred_dirs (List[Union[str, Path]]): List of directories containing nnUNet predictions. 246 mednext_pred_dirs (List[Union[str, Path]]): List of directories containing MedNeXt predictions. 247 input_img_dir (Union[str, Path]): Directory containing the original input images. 248 ensembled_dir (Union[str, Path]): Directory to save the ensembled predictions. 249 weights (List[float], optional): Weights for each model in the ensemble. Defaults to [1.0, 1.0, 1.0]. 250 cv (bool, optional): Flag indicating whether to perform cross-validation ensemble. Defaults to False. 251 """ 252 ensembled_dir = maybe_make_dir(ensembled_dir) 253 254 if cv: 255 256 # Get files inside each of the nnunet_pred_dirs items 257 cases = [[case_path.name[:-7] for case_path in pred_dir.iterdir() if str(case_path).endswith(".nii.gz")] for pred_dir in nnunet_pred_dirs] 258 259 for f in range(len(cases)): 260 for case in cases[f]: 261 262 # swinunetr_npz_path_list = [swinunetr_pred_dirs[f] / f"{case}-t1n.npz"] 263 swinunetr_npz_path_list = [swinunetr_pred_dirs[f] / f"{case}.npz"] 264 nnunet_npz_path_list = [nnunet_pred_dirs[f] / f"{case}.npz"] 265 mednext_npz_path_list = [mednext_pred_dirs[f] / f"{case}.npz"] 266 mednext_pkl_path_list = [mednext_pred_dirs[f] / f"{case}.pkl"] 267 268 saved_path = ped_ensemble( 269 swinunetr_npz_path_list, 270 nnunet_npz_path_list, 271 mednext_npz_path_list, 272 mednext_pkl_path_list, 273 ensembled_dir, 274 input_img_dir / case / f"{case}-t1n.nii.gz", 275 weights=weights 276 ) 277 print(f"Saved {saved_path}") 278 279 else: 280 281 cases = [case_path.name for case_path in input_img_dir.iterdir() if case_path.is_dir()] 282 283 for case in cases: 284 swinunetr_npz_path_list = [pred / f"{case}-t1n.npz" for pred in swinunetr_pred_dirs] 285 nnunet_npz_path_list = [pred / f"{case}.npz" for pred in nnunet_pred_dirs] 286 mednext_npz_path_list = [pred / f"{case}.npz" for pred in mednext_pred_dirs] 287 mednext_pkl_path_list = [pred / f"{case}.pkl" for pred in mednext_pred_dirs] 288 289 saved_path = ped_ensemble( 290 swinunetr_npz_path_list, 291 nnunet_npz_path_list, 292 mednext_npz_path_list, 293 mednext_pkl_path_list, 294 ensembled_dir, 295 input_img_dir / case / f"{case}-t1n.nii.gz", 296 weights=weights 297 ) 298 print(f"Saved {saved_path}")
Performs ensemble of predictions for multiple cases either in cross-validation (cv) or validation mode.
Args: swinunetr_pred_dirs (List[Union[str, Path]]): List of directories containing SwinUNETR predictions. nnunet_pred_dirs (List[Union[str, Path]]): List of directories containing nnUNet predictions. mednext_pred_dirs (List[Union[str, Path]]): List of directories containing MedNeXt predictions. input_img_dir (Union[str, Path]): Directory containing the original input images. ensembled_dir (Union[str, Path]): Directory to save the ensembled predictions. weights (List[float], optional): Weights for each model in the ensemble. Defaults to [1.0, 1.0, 1.0]. cv (bool, optional): Flag indicating whether to perform cross-validation ensemble. Defaults to False.
301def main_cv(): 302 """ 303 Executes ensemble predictions for cross-validation (cv) mode. 304 """ 305 swinunetr_pred_path = Path("/home/v363/v363397/media/output_cv/ped_stratified") 306 swinunetr_pred_dirs = [ 307 swinunetr_pred_path / f'swinunetr_e650_f{i}_b1p4' for i in [0, 1, 2] 308 ] + [ 309 swinunetr_pred_path / f'swinunetr_e1000_f{i}_b1p4' for i in [3, 4] 310 ] 311 312 nnunet_pred_path = Path("/home/v363/v363397/media/nnUNet_results/Dataset021_BraTS2024-PED/nnUNetTrainer_200epochs__nnUNetPlans__3d_fullres") 313 nnunet_pred_dirs = [nnunet_pred_path / f'fold_{i}' / 'validation' for i in range(5)] 314 315 mednext_pred_path = Path("/home/v363/v363397/media/nnUNet_trained_models/nnUNet/3d_fullres/Task021_BraTS2024-PEDs/nnUNetTrainerV2_MedNeXt_M_kernel3_200epochs_StratifiedSplit__nnUNetPlansv2.1_trgSp_1x1x1") 316 mednext_pred_dirs = [mednext_pred_path / f'fold_{i}' / 'validation_raw' for i in range(5)] 317 318 input_img_dir = Path("/home/v363/v363397/stay/brats2024/data/MICCAI-BraTS2024-PED/BraTS-PEDs2024_Training") 319 ensembled_dir = Path("/home/v363/v363397/stay/brats2024/Task099a_postprocessed_cv/PEDs/ensembled_preds_cv") 320 321 weights = [0.330911177, 0.330839468, 0.338249355] 322 323 batch_ped_ensemble( 324 swinunetr_pred_dirs, 325 nnunet_pred_dirs, 326 mednext_pred_dirs, 327 input_img_dir, 328 ensembled_dir, 329 weights=weights, 330 cv=True 331 )
Executes ensemble predictions for cross-validation (cv) mode.
334def main_val(): 335 """ 336 Executes ensemble predictions for validation mode. 337 """ 338 swinunetr_pred_path = Path("/home/v363/v363397/media/output_val/ped_stratified") 339 swinunetr_pred_dirs = [ 340 swinunetr_pred_path / f'swinunetr_e650_f{i}_b1p4' for i in [0, 1, 2] 341 ] + [ 342 swinunetr_pred_path / f'swinunetr_e1000_f{i}_b1p4' for i in [3, 4] 343 ] 344 345 nnunet_pred_path = Path("/home/v363/v363397/stay/brats2024/Task012a_BraTS2024-PED_nnUNet_SS/6_predict/predicted/folds_independent") 346 nnunet_pred_dirs = [nnunet_pred_path / f'fold_{i}' / "nnUNetTrainer_200epochs/3d_fullres" for i in range(5)] 347 348 mednext_pred_path = Path("/home/v363/v363397/stay/brats2024/Task012b_BraTS2024-PED_MedNeXt_SS/6_predict/predicted/folds_independent") 349 mednext_pred_dirs = [mednext_pred_path / f'fold_{i}' / "nnUNetTrainerV2_MedNeXt_M_kernel3_200epochs_StratifiedSplit/3d_fullres" for i in range(5)] 350 351 input_img_dir = Path("/home/v363/v363397/stay/brats2024/data/MICCAI-BraTS2024-PED/BraTS_Validation_Data_backup") 352 ensembled_dir = Path("/home/v363/v363397/stay/brats2024/Task099b_postprocessed_val/PEDs/ensembled_preds_val") 353 354 weights = [0.330911177, 0.330839468, 0.338249355] 355 356 batch_ped_ensemble( 357 swinunetr_pred_dirs, 358 nnunet_pred_dirs, 359 mednext_pred_dirs, 360 input_img_dir, 361 ensembled_dir, 362 weights=weights, 363 )
Executes ensemble predictions for validation mode.