inference.postproc.postprocess_cc

Post-processing Connected Components

Provides functionalities related to computing connected components and postprocessing segmentation predictions by removing small connected components based on thresholds.

  1"""
  2### Post-processing Connected Components
  3
  4Provides functionalities related to computing connected components and postprocessing segmentation predictions by removing small connected components based on thresholds.
  5"""
  6
  7
  8import glob
  9import os
 10import nibabel as nib
 11import numpy as np
 12import cc3d
 13import json
 14import argparse
 15import pandas as pd
 16from pathlib import Path
 17from typing import Any, List, Tuple, Dict
 18
 19
 20LABEL_MAPPING_FACTORY: Dict[str, Dict[int, str]] = {
 21    "BraTS-PED": {
 22        1: "ET",
 23        2: "NET",
 24        3: "CC",
 25        4: "ED"
 26    },
 27    "BraTS-SSA": {
 28        1: "NCR",
 29        2: "ED",
 30        3: "ET"
 31    },
 32    "BraTS-MEN-RT": {
 33        1: "GTV"
 34    },
 35    "BraTS-MET": {
 36        1: "NETC",
 37        2: "SNFH",
 38        3: "ET"
 39    }
 40}
 41
 42
 43def parse_args() -> argparse.Namespace:
 44    """
 45    Parses command-line arguments for the BraTS2024 CC Postprocessing script.
 46
 47    Returns:
 48        argparse.Namespace: Parsed command-line arguments.
 49    """
 50    parser = argparse.ArgumentParser(description='BraTS2024 CC Postprocessing.')
 51    parser.add_argument('--challenge_name', type=str, 
 52                        help='The name of the challenge (e.g., BraTS-PED, BraTS-MET)')
 53    parser.add_argument('--input_folder_pred', type=str,
 54                        help='The input folder containing the predictions')
 55    parser.add_argument('--output_folder_pp_cc', type=str,
 56                        help='The output folder to save the postprocessed predictions')
 57    parser.add_argument('--thresholds_file', type=str,
 58                        help='The JSON file containing the thresholds')
 59    parser.add_argument('--clusters_file', type=str,
 60                        help='The JSON file containing the clusters')
 61
 62    return parser.parse_args()
 63
 64
 65def get_connected_components(img: np.ndarray, value: int) -> Tuple[np.ndarray, int]:
 66    """
 67    Identifies connected components for a specific label in the segmentation.
 68
 69    Args:
 70        img (np.ndarray): The segmentation array.
 71        value (int): The label value to find connected components for.
 72
 73    Returns:
 74        Tuple[np.ndarray, int]: 
 75            - Array with labeled connected components.
 76            - Number of connected components found.
 77    """
 78    labels, n = cc3d.connected_components((img == value).astype(np.int8), connectivity=26, return_N=True)
 79    return labels, n
 80
 81
 82def postprocess_cc(
 83    img: np.ndarray,
 84    thresholds_dict: Dict[str, Dict[str, Any]],
 85    label_mapping: Dict[int, str]
 86) -> np.ndarray:
 87    """
 88    Postprocesses the segmentation image by removing small connected components based on thresholds.
 89
 90    Args:
 91        img (np.ndarray): The original segmentation array.
 92        thresholds_dict (Dict[str, Dict[str, Any]]): 
 93            Dictionary containing threshold values for each label.
 94        label_mapping (Dict[int, str]): Mapping from label numbers to label names.
 95
 96    Returns:
 97        np.ndarray: The postprocessed segmentation array.
 98    """
 99    # Make a copy of pred_orig
