inference.postproc.postprocess_lblredef

Post-processing Label Redefinition

Provides functionalities related to postprocessing segmentation predictions by redefining labels based on threshold criteria.

  1"""
  2### Post-processing Label Redefinition
  3
  4Provides functionalities related to postprocessing segmentation predictions by redefining labels based on threshold criteria.
  5"""
  6
  7
  8import pandas as pd
  9import glob
 10import os
 11import nibabel as nib
 12import numpy as np
 13import json
 14import argparse
 15from typing import Any, List, Tuple, Dict
 16
 17
 18LABEL_MAPPING_FACTORY: Dict[str, Dict[int, str]] = {
 19    "BraTS-PED": {
 20        1: "ET",
 21        2: "NET",
 22        3: "CC",
 23        4: "ED"
 24    },
 25    "BraTS-SSA": {
 26        1: "NCR",
 27        2: "ED",
 28        3: "ET"
 29    },
 30    "BraTS-MEN-RT": {
 31        1: "GTV"
 32    },
 33    "BraTS-MET": {
 34        1: "NETC",
 35        2: "SNFH",
 36        3: "ET"
 37    }
 38}
 39
 40
 41def parse_args() -> argparse.Namespace:
 42    """
 43    Parses command-line arguments for the BraTS2024 Postprocessing Label Redefinition script.
 44
 45    Returns:
 46        argparse.Namespace: Parsed command-line arguments.
 47    """
 48    parser = argparse.ArgumentParser(description='BraTS2024 Postprocessing.')
 49    parser.add_argument('--challenge_name', type=str, 
 50                        help='The name of the challenge (e.g., BraTS-PED, BraTS-MET)')
 51    parser.add_argument('--input_folder_pred', type=str,
 52                        help='The input folder containing the predictions')
 53    parser.add_argument('--output_folder_pp_cc', type=str,
 54                        help='The output folder to save the postprocessed predictions')
 55    parser.add_argument('--thresholds_file', type=str,
 56                        help='The JSON file containing the thresholds')
 57    parser.add_argument('--clusters_file', type=str,
 58                        help='The JSON file containing the clusters')
 59    return parser.parse_args()
 60
 61
 62def get_ratio_labels_wt(seg: np.ndarray, labels: List[int] = [1, 2, 3, 4]) -> float:
 63    """
 64    Calculates the ratio of selected label voxels to whole tumor (WT) voxels.
 65
 66    Args:
 67        seg (np.ndarray): The segmentation array.
 68        labels (List[int], optional): List of label integers to include in the ratio. Defaults to [1, 2, 3, 4].
 69
 70    Returns:
 71        float: The ratio of selected label voxels to WT voxels.
 72    """
 73    selected_voxels = sum(np.sum(seg == l) for l in labels)
 74    wt_voxels = np.sum(seg != 0)  # Everything but background
 75    if wt_voxels == 0:
 76        return 1.0
 77    return selected_voxels / wt_voxels
 78
 79
 80def postprocess_lblredef(
 81    img: np.ndarray,
 82    thresholds_dict: Dict[str, Dict[str, Any]],
 83    label_mapping: Dict[int, str]
 84) -> np.ndarray:
 85    """
 86    Redefines labels in the segmentation image based on threshold criteria.
 87
 88    Args:
 89        img (np.ndarray): The original segmentation array.
 90        thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label.
 91        label_mapping (Dict[int, str]): Mapping from label numbers to label names.
 92
 93    Returns:
 94        np.ndarray: The postprocessed segmentation array.
 95    """
 96    # Make a copy of the original prediction
 97    pred = img.copy()
 98    
 99    # Iterate over each label and apply redefinition based on thresholds
