inference.swinunetr.inference

SwinUNETR Inference

Provides functionalities related to SwinUNETR segmentation inference.

  1"""
  2### SwinUNETR Inference
  3
  4Provides functionalities related to SwinUNETR segmentation inference.
  5"""
  6
  7import logging
  8from pathlib import Path
  9import time
 10import torch
 11import numpy as np
 12from typing import Any, List, Tuple, Union, Optional, Dict
 13from monai.inferers import sliding_window_inference
 14from monai.networks.nets import SwinUNETR
 15from monai.transforms import Activations
 16from monai.data import decollate_batch
 17import nibabel as nib
 18from functools import partial
 19import argparse
 20import json
 21
 22from monai.transforms import (
 23    LoadImaged,
 24    Compose,
 25    CropForegroundd,
 26    CopyItemsd,
 27    SpatialPadd,
 28    Spacingd,
 29    EnsureChannelFirstd,
 30    EnsureTyped,
 31    MapTransform,
 32    OneOf,
 33    NormalizeIntensityd,
 34    RandSpatialCropd,
 35    RandCropByPosNegLabeld,
 36    RandSpatialCropSamplesd,
 37    RandCoarseDropoutd,
 38    RandCoarseShuffled,
 39    RandScaleIntensityd,
 40    RandShiftIntensityd,
 41    RandFlipd,
 42    RandAdjustContrastd,
 43    ToTensord,
 44)
 45
 46from monai.data import DataLoader, CacheDataset, decollate_batch
 47# from datasets_utils import get_loader
 48
 49parser = argparse.ArgumentParser(description='Swin UNETR segmentation inference')
 50parser.add_argument('--datadir', default='/dataset/dataset0/', type=str, help='dataset directory')
 51parser.add_argument('--exp_path', default='test1', type=str, help='experiment output path')
 52parser.add_argument('--jsonlist', default='dataset_0.json', type=str, help='dataset json file')
 53parser.add_argument('--fold', default=1, type=int, help='data fold')
 54parser.add_argument('--pretrained_model_name', default='model.pt', type=str, help='pretrained model name')
 55parser.add_argument('--feature_size', default=48, type=int, help='feature size')
 56parser.add_argument('--infer_overlap', default=0.6, type=float, help='sliding window inference overlap')
 57parser.add_argument('--in_channels', default=4, type=int, help='number of input channels')
 58parser.add_argument('--out_channels', default=3, type=int, help='number of output channels')
 59parser.add_argument('--a_min', default=-175.0, type=float, help='a_min in ScaleIntensityRanged')
 60parser.add_argument('--a_max', default=250.0, type=float, help='a_max in ScaleIntensityRanged')
 61parser.add_argument('--b_min', default=0.0, type=float, help='b_min in ScaleIntensityRanged')
 62parser.add_argument('--b_max', default=1.0, type=float, help='b_max in ScaleIntensityRanged')
 63parser.add_argument('--space_x', default=1.5, type=float, help='spacing in x direction')
 64parser.add_argument('--space_y', default=1.5, type=float, help='spacing in y direction')
 65parser.add_argument('--space_z', default=2.0, type=float, help='spacing in z direction')
 66parser.add_argument('--roi_x', default=128, type=int, help='roi size in x direction')
 67parser.add_argument('--roi_y', default=128, type=int, help='roi size in y direction')
 68parser.add_argument('--roi_z', default=128, type=int, help='roi size in z direction')
 69parser.add_argument('--posrate', default=1.0, type=float, help='positive label rate')
 70parser.add_argument('--negrate', default=1.0, type=float, help='negative label rate')
 71parser.add_argument('--nsamples', default=1, type=int, help='number of cropped samples')
 72parser.add_argument('--dropout_rate', default=0.0, type=float, help='dropout rate')
 73parser.add_argument('--distributed', action='store_true', help='start distributed training')
 74parser.add_argument('--cacherate', default=1.0, type=float, help='cache data rate')
 75parser.add_argument('--workers', default=8, type=int, help='number of workers')
 76parser.add_argument('--batch_size', default=1, type=int, help='batch size')
 77parser.add_argument('--RandFlipd_prob', default=0.2, type=float, help='RandFlipd augmentation probability')
 78parser.add_argument('--RandRotate90d_prob', default=0.2, type=float, help='RandRotate90d augmentation probability')
 79parser.add_argument('--RandScaleIntensityd_prob', default=0.1, type=float, help='RandScaleIntensityd augmentation probability')
 80parser.add_argument('--RandShiftIntensityd_prob', default=0.1, type=float, help='RandShiftIntensityd augmentation probability')
 81parser.add_argument('--spatial_dims', default=3, type=int, help='spatial dimension of input data')
 82parser.add_argument('--use_checkpoint', action='store_true', help='use gradient checkpointing to save memory')
 83parser.add_argument('--pretrained_dir', default='./pretrained_models/fold1_f48_ep300_4gpu_dice0_9059/', type=str,
 84                    help='pretrained checkpoint directory')
 85parser.add_argument('--pred_label', action='store_true', help='predict labels or regions')
 86
 87
 88def load_json(path: Path) -> Any:
 89    """
 90    Loads a JSON file from the specified path.
 91
 92    Args:
 93        path (Path): A Path representing the file path.
 94
 95    Returns:
 96        Any: The data loaded from the JSON file.
 97    """
 98    with open(path, 'r') as f:
 99        return json.load(f)