100    pred = img.copy()
101    # Iterate over each of the labels
102    for label_name, th_dict in thresholds_dict.items():
103        label_number = int(label_name.split("_")[-1])
104        label_name = label_mapping[label_number]
105        th = th_dict["th"]
106        print(f"Label {label_name} - value {label_number} - Applying threshold {th}")
107        # Get connected components
108        components, n = get_connected_components(pred, label_number)
109        # Apply threshold to each of the components
110        for el in range(1, n+1):
111            # get the connected component (cc)
112            cc = np.where(components == el, 1, 0)
113            # calculate the volume
114            vol = np.count_nonzero(cc)
115            # if the volume is less than the threshold, dismiss the cc
116            if vol < th:
117                pred = np.where(components == el, 0, pred)
118    # Return the postprocessed image
119    return pred
120
121
122def postprocess_batch(
123    input_files: List[str],
124    output_folder: str,
125    thresholds_dict: Dict[str, Dict[str, Any]],
126    label_mapping: Dict[int, str]
127) -> None:
128    """
129    Applies postprocessing to a batch of segmentation files.
130
131    Args:
132        input_files (List[str]): List of input file paths.
133        output_folder (str): Path to the output directory to save postprocessed files.
134        thresholds_dict (Dict[str, Dict[str, Any]]): 
135            Dictionary containing threshold values for each label.
136        label_mapping (Dict[int, str]): Mapping from label numbers to label names.
137    """
138    for f in input_files:
139        print(f"Processing file {f}")
140        save_path = os.path.join(output_folder, os.path.basename(f))
141        
142        if os.path.exists(save_path):
143            print(f"File {save_path} already exists. Skipping.")
144            continue
145        
146        # Read the segmentation file
147        pred_orig = nib.load(f).get_fdata()
148        
149        # Apply postprocessing
150        pred = postprocess_cc(pred_orig, thresholds_dict, label_mapping)
151        
152        # Ensure the output directory exists
153        os.makedirs(os.path.dirname(save_path), exist_ok=True)
154        
155        # Save the postprocessed segmentation
156        nib.save(nib.Nifti1Image(pred, nib.load(f).affine), save_path)
157
158
159def get_thresholds_task(
160    challenge_name: str,
161    input_file: str
162) -> Dict[str, Dict[str, Any]]:
163    """
164    Retrieves threshold settings for a specific challenge from a JSON file.
165
166    Args:
167        challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET).
168        input_file (str): Path to the JSON file containing thresholds.
169
170    Returns:
171        Dict[str, Dict[str, Any]]: Thresholds for the specified challenge.
172
173    Raises:
174        ValueError: If the challenge name is not found in the JSON file.
175    """
176    with open(input_file, 'r') as f:
177        thresholds = json.load(f)
178    
179    if challenge_name not in thresholds:
180        raise ValueError(f"Challenge {challenge_name} not found in the thresholds JSON file.")
181    
182    return thresholds[challenge_name]
183
184
185def get_thresholds_cluster(
186    thresholds: Dict[str, Dict[str, Any]],
187    cluster_name: str
188) -> Dict[str, Any]:
189    """
190    Retrieves threshold settings for a specific cluster within a challenge.
191
192    Args:
193        thresholds (Dict[str, Dict[str, Any]]): Thresholds for the challenge.
194        cluster_name (str): The name of the cluster (e.g., "cluster_1").
195
196    Returns:
197        Dict[str, Any]: Threshold settings for the specified cluster.
198
199    Raises:
200        ValueError: If the cluster name is not found in the thresholds.
201    """
202    if cluster_name not in thresholds:
203        raise ValueError(f"Cluster {cluster_name} not found in the thresholds JSON file.")
204    return thresholds[cluster_name]
205
206
207def get_files(input_folder_pred: str) -> List[str]:
208    """
209    Retrieves all prediction file paths from the input directory.
210
211    Args:
212        input_folder_pred (str): Path to the input directory containing prediction files.
213
214    Returns:
215        List[str]: Sorted list of prediction file paths.
216    """
217    files_pred = sorted(glob.glob(os.path.join(input_folder_pred, "*.nii.gz")))
218    print(f"Found {len(files_pred)} files to be processed.")
219    return files_pred
220
221
222def get_cluster_files(
223    cluster_assignment: List[Dict[str, Any]],
224    cluster_id: int,
225    files_pred: List[str]
226) -> List[str]:
227    """
228    Retrieves prediction files that belong to a specific cluster.
229
230    Args:
231        cluster_assignment (List[Dict[str, Any]]): 
232            List of cluster assignments with StudyID and cluster.
233        cluster_id (int): The cluster identifier.
234        files_pred (List[str]): List of all prediction file paths.
235
236    Returns:
237        List[str]: List of prediction files belonging to the specified cluster.
238    """
239    # Extract StudyIDs for the given cluster
240    cluster_ids = [e["StudyID"] for e in cluster_assignment if e["cluster"] == cluster_id]
241    
242    # Filter prediction files that match the StudyIDs
243    cluster_files_pred = [
244        f for f in files_pred 
245        if os.path.basename(f).replace(".nii.gz", "") in cluster_ids
246    ]
247    print(f"Cluster {cluster_id} contains {len(cluster_files_pred)} files.")
248    return cluster_files_pred
249
250
251def read_cluster_assignment(clusters_json: str) -> Tuple[List[Dict[str, Any]], List[int]]:
252    """
253    Reads cluster assignments from a JSON file.
254
255    Args:
256        clusters_json (str): Path to the JSON file containing cluster assignments.
257
258    Returns:
259        Tuple[List[Dict[str, Any]], List[int]]: 
260            - List of cluster assignments with StudyID and cluster.
261            - Sorted list of unique cluster identifiers.
262    """
263    with open(clusters_json, "r") as f:
264        cluster_assignment = json.load(f)
265    
266    # Filter relevant keys
267    cluster_assignment = [
268        {key: value for key, value in e.items() if key in ["StudyID", "cluster"]}
269        for e in cluster_assignment
270    ]
271    
272    cluster_array = np.unique([e["cluster"] for e in cluster_assignment])
273    print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.")
274    return cluster_assignment.tolist(), sorted(cluster_array)
275
276
277def read_cluster_assignment_df(clusterdf: pd.DataFrame) -> Tuple[List[Dict[str, Any]], List[int]]:
278    """
279    Reads cluster assignments from a pandas DataFrame.
280
281    Args:
282        clusterdf (pd.DataFrame): DataFrame containing cluster assignments.
283
284    Returns:
285        Tuple[List[Dict[str, Any]], List[int]]: 
286            - List of cluster assignments with StudyID and cluster.
287            - Sorted list of unique cluster identifiers.
288    """
289    # Convert DataFrame to list of dictionaries
290    cluster_assignment = clusterdf.to_dict(orient="records")
291    
292    # Filter relevant keys
293    cluster_assignment = [
294        {key: value for key, value in e.items() if key in ["StudyID", "cluster"]}
295        for e in cluster_assignment
296    ]
297    
298    cluster_array = np.unique([e["cluster"] for e in cluster_assignment])
299    print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.")
300    return cluster_assignment.tolist(), sorted(cluster_array)
301
302
303def remove_small_component(
304    challenge_name: str,
305    thresholds_file: str,
306    input_folder_pred: str,
307    clustersdf: pd.DataFrame,
308    output_folder_pp_cc: str
309) -> str:
310    """
311    Removes small connected components from segmentation predictions based on cluster-specific thresholds.
312
313    Args:
314        challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET).
315        thresholds_file (str): Path to the JSON file containing threshold settings.
316        input_folder_pred (str): Path to the input directory containing prediction files.
317        clustersdf (pd.DataFrame): DataFrame containing cluster assignments.
318        output_folder_pp_cc (str): Path to the output directory to save postprocessed predictions.
319
320    Returns:
321        str: Path to the output directory containing postprocessed predictions.
322    """
323    # Retrieve label mapping for the challenge
324    label_mapping = LABEL_MAPPING_FACTORY.get(challenge_name)
325    if label_mapping is None:
326        raise ValueError(f"Unsupported challenge name: {challenge_name}")
327    
328    # Load threshold settings for the challenge
329    thresholds = get_thresholds_task(challenge_name, thresholds_file)
330    
331    # Retrieve all prediction files
332    files_pred = get_files(input_folder_pred)
333    
334    # Read cluster assignments from DataFrame
335    cluster_assignment, cluster_array = read_cluster_assignment_df(clustersdf)
336    
337    # Iterate over each cluster and apply postprocessing
338    for cluster in cluster_array:
339        cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred)
340        cluster_key = f"cluster_{cluster}"
341        thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key)
342        
343        # Apply postprocessing to the cluster-specific prediction files
344        postprocess_batch(
345            input_files=cluster_files_pred,
346            output_folder=output_folder_pp_cc,
347            thresholds_dict=thresholds_cluster,
348            label_mapping=label_mapping
349        )
350    
351    return output_folder_pp_cc
352
353
354def main() -> None:
355    """
356    Main function to execute label redefinition and connected component postprocessing.
357
358    Parses command-line arguments, loads necessary data, and performs postprocessing
359    on prediction files based on cluster assignments and threshold settings.
360    """
361    args = parse_args()
362    
363    # Assign label mapping based on the challenge name
364    args.label_mapping = LABEL_MAPPING_FACTORY.get(args.challenge_name)
365    if args.label_mapping is None:
366        raise ValueError(f"Unsupported challenge name: {args.challenge_name}")
367    
368    # Load threshold settings for the specified challenge
369    thresholds = get_thresholds_task(args.challenge_name, args.thresholds_file)
370    
371    # Retrieve all prediction files from the input directory
372    files_pred = get_files(args.input_folder_pred)
373    
374    # Read cluster assignments from the clusters JSON file
375    cluster_assignment, cluster_array = read_cluster_assignment(args.clusters_file)
376    
377    # Iterate over each cluster and apply label redefinition postprocessing
378    for cluster in cluster_array:
379        cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred)
380        cluster_key = f"cluster_{cluster}"
381        thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key)
382        
383        # Apply postprocessing to the cluster-specific prediction files
384        postprocess_batch(
385            input_files=cluster_files_pred,
386            output_folder=args.output_folder_pp_cc,
387            thresholds_dict=thresholds_cluster,
388            label_mapping=args.label_mapping
389        )
390
391
392if __name__ == "__main__":
393    main()
394
395    # Example Command to Run the Script:
396    # python remove_small_component.py --challenge_name BraTS-PED \
397    # --input_folder_pred /path/to/predictions \
398    # --output_folder_pp_cc /path/to/output \
399    # --thresholds_file /path/to/thresholds.json \
400    # --clusters_file /path/to/clusters.json
LABEL_MAPPING_FACTORY: Dict[str, Dict[int, str]] = {'BraTS-PED': {1: 'ET', 2: 'NET', 3: 'CC', 4: 'ED'}, 'BraTS-SSA': {1: 'NCR', 2: 'ED', 3: 'ET'}, 'BraTS-MEN-RT': {1: 'GTV'}, 'BraTS-MET': {1: 'NETC', 2: 'SNFH', 3: 'ET'}}
def parse_args() -> argparse.Namespace:
44def parse_args() -> argparse.Namespace:
45    """
46    Parses command-line arguments for the BraTS2024 CC Postprocessing script.
47
48    Returns:
49        argparse.Namespace: Parsed command-line arguments.
50    """
51    parser = argparse.ArgumentParser(description='BraTS2024 CC Postprocessing.')
52    parser.add_argument('--challenge_name', type=str, 
53                        help='The name of the challenge (e.g., BraTS-PED, BraTS-MET)')
54    parser.add_argument('--input_folder_pred', type=str,
55                        help='The input folder containing the predictions')
56    parser.add_argument('--output_folder_pp_cc', type=str,
57                        help='The output folder to save the postprocessed predictions')
58    parser.add_argument('--thresholds_file', type=str,
59                        help='The JSON file containing the thresholds')
60    parser.add_argument('--clusters_file', type=str,
61                        help='The JSON file containing the clusters')
62
63    return parser.parse_args()

