inference.postproc.postprocess_cc
Post-processing Connected Components
Provides functionalities related to computing connected components and postprocessing segmentation predictions by removing small connected components based on thresholds.
1""" 2### Post-processing Connected Components 3 4Provides functionalities related to computing connected components and postprocessing segmentation predictions by removing small connected components based on thresholds. 5""" 6 7 8import glob 9import os 10import nibabel as nib 11import numpy as np 12import cc3d 13import json 14import argparse 15import pandas as pd 16from pathlib import Path 17from typing import Any, List, Tuple, Dict 18 19 20LABEL_MAPPING_FACTORY: Dict[str, Dict[int, str]] = { 21 "BraTS-PED": { 22 1: "ET", 23 2: "NET", 24 3: "CC", 25 4: "ED" 26 }, 27 "BraTS-SSA": { 28 1: "NCR", 29 2: "ED", 30 3: "ET" 31 }, 32 "BraTS-MEN-RT": { 33 1: "GTV" 34 }, 35 "BraTS-MET": { 36 1: "NETC", 37 2: "SNFH", 38 3: "ET" 39 } 40} 41 42 43def parse_args() -> argparse.Namespace: 44 """ 45 Parses command-line arguments for the BraTS2024 CC Postprocessing script. 46 47 Returns: 48 argparse.Namespace: Parsed command-line arguments. 49 """ 50 parser = argparse.ArgumentParser(description='BraTS2024 CC Postprocessing.') 51 parser.add_argument('--challenge_name', type=str, 52 help='The name of the challenge (e.g., BraTS-PED, BraTS-MET)') 53 parser.add_argument('--input_folder_pred', type=str, 54 help='The input folder containing the predictions') 55 parser.add_argument('--output_folder_pp_cc', type=str, 56 help='The output folder to save the postprocessed predictions') 57 parser.add_argument('--thresholds_file', type=str, 58 help='The JSON file containing the thresholds') 59 parser.add_argument('--clusters_file', type=str, 60 help='The JSON file containing the clusters') 61 62 return parser.parse_args() 63 64 65def get_connected_components(img: np.ndarray, value: int) -> Tuple[np.ndarray, int]: 66 """ 67 Identifies connected components for a specific label in the segmentation. 68 69 Args: 70 img (np.ndarray): The segmentation array. 71 value (int): The label value to find connected components for. 72 73 Returns: 74 Tuple[np.ndarray, int]: 75 - Array with labeled connected components. 76 - Number of connected components found. 77 """ 78 labels, n = cc3d.connected_components((img == value).astype(np.int8), connectivity=26, return_N=True) 79 return labels, n 80 81 82def postprocess_cc( 83 img: np.ndarray, 84 thresholds_dict: Dict[str, Dict[str, Any]], 85 label_mapping: Dict[int, str] 86) -> np.ndarray: 87 """ 88 Postprocesses the segmentation image by removing small connected components based on thresholds. 89 90 Args: 91 img (np.ndarray): The original segmentation array. 92 thresholds_dict (Dict[str, Dict[str, Any]]): 93 Dictionary containing threshold values for each label. 94 label_mapping (Dict[int, str]): Mapping from label numbers to label names. 95 96 Returns: 97 np.ndarray: The postprocessed segmentation array. 98 """ 99 # Make a copy of pred_orig 100 pred = img.copy() 101 # Iterate over each of the labels 102 for label_name, th_dict in thresholds_dict.items(): 103 label_number = int(label_name.split("_")[-1]) 104 label_name = label_mapping[label_number] 105 th = th_dict["th"] 106 print(f"Label {label_name} - value {label_number} - Applying threshold {th}") 107 # Get connected components 108 components, n = get_connected_components(pred, label_number) 109 # Apply threshold to each of the components 110 for el in range(1, n+1): 111 # get the connected component (cc) 112 cc = np.where(components == el, 1, 0) 113 # calculate the volume 114 vol = np.count_nonzero(cc) 115 # if the volume is less than the threshold, dismiss the cc 116 if vol < th: 117 pred = np.where(components == el, 0, pred) 118 # Return the postprocessed image 119 return pred 120 121 122def postprocess_batch( 123 input_files: List[str], 124 output_folder: str, 125 thresholds_dict: Dict[str, Dict[str, Any]], 126 label_mapping: Dict[int, str] 127) -> None: 128 """ 129 Applies postprocessing to a batch of segmentation files. 130 131 Args: 132 input_files (List[str]): List of input file paths. 133 output_folder (str): Path to the output directory to save postprocessed files. 134 thresholds_dict (Dict[str, Dict[str, Any]]): 135 Dictionary containing threshold values for each label. 136 label_mapping (Dict[int, str]): Mapping from label numbers to label names. 137 """ 138 for f in input_files: 139 print(f"Processing file {f}") 140 save_path = os.path.join(output_folder, os.path.basename(f)) 141 142 if os.path.exists(save_path): 143 print(f"File {save_path} already exists. Skipping.") 144 continue 145 146 # Read the segmentation file 147 pred_orig = nib.load(f).get_fdata() 148 149 # Apply postprocessing 150 pred = postprocess_cc(pred_orig, thresholds_dict, label_mapping) 151 152 # Ensure the output directory exists 153 os.makedirs(os.path.dirname(save_path), exist_ok=True) 154 155 # Save the postprocessed segmentation 156 nib.save(nib.Nifti1Image(pred, nib.load(f).affine), save_path) 157 158 159def get_thresholds_task( 160 challenge_name: str, 161 input_file: str 162) -> Dict[str, Dict[str, Any]]: 163 """ 164 Retrieves threshold settings for a specific challenge from a JSON file. 165 166 Args: 167 challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). 168 input_file (str): Path to the JSON file containing thresholds. 169 170 Returns: 171 Dict[str, Dict[str, Any]]: Thresholds for the specified challenge. 172 173 Raises: 174 ValueError: If the challenge name is not found in the JSON file. 175 """ 176 with open(input_file, 'r') as f: 177 thresholds = json.load(f) 178 179 if challenge_name not in thresholds: 180 raise ValueError(f"Challenge {challenge_name} not found in the thresholds JSON file.") 181 182 return thresholds[challenge_name] 183 184 185def get_thresholds_cluster( 186 thresholds: Dict[str, Dict[str, Any]], 187 cluster_name: str 188) -> Dict[str, Any]: 189 """ 190 Retrieves threshold settings for a specific cluster within a challenge. 191 192 Args: 193 thresholds (Dict[str, Dict[str, Any]]): Thresholds for the challenge. 194 cluster_name (str): The name of the cluster (e.g., "cluster_1"). 195 196 Returns: 197 Dict[str, Any]: Threshold settings for the specified cluster. 198 199 Raises: 200 ValueError: If the cluster name is not found in the thresholds. 201 """ 202 if cluster_name not in thresholds: 203 raise ValueError(f"Cluster {cluster_name} not found in the thresholds JSON file.") 204 return thresholds[cluster_name] 205 206 207def get_files(input_folder_pred: str) -> List[str]: 208 """ 209 Retrieves all prediction file paths from the input directory. 210 211 Args: 212 input_folder_pred (str): Path to the input directory containing prediction files. 213 214 Returns: 215 List[str]: Sorted list of prediction file paths. 216 """ 217 files_pred = sorted(glob.glob(os.path.join(input_folder_pred, "*.nii.gz"))) 218 print(f"Found {len(files_pred)} files to be processed.") 219 return files_pred 220 221 222def get_cluster_files( 223 cluster_assignment: List[Dict[str, Any]], 224 cluster_id: int, 225 files_pred: List[str] 226) -> List[str]: 227 """ 228 Retrieves prediction files that belong to a specific cluster. 229 230 Args: 231 cluster_assignment (List[Dict[str, Any]]): 232 List of cluster assignments with StudyID and cluster. 233 cluster_id (int): The cluster identifier. 234 files_pred (List[str]): List of all prediction file paths. 235 236 Returns: 237 List[str]: List of prediction files belonging to the specified cluster. 238 """ 239 # Extract StudyIDs for the given cluster 240 cluster_ids = [e["StudyID"] for e in cluster_assignment if e["cluster"] == cluster_id] 241 242 # Filter prediction files that match the StudyIDs 243 cluster_files_pred = [ 244 f for f in files_pred 245 if os.path.basename(f).replace(".nii.gz", "") in cluster_ids 246 ] 247 print(f"Cluster {cluster_id} contains {len(cluster_files_pred)} files.") 248 return cluster_files_pred 249 250 251def read_cluster_assignment(clusters_json: str) -> Tuple[List[Dict[str, Any]], List[int]]: 252 """ 253 Reads cluster assignments from a JSON file. 254 255 Args: 256 clusters_json (str): Path to the JSON file containing cluster assignments. 257 258 Returns: 259 Tuple[List[Dict[str, Any]], List[int]]: 260 - List of cluster assignments with StudyID and cluster. 261 - Sorted list of unique cluster identifiers. 262 """ 263 with open(clusters_json, "r") as f: 264 cluster_assignment = json.load(f) 265 266 # Filter relevant keys 267 cluster_assignment = [ 268 {key: value for key, value in e.items() if key in ["StudyID", "cluster"]} 269 for e in cluster_assignment 270 ] 271 272 cluster_array = np.unique([e["cluster"] for e in cluster_assignment]) 273 print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.") 274 return cluster_assignment.tolist(), sorted(cluster_array) 275 276 277def read_cluster_assignment_df(clusterdf: pd.DataFrame) -> Tuple[List[Dict[str, Any]], List[int]]: 278 """ 279 Reads cluster assignments from a pandas DataFrame. 280 281 Args: 282 clusterdf (pd.DataFrame): DataFrame containing cluster assignments. 283 284 Returns: 285 Tuple[List[Dict[str, Any]], List[int]]: 286 - List of cluster assignments with StudyID and cluster. 287 - Sorted list of unique cluster identifiers. 288 """ 289 # Convert DataFrame to list of dictionaries 290 cluster_assignment = clusterdf.to_dict(orient="records") 291 292 # Filter relevant keys 293 cluster_assignment = [ 294 {key: value for key, value in e.items() if key in ["StudyID", "cluster"]} 295 for e in cluster_assignment 296 ] 297 298 cluster_array = np.unique([e["cluster"] for e in cluster_assignment]) 299 print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.") 300 return cluster_assignment.tolist(), sorted(cluster_array) 301 302 303def remove_small_component( 304 challenge_name: str, 305 thresholds_file: str, 306 input_folder_pred: str, 307 clustersdf: pd.DataFrame, 308 output_folder_pp_cc: str 309) -> str: 310 """ 311 Removes small connected components from segmentation predictions based on cluster-specific thresholds. 312 313 Args: 314 challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). 315 thresholds_file (str): Path to the JSON file containing threshold settings. 316 input_folder_pred (str): Path to the input directory containing prediction files. 317 clustersdf (pd.DataFrame): DataFrame containing cluster assignments. 318 output_folder_pp_cc (str): Path to the output directory to save postprocessed predictions. 319 320 Returns: 321 str: Path to the output directory containing postprocessed predictions. 322 """ 323 # Retrieve label mapping for the challenge 324 label_mapping = LABEL_MAPPING_FACTORY.get(challenge_name) 325 if label_mapping is None: 326 raise ValueError(f"Unsupported challenge name: {challenge_name}") 327 328 # Load threshold settings for the challenge 329 thresholds = get_thresholds_task(challenge_name, thresholds_file) 330 331 # Retrieve all prediction files 332 files_pred = get_files(input_folder_pred) 333 334 # Read cluster assignments from DataFrame 335 cluster_assignment, cluster_array = read_cluster_assignment_df(clustersdf) 336 337 # Iterate over each cluster and apply postprocessing 338 for cluster in cluster_array: 339 cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred) 340 cluster_key = f"cluster_{cluster}" 341 thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key) 342 343 # Apply postprocessing to the cluster-specific prediction files 344 postprocess_batch( 345 input_files=cluster_files_pred, 346 output_folder=output_folder_pp_cc, 347 thresholds_dict=thresholds_cluster, 348 label_mapping=label_mapping 349 ) 350 351 return output_folder_pp_cc 352 353 354def main() -> None: 355 """ 356 Main function to execute label redefinition and connected component postprocessing. 357 358 Parses command-line arguments, loads necessary data, and performs postprocessing 359 on prediction files based on cluster assignments and threshold settings. 360 """ 361 args = parse_args() 362 363 # Assign label mapping based on the challenge name 364 args.label_mapping = LABEL_MAPPING_FACTORY.get(args.challenge_name) 365 if args.label_mapping is None: 366 raise ValueError(f"Unsupported challenge name: {args.challenge_name}") 367 368 # Load threshold settings for the specified challenge 369 thresholds = get_thresholds_task(args.challenge_name, args.thresholds_file) 370 371 # Retrieve all prediction files from the input directory 372 files_pred = get_files(args.input_folder_pred) 373 374 # Read cluster assignments from the clusters JSON file 375 cluster_assignment, cluster_array = read_cluster_assignment(args.clusters_file) 376 377 # Iterate over each cluster and apply label redefinition postprocessing 378 for cluster in cluster_array: 379 cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred) 380 cluster_key = f"cluster_{cluster}" 381 thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key) 382 383 # Apply postprocessing to the cluster-specific prediction files 384 postprocess_batch( 385 input_files=cluster_files_pred, 386 output_folder=args.output_folder_pp_cc, 387 thresholds_dict=thresholds_cluster, 388 label_mapping=args.label_mapping 389 ) 390 391 392if __name__ == "__main__": 393 main() 394 395 # Example Command to Run the Script: 396 # python remove_small_component.py --challenge_name BraTS-PED \ 397 # --input_folder_pred /path/to/predictions \ 398 # --output_folder_pp_cc /path/to/output \ 399 # --thresholds_file /path/to/thresholds.json \ 400 # --clusters_file /path/to/clusters.json
44def parse_args() -> argparse.Namespace: 45 """ 46 Parses command-line arguments for the BraTS2024 CC Postprocessing script. 47 48 Returns: 49 argparse.Namespace: Parsed command-line arguments. 50 """ 51 parser = argparse.ArgumentParser(description='BraTS2024 CC Postprocessing.') 52 parser.add_argument('--challenge_name', type=str, 53 help='The name of the challenge (e.g., BraTS-PED, BraTS-MET)') 54 parser.add_argument('--input_folder_pred', type=str, 55 help='The input folder containing the predictions') 56 parser.add_argument('--output_folder_pp_cc', type=str, 57 help='The output folder to save the postprocessed predictions') 58 parser.add_argument('--thresholds_file', type=str, 59 help='The JSON file containing the thresholds') 60 parser.add_argument('--clusters_file', type=str, 61 help='The JSON file containing the clusters') 62 63 return parser.parse_args()
Parses command-line arguments for the BraTS2024 CC Postprocessing script.
Returns: argparse.Namespace: Parsed command-line arguments.
66def get_connected_components(img: np.ndarray, value: int) -> Tuple[np.ndarray, int]: 67 """ 68 Identifies connected components for a specific label in the segmentation. 69 70 Args: 71 img (np.ndarray): The segmentation array. 72 value (int): The label value to find connected components for. 73 74 Returns: 75 Tuple[np.ndarray, int]: 76 - Array with labeled connected components. 77 - Number of connected components found. 78 """ 79 labels, n = cc3d.connected_components((img == value).astype(np.int8), connectivity=26, return_N=True) 80 return labels, n
Identifies connected components for a specific label in the segmentation.
Args: img (np.ndarray): The segmentation array. value (int): The label value to find connected components for.
Returns: Tuple[np.ndarray, int]: - Array with labeled connected components. - Number of connected components found.
83def postprocess_cc( 84 img: np.ndarray, 85 thresholds_dict: Dict[str, Dict[str, Any]], 86 label_mapping: Dict[int, str] 87) -> np.ndarray: 88 """ 89 Postprocesses the segmentation image by removing small connected components based on thresholds. 90 91 Args: 92 img (np.ndarray): The original segmentation array. 93 thresholds_dict (Dict[str, Dict[str, Any]]): 94 Dictionary containing threshold values for each label. 95 label_mapping (Dict[int, str]): Mapping from label numbers to label names. 96 97 Returns: 98 np.ndarray: The postprocessed segmentation array. 99 """ 100 # Make a copy of pred_orig 101 pred = img.copy() 102 # Iterate over each of the labels 103 for label_name, th_dict in thresholds_dict.items(): 104 label_number = int(label_name.split("_")[-1]) 105 label_name = label_mapping[label_number] 106 th = th_dict["th"] 107 print(f"Label {label_name} - value {label_number} - Applying threshold {th}") 108 # Get connected components 109 components, n = get_connected_components(pred, label_number) 110 # Apply threshold to each of the components 111 for el in range(1, n+1): 112 # get the connected component (cc) 113 cc = np.where(components == el, 1, 0) 114 # calculate the volume 115 vol = np.count_nonzero(cc) 116 # if the volume is less than the threshold, dismiss the cc 117 if vol < th: 118 pred = np.where(components == el, 0, pred) 119 # Return the postprocessed image 120 return pred
Postprocesses the segmentation image by removing small connected components based on thresholds.
Args: img (np.ndarray): The original segmentation array. thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label. label_mapping (Dict[int, str]): Mapping from label numbers to label names.
Returns: np.ndarray: The postprocessed segmentation array.
123def postprocess_batch( 124 input_files: List[str], 125 output_folder: str, 126 thresholds_dict: Dict[str, Dict[str, Any]], 127 label_mapping: Dict[int, str] 128) -> None: 129 """ 130 Applies postprocessing to a batch of segmentation files. 131 132 Args: 133 input_files (List[str]): List of input file paths. 134 output_folder (str): Path to the output directory to save postprocessed files. 135 thresholds_dict (Dict[str, Dict[str, Any]]): 136 Dictionary containing threshold values for each label. 137 label_mapping (Dict[int, str]): Mapping from label numbers to label names. 138 """ 139 for f in input_files: 140 print(f"Processing file {f}") 141 save_path = os.path.join(output_folder, os.path.basename(f)) 142 143 if os.path.exists(save_path): 144 print(f"File {save_path} already exists. Skipping.") 145 continue 146 147 # Read the segmentation file 148 pred_orig = nib.load(f).get_fdata() 149 150 # Apply postprocessing 151 pred = postprocess_cc(pred_orig, thresholds_dict, label_mapping) 152 153 # Ensure the output directory exists 154 os.makedirs(os.path.dirname(save_path), exist_ok=True) 155 156 # Save the postprocessed segmentation 157 nib.save(nib.Nifti1Image(pred, nib.load(f).affine), save_path)
Applies postprocessing to a batch of segmentation files.
Args: input_files (List[str]): List of input file paths. output_folder (str): Path to the output directory to save postprocessed files. thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label. label_mapping (Dict[int, str]): Mapping from label numbers to label names.
160def get_thresholds_task( 161 challenge_name: str, 162 input_file: str 163) -> Dict[str, Dict[str, Any]]: 164 """ 165 Retrieves threshold settings for a specific challenge from a JSON file. 166 167 Args: 168 challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). 169 input_file (str): Path to the JSON file containing thresholds. 170 171 Returns: 172 Dict[str, Dict[str, Any]]: Thresholds for the specified challenge. 173 174 Raises: 175 ValueError: If the challenge name is not found in the JSON file. 176 """ 177 with open(input_file, 'r') as f: 178 thresholds = json.load(f) 179 180 if challenge_name not in thresholds: 181 raise ValueError(f"Challenge {challenge_name} not found in the thresholds JSON file.") 182 183 return thresholds[challenge_name]
Retrieves threshold settings for a specific challenge from a JSON file.
Args: challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). input_file (str): Path to the JSON file containing thresholds.
Returns: Dict[str, Dict[str, Any]]: Thresholds for the specified challenge.
Raises: ValueError: If the challenge name is not found in the JSON file.
186def get_thresholds_cluster( 187 thresholds: Dict[str, Dict[str, Any]], 188 cluster_name: str 189) -> Dict[str, Any]: 190 """ 191 Retrieves threshold settings for a specific cluster within a challenge. 192 193 Args: 194 thresholds (Dict[str, Dict[str, Any]]): Thresholds for the challenge. 195 cluster_name (str): The name of the cluster (e.g., "cluster_1"). 196 197 Returns: 198 Dict[str, Any]: Threshold settings for the specified cluster. 199 200 Raises: 201 ValueError: If the cluster name is not found in the thresholds. 202 """ 203 if cluster_name not in thresholds: 204 raise ValueError(f"Cluster {cluster_name} not found in the thresholds JSON file.") 205 return thresholds[cluster_name]
Retrieves threshold settings for a specific cluster within a challenge.
Args: thresholds (Dict[str, Dict[str, Any]]): Thresholds for the challenge. cluster_name (str): The name of the cluster (e.g., "cluster_1").
Returns: Dict[str, Any]: Threshold settings for the specified cluster.
Raises: ValueError: If the cluster name is not found in the thresholds.
208def get_files(input_folder_pred: str) -> List[str]: 209 """ 210 Retrieves all prediction file paths from the input directory. 211 212 Args: 213 input_folder_pred (str): Path to the input directory containing prediction files. 214 215 Returns: 216 List[str]: Sorted list of prediction file paths. 217 """ 218 files_pred = sorted(glob.glob(os.path.join(input_folder_pred, "*.nii.gz"))) 219 print(f"Found {len(files_pred)} files to be processed.") 220 return files_pred
Retrieves all prediction file paths from the input directory.
Args: input_folder_pred (str): Path to the input directory containing prediction files.
Returns: List[str]: Sorted list of prediction file paths.
223def get_cluster_files( 224 cluster_assignment: List[Dict[str, Any]], 225 cluster_id: int, 226 files_pred: List[str] 227) -> List[str]: 228 """ 229 Retrieves prediction files that belong to a specific cluster. 230 231 Args: 232 cluster_assignment (List[Dict[str, Any]]): 233 List of cluster assignments with StudyID and cluster. 234 cluster_id (int): The cluster identifier. 235 files_pred (List[str]): List of all prediction file paths. 236 237 Returns: 238 List[str]: List of prediction files belonging to the specified cluster. 239 """ 240 # Extract StudyIDs for the given cluster 241 cluster_ids = [e["StudyID"] for e in cluster_assignment if e["cluster"] == cluster_id] 242 243 # Filter prediction files that match the StudyIDs 244 cluster_files_pred = [ 245 f for f in files_pred 246 if os.path.basename(f).replace(".nii.gz", "") in cluster_ids 247 ] 248 print(f"Cluster {cluster_id} contains {len(cluster_files_pred)} files.") 249 return cluster_files_pred
Retrieves prediction files that belong to a specific cluster.
Args: cluster_assignment (List[Dict[str, Any]]): List of cluster assignments with StudyID and cluster. cluster_id (int): The cluster identifier. files_pred (List[str]): List of all prediction file paths.
Returns: List[str]: List of prediction files belonging to the specified cluster.
252def read_cluster_assignment(clusters_json: str) -> Tuple[List[Dict[str, Any]], List[int]]: 253 """ 254 Reads cluster assignments from a JSON file. 255 256 Args: 257 clusters_json (str): Path to the JSON file containing cluster assignments. 258 259 Returns: 260 Tuple[List[Dict[str, Any]], List[int]]: 261 - List of cluster assignments with StudyID and cluster. 262 - Sorted list of unique cluster identifiers. 263 """ 264 with open(clusters_json, "r") as f: 265 cluster_assignment = json.load(f) 266 267 # Filter relevant keys 268 cluster_assignment = [ 269 {key: value for key, value in e.items() if key in ["StudyID", "cluster"]} 270 for e in cluster_assignment 271 ] 272 273 cluster_array = np.unique([e["cluster"] for e in cluster_assignment]) 274 print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.") 275 return cluster_assignment.tolist(), sorted(cluster_array)
Reads cluster assignments from a JSON file.
Args: clusters_json (str): Path to the JSON file containing cluster assignments.
Returns: Tuple[List[Dict[str, Any]], List[int]]: - List of cluster assignments with StudyID and cluster. - Sorted list of unique cluster identifiers.
278def read_cluster_assignment_df(clusterdf: pd.DataFrame) -> Tuple[List[Dict[str, Any]], List[int]]: 279 """ 280 Reads cluster assignments from a pandas DataFrame. 281 282 Args: 283 clusterdf (pd.DataFrame): DataFrame containing cluster assignments. 284 285 Returns: 286 Tuple[List[Dict[str, Any]], List[int]]: 287 - List of cluster assignments with StudyID and cluster. 288 - Sorted list of unique cluster identifiers. 289 """ 290 # Convert DataFrame to list of dictionaries 291 cluster_assignment = clusterdf.to_dict(orient="records") 292 293 # Filter relevant keys 294 cluster_assignment = [ 295 {key: value for key, value in e.items() if key in ["StudyID", "cluster"]} 296 for e in cluster_assignment 297 ] 298 299 cluster_array = np.unique([e["cluster"] for e in cluster_assignment]) 300 print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.") 301 return cluster_assignment.tolist(), sorted(cluster_array)
Reads cluster assignments from a pandas DataFrame.
Args: clusterdf (pd.DataFrame): DataFrame containing cluster assignments.
Returns: Tuple[List[Dict[str, Any]], List[int]]: - List of cluster assignments with StudyID and cluster. - Sorted list of unique cluster identifiers.
304def remove_small_component( 305 challenge_name: str, 306 thresholds_file: str, 307 input_folder_pred: str, 308 clustersdf: pd.DataFrame, 309 output_folder_pp_cc: str 310) -> str: 311 """ 312 Removes small connected components from segmentation predictions based on cluster-specific thresholds. 313 314 Args: 315 challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). 316 thresholds_file (str): Path to the JSON file containing threshold settings. 317 input_folder_pred (str): Path to the input directory containing prediction files. 318 clustersdf (pd.DataFrame): DataFrame containing cluster assignments. 319 output_folder_pp_cc (str): Path to the output directory to save postprocessed predictions. 320 321 Returns: 322 str: Path to the output directory containing postprocessed predictions. 323 """ 324 # Retrieve label mapping for the challenge 325 label_mapping = LABEL_MAPPING_FACTORY.get(challenge_name) 326 if label_mapping is None: 327 raise ValueError(f"Unsupported challenge name: {challenge_name}") 328 329 # Load threshold settings for the challenge 330 thresholds = get_thresholds_task(challenge_name, thresholds_file) 331 332 # Retrieve all prediction files 333 files_pred = get_files(input_folder_pred) 334 335 # Read cluster assignments from DataFrame 336 cluster_assignment, cluster_array = read_cluster_assignment_df(clustersdf) 337 338 # Iterate over each cluster and apply postprocessing 339 for cluster in cluster_array: 340 cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred) 341 cluster_key = f"cluster_{cluster}" 342 thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key) 343 344 # Apply postprocessing to the cluster-specific prediction files 345 postprocess_batch( 346 input_files=cluster_files_pred, 347 output_folder=output_folder_pp_cc, 348 thresholds_dict=thresholds_cluster, 349 label_mapping=label_mapping 350 ) 351 352 return output_folder_pp_cc
Removes small connected components from segmentation predictions based on cluster-specific thresholds.
Args: challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). thresholds_file (str): Path to the JSON file containing threshold settings. input_folder_pred (str): Path to the input directory containing prediction files. clustersdf (pd.DataFrame): DataFrame containing cluster assignments. output_folder_pp_cc (str): Path to the output directory to save postprocessed predictions.
Returns: str: Path to the output directory containing postprocessed predictions.
355def main() -> None: 356 """ 357 Main function to execute label redefinition and connected component postprocessing. 358 359 Parses command-line arguments, loads necessary data, and performs postprocessing 360 on prediction files based on cluster assignments and threshold settings. 361 """ 362 args = parse_args() 363 364 # Assign label mapping based on the challenge name 365 args.label_mapping = LABEL_MAPPING_FACTORY.get(args.challenge_name) 366 if args.label_mapping is None: 367 raise ValueError(f"Unsupported challenge name: {args.challenge_name}") 368 369 # Load threshold settings for the specified challenge 370 thresholds = get_thresholds_task(args.challenge_name, args.thresholds_file) 371 372 # Retrieve all prediction files from the input directory 373 files_pred = get_files(args.input_folder_pred) 374 375 # Read cluster assignments from the clusters JSON file 376 cluster_assignment, cluster_array = read_cluster_assignment(args.clusters_file) 377 378 # Iterate over each cluster and apply label redefinition postprocessing 379 for cluster in cluster_array: 380 cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred) 381 cluster_key = f"cluster_{cluster}" 382 thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key) 383 384 # Apply postprocessing to the cluster-specific prediction files 385 postprocess_batch( 386 input_files=cluster_files_pred, 387 output_folder=args.output_folder_pp_cc, 388 thresholds_dict=thresholds_cluster, 389 label_mapping=args.label_mapping 390 )
Main function to execute label redefinition and connected component postprocessing.
Parses command-line arguments, loads necessary data, and performs postprocessing on prediction files based on cluster assignments and threshold settings.