100
101
102def save_json(path: Path, data: Any) -> None:
103    """
104    Saves data to a JSON file at the specified path.
105
106    Args:
107        path (Path): A Path representing the file path.
108        data (Any): The data to be serialized and saved.
109    """
110    with open(path, 'w') as f:
111        json.dump(data, f, indent=4)
112
113
114class ConvertToMultiChannelBasedOnBratsPEDClassesd(MapTransform):
115    """
116    Convert labels to multi channels based on new BraTS2023 classes:
117    label 2 is the peritumoral edema,
118    label 3 is the GD-enhancing tumor,
119    label 1 is the necrotic and non-enhancing tumor core.
120    
121    The possible classes are TC (Tumor core), WT (Whole tumor),
122    and ET (Enhancing tumor).
123    """
124
125    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
126        d = dict(data)
127        for key in self.keys:
128            result = []
129            # Merge label 1, 2, 3, 4 to construct WT
130            result.append(torch.logical_or(
131                torch.logical_or(torch.logical_or(d[key] == 1, d[key] == 2), d[key] == 3),
132                d[key] == 4
133            ))
134            # Merge label 1, 2, 3 to construct TC
135            result.append(torch.logical_or(
136                torch.logical_or(d[key] == 2, d[key] == 3),
137                d[key] == 1
138            ))
139            # Merge label 1, 2 to construct NET
140            result.append(torch.logical_or(d[key] == 1, d[key] == 2))
141            # Label 1 is ET
142            result.append(d[key] == 1)
143            d[key] = torch.stack(result, axis=0).float()
144        return d
145
146
147def get_loader(args: argparse.Namespace) -> DataLoader:
148    """
149    Load datasets for training, validation, and testing from JSON files.
150
151    Args:
152        args (argparse.Namespace): Parsed command-line arguments.
153
154    Returns:
155        DataLoader: DataLoader for the validation dataset.
156    """
157    data_root = Path(args.datadir)
158    # data_root = Path(args.data_dir)
159    channel_order = ['-t1n.nii.gz', '-t1c.nii.gz', '-t2w.nii.gz', '-t2f.nii.gz']
160    img_paths = [f"{data_root.name}{c}" for c in channel_order]
161
162    # val_data = json_data['validation']
163    val_data = [{'image': img_paths}]
164    # val_data = load_json(args.jsonlist)['validation']
165
166    # Add data root to JSON file lists
167    for i in range(len(val_data)):
168        val_data[i]['label'] = ""
169        for j in range(len(val_data[i]['image'])):
170            val_data[i]['image'][j] = str(data_root / val_data[i]['image'][j])
171
172    val_transform = Compose(
173        [
174            LoadImaged(keys=["image"], image_only=False),
175            EnsureChannelFirstd(keys=["image"]),
176            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
177            ToTensord(keys=["image"]),
178        ]
179    )
180    val_ds = CacheDataset(
181        data=val_data,
182        transform=val_transform,
183        cache_rate=args.cacherate,
184        num_workers=args.workers
185    )
186    val_loader = DataLoader(
187        val_ds,
188        batch_size=1,
189        shuffle=False,
190        num_workers=args.workers,
191        pin_memory=False
192    )
193
194    return val_loader
195
196
197def main() -> None:
198    """
199    Main function to perform Swin UNETR segmentation inference.
200
201    This function parses command-line arguments, sets up logging,
202    loads the validation data, initializes the model, performs inference
203    on the validation dataset, and saves the segmentation results.
204    """
205    time0 = time.time()
206    args = parser.parse_args()
207    output_directory = Path(args.exp_path)
208    if not output_directory.exists():
209        output_directory.mkdir(parents=True)
210    
211    # Configure logging
212    logging.basicConfig(
213        filename=output_directory / 'infer.log',
214        filemode='w',
215        level=logging.INFO,
216        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
217        datefmt='%Y-%m-%d %H:%M:%S'
218    )
219
220    # Load validation data
221    val_loader = get_loader(args)
222    pretrained_dir = args.pretrained_dir
223    model_name = args.pretrained_model_name
224    pretrained_pth = Path(pretrained_dir) / model_name
225    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
226    
227    # Initialize the SwinUNETR model
228    model = SwinUNETR(
229        img_size=(args.roi_x, args.roi_y, args.roi_z),
230        in_channels=args.in_channels,
231        out_channels=args.out_channels,
232        feature_size=args.feature_size,
233        drop_rate=0.0,
234        attn_drop_rate=0.0,
235        dropout_path_rate=0.0,
236        use_checkpoint=args.use_checkpoint
237    )
238
239    # Load pretrained model weights
240    model_dict = torch.load(pretrained_pth, map_location=device)['model']
241    model.load_state_dict(model_dict)
242    model.eval()
243    model.to(device)
244
245    # Set up the inference function with sliding window
246    model_inferer_test = partial(
247        sliding_window_inference,
248        roi_size=[args.roi_x, args.roi_y, args.roi_z],
249        sw_batch_size=1,
250        predictor=model,
251        overlap=args.infer_overlap,
252    )
253    
254    # Set up activation function
255    post_trans = Activations(sigmoid=not args.pred_label, softmax=args.pred_label)
256
257    with torch.no_grad():
258        for i, batch in enumerate(val_loader):
259            image = batch["image"].cuda()
260            affine = batch['image_meta_dict']['original_affine'][0].numpy()
261            filepath = Path(batch['image_meta_dict']['filename_or_obj'][0])
262            img_name = filepath.name.split('.nii.gz')[0]
263            
264            # Perform inference
265            output_pred = model_inferer_test(image)
266            logging.info(f"Inference on case {img_name}")
267            logging.info(f"Label-wise: {args.pred_label}")
268            logging.info(f"Image shape: {image.shape}")
269            logging.info(f"Prediction shape: {output_pred.shape}")
270            
271            # Apply activation and convert to NumPy
272            prob = [post_trans(i) for i in decollate_batch(output_pred)]
273            prob_np = prob[0].detach().cpu().numpy()
274            logging.info(f"Probmap shape: {prob_np.shape}")
275            np.savez(output_directory / f"{img_name}.npz", probabilities=prob_np)
276            
277            # Save integer masks based on prediction
278            if args.pred_label:
279                seg_out = np.argmax(prob_np, axis=0)
280            else:
281                seg = (prob_np > 0.5).astype(np.int8)
282                seg_out = np.zeros_like(seg[0])
283                seg_out = np.where(seg[0] == 1, 4, 0)
284                seg_out = np.where((seg[1] == 1) & (seg_out == 4), 3, seg_out)
285                seg_out = np.where((seg[2] == 1) & (seg_out == 3), 2, seg_out)
286                seg_out = np.where((seg[3] == 1) & (seg_out == 2), 1, seg_out)
287                # seg_out[seg[3] == 1] = 4
288                # seg_out[seg[0] == 1] = 1
289                # seg_out[seg[2] == 1] = 3
290
291            # Save the segmentation result as a NIfTI file
292            nib.save(
293                nib.Nifti1Image(seg_out.astype(np.int8), affine),
294                output_directory / f"{img_name}.nii.gz"
295            )
296            
297            logging.info(f"Seg shape: {seg_out.shape}")
298                 
299        logging.info(f"Finished inference! {int(time.time() - time0)} s")
300
301
302if __name__ == '__main__':
303    main()
parser = ArgumentParser(prog='pdoc', usage=None, description='Swin UNETR segmentation inference', formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)
def load_json(path: pathlib.Path) -> Any:
 89def load_json(path: Path) -> Any:
 90    """
 91    Loads a JSON file from the specified path.
 92
 93    Args:
 94        path (Path): A Path representing the file path.
 95
 96    Returns:
 97        Any: The data loaded from the JSON file.
 98    """
 99    with open(path, 'r') as f:
100        return json.load(f)

Loads a JSON file from the specified path.

Args: path (Path): A Path representing the file path.

Returns: Any: The data loaded from the JSON file.

def save_json(path: pathlib.Path, data: Any) -> None:
103def save_json(path: Path, data: Any) -> None:
104    """
105    Saves data to a JSON file at the specified path.
106
107    Args:
108        path (Path): A Path representing the file path.
109        data (Any): The data to be serialized and saved.
110    """
111    with open(path, 'w') as f:
112        json.dump(data, f, indent=4)

Saves data to a JSON file at the specified path.

Args: path (Path): A Path representing the file path. data (Any): The data to be serialized and saved.

class ConvertToMultiChannelBasedOnBratsPEDClassesd(monai.transforms.transform.MapTransform):
115class ConvertToMultiChannelBasedOnBratsPEDClassesd(MapTransform):
116    """
117    Convert labels to multi channels based on new BraTS2023 classes:
118    label 2 is the peritumoral edema,
119    label 3 is the GD-enhancing tumor,
120    label 1 is the necrotic and non-enhancing tumor core.
121    
122    The possible classes are TC (Tumor core), WT (Whole tumor),
123    and ET (Enhancing tumor).
124    """
125
126    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
127        d = dict(data)
128        for key in self.keys:
129            result = []
130            # Merge label 1, 2, 3, 4 to construct WT
131            result.append(torch.logical_or(
132                torch.logical_or(torch.logical_or(d[key] == 1, d[key] == 2), d[key] == 3),
133                d[key] == 4
134            ))
135            # Merge label 1, 2, 3 to construct TC
136            result.append(torch.logical_or(
137                torch.logical_or(d[key] == 2, d[key] == 3),
138                d[key] == 1
139            ))
140            # Merge label 1, 2 to construct NET
141            result.append(torch.logical_or(d[key] == 1, d[key] == 2))
142            # Label 1 is ET
143            result.append(d[key] == 1)
144            d[key] = torch.stack(result, axis=0).float()
145        return d