Parses command-line arguments for the BraTS2024 CC Postprocessing script.

Returns: argparse.Namespace: Parsed command-line arguments.

def get_connected_components(img: numpy.ndarray, value: int) -> Tuple[numpy.ndarray, int]:
66def get_connected_components(img: np.ndarray, value: int) -> Tuple[np.ndarray, int]:
67    """
68    Identifies connected components for a specific label in the segmentation.
69
70    Args:
71        img (np.ndarray): The segmentation array.
72        value (int): The label value to find connected components for.
73
74    Returns:
75        Tuple[np.ndarray, int]: 
76            - Array with labeled connected components.
77            - Number of connected components found.
78    """
79    labels, n = cc3d.connected_components((img == value).astype(np.int8), connectivity=26, return_N=True)
80    return labels, n

Identifies connected components for a specific label in the segmentation.

Args: img (np.ndarray): The segmentation array. value (int): The label value to find connected components for.

Returns: Tuple[np.ndarray, int]: - Array with labeled connected components. - Number of connected components found.

def postprocess_cc( img: numpy.ndarray, thresholds_dict: Dict[str, Dict[str, Any]], label_mapping: Dict[int, str]) -> numpy.ndarray:
 83def postprocess_cc(
 84    img: np.ndarray,
 85    thresholds_dict: Dict[str, Dict[str, Any]],
 86    label_mapping: Dict[int, str]
 87) -> np.ndarray:
 88    """
 89    Postprocesses the segmentation image by removing small connected components based on thresholds.
 90
 91    Args:
 92        img (np.ndarray): The original segmentation array.
 93        thresholds_dict (Dict[str, Dict[str, Any]]): 
 94            Dictionary containing threshold values for each label.
 95        label_mapping (Dict[int, str]): Mapping from label numbers to label names.
 96
 97    Returns:
 98        np.ndarray: The postprocessed segmentation array.
 99    """