100    for label_name, th_dict in thresholds_dict.items():
101        label_number = int(label_name.split("_")[-1])
102        label_name_mapped = label_mapping.get(label_number, f"Label_{label_number}")
103        th = th_dict.get("th", 0)
104        redefine_to = th_dict.get("redefine_to", label_number)
105        
106        print(f"Label {label_name_mapped} - value {label_number} - Redefinition to {redefine_to} - Applying threshold {th}")
107        
108        # Get ratio with respect to the whole tumor
109        ratio = get_ratio_labels_wt(pred, labels=[label_number])
110        
111        # Conditioned redefinition
112        if ratio < th:
113            pred = np.where(pred == label_number, redefine_to, pred)
114    
115    return pred
116
117
118def postprocess_batch(
119    input_files: List[str],
120    output_folder: str,
121    thresholds_dict: Dict[str, Dict[str, Any]],
122    label_mapping: Dict[int, str]
123) -> None:
124    """
125    Applies postprocessing to a batch of segmentation files.
126
127    Args:
128        input_files (List[str]): List of input file paths.
129        output_folder (str): Path to the output directory to save postprocessed files.
130        thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label.
131        label_mapping (Dict[int, str]): Mapping from label numbers to label names.
132    """
133    for f in input_files:
134        print(f"Processing file {f}")
135        save_path = os.path.join(output_folder, os.path.basename(f))
136        
137        if os.path.exists(save_path):
138            print(f"File {save_path} already exists. Skipping.")
139            continue
140        
141        # Read the segmentation file
142        pred_orig = nib.load(f).get_fdata()
143        
144        # Apply postprocessing
145        pred = postprocess_lblredef(pred_orig, thresholds_dict, label_mapping)
146        
147        # Ensure the output directory exists
148        os.makedirs(os.path.dirname(save_path), exist_ok=True)
149        
150        # Save the postprocessed segmentation
151        nib.save(nib.Nifti1Image(pred, nib.load(f).affine), save_path)
152
153
154def get_thresholds_task(challenge_name: str, input_file: str) -> Dict[str, Dict[str, Any]]:
155    """
156    Retrieves threshold settings for a specific challenge from a JSON file.
157
158    Args:
159        challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET).
160        input_file (str): Path to the JSON file containing thresholds.
161
162    Returns:
163        Dict[str, Dict[str, Any]]: Thresholds for the specified challenge.
164
165    Raises:
166        ValueError: If the challenge name is not found in the JSON file.
167    """
168    with open(input_file, 'r') as f:
169        thresholds = json.load(f)
170    
171    if challenge_name not in thresholds:
172        raise ValueError(f"Challenge {challenge_name} not found in the thresholds JSON file.")
173    
174    return thresholds[challenge_name]
175
176
177def get_thresholds_cluster(thresholds: Dict[str, Dict[str, Any]], cluster_name: str) -> Dict[str, Any]:
178    """
179    Retrieves threshold settings for a specific cluster within a challenge.
180
181    Args:
182        thresholds (Dict[str, Dict[str, Any]]): Thresholds for the challenge.
183        cluster_name (str): The name of the cluster (e.g., "cluster_1").
184
185    Returns:
186        Dict[str, Any]: Threshold settings for the specified cluster.
187
188    Raises:
189        ValueError: If the cluster name is not found in the thresholds.
190    """
191    if cluster_name not in thresholds:
192        raise ValueError(f"Cluster {cluster_name} not found in the thresholds JSON file.")
193    return thresholds[cluster_name]
194
195
196def get_files(input_folder_pred: str) -> List[str]:
197    """
198    Retrieves all prediction file paths from the input directory.
199
200    Args:
201        input_folder_pred (str): Path to the input directory containing prediction files.
202
203    Returns:
204        List[str]: Sorted list of prediction file paths.
205    """
206    files_pred = sorted(glob.glob(os.path.join(input_folder_pred, "*.nii.gz")))
207    print(f"Found {len(files_pred)} files to be processed.")
208    return files_pred
209
210
211def get_cluster_files(
212    cluster_assignment: List[Dict[str, Any]],
213    cluster_id: int,
214    files_pred: List[str]
215) -> List[str]:
216    """
217    Retrieves prediction files that belong to a specific cluster.
218
219    Args:
220        cluster_assignment (List[Dict[str, Any]]): List of cluster assignments with StudyID and cluster.
221        cluster_id (int): The cluster identifier.
222        files_pred (List[str]): List of all prediction file paths.
223
224    Returns:
225        List[str]: List of prediction files belonging to the specified cluster.
226    """
227    cluster_ids = [e["StudyID"] for e in cluster_assignment if e["cluster"] == cluster_id]
228    cluster_files_pred = [
229        f for f in files_pred 
230        if os.path.basename(f).replace(".nii.gz", "") in cluster_ids
231    ]
232    print(f"Cluster {cluster_id} contains {len(cluster_files_pred)} files.")
233    return cluster_files_pred
234
235
236def read_cluster_assignment(clusters_json: str) -> Tuple[List[Dict[str, Any]], List[int]]:
237    """
238    Reads cluster assignments from a JSON file.
239
240    Args:
241        clusters_json (str): Path to the JSON file containing cluster assignments.
242
243    Returns:
244        Tuple[List[Dict[str, Any]], List[int]]: 
245            - List of cluster assignments with StudyID and cluster.
246            - Sorted list of unique cluster identifiers.
247    """
248    with open(clusters_json, "r") as f:
249        cluster_assignment = json.load(f)
250    
251    # Filter relevant keys
252    cluster_assignment = [
253        {key: value for key, value in e.items() if key in ["StudyID", "cluster"]}
254        for e in cluster_assignment
255    ]
256    
257    cluster_array = np.unique([e["cluster"] for e in cluster_assignment])
258    print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.")
259    return cluster_assignment, sorted(cluster_array)
260
261
262def read_cluster_assignment_df(clusterdf: pd.DataFrame) -> Tuple[List[Dict[str, Any]], List[int]]:
263    """
264    Reads cluster assignments from a pandas DataFrame.
265
266    Args:
267        clusterdf (pd.DataFrame): DataFrame containing cluster assignments.
268
269    Returns:
270        Tuple[List[Dict[str, Any]], List[int]]: 
271            - List of cluster assignments with StudyID and cluster.
272            - Sorted list of unique cluster identifiers.
273    """
274    cluster_assignment = clusterdf.to_dict(orient="records")
275    cluster_assignment = [
276        {key: value for key, value in e.items() if key in ["StudyID", "cluster"]}
277        for e in cluster_assignment
278    ]
279    cluster_array = np.unique([e["cluster"] for e in cluster_assignment])
280    print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.")
281    return cluster_assignment, sorted(cluster_array)
282
283
284def label_redefinition(
285    challenge_name: str,
286    thresholds_file: str,
287    input_folder_pred: str,
288    clustersdf: pd.DataFrame,
289    output_folder_pp_cc: str
290) -> str:
291    """
292    Performs label redefinition postprocessing for all clusters in the dataset.
293
294    Args:
295        challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET).
296        thresholds_file (str): Path to the JSON file containing threshold settings.
297        input_folder_pred (str): Path to the input directory containing prediction files.
298        clustersdf (pd.DataFrame): DataFrame containing cluster assignments.
299        output_folder_pp_cc (str): Path to the output directory to save postprocessed predictions.
300
301    Returns:
302        str: Path to the output directory containing postprocessed predictions.
303    """
304    # Retrieve label mapping for the challenge
305    label_mapping = LABEL_MAPPING_FACTORY.get(challenge_name)
306    if label_mapping is None:
307        raise ValueError(f"Unsupported challenge name: {challenge_name}")
308    
309    # Load threshold settings for the challenge
310    thresholds = get_thresholds_task(challenge_name, thresholds_file)
311    
312    # Retrieve all prediction files
313    files_pred = get_files(input_folder_pred)
314    
315    # Read cluster assignments from DataFrame
316    cluster_assignment, cluster_array = read_cluster_assignment_df(clustersdf)
317    
318    # Iterate over each cluster and apply postprocessing
319    for cluster in cluster_array:
320        cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred)
321        cluster_key = f"cluster_{cluster}"
322        thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key)
323        postprocess_batch(
324            input_files=cluster_files_pred,
325            output_folder=output_folder_pp_cc,
326            thresholds_dict=thresholds_cluster,
327            label_mapping=label_mapping
328        )
329    
330    return output_folder_pp_cc
331
332
333def main() -> None:
334    """
335    Main function to execute label redefinition postprocessing.
336
337    Parses command-line arguments, loads necessary data, and performs postprocessing
338    on prediction files based on cluster assignments and threshold settings.
339    """
340    args = parse_args()
341    
342    # Load threshold settings for the specified challenge
343    thresholds = get_thresholds_task(args.challenge_name, args.thresholds_file)
344    
345    # Retrieve all prediction files from the input directory
346    files_pred = get_files(args.input_folder_pred)
347    
348    # Read cluster assignments from the clusters JSON file
349    cluster_assignment, cluster_array = read_cluster_assignment(args.clusters_file)
350    
351    # Iterate over each cluster and apply label redefinition
352    for cluster in cluster_array:
353        cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred)
354        cluster_key = f"cluster_{cluster}"
355        thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key)
356        postprocess_batch(
357            input_files=cluster_files_pred,
358            output_folder=args.output_folder_pp_cc,
359            thresholds_dict=thresholds_cluster,
360            label_mapping=LABEL_MAPPING_FACTORY.get(args.challenge_name, {})
361        )
362
363
364if __name__ == "__main__":
365    main()
366
367    # Example Command to Run the Script:
368    # python label_redefinition.py --challenge_name BraTS-PED --input_folder_pred /path/to/predictions \
369    # --output_folder_pp_cc /path/to/output --thresholds_file /path/to/thresholds.json \
370    # --clusters_file /path/to/clusters.json
LABEL_MAPPING_FACTORY: Dict[str, Dict[int, str]] = {'BraTS-PED': {1: 'ET', 2: 'NET', 3: 'CC', 4: 'ED'}, 'BraTS-SSA': {1: 'NCR', 2: 'ED', 3: 'ET'}, 'BraTS-MEN-RT': {1: 'GTV'}, 'BraTS-MET': {1: 'NETC', 2: 'SNFH', 3: 'ET'}}
def parse_args() -> argparse.Namespace:
42def parse_args() -> argparse.Namespace:
43    """
44    Parses command-line arguments for the BraTS2024 Postprocessing Label Redefinition script.
45
46    Returns:
47        argparse.Namespace: Parsed command-line arguments.
48    """
49    parser = argparse.ArgumentParser(description='BraTS2024 Postprocessing.')
50    parser.add_argument('--challenge_name', type=str, 
51                        help='The name of the challenge (e.g., BraTS-PED, BraTS-MET)')
52    parser.add_argument('--input_folder_pred', type=str,
53                        help='The input folder containing the predictions')
54    parser.add_argument('--output_folder_pp_cc', type=str,
55                        help='The output folder to save the postprocessed predictions')
56    parser.add_argument('--thresholds_file', type=str,
57                        help='The JSON file containing the thresholds')
58    parser.add_argument('--clusters_file', type=str,
59                        help='The JSON file containing the clusters')
60    return parser.parse_args()