Convert labels to multi channels based on new BraTS2023 classes: label 2 is the peritumoral edema, label 3 is the GD-enhancing tumor, label 1 is the necrotic and non-enhancing tumor core.

The possible classes are TC (Tumor core), WT (Whole tumor), and ET (Enhancing tumor).

def get_loader(args: argparse.Namespace) -> monai.data.dataloader.DataLoader:
148def get_loader(args: argparse.Namespace) -> DataLoader:
149    """
150    Load datasets for training, validation, and testing from JSON files.
151
152    Args:
153        args (argparse.Namespace): Parsed command-line arguments.
154
155    Returns:
156        DataLoader: DataLoader for the validation dataset.
157    """
158    data_root = Path(args.datadir)
159    # data_root = Path(args.data_dir)
160    channel_order = ['-t1n.nii.gz', '-t1c.nii.gz', '-t2w.nii.gz', '-t2f.nii.gz']
161    img_paths = [f"{data_root.name}{c}" for c in channel_order]
162
163    # val_data = json_data['validation']
164    val_data = [{'image': img_paths}]
165    # val_data = load_json(args.jsonlist)['validation']
166
167    # Add data root to JSON file lists
168    for i in range(len(val_data)):
169        val_data[i]['label'] = ""
170        for j in range(len(val_data[i]['image'])):
171            val_data[i]['image'][j] = str(data_root / val_data[i]['image'][j])
172
173    val_transform = Compose(
174        [
175            LoadImaged(keys=["image"], image_only=False),
176            EnsureChannelFirstd(keys=["image"]),
177            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
178            ToTensord(keys=["image"]),
179        ]
180    )
181    val_ds = CacheDataset(
182        data=val_data,
183        transform=val_transform,
184        cache_rate=args.cacherate,
185        num_workers=args.workers
186    )
187    val_loader = DataLoader(
188        val_ds,
189        batch_size=1,
190        shuffle=False,
191        num_workers=args.workers,
192        pin_memory=False
193    )
194
195    return val_loader

Load datasets for training, validation, and testing from JSON files.

Args: args (argparse.Namespace): Parsed command-line arguments.