100    # Make a copy of pred_orig
101    pred = img.copy()
102    # Iterate over each of the labels
103    for label_name, th_dict in thresholds_dict.items():
104        label_number = int(label_name.split("_")[-1])
105        label_name = label_mapping[label_number]
106        th = th_dict["th"]
107        print(f"Label {label_name} - value {label_number} - Applying threshold {th}")
108        # Get connected components
109        components, n = get_connected_components(pred, label_number)
110        # Apply threshold to each of the components
111        for el in range(1, n+1):
112            # get the connected component (cc)
113            cc = np.where(components == el, 1, 0)
114            # calculate the volume
115            vol = np.count_nonzero(cc)
116            # if the volume is less than the threshold, dismiss the cc
117            if vol < th:
118                pred = np.where(components == el, 0, pred)
119    # Return the postprocessed image
120    return pred

Postprocesses the segmentation image by removing small connected components based on thresholds.

Args: img (np.ndarray): The original segmentation array. thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label. label_mapping (Dict[int, str]): Mapping from label numbers to label names.

Returns: np.ndarray: The postprocessed segmentation array.

def postprocess_batch( input_files: List[str], output_folder: str, thresholds_dict: Dict[str, Dict[str, Any]], label_mapping: Dict[int, str]) -> None:
123def postprocess_batch(
124    input_files: List[str],
125    output_folder: str,
126    thresholds_dict: Dict[str, Dict[str, Any]],
127    label_mapping: Dict[int, str]
128) -> None:
129    """
130    Applies postprocessing to a batch of segmentation files.
131
132    Args:
133        input_files (List[str]): List of input file paths.
134        output_folder (str): Path to the output directory to save postprocessed files.
135        thresholds_dict (Dict[str, Dict[str, Any]]): 
136            Dictionary containing threshold values for each label.
137        label_mapping (Dict[int, str]): Mapping from label numbers to label names.
138    """
139    for f in input_files:
140        print(f"Processing file {f}")
141        save_path = os.path.join(output_folder, os.path.basename(f))
142        
143        if os.path.exists(save_path):
144            print(f"File {save_path} already exists. Skipping.")
145            continue
146        
147        # Read the segmentation file
148        pred_orig = nib.load(f).get_fdata()
149        
150        # Apply postprocessing
151        pred = postprocess_cc(pred_orig, thresholds_dict, label_mapping)
152        
153        # Ensure the output directory exists
154        os.makedirs(os.path.dirname(save_path), exist_ok=True)
155        
156        # Save the postprocessed segmentation
157        nib.save(nib.Nifti1Image(pred, nib.load(f).affine), save_path)

Applies postprocessing to a batch of segmentation files.

Args: input_files (List[str]): List of input file paths. output_folder (str): Path to the output directory to save postprocessed files. thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label. label_mapping (Dict[int, str]): Mapping from label numbers to label names.

def get_thresholds_task(challenge_name: str, input_file: str) -> Dict[str, Dict[str, Any]]:
160def get_thresholds_task(
161    challenge_name: str,
162    input_file: str
163) -> Dict[str, Dict[str, Any]]:
164    """
165    Retrieves threshold settings for a specific challenge from a JSON file.
166
167    Args:
168        challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET).
169        input_file (str): Path to the JSON file containing thresholds.
170
171    Returns:
172        Dict[str, Dict[str, Any]]: Thresholds for the specified challenge.
173
174    Raises:
175        ValueError: If the challenge name is not found in the JSON file.
176    """
177    with open(input_file, 'r') as f:
178        thresholds = json.load(f)
179    
180    if challenge_name not in thresholds:
181        raise ValueError(f"Challenge {challenge_name} not found in the thresholds JSON file.")
182    
183    return thresholds[challenge_name]

Retrieves threshold settings for a specific challenge from a JSON file.

Args: challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). input_file (str): Path to the JSON file containing thresholds.

Returns: Dict[str, Dict[str, Any]]: Thresholds for the specified challenge.

Raises: ValueError: If the challenge name is not found in the JSON file.

def get_thresholds_cluster( thresholds: Dict[str, Dict[str, Any]], cluster_name: str) -> Dict[str, Any]:
186def get_thresholds_cluster(
187    thresholds: Dict[str, Dict[str, Any]],
188    cluster_name: str
189) -> Dict[str, Any]:
190    """
191    Retrieves threshold settings for a specific cluster within a challenge.
192
193    Args:
194        thresholds (Dict[str, Dict[str, Any]]): Thresholds for the challenge.
195        cluster_name (str): The name of the cluster (e.g., "cluster_1").
196
197    Returns:
198        Dict[str, Any]: Threshold settings for the specified cluster.
199
200    Raises:
201        ValueError: If the cluster name is not found in the thresholds.
202    """
203    if cluster_name not in thresholds:
204        raise ValueError(f"Cluster {cluster_name} not found in the thresholds JSON file.")
205    return thresholds[cluster_name]

Retrieves threshold settings for a specific cluster within a challenge.

Args: thresholds (Dict[str, Dict[str, Any]]): Thresholds for the challenge. cluster_name (str): The name of the cluster (e.g., "cluster_1").

Returns: Dict[str, Any]: Threshold settings for the specified cluster.

Raises: ValueError: If the cluster name is not found in the thresholds.

def get_files(input_folder_pred: str) -> List[str]:
208def get_files(input_folder_pred: str) -> List[str]:
209    """
210    Retrieves all prediction file paths from the input directory.
211
212    Args:
213        input_folder_pred (str): Path to the input directory containing prediction files.
214
215    Returns:
216        List[str]: Sorted list of prediction file paths.
217    """
218    files_pred = sorted(glob.glob(os.path.join(input_folder_pred, "*.nii.gz")))
219    print(f"Found {len(files_pred)} files to be processed.")
220    return files_pred

Retrieves all prediction file paths from the input directory.

Args: input_folder_pred (str): Path to the input directory containing prediction files.

Returns: List[str]: Sorted list of prediction file paths.

def get_cluster_files( cluster_assignment: List[Dict[str, Any]], cluster_id: int, files_pred: List[str]) -> List[str]:
223def get_cluster_files(
224    cluster_assignment: List[Dict[str, Any]],
225    cluster_id: int,
226    files_pred: List[str]
227) -> List[str]:
228    """
229    Retrieves prediction files that belong to a specific cluster.
230
231    Args:
232        cluster_assignment (List[Dict[str, Any]]): 
233            List of cluster assignments with StudyID and cluster.
234        cluster_id (int): The cluster identifier.
235        files_pred (List[str]): List of all prediction file paths.
236
237    Returns:
238        List[str]: List of prediction files belonging to the specified cluster.
239    """
240    # Extract StudyIDs for the given cluster
241    cluster_ids = [e["StudyID"] for e in cluster_assignment if e["cluster"] == cluster_id]
242    
243    # Filter prediction files that match the StudyIDs
244    cluster_files_pred = [
245        f for f in files_pred 
246        if os.path.basename(f).replace(".nii.gz", "") in cluster_ids
247    ]
248    print(f"Cluster {cluster_id} contains {len(cluster_files_pred)} files.")
249    return cluster_files_pred

Retrieves prediction files that belong to a specific cluster.

Args: cluster_assignment (List[Dict[str, Any]]): List of cluster assignments with StudyID and cluster. cluster_id (int): The cluster identifier. files_pred (List[str]): List of all prediction file paths.

Returns: List[str]: List of prediction files belonging to the specified cluster.

def read_cluster_assignment(clusters_json: str) -> Tuple[List[Dict[str, Any]], List[int]]:
252def read_cluster_assignment(clusters_json: str) -> Tuple[List[Dict[str, Any]], List[int]]:
253    """
254    Reads cluster assignments from a JSON file.
255
256    Args:
257        clusters_json (str): Path to the JSON file containing cluster assignments.
258
259    Returns:
260        Tuple[List[Dict[str, Any]], List[int]]: 
261            - List of cluster assignments with StudyID and cluster.
262            - Sorted list of unique cluster identifiers.
263    """
264    with open(clusters_json, "r") as f:
265        cluster_assignment = json.load(f)
266    
267    # Filter relevant keys
268    cluster_assignment = [
269        {key: value for key, value in e.items() if key in ["StudyID", "cluster"]}
270        for e in cluster_assignment
271    ]
272    
273    cluster_array = np.unique([e["cluster"] for e in cluster_assignment])
274    print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.")
275    return cluster_assignment.tolist(), sorted(cluster_array)

Reads cluster assignments from a JSON file.

Args: clusters_json (str): Path to the JSON file containing cluster assignments.

Returns: Tuple[List[Dict[str, Any]], List[int]]: - List of cluster assignments with StudyID and cluster. - Sorted list of unique cluster identifiers.

def read_cluster_assignment_df( clusterdf: pandas.core.frame.DataFrame) -> Tuple[List[Dict[str, Any]], List[int]]:
278def read_cluster_assignment_df(clusterdf: pd.DataFrame) -> Tuple[List[Dict[str, Any]], List[int]]:
279    """
280    Reads cluster assignments from a pandas DataFrame.
281
282    Args:
283        clusterdf (pd.DataFrame): DataFrame containing cluster assignments.
284
285    Returns:
286        Tuple[List[Dict[str, Any]], List[int]]: 
287            - List of cluster assignments with StudyID and cluster.
288            - Sorted list of unique cluster identifiers.
289    """
290    # Convert DataFrame to list of dictionaries
291    cluster_assignment = clusterdf.to_dict(orient="records")
292    
293    # Filter relevant keys
294    cluster_assignment = [
295        {key: value for key, value in e.items() if key in ["StudyID", "cluster"]}
296        for e in cluster_assignment
297    ]
298    
299    cluster_array = np.unique([e["cluster"] for e in cluster_assignment])
300    print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.")
301    return cluster_assignment.tolist(), sorted(cluster_array)

Reads cluster assignments from a pandas DataFrame.

Args: clusterdf (pd.DataFrame): DataFrame containing cluster assignments.

Returns: Tuple[List[Dict[str, Any]], List[int]]: - List of cluster assignments with StudyID and cluster. - Sorted list of unique cluster identifiers.

def remove_small_component( challenge_name: str, thresholds_file: str, input_folder_pred: str, clustersdf: pandas.core.frame.DataFrame, output_folder_pp_cc: str) -> str:
304def remove_small_component(
305    challenge_name: str,
306    thresholds_file: str,
307    input_folder_pred: str,
308    clustersdf: pd.DataFrame,
309    output_folder_pp_cc: str
310) -> str:
311    """
312    Removes small connected components from segmentation predictions based on cluster-specific thresholds.
313
314    Args:
315        challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET).
316        thresholds_file (str): Path to the JSON file containing threshold settings.
317        input_folder_pred (str): Path to the input directory containing prediction files.
318        clustersdf (pd.DataFrame): DataFrame containing cluster assignments.
319        output_folder_pp_cc (str): Path to the output directory to save postprocessed predictions.
320
321    Returns:
322        str: Path to the output directory containing postprocessed predictions.
323    """
324    # Retrieve label mapping for the challenge
325    label_mapping = LABEL_MAPPING_FACTORY.get(challenge_name)
326    if label_mapping is None:
327        raise ValueError(f"Unsupported challenge name: {challenge_name}")
328    
329    # Load threshold settings for the challenge
330    thresholds = get_thresholds_task(challenge_name, thresholds_file)
331    
332    # Retrieve all prediction files
333    files_pred = get_files(input_folder_pred)
334    
335    # Read cluster assignments from DataFrame
336    cluster_assignment, cluster_array = read_cluster_assignment_df(clustersdf)
337    
338    # Iterate over each cluster and apply postprocessing
339    for cluster in cluster_array:
340        cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred)
341        cluster_key = f"cluster_{cluster}"
342        thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key)
343        
344        # Apply postprocessing to the cluster-specific prediction files
345        postprocess_batch(
346            input_files=cluster_files_pred,
347            output_folder=output_folder_pp_cc,
348            thresholds_dict=thresholds_cluster,
349            label_mapping=label_mapping
350        )
351    
352    return output_folder_pp_cc

Removes small connected components from segmentation predictions based on cluster-specific thresholds.

Args: challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). thresholds_file (str): Path to the JSON file containing threshold settings. input_folder_pred (str): Path to the input directory containing prediction files. clustersdf (pd.DataFrame): DataFrame containing cluster assignments. output_folder_pp_cc (str): Path to the output directory to save postprocessed predictions.

Returns: str: Path to the output directory containing postprocessed predictions.

def main() -> None:
355def main() -> None:
356    """
357    Main function to execute label redefinition and connected component postprocessing.
358
359    Parses command-line arguments, loads necessary data, and performs postprocessing
360    on prediction files based on cluster assignments and threshold settings.
361    """
362    args = parse_args()
363    
364    # Assign label mapping based on the challenge name
365    args.label_mapping = LABEL_MAPPING_FACTORY.get(args.challenge_name)
366    if args.label_mapping is None:
367        raise ValueError(f"Unsupported challenge name: {args.challenge_name}")
368    
369    # Load threshold settings for the specified challenge
370    thresholds = get_thresholds_task(args.challenge_name, args.thresholds_file)
371    
372    # Retrieve all prediction files from the input directory
373    files_pred = get_files(args.input_folder_pred)
374    
375    # Read cluster assignments from the clusters JSON file
376    cluster_assignment, cluster_array = read_cluster_assignment(args.clusters_file)
377    
378    # Iterate over each cluster and apply label redefinition postprocessing
379    for cluster in cluster_array:
380        cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred)
381        cluster_key = f"cluster_{cluster}"
382        thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key)
383        
384        # Apply postprocessing to the cluster-specific prediction files
385        postprocess_batch(
386            input_files=cluster_files_pred,
387            output_folder=args.output_folder_pp_cc,
388            thresholds_dict=thresholds_cluster,
389            label_mapping=args.label_mapping
390        )

Main function to execute label redefinition and connected component postprocessing.

Parses command-line arguments, loads necessary data, and performs postprocessing on prediction files based on cluster assignments and threshold settings.