Parses command-line arguments for the BraTS2024 Postprocessing Label Redefinition script.

Returns: argparse.Namespace: Parsed command-line arguments.

def get_ratio_labels_wt(seg: numpy.ndarray, labels: List[int] = [1, 2, 3, 4]) -> float:
63def get_ratio_labels_wt(seg: np.ndarray, labels: List[int] = [1, 2, 3, 4]) -> float:
64    """
65    Calculates the ratio of selected label voxels to whole tumor (WT) voxels.
66
67    Args:
68        seg (np.ndarray): The segmentation array.
69        labels (List[int], optional): List of label integers to include in the ratio. Defaults to [1, 2, 3, 4].
70
71    Returns:
72        float: The ratio of selected label voxels to WT voxels.
73    """
74    selected_voxels = sum(np.sum(seg == l) for l in labels)
75    wt_voxels = np.sum(seg != 0)  # Everything but background
76    if wt_voxels == 0:
77        return 1.0
78    return selected_voxels / wt_voxels

Calculates the ratio of selected label voxels to whole tumor (WT) voxels.

Args: seg (np.ndarray): The segmentation array. labels (List[int], optional): List of label integers to include in the ratio. Defaults to [1, 2, 3, 4].

Returns: float: The ratio of selected label voxels to WT voxels.

def postprocess_lblredef( img: numpy.ndarray, thresholds_dict: Dict[str, Dict[str, Any]], label_mapping: Dict[int, str]) -> numpy.ndarray:
 81def postprocess_lblredef(
 82    img: np.ndarray,
 83    thresholds_dict: Dict[str, Dict[str, Any]],
 84    label_mapping: Dict[int, str]
 85) -> np.ndarray:
 86    """
 87    Redefines labels in the segmentation image based on threshold criteria.
 88
 89    Args:
 90        img (np.ndarray): The original segmentation array.
 91        thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label.
 92        label_mapping (Dict[int, str]): Mapping from label numbers to label names.
 93
 94    Returns:
 95        np.ndarray: The postprocessed segmentation array.
 96    """
 97    # Make a copy of the original prediction
 98    pred = img.copy()
 99    
100    # Iterate over each label and apply redefinition based on thresholds
101    for label_name, th_dict in thresholds_dict.items():
102        label_number = int(label_name.split("_")[-1])
103        label_name_mapped = label_mapping.get(label_number, f"Label_{label_number}")
104        th = th_dict.get("th", 0)
105        redefine_to = th_dict.get("redefine_to", label_number)
106        
107        print(f"Label {label_name_mapped} - value {label_number} - Redefinition to {redefine_to} - Applying threshold {th}")
108        
109        # Get ratio with respect to the whole tumor
110        ratio = get_ratio_labels_wt(pred, labels=[label_number])
111        
112        # Conditioned redefinition
113        if ratio < th:
114            pred = np.where(pred == label_number, redefine_to, pred)
115    
116    return pred

Redefines labels in the segmentation image based on threshold criteria.

Args: img (np.ndarray): The original segmentation array. thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label. label_mapping (Dict[int, str]): Mapping from label numbers to label names.

Returns: np.ndarray: The postprocessed segmentation array.

def postprocess_batch( input_files: List[str], output_folder: str, thresholds_dict: Dict[str, Dict[str, Any]], label_mapping: Dict[int, str]) -> None:
119def postprocess_batch(
120    input_files: List[str],
121    output_folder: str,
122    thresholds_dict: Dict[str, Dict[str, Any]],
123    label_mapping: Dict[int, str]
124) -> None:
125    """
126    Applies postprocessing to a batch of segmentation files.
127
128    Args:
129        input_files (List[str]): List of input file paths.
130        output_folder (str): Path to the output directory to save postprocessed files.
131        thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label.
132        label_mapping (Dict[int, str]): Mapping from label numbers to label names.
133    """
134    for f in input_files:
135        print(f"Processing file {f}")
136        save_path = os.path.join(output_folder, os.path.basename(f))
137        
138        if os.path.exists(save_path):
139            print(f"File {save_path} already exists. Skipping.")
140            continue
141        
142        # Read the segmentation file
143        pred_orig = nib.load(f).get_fdata()
144        
145        # Apply postprocessing
146        pred = postprocess_lblredef(pred_orig, thresholds_dict, label_mapping)
147        
148        # Ensure the output directory exists
149        os.makedirs(os.path.dirname(save_path), exist_ok=True)
150        
151        # Save the postprocessed segmentation
152        nib.save(nib.Nifti1Image(pred, nib.load(f).affine), save_path)

Applies postprocessing to a batch of segmentation files.

Args: input_files (List[str]): List of input file paths. output_folder (str): Path to the output directory to save postprocessed files. thresholds_dict (Dict[str, Dict[str, Any]]): Dictionary containing threshold values for each label. label_mapping (Dict[int, str]): Mapping from label numbers to label names.

def get_thresholds_task(challenge_name: str, input_file: str) -> Dict[str, Dict[str, Any]]:
155def get_thresholds_task(challenge_name: str, input_file: str) -> Dict[str, Dict[str, Any]]:
156    """
157    Retrieves threshold settings for a specific challenge from a JSON file.
158
159    Args:
160        challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET).
161        input_file (str): Path to the JSON file containing thresholds.
162
163    Returns:
164        Dict[str, Dict[str, Any]]: Thresholds for the specified challenge.
165
166    Raises:
167        ValueError: If the challenge name is not found in the JSON file.
168    """
169    with open(input_file, 'r') as f:
170        thresholds = json.load(f)
171    
172    if challenge_name not in thresholds:
173        raise ValueError(f"Challenge {challenge_name} not found in the thresholds JSON file.")
174    
175    return thresholds[challenge_name]

Retrieves threshold settings for a specific challenge from a JSON file.

Args: challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). input_file (str): Path to the JSON file containing thresholds.

Returns: Dict[str, Dict[str, Any]]: Thresholds for the specified challenge.

Raises: ValueError: If the challenge name is not found in the JSON file.

def get_thresholds_cluster( thresholds: Dict[str, Dict[str, Any]], cluster_name: str) -> Dict[str, Any]:
178def get_thresholds_cluster(thresholds: Dict[str, Dict[str, Any]], cluster_name: str) -> Dict[str, Any]:
179    """
180    Retrieves threshold settings for a specific cluster within a challenge.
181
182    Args:
183        thresholds (Dict[str, Dict[str, Any]]): Thresholds for the challenge.
184        cluster_name (str): The name of the cluster (e.g., "cluster_1").
185
186    Returns:
187        Dict[str, Any]: Threshold settings for the specified cluster.
188
189    Raises:
190        ValueError: If the cluster name is not found in the thresholds.
191    """
192    if cluster_name not in thresholds:
193        raise ValueError(f"Cluster {cluster_name} not found in the thresholds JSON file.")
194    return thresholds[cluster_name]

Retrieves threshold settings for a specific cluster within a challenge.

Args: thresholds (Dict[str, Dict[str, Any]]): Thresholds for the challenge. cluster_name (str): The name of the cluster (e.g., "cluster_1").

Returns: Dict[str, Any]: Threshold settings for the specified cluster.

Raises: ValueError: If the cluster name is not found in the thresholds.

def get_files(input_folder_pred: str) -> List[str]:
197def get_files(input_folder_pred: str) -> List[str]:
198    """
199    Retrieves all prediction file paths from the input directory.
200
201    Args:
202        input_folder_pred (str): Path to the input directory containing prediction files.
203
204    Returns:
205        List[str]: Sorted list of prediction file paths.
206    """
207    files_pred = sorted(glob.glob(os.path.join(input_folder_pred, "*.nii.gz")))
208    print(f"Found {len(files_pred)} files to be processed.")
209    return files_pred

Retrieves all prediction file paths from the input directory.

Args: input_folder_pred (str): Path to the input directory containing prediction files.

Returns: List[str]: Sorted list of prediction file paths.

def get_cluster_files( cluster_assignment: List[Dict[str, Any]], cluster_id: int, files_pred: List[str]) -> List[str]:
212def get_cluster_files(
213    cluster_assignment: List[Dict[str, Any]],
214    cluster_id: int,
215    files_pred: List[str]
216) -> List[str]:
217    """
218    Retrieves prediction files that belong to a specific cluster.
219
220    Args:
221        cluster_assignment (List[Dict[str, Any]]): List of cluster assignments with StudyID and cluster.
222        cluster_id (int): The cluster identifier.
223        files_pred (List[str]): List of all prediction file paths.
224
225    Returns:
226        List[str]: List of prediction files belonging to the specified cluster.
227    """
228    cluster_ids = [e["StudyID"] for e in cluster_assignment if e["cluster"] == cluster_id]
229    cluster_files_pred = [
230        f for f in files_pred 
231        if os.path.basename(f).replace(".nii.gz", "") in cluster_ids
232    ]
233    print(f"Cluster {cluster_id} contains {len(cluster_files_pred)} files.")
234    return cluster_files_pred

Retrieves prediction files that belong to a specific cluster.

Args: cluster_assignment (List[Dict[str, Any]]): List of cluster assignments with StudyID and cluster. cluster_id (int): The cluster identifier. files_pred (List[str]): List of all prediction file paths.

Returns: List[str]: List of prediction files belonging to the specified cluster.

def read_cluster_assignment(clusters_json: str) -> Tuple[List[Dict[str, Any]], List[int]]:
237def read_cluster_assignment(clusters_json: str) -> Tuple[List[Dict[str, Any]], List[int]]:
238    """
239    Reads cluster assignments from a JSON file.
240
241    Args:
242        clusters_json (str): Path to the JSON file containing cluster assignments.
243
244    Returns:
245        Tuple[List[Dict[str, Any]], List[int]]: 
246            - List of cluster assignments with StudyID and cluster.
247            - Sorted list of unique cluster identifiers.
248    """
249    with open(clusters_json, "r") as f:
250        cluster_assignment = json.load(f)
251    
252    # Filter relevant keys
253    cluster_assignment = [
254        {key: value for key, value in e.items() if key in ["StudyID", "cluster"]}
255        for e in cluster_assignment
256    ]
257    
258    cluster_array = np.unique([e["cluster"] for e in cluster_assignment])
259    print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.")
260    return cluster_assignment, sorted(cluster_array)

Reads cluster assignments from a JSON file.

Args: clusters_json (str): Path to the JSON file containing cluster assignments.

Returns: Tuple[List[Dict[str, Any]], List[int]]: - List of cluster assignments with StudyID and cluster. - Sorted list of unique cluster identifiers.

def read_cluster_assignment_df( clusterdf: pandas.core.frame.DataFrame) -> Tuple[List[Dict[str, Any]], List[int]]:
263def read_cluster_assignment_df(clusterdf: pd.DataFrame) -> Tuple[List[Dict[str, Any]], List[int]]:
264    """
265    Reads cluster assignments from a pandas DataFrame.
266
267    Args:
268        clusterdf (pd.DataFrame): DataFrame containing cluster assignments.
269
270    Returns:
271        Tuple[List[Dict[str, Any]], List[int]]: 
272            - List of cluster assignments with StudyID and cluster.
273            - Sorted list of unique cluster identifiers.
274    """
275    cluster_assignment = clusterdf.to_dict(orient="records")
276    cluster_assignment = [
277        {key: value for key, value in e.items() if key in ["StudyID", "cluster"]}
278        for e in cluster_assignment
279    ]
280    cluster_array = np.unique([e["cluster"] for e in cluster_assignment])
281    print(f"Found {len(cluster_array)} clusters: {sorted(cluster_array)}.")
282    return cluster_assignment, sorted(cluster_array)

Reads cluster assignments from a pandas DataFrame.

Args: clusterdf (pd.DataFrame): DataFrame containing cluster assignments.

Returns: Tuple[List[Dict[str, Any]], List[int]]: - List of cluster assignments with StudyID and cluster. - Sorted list of unique cluster identifiers.

def label_redefinition( challenge_name: str, thresholds_file: str, input_folder_pred: str, clustersdf: pandas.core.frame.DataFrame, output_folder_pp_cc: str) -> str:
285def label_redefinition(
286    challenge_name: str,
287    thresholds_file: str,
288    input_folder_pred: str,
289    clustersdf: pd.DataFrame,
290    output_folder_pp_cc: str
291) -> str:
292    """
293    Performs label redefinition postprocessing for all clusters in the dataset.
294
295    Args:
296        challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET).
297        thresholds_file (str): Path to the JSON file containing threshold settings.
298        input_folder_pred (str): Path to the input directory containing prediction files.
299        clustersdf (pd.DataFrame): DataFrame containing cluster assignments.
300        output_folder_pp_cc (str): Path to the output directory to save postprocessed predictions.
301
302    Returns:
303        str: Path to the output directory containing postprocessed predictions.
304    """
305    # Retrieve label mapping for the challenge
306    label_mapping = LABEL_MAPPING_FACTORY.get(challenge_name)
307    if label_mapping is None:
308        raise ValueError(f"Unsupported challenge name: {challenge_name}")
309    
310    # Load threshold settings for the challenge
311    thresholds = get_thresholds_task(challenge_name, thresholds_file)
312    
313    # Retrieve all prediction files
314    files_pred = get_files(input_folder_pred)
315    
316    # Read cluster assignments from DataFrame
317    cluster_assignment, cluster_array = read_cluster_assignment_df(clustersdf)
318    
319    # Iterate over each cluster and apply postprocessing
320    for cluster in cluster_array:
321        cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred)
322        cluster_key = f"cluster_{cluster}"
323        thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key)
324        postprocess_batch(
325            input_files=cluster_files_pred,
326            output_folder=output_folder_pp_cc,
327            thresholds_dict=thresholds_cluster,
328            label_mapping=label_mapping
329        )
330    
331    return output_folder_pp_cc

Performs label redefinition postprocessing for all clusters in the dataset.

Args: challenge_name (str): The name of the challenge (e.g., BraTS-PED, BraTS-MET). thresholds_file (str): Path to the JSON file containing threshold settings. input_folder_pred (str): Path to the input directory containing prediction files. clustersdf (pd.DataFrame): DataFrame containing cluster assignments. output_folder_pp_cc (str): Path to the output directory to save postprocessed predictions.

Returns: str: Path to the output directory containing postprocessed predictions.

def main() -> None:
334def main() -> None:
335    """
336    Main function to execute label redefinition postprocessing.
337
338    Parses command-line arguments, loads necessary data, and performs postprocessing
339    on prediction files based on cluster assignments and threshold settings.
340    """
341    args = parse_args()
342    
343    # Load threshold settings for the specified challenge
344    thresholds = get_thresholds_task(args.challenge_name, args.thresholds_file)
345    
346    # Retrieve all prediction files from the input directory
347    files_pred = get_files(args.input_folder_pred)
348    
349    # Read cluster assignments from the clusters JSON file
350    cluster_assignment, cluster_array = read_cluster_assignment(args.clusters_file)
351    
352    # Iterate over each cluster and apply label redefinition
353    for cluster in cluster_array:
354        cluster_files_pred = get_cluster_files(cluster_assignment, cluster, files_pred)
355        cluster_key = f"cluster_{cluster}"
356        thresholds_cluster = get_thresholds_cluster(thresholds, cluster_key)
357        postprocess_batch(
358            input_files=cluster_files_pred,
359            output_folder=args.output_folder_pp_cc,
360            thresholds_dict=thresholds_cluster,
361            label_mapping=LABEL_MAPPING_FACTORY.get(args.challenge_name, {})
362        )

Main function to execute label redefinition postprocessing.

Parses command-line arguments, loads necessary data, and performs postprocessing on prediction files based on cluster assignments and threshold settings.