inference.postproc.postprocess_lblredef
Post-processing Label Redefinition
Provides functionalities related to postprocessing segmentation predictions by redefining labels based on threshold criteria.
1""" 2### Post-processing Label Redefinition 3 4Provides functionalities related to postprocessing segmentation predictions by redefining labels based on threshold criteria. 5""" 6 7 8import pandas as pd 9import glob 10import os 11import nibabel as nib 12import numpy as np 13import json 14import argparse 15from typing import Any, List, Tuple, Dict 16 17 18LABEL_MAPPING_FACTORY: Dict[str, Dict[int, str]] = { 19 "BraTS-PED": { 20 1: "ET", 21 2: "NET", 22 3: "CC", 23 4: "ED" 24 }, 25 "BraTS-SSA": { 26 1: "NCR", 27 2: "ED", 28 3: "ET" 29 }, 30 "BraTS-MEN-RT": { 31 1: "GTV" 32 }, 33 "BraTS-MET": { 34 1: "NETC", 35 2: "SNFH", 36 3: "ET" 37 } 38} 39 40 41def parse_args() -> argparse.Namespace: 42 """ 43 Parses command-line arguments for the BraTS2024 Postprocessing Label Redefinition script. 44 45 Returns: 46 argparse.Namespace: Parsed command-line arguments. 47 """ 48 parser = argparse.ArgumentParser(description='BraTS2024 Postprocessing.') 49 parser.add_argument('--challenge_name', type=str, 50 help='The name of the challenge (e.g., BraTS-PED, BraTS-MET)') 51 parser.add_argument('--input_folder_pred', type=str, 52 help='The input folder containing the predictions') 53 parser.add_argument('--output_folder_pp_cc', type=str, 54 help='The output folder to save the postprocessed predictions') 55 parser.add_argument('--thresholds_file', type=str, 56 help='The JSON file containing the thresholds') 57 parser.add_argument('--clusters_file', type=str, 58 help='The JSON file containing the clusters') 59 return parser.parse_args() 60 61 62def get_ratio_labels_wt(seg: np.ndarray, labels: List[int] = [1, 2, 3, 4]) -> float: 63 """ 64 Calculates the ratio of selected label voxels to whole tumor (WT) voxels. 65 66 Args: 67 seg (np.ndarray): The segmentation array. 68 labels (List[int], optional): List of label integers to include in the ratio. Defaults to [1, 2, 3, 4]. 69 70 Returns: 71 float: The ratio of selected label voxels to WT voxels. 72 """ 73 selected_voxels = sum(np.sum(seg == l) for l in labels) 74 wt_voxels = np.sum(seg != 0) # Everything but background 75 if wt_voxels == 0: 76 return 1.0 77 return selected_voxels / wt_voxels 78 79 80def postprocess_lblredef( 81 img: np.ndarray, 82 thresholds_dict: Dict[str, Dict[str, Any]], 83 label_mapping: Dict[int, str] 84) -> np.ndarray: 85 """ 86 Redefines labels in the segmentation image based on threshold criteria. 87 88 Args: 89 img (np.ndarray): The original segmentation array. 90 thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label. 91 label_mapping (Dict[int, str]): Mapping from label numbers to label names. 92 93 Returns: 94 np.ndarray: The postprocessed segmentation array. 95 """ 96 # Make a copy of the original prediction 97 pred = img.copy() 98 99 # Iterate over each label and apply redefinition based on thresholds 100 for label_name, th_dict in thresholds_dict.items(): 101 label_number = int(label_name.split("_")[-1]) 102 label_name_mapped = label_mapping.get(label_number, f"Label_{label_number}") 103 th = th_dict.get("th", 0) 104 redefine_to = th_dict.get("redefine_to", label_number) 105 106 print(f"Label {label_name_mapped} - value {label_number} - Redefinition to {redefine_to} - Applying threshold {th}") 107 108 # Get ratio with respect to the whole tumor 109 ratio = get_ratio_labels_wt(pred, labels=[label_number]) 110 111 # Conditioned redefinition 112 if ratio < th: 113 pred = np.where(pred == label_number, redefine_to, pred) 114 115 return pred 116 117 118def postprocess_batch( 119 input_files: List[str], 120 output_folder: str, 121 thresholds_dict: Dict[str, Dict[str, Any]], 122 label_mapping: Dict[int, str] 123) -> None: 124 """ 125 Applies postprocessing to a batch of segmentation files. 126 127 Args: 128 input_files (List[str]): List of input file paths. 129 output_folder (str): Path to the output directory to save postprocessed files. 130 thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label. 131 label_mapping (Dict[int, str]): Mapping from label numbers to label names. 132 """ 133 for f in input_files: 134 print(f"Processing file {f}") 135 save_path = os.path.join(output_folder, os.path.basename(f)) 136 137 if os.path.exists(save_path): 138 print(f"File {save_path} already exists. Skipping.") 139 continue 140 141 # Read the segmentation file 142 pred_orig = nib.load(f).get_fdata() 143 144 # Apply postprocessing 145 pred = postprocess_lblredef(pred_orig, thresholds_dict, label_mapping) 146 147 # Ensure the output directory exists 148 os.makedirs(os.path.dirname(save_path), exist_ok=True) 149 150 # Save the postprocessed segmentation 151 nib.save(nib.Nifti1Image(pred, nib.load(f).affine), save_path) 152 153 154def get_thresholds_task(challenge_name: str, input_file: str) -> Dict[str, Dict[str, Any]]: 155 """ 156 Retrieves threshold settings for a specific challenge from a JSON file. 157 158 Args: 159 challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). 160 input_file (str): Path to the JSON file containing thresholds. 161 162 Returns: 163 Dict[str, Dict[str, Any]]: Thresholds for the specified challenge. 164 165 Raises: 166 ValueError: If the challenge name is not found in the JSON file. 167 """ 168 with open(input_file, 'r') as f: 169 thresholds = json.load(f) 170 171 if challenge_name not in thresholds: 172 raise ValueError(f"Challenge {challenge_name} not found in the thresholds JSON file.") 173 174 return thresholds[challenge_name] 175 176 177def get_thresholds_cluster(thresholds: Dict[str, Dict[str, Any]], cluster_name: str) -> Dict[str, Any]: 178 """ 179 Retrieves threshold settings for a specific cluster within a challenge. 180 181 Args: 182 thresholds (Dict[str, Dict[str, Any]]): Thresholds for the challenge. 183 cluster_name (str): The name of the cluster (e.g., "cluster_1"). 184 185 Returns: 186 Dict[str, Any]: Threshold settings for the specified cluster. 187 188 Raises: 189 ValueError: If the cluster name is not found in the thresholds. 190 """ 191 if cluster_name not in thresholds: 192 raise ValueError(f"Cluster {cluster_name} not found in the thresholds JSON file.") 193 return thresholds[cluster_name] 194 195 196def get_files(input_folder_pred: str) -> List[str]: 197 """ 198 Retrieves all prediction file paths from the input directory. 199 200 Args: 201 input_folder_pred (str): Path to the input directory containing prediction files. 202 203 Returns: 204 List[str]: Sorted list of prediction file paths. 205 """ 206 files_pred = sorted(glob.glob(os.path.join(input_folder_pred, "*.nii.gz"))) 207 print(f"Found {len(files_pred)} files to be processed.") 208 return files_pred 209 210 211def get_cluster_files( 212 cluster_assignment: List[Dict[str, Any]], 213 cluster_id: int, 214 files_pred: List[str] 215) -> List[str]: 216 """ 217 Retrieves prediction files that belong to a specific cluster. 218 219 Args: 220 cluster_assignment (List[Dict[str, Any]]): List of cluster assignments with StudyID and cluster. 221 cluster_id (int): The cluster identifier. 222 files_pred (List[str]): List of all prediction file paths. 223 224 Returns: 225 List[str]: List of prediction files belonging to the specified cluster. 226 """ 227 cluster_ids = [e["StudyID"] for e in cluster_assignment if e["cluster"] == cluster_id] 228 cluster_files_pred = [ 229 f for f in files_pred 230 if os.path.basename(f).replace(".nii.gz", "") in cluster_ids 231 ] 232 print(f"Cluster {cluster_id} contains {len(cluster_files_pred)} files.") 233 return cluster_files_pred 234 235 236def read_cluster_assignment(clusters_json: str) -> Tuple[List[Dict[str, Any]], List[int]]: 237 """ 238 Reads cluster assignments from a JSON file. 239 240 Args: 241 clusters_json (str): Path to the JSON file containing cluster assignments. 242 243 Returns: 244 Tuple[List[Dict[str, Any]], List[int]]: 245 - List of cluster assignments with StudyID and cluster. 246 - Sorted list of unique cluster identifiers. 247 """ 248 with open(clusters_json, "r") as f: 249 cluster_assignment = json.load(f) 250 251 # Filter relevant keys 252 cluster_assignment = [ 253 {key: value for key, value in e.items() if key in ["StudyID", "cluster"]} 254 for e in cluster_assignment 255 ] 256 257 cluster_array = np.unique([e["cluster"] for e in cluster_assignment]) 258 print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.") 259 return cluster_assignment, sorted(cluster_array) 260 261 262def read_cluster_assignment_df(clusterdf: pd.DataFrame) -> Tuple[List[Dict[str, Any]], List[int]]: 263 """ 264 Reads cluster assignments from a pandas DataFrame. 265 266 Args: 267 clusterdf (pd.DataFrame): DataFrame containing cluster assignments. 268 269 Returns: 270 Tuple[List[Dict[str, Any]], List[int]]: 271 - List of cluster assignments with StudyID and cluster. 272 - Sorted list of unique cluster identifiers. 273 """ 274 cluster_assignment = clusterdf.to_dict(orient="records") 275 cluster_assignment = [ 276 {key: value for key, value in e.items() if key in ["StudyID", "cluster"]} 277 for e in cluster_assignment 278 ] 279 cluster_array = np.unique([e["cluster"] for e in cluster_assignment]) 280 print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.") 281 return cluster_assignment, sorted(cluster_array) 282 283 284def label_redefinition( 285 challenge_name: str, 286 thresholds_file: str, 287 input_folder_pred: str, 288 clustersdf: pd.DataFrame, 289 output_folder_pp_cc: str 290) -> str: 291 """ 292 Performs label redefinition postprocessing for all clusters in the dataset. 293 294 Args: 295 challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). 296 thresholds_file (str): Path to the JSON file containing threshold settings. 297 input_folder_pred (str): Path to the input directory containing prediction files. 298 clustersdf (pd.DataFrame): DataFrame containing cluster assignments. 299 output_folder_pp_cc (str): Path to the output directory to save postprocessed predictions. 300 301 Returns: 302 str: Path to the output directory containing postprocessed predictions. 303 """ 304 # Retrieve label mapping for the challenge 305 label_mapping = LABEL_MAPPING_FACTORY.get(challenge_name) 306 if label_mapping is None: 307 raise ValueError(f"Unsupported challenge name: {challenge_name}") 308 309 # Load threshold settings for the challenge 310 thresholds = get_thresholds_task(challenge_name, thresholds_file) 311 312 # Retrieve all prediction files 313 files_pred = get_files(input_folder_pred) 314 315 # Read cluster assignments from DataFrame 316 cluster_assignment, cluster_array = read_cluster_assignment_df(clustersdf) 317 318 # Iterate over each cluster and apply postprocessing 319 for cluster in cluster_array: 320 cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred) 321 cluster_key = f"cluster_{cluster}" 322 thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key) 323 postprocess_batch( 324 input_files=cluster_files_pred, 325 output_folder=output_folder_pp_cc, 326 thresholds_dict=thresholds_cluster, 327 label_mapping=label_mapping 328 ) 329 330 return output_folder_pp_cc 331 332 333def main() -> None: 334 """ 335 Main function to execute label redefinition postprocessing. 336 337 Parses command-line arguments, loads necessary data, and performs postprocessing 338 on prediction files based on cluster assignments and threshold settings. 339 """ 340 args = parse_args() 341 342 # Load threshold settings for the specified challenge 343 thresholds = get_thresholds_task(args.challenge_name, args.thresholds_file) 344 345 # Retrieve all prediction files from the input directory 346 files_pred = get_files(args.input_folder_pred) 347 348 # Read cluster assignments from the clusters JSON file 349 cluster_assignment, cluster_array = read_cluster_assignment(args.clusters_file) 350 351 # Iterate over each cluster and apply label redefinition 352 for cluster in cluster_array: 353 cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred) 354 cluster_key = f"cluster_{cluster}" 355 thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key) 356 postprocess_batch( 357 input_files=cluster_files_pred, 358 output_folder=args.output_folder_pp_cc, 359 thresholds_dict=thresholds_cluster, 360 label_mapping=LABEL_MAPPING_FACTORY.get(args.challenge_name, {}) 361 ) 362 363 364if __name__ == "__main__": 365 main() 366 367 # Example Command to Run the Script: 368 # python label_redefinition.py --challenge_name BraTS-PED --input_folder_pred /path/to/predictions \ 369 # --output_folder_pp_cc /path/to/output --thresholds_file /path/to/thresholds.json \ 370 # --clusters_file /path/to/clusters.json
42def parse_args() -> argparse.Namespace: 43 """ 44 Parses command-line arguments for the BraTS2024 Postprocessing Label Redefinition script. 45 46 Returns: 47 argparse.Namespace: Parsed command-line arguments. 48 """ 49 parser = argparse.ArgumentParser(description='BraTS2024 Postprocessing.') 50 parser.add_argument('--challenge_name', type=str, 51 help='The name of the challenge (e.g., BraTS-PED, BraTS-MET)') 52 parser.add_argument('--input_folder_pred', type=str, 53 help='The input folder containing the predictions') 54 parser.add_argument('--output_folder_pp_cc', type=str, 55 help='The output folder to save the postprocessed predictions') 56 parser.add_argument('--thresholds_file', type=str, 57 help='The JSON file containing the thresholds') 58 parser.add_argument('--clusters_file', type=str, 59 help='The JSON file containing the clusters') 60 return parser.parse_args()
Parses command-line arguments for the BraTS2024 Postprocessing Label Redefinition script.
Returns: argparse.Namespace: Parsed command-line arguments.
63def get_ratio_labels_wt(seg: np.ndarray, labels: List[int] = [1, 2, 3, 4]) -> float: 64 """ 65 Calculates the ratio of selected label voxels to whole tumor (WT) voxels. 66 67 Args: 68 seg (np.ndarray): The segmentation array. 69 labels (List[int], optional): List of label integers to include in the ratio. Defaults to [1, 2, 3, 4]. 70 71 Returns: 72 float: The ratio of selected label voxels to WT voxels. 73 """ 74 selected_voxels = sum(np.sum(seg == l) for l in labels) 75 wt_voxels = np.sum(seg != 0) # Everything but background 76 if wt_voxels == 0: 77 return 1.0 78 return selected_voxels / wt_voxels
Calculates the ratio of selected label voxels to whole tumor (WT) voxels.
Args: seg (np.ndarray): The segmentation array. labels (List[int], optional): List of label integers to include in the ratio. Defaults to [1, 2, 3, 4].
Returns: float: The ratio of selected label voxels to WT voxels.
81def postprocess_lblredef( 82 img: np.ndarray, 83 thresholds_dict: Dict[str, Dict[str, Any]], 84 label_mapping: Dict[int, str] 85) -> np.ndarray: 86 """ 87 Redefines labels in the segmentation image based on threshold criteria. 88 89 Args: 90 img (np.ndarray): The original segmentation array. 91 thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label. 92 label_mapping (Dict[int, str]): Mapping from label numbers to label names. 93 94 Returns: 95 np.ndarray: The postprocessed segmentation array. 96 """ 97 # Make a copy of the original prediction 98 pred = img.copy() 99 100 # Iterate over each label and apply redefinition based on thresholds 101 for label_name, th_dict in thresholds_dict.items(): 102 label_number = int(label_name.split("_")[-1]) 103 label_name_mapped = label_mapping.get(label_number, f"Label_{label_number}") 104 th = th_dict.get("th", 0) 105 redefine_to = th_dict.get("redefine_to", label_number) 106 107 print(f"Label {label_name_mapped} - value {label_number} - Redefinition to {redefine_to} - Applying threshold {th}") 108 109 # Get ratio with respect to the whole tumor 110 ratio = get_ratio_labels_wt(pred, labels=[label_number]) 111 112 # Conditioned redefinition 113 if ratio < th: 114 pred = np.where(pred == label_number, redefine_to, pred) 115 116 return pred
Redefines labels in the segmentation image based on threshold criteria.
Args: img (np.ndarray): The original segmentation array. thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label. label_mapping (Dict[int, str]): Mapping from label numbers to label names.
Returns: np.ndarray: The postprocessed segmentation array.
119def postprocess_batch( 120 input_files: List[str], 121 output_folder: str, 122 thresholds_dict: Dict[str, Dict[str, Any]], 123 label_mapping: Dict[int, str] 124) -> None: 125 """ 126 Applies postprocessing to a batch of segmentation files. 127 128 Args: 129 input_files (List[str]): List of input file paths. 130 output_folder (str): Path to the output directory to save postprocessed files. 131 thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label. 132 label_mapping (Dict[int, str]): Mapping from label numbers to label names. 133 """ 134 for f in input_files: 135 print(f"Processing file {f}") 136 save_path = os.path.join(output_folder, os.path.basename(f)) 137 138 if os.path.exists(save_path): 139 print(f"File {save_path} already exists. Skipping.") 140 continue 141 142 # Read the segmentation file 143 pred_orig = nib.load(f).get_fdata() 144 145 # Apply postprocessing 146 pred = postprocess_lblredef(pred_orig, thresholds_dict, label_mapping) 147 148 # Ensure the output directory exists 149 os.makedirs(os.path.dirname(save_path), exist_ok=True) 150 151 # Save the postprocessed segmentation 152 nib.save(nib.Nifti1Image(pred, nib.load(f).affine), save_path)
Applies postprocessing to a batch of segmentation files.
Args: input_files (List[str]): List of input file paths. output_folder (str): Path to the output directory to save postprocessed files. thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label. label_mapping (Dict[int, str]): Mapping from label numbers to label names.
155def get_thresholds_task(challenge_name: str, input_file: str) -> Dict[str, Dict[str, Any]]: 156 """ 157 Retrieves threshold settings for a specific challenge from a JSON file. 158 159 Args: 160 challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). 161 input_file (str): Path to the JSON file containing thresholds. 162 163 Returns: 164 Dict[str, Dict[str, Any]]: Thresholds for the specified challenge. 165 166 Raises: 167 ValueError: If the challenge name is not found in the JSON file. 168 """ 169 with open(input_file, 'r') as f: 170 thresholds = json.load(f) 171 172 if challenge_name not in thresholds: 173 raise ValueError(f"Challenge {challenge_name} not found in the thresholds JSON file.") 174 175 return thresholds[challenge_name]
Retrieves threshold settings for a specific challenge from a JSON file.
Args: challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). input_file (str): Path to the JSON file containing thresholds.
Returns: Dict[str, Dict[str, Any]]: Thresholds for the specified challenge.
Raises: ValueError: If the challenge name is not found in the JSON file.
178def get_thresholds_cluster(thresholds: Dict[str, Dict[str, Any]], cluster_name: str) -> Dict[str, Any]: 179 """ 180 Retrieves threshold settings for a specific cluster within a challenge. 181 182 Args: 183 thresholds (Dict[str, Dict[str, Any]]): Thresholds for the challenge. 184 cluster_name (str): The name of the cluster (e.g., "cluster_1"). 185 186 Returns: 187 Dict[str, Any]: Threshold settings for the specified cluster. 188 189 Raises: 190 ValueError: If the cluster name is not found in the thresholds. 191 """ 192 if cluster_name not in thresholds: 193 raise ValueError(f"Cluster {cluster_name} not found in the thresholds JSON file.") 194 return thresholds[cluster_name]
Retrieves threshold settings for a specific cluster within a challenge.
Args: thresholds (Dict[str, Dict[str, Any]]): Thresholds for the challenge. cluster_name (str): The name of the cluster (e.g., "cluster_1").
Returns: Dict[str, Any]: Threshold settings for the specified cluster.
Raises: ValueError: If the cluster name is not found in the thresholds.
197def get_files(input_folder_pred: str) -> List[str]: 198 """ 199 Retrieves all prediction file paths from the input directory. 200 201 Args: 202 input_folder_pred (str): Path to the input directory containing prediction files. 203 204 Returns: 205 List[str]: Sorted list of prediction file paths. 206 """ 207 files_pred = sorted(glob.glob(os.path.join(input_folder_pred, "*.nii.gz"))) 208 print(f"Found {len(files_pred)} files to be processed.") 209 return files_pred
Retrieves all prediction file paths from the input directory.
Args: input_folder_pred (str): Path to the input directory containing prediction files.
Returns: List[str]: Sorted list of prediction file paths.
212def get_cluster_files( 213 cluster_assignment: List[Dict[str, Any]], 214 cluster_id: int, 215 files_pred: List[str] 216) -> List[str]: 217 """ 218 Retrieves prediction files that belong to a specific cluster. 219 220 Args: 221 cluster_assignment (List[Dict[str, Any]]): List of cluster assignments with StudyID and cluster. 222 cluster_id (int): The cluster identifier. 223 files_pred (List[str]): List of all prediction file paths. 224 225 Returns: 226 List[str]: List of prediction files belonging to the specified cluster. 227 """ 228 cluster_ids = [e["StudyID"] for e in cluster_assignment if e["cluster"] == cluster_id] 229 cluster_files_pred = [ 230 f for f in files_pred 231 if os.path.basename(f).replace(".nii.gz", "") in cluster_ids 232 ] 233 print(f"Cluster {cluster_id} contains {len(cluster_files_pred)} files.") 234 return cluster_files_pred
Retrieves prediction files that belong to a specific cluster.
Args: cluster_assignment (List[Dict[str, Any]]): List of cluster assignments with StudyID and cluster. cluster_id (int): The cluster identifier. files_pred (List[str]): List of all prediction file paths.
Returns: List[str]: List of prediction files belonging to the specified cluster.
237def read_cluster_assignment(clusters_json: str) -> Tuple[List[Dict[str, Any]], List[int]]: 238 """ 239 Reads cluster assignments from a JSON file. 240 241 Args: 242 clusters_json (str): Path to the JSON file containing cluster assignments. 243 244 Returns: 245 Tuple[List[Dict[str, Any]], List[int]]: 246 - List of cluster assignments with StudyID and cluster. 247 - Sorted list of unique cluster identifiers. 248 """ 249 with open(clusters_json, "r") as f: 250 cluster_assignment = json.load(f) 251 252 # Filter relevant keys 253 cluster_assignment = [ 254 {key: value for key, value in e.items() if key in ["StudyID", "cluster"]} 255 for e in cluster_assignment 256 ] 257 258 cluster_array = np.unique([e["cluster"] for e in cluster_assignment]) 259 print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.") 260 return cluster_assignment, sorted(cluster_array)
Reads cluster assignments from a JSON file.
Args: clusters_json (str): Path to the JSON file containing cluster assignments.
Returns: Tuple[List[Dict[str, Any]], List[int]]: - List of cluster assignments with StudyID and cluster. - Sorted list of unique cluster identifiers.
263def read_cluster_assignment_df(clusterdf: pd.DataFrame) -> Tuple[List[Dict[str, Any]], List[int]]: 264 """ 265 Reads cluster assignments from a pandas DataFrame. 266 267 Args: 268 clusterdf (pd.DataFrame): DataFrame containing cluster assignments. 269 270 Returns: 271 Tuple[List[Dict[str, Any]], List[int]]: 272 - List of cluster assignments with StudyID and cluster. 273 - Sorted list of unique cluster identifiers. 274 """ 275 cluster_assignment = clusterdf.to_dict(orient="records") 276 cluster_assignment = [ 277 {key: value for key, value in e.items() if key in ["StudyID", "cluster"]} 278 for e in cluster_assignment 279 ] 280 cluster_array = np.unique([e["cluster"] for e in cluster_assignment]) 281 print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.") 282 return cluster_assignment, sorted(cluster_array)
Reads cluster assignments from a pandas DataFrame.
Args: clusterdf (pd.DataFrame): DataFrame containing cluster assignments.
Returns: Tuple[List[Dict[str, Any]], List[int]]: - List of cluster assignments with StudyID and cluster. - Sorted list of unique cluster identifiers.
285def label_redefinition( 286 challenge_name: str, 287 thresholds_file: str, 288 input_folder_pred: str, 289 clustersdf: pd.DataFrame, 290 output_folder_pp_cc: str 291) -> str: 292 """ 293 Performs label redefinition postprocessing for all clusters in the dataset. 294 295 Args: 296 challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). 297 thresholds_file (str): Path to the JSON file containing threshold settings. 298 input_folder_pred (str): Path to the input directory containing prediction files. 299 clustersdf (pd.DataFrame): DataFrame containing cluster assignments. 300 output_folder_pp_cc (str): Path to the output directory to save postprocessed predictions. 301 302 Returns: 303 str: Path to the output directory containing postprocessed predictions. 304 """ 305 # Retrieve label mapping for the challenge 306 label_mapping = LABEL_MAPPING_FACTORY.get(challenge_name) 307 if label_mapping is None: 308 raise ValueError(f"Unsupported challenge name: {challenge_name}") 309 310 # Load threshold settings for the challenge 311 thresholds = get_thresholds_task(challenge_name, thresholds_file) 312 313 # Retrieve all prediction files 314 files_pred = get_files(input_folder_pred) 315 316 # Read cluster assignments from DataFrame 317 cluster_assignment, cluster_array = read_cluster_assignment_df(clustersdf) 318 319 # Iterate over each cluster and apply postprocessing 320 for cluster in cluster_array: 321 cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred) 322 cluster_key = f"cluster_{cluster}" 323 thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key) 324 postprocess_batch( 325 input_files=cluster_files_pred, 326 output_folder=output_folder_pp_cc, 327 thresholds_dict=thresholds_cluster, 328 label_mapping=label_mapping 329 ) 330 331 return output_folder_pp_cc
Performs label redefinition postprocessing for all clusters in the dataset.
Args: challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). thresholds_file (str): Path to the JSON file containing threshold settings. input_folder_pred (str): Path to the input directory containing prediction files. clustersdf (pd.DataFrame): DataFrame containing cluster assignments. output_folder_pp_cc (str): Path to the output directory to save postprocessed predictions.
Returns: str: Path to the output directory containing postprocessed predictions.
334def main() -> None: 335 """ 336 Main function to execute label redefinition postprocessing. 337 338 Parses command-line arguments, loads necessary data, and performs postprocessing 339 on prediction files based on cluster assignments and threshold settings. 340 """ 341 args = parse_args() 342 343 # Load threshold settings for the specified challenge 344 thresholds = get_thresholds_task(args.challenge_name, args.thresholds_file) 345 346 # Retrieve all prediction files from the input directory 347 files_pred = get_files(args.input_folder_pred) 348 349 # Read cluster assignments from the clusters JSON file 350 cluster_assignment, cluster_array = read_cluster_assignment(args.clusters_file) 351 352 # Iterate over each cluster and apply label redefinition 353 for cluster in cluster_array: 354 cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred) 355 cluster_key = f"cluster_{cluster}" 356 thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key) 357 postprocess_batch( 358 input_files=cluster_files_pred, 359 output_folder=args.output_folder_pp_cc, 360 thresholds_dict=thresholds_cluster, 361 label_mapping=LABEL_MAPPING_FACTORY.get(args.challenge_name, {}) 362 )
Main function to execute label redefinition postprocessing.
Parses command-line arguments, loads necessary data, and performs postprocessing on prediction files based on cluster assignments and threshold settings.