Returns: DataLoader: DataLoader for the validation dataset.

def main() -> None:
198def main() -> None:
199    """
200    Main function to perform Swin UNETR segmentation inference.
201
202    This function parses command-line arguments, sets up logging,
203    loads the validation data, initializes the model, performs inference
204    on the validation dataset, and saves the segmentation results.
205    """
206    time0 = time.time()
207    args = parser.parse_args()
208    output_directory = Path(args.exp_path)
209    if not output_directory.exists():
210        output_directory.mkdir(parents=True)
211    
212    # Configure logging
213    logging.basicConfig(
214        filename=output_directory / 'infer.log',
215        filemode='w',
216        level=logging.INFO,
217        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
218        datefmt='%Y-%m-%d %H:%M:%S'
219    )
220
221    # Load validation data
222    val_loader = get_loader(args)
223    pretrained_dir = args.pretrained_dir
224    model_name = args.pretrained_model_name
225    pretrained_pth = Path(pretrained_dir) / model_name
226    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
227    
228    # Initialize the SwinUNETR model
229    model = SwinUNETR(
230        img_size=(args.roi_x, args.roi_y, args.roi_z),
231        in_channels=args.in_channels,
232        out_channels=args.out_channels,
233        feature_size=args.feature_size,
234        drop_rate=0.0,
235        attn_drop_rate=0.0,
236        dropout_path_rate=0.0,
237        use_checkpoint=args.use_checkpoint
238    )
239
240    # Load pretrained model weights
241    model_dict = torch.load(pretrained_pth, map_location=device)['model']
242    model.load_state_dict(model_dict)
243    model.eval()
244    model.to(device)
245
246    # Set up the inference function with sliding window
247    model_inferer_test = partial(
248        sliding_window_inference,
249        roi_size=[args.roi_x, args.roi_y, args.roi_z],
250        sw_batch_size=1,
251        predictor=model,
252        overlap=args.infer_overlap,
253    )
254    
255    # Set up activation function
256    post_trans = Activations(sigmoid=not args.pred_label, softmax=args.pred_label)
257
258    with torch.no_grad():
259        for i, batch in enumerate(val_loader):
260            image = batch["image"].cuda()
261            affine = batch['image_meta_dict']['original_affine'][0].numpy()
262            filepath = Path(batch['image_meta_dict']['filename_or_obj'][0])
263            img_name = filepath.name.split('.nii.gz')[0]
264            
265            # Perform inference
266            output_pred = model_inferer_test(image)
267            logging.info(f"Inference on case {img_name}")
268            logging.info(f"Label-wise: {args.pred_label}")
269            logging.info(f"Image shape: {image.shape}")
270            logging.info(f"Prediction shape: {output_pred.shape}")
271            
272            # Apply activation and convert to NumPy
273            prob = [post_trans(i) for i in decollate_batch(output_pred)]
274            prob_np = prob[0].detach().cpu().numpy()
275            logging.info(f"Probmap shape: {prob_np.shape}")
276            np.savez(output_directory / f"{img_name}.npz", probabilities=prob_np)
277            
278            # Save integer masks based on prediction
279            if args.pred_label:
280                seg_out = np.argmax(prob_np, axis=0)
281            else:
282                seg = (prob_np > 0.5).astype(np.int8)
283                seg_out = np.zeros_like(seg[0])
284                seg_out = np.where(seg[0] == 1, 4, 0)
285                seg_out = np.where((seg[1] == 1) & (seg_out == 4), 3, seg_out)
286                seg_out = np.where((seg[2] == 1) & (seg_out == 3), 2, seg_out)
287                seg_out = np.where((seg[3] == 1) & (seg_out == 2), 1, seg_out)
288                # seg_out[seg[3] == 1] = 4
289                # seg_out[seg[0] == 1] = 1
290                # seg_out[seg[2] == 1] = 3
291
292            # Save the segmentation result as a NIfTI file
293            nib.save(
294                nib.Nifti1Image(seg_out.astype(np.int8), affine),
295                output_directory / f"{img_name}.nii.gz"
296            )
297            
298            logging.info(f"Seg shape: {seg_out.shape}")
299                 
300        logging.info(f"Finished inference! {int(time.time() - time0)} s")

Main function to perform Swin UNETR segmentation inference.

This function parses command-line arguments, sets up logging, loads the validation data, initializes the model, performs inference on the validation dataset, and saves the segmentation results.