inference.swinunetr.inference
SwinUNETR Inference
Provides functionalities related to SwinUNETR segmentation inference.
1""" 2### SwinUNETR Inference 3 4Provides functionalities related to SwinUNETR segmentation inference. 5""" 6 7import logging 8from pathlib import Path 9import time 10import torch 11import numpy as np 12from typing import Any, List, Tuple, Union, Optional, Dict 13from monai.inferers import sliding_window_inference 14from monai.networks.nets import SwinUNETR 15from monai.transforms import Activations 16from monai.data import decollate_batch 17import nibabel as nib 18from functools import partial 19import argparse 20import json 21 22from monai.transforms import ( 23 LoadImaged, 24 Compose, 25 CropForegroundd, 26 CopyItemsd, 27 SpatialPadd, 28 Spacingd, 29 EnsureChannelFirstd, 30 EnsureTyped, 31 MapTransform, 32 OneOf, 33 NormalizeIntensityd, 34 RandSpatialCropd, 35 RandCropByPosNegLabeld, 36 RandSpatialCropSamplesd, 37 RandCoarseDropoutd, 38 RandCoarseShuffled, 39 RandScaleIntensityd, 40 RandShiftIntensityd, 41 RandFlipd, 42 RandAdjustContrastd, 43 ToTensord, 44) 45 46from monai.data import DataLoader, CacheDataset, decollate_batch 47# from datasets_utils import get_loader 48 49parser = argparse.ArgumentParser(description='Swin UNETR segmentation inference') 50parser.add_argument('--datadir', default='/dataset/dataset0/', type=str, help='dataset directory') 51parser.add_argument('--exp_path', default='test1', type=str, help='experiment output path') 52parser.add_argument('--jsonlist', default='dataset_0.json', type=str, help='dataset json file') 53parser.add_argument('--fold', default=1, type=int, help='data fold') 54parser.add_argument('--pretrained_model_name', default='model.pt', type=str, help='pretrained model name') 55parser.add_argument('--feature_size', default=48, type=int, help='feature size') 56parser.add_argument('--infer_overlap', default=0.6, type=float, help='sliding window inference overlap') 57parser.add_argument('--in_channels', default=4, type=int, help='number of input channels') 58parser.add_argument('--out_channels', default=3, type=int, help='number of output channels') 59parser.add_argument('--a_min', default=-175.0, type=float, help='a_min in ScaleIntensityRanged') 60parser.add_argument('--a_max', default=250.0, type=float, help='a_max in ScaleIntensityRanged') 61parser.add_argument('--b_min', default=0.0, type=float, help='b_min in ScaleIntensityRanged') 62parser.add_argument('--b_max', default=1.0, type=float, help='b_max in ScaleIntensityRanged') 63parser.add_argument('--space_x', default=1.5, type=float, help='spacing in x direction') 64parser.add_argument('--space_y', default=1.5, type=float, help='spacing in y direction') 65parser.add_argument('--space_z', default=2.0, type=float, help='spacing in z direction') 66parser.add_argument('--roi_x', default=128, type=int, help='roi size in x direction') 67parser.add_argument('--roi_y', default=128, type=int, help='roi size in y direction') 68parser.add_argument('--roi_z', default=128, type=int, help='roi size in z direction') 69parser.add_argument('--posrate', default=1.0, type=float, help='positive label rate') 70parser.add_argument('--negrate', default=1.0, type=float, help='negative label rate') 71parser.add_argument('--nsamples', default=1, type=int, help='number of cropped samples') 72parser.add_argument('--dropout_rate', default=0.0, type=float, help='dropout rate') 73parser.add_argument('--distributed', action='store_true', help='start distributed training') 74parser.add_argument('--cacherate', default=1.0, type=float, help='cache data rate') 75parser.add_argument('--workers', default=8, type=int, help='number of workers') 76parser.add_argument('--batch_size', default=1, type=int, help='batch size') 77parser.add_argument('--RandFlipd_prob', default=0.2, type=float, help='RandFlipd augmentation probability') 78parser.add_argument('--RandRotate90d_prob', default=0.2, type=float, help='RandRotate90d augmentation probability') 79parser.add_argument('--RandScaleIntensityd_prob', default=0.1, type=float, help='RandScaleIntensityd augmentation probability') 80parser.add_argument('--RandShiftIntensityd_prob', default=0.1, type=float, help='RandShiftIntensityd augmentation probability') 81parser.add_argument('--spatial_dims', default=3, type=int, help='spatial dimension of input data') 82parser.add_argument('--use_checkpoint', action='store_true', help='use gradient checkpointing to save memory') 83parser.add_argument('--pretrained_dir', default='./pretrained_models/fold1_f48_ep300_4gpu_dice0_9059/', type=str, 84 help='pretrained checkpoint directory') 85parser.add_argument('--pred_label', action='store_true', help='predict labels or regions') 86 87 88def load_json(path: Path) -> Any: 89 """ 90 Loads a JSON file from the specified path. 91 92 Args: 93 path (Path): A Path representing the file path. 94 95 Returns: 96 Any: The data loaded from the JSON file. 97 """ 98 with open(path, 'r') as f: 99 return json.load(f) 100 101 102def save_json(path: Path, data: Any) -> None: 103 """ 104 Saves data to a JSON file at the specified path. 105 106 Args: 107 path (Path): A Path representing the file path. 108 data (Any): The data to be serialized and saved. 109 """ 110 with open(path, 'w') as f: 111 json.dump(data, f, indent=4) 112 113 114class ConvertToMultiChannelBasedOnBratsPEDClassesd(MapTransform): 115 """ 116 Convert labels to multi channels based on new BraTS2023 classes: 117 label 2 is the peritumoral edema, 118 label 3 is the GD-enhancing tumor, 119 label 1 is the necrotic and non-enhancing tumor core. 120 121 The possible classes are TC (Tumor core), WT (Whole tumor), 122 and ET (Enhancing tumor). 123 """ 124 125 def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: 126 d = dict(data) 127 for key in self.keys: 128 result = [] 129 # Merge label 1, 2, 3, 4 to construct WT 130 result.append(torch.logical_or( 131 torch.logical_or(torch.logical_or(d[key] == 1, d[key] == 2), d[key] == 3), 132 d[key] == 4 133 )) 134 # Merge label 1, 2, 3 to construct TC 135 result.append(torch.logical_or( 136 torch.logical_or(d[key] == 2, d[key] == 3), 137 d[key] == 1 138 )) 139 # Merge label 1, 2 to construct NET 140 result.append(torch.logical_or(d[key] == 1, d[key] == 2)) 141 # Label 1 is ET 142 result.append(d[key] == 1) 143 d[key] = torch.stack(result, axis=0).float() 144 return d 145 146 147def get_loader(args: argparse.Namespace) -> DataLoader: 148 """ 149 Load datasets for training, validation, and testing from JSON files. 150 151 Args: 152 args (argparse.Namespace): Parsed command-line arguments. 153 154 Returns: 155 DataLoader: DataLoader for the validation dataset. 156 """ 157 data_root = Path(args.datadir) 158 # data_root = Path(args.data_dir) 159 channel_order = ['-t1n.nii.gz', '-t1c.nii.gz', '-t2w.nii.gz', '-t2f.nii.gz'] 160 img_paths = [f"{data_root.name}{c}" for c in channel_order] 161 162 # val_data = json_data['validation'] 163 val_data = [{'image': img_paths}] 164 # val_data = load_json(args.jsonlist)['validation'] 165 166 # Add data root to JSON file lists 167 for i in range(len(val_data)): 168 val_data[i]['label'] = "" 169 for j in range(len(val_data[i]['image'])): 170 val_data[i]['image'][j] = str(data_root / val_data[i]['image'][j]) 171 172 val_transform = Compose( 173 [ 174 LoadImaged(keys=["image"], image_only=False), 175 EnsureChannelFirstd(keys=["image"]), 176 NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), 177 ToTensord(keys=["image"]), 178 ] 179 ) 180 val_ds = CacheDataset( 181 data=val_data, 182 transform=val_transform, 183 cache_rate=args.cacherate, 184 num_workers=args.workers 185 ) 186 val_loader = DataLoader( 187 val_ds, 188 batch_size=1, 189 shuffle=False, 190 num_workers=args.workers, 191 pin_memory=False 192 ) 193 194 return val_loader 195 196 197def main() -> None: 198 """ 199 Main function to perform Swin UNETR segmentation inference. 200 201 This function parses command-line arguments, sets up logging, 202 loads the validation data, initializes the model, performs inference 203 on the validation dataset, and saves the segmentation results. 204 """ 205 time0 = time.time() 206 args = parser.parse_args() 207 output_directory = Path(args.exp_path) 208 if not output_directory.exists(): 209 output_directory.mkdir(parents=True) 210 211 # Configure logging 212 logging.basicConfig( 213 filename=output_directory / 'infer.log', 214 filemode='w', 215 level=logging.INFO, 216 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 217 datefmt='%Y-%m-%d %H:%M:%S' 218 ) 219 220 # Load validation data 221 val_loader = get_loader(args) 222 pretrained_dir = args.pretrained_dir 223 model_name = args.pretrained_model_name 224 pretrained_pth = Path(pretrained_dir) / model_name 225 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 226 227 # Initialize the SwinUNETR model 228 model = SwinUNETR( 229 img_size=(args.roi_x, args.roi_y, args.roi_z), 230 in_channels=args.in_channels, 231 out_channels=args.out_channels, 232 feature_size=args.feature_size, 233 drop_rate=0.0, 234 attn_drop_rate=0.0, 235 dropout_path_rate=0.0, 236 use_checkpoint=args.use_checkpoint 237 ) 238 239 # Load pretrained model weights 240 model_dict = torch.load(pretrained_pth, map_location=device)['model'] 241 model.load_state_dict(model_dict) 242 model.eval() 243 model.to(device) 244 245 # Set up the inference function with sliding window 246 model_inferer_test = partial( 247 sliding_window_inference, 248 roi_size=[args.roi_x, args.roi_y, args.roi_z], 249 sw_batch_size=1, 250 predictor=model, 251 overlap=args.infer_overlap, 252 ) 253 254 # Set up activation function 255 post_trans = Activations(sigmoid=not args.pred_label, softmax=args.pred_label) 256 257 with torch.no_grad(): 258 for i, batch in enumerate(val_loader): 259 image = batch["image"].cuda() 260 affine = batch['image_meta_dict']['original_affine'][0].numpy() 261 filepath = Path(batch['image_meta_dict']['filename_or_obj'][0]) 262 img_name = filepath.name.split('.nii.gz')[0] 263 264 # Perform inference 265 output_pred = model_inferer_test(image) 266 logging.info(f"Inference on case {img_name}") 267 logging.info(f"Label-wise: {args.pred_label}") 268 logging.info(f"Image shape: {image.shape}") 269 logging.info(f"Prediction shape: {output_pred.shape}") 270 271 # Apply activation and convert to NumPy 272 prob = [post_trans(i) for i in decollate_batch(output_pred)] 273 prob_np = prob[0].detach().cpu().numpy() 274 logging.info(f"Probmap shape: {prob_np.shape}") 275 np.savez(output_directory / f"{img_name}.npz", probabilities=prob_np) 276 277 # Save integer masks based on prediction 278 if args.pred_label: 279 seg_out = np.argmax(prob_np, axis=0) 280 else: 281 seg = (prob_np > 0.5).astype(np.int8) 282 seg_out = np.zeros_like(seg[0]) 283 seg_out = np.where(seg[0] == 1, 4, 0) 284 seg_out = np.where((seg[1] == 1) & (seg_out == 4), 3, seg_out) 285 seg_out = np.where((seg[2] == 1) & (seg_out == 3), 2, seg_out) 286 seg_out = np.where((seg[3] == 1) & (seg_out == 2), 1, seg_out) 287 # seg_out[seg[3] == 1] = 4 288 # seg_out[seg[0] == 1] = 1 289 # seg_out[seg[2] == 1] = 3 290 291 # Save the segmentation result as a NIfTI file 292 nib.save( 293 nib.Nifti1Image(seg_out.astype(np.int8), affine), 294 output_directory / f"{img_name}.nii.gz" 295 ) 296 297 logging.info(f"Seg shape: {seg_out.shape}") 298 299 logging.info(f"Finished inference! {int(time.time() - time0)} s") 300 301 302if __name__ == '__main__': 303 main()
89def load_json(path: Path) -> Any: 90 """ 91 Loads a JSON file from the specified path. 92 93 Args: 94 path (Path): A Path representing the file path. 95 96 Returns: 97 Any: The data loaded from the JSON file. 98 """ 99 with open(path, 'r') as f: 100 return json.load(f)
Loads a JSON file from the specified path.
Args: path (Path): A Path representing the file path.
Returns: Any: The data loaded from the JSON file.
103def save_json(path: Path, data: Any) -> None: 104 """ 105 Saves data to a JSON file at the specified path. 106 107 Args: 108 path (Path): A Path representing the file path. 109 data (Any): The data to be serialized and saved. 110 """ 111 with open(path, 'w') as f: 112 json.dump(data, f, indent=4)
Saves data to a JSON file at the specified path.
Args: path (Path): A Path representing the file path. data (Any): The data to be serialized and saved.
115class ConvertToMultiChannelBasedOnBratsPEDClassesd(MapTransform): 116 """ 117 Convert labels to multi channels based on new BraTS2023 classes: 118 label 2 is the peritumoral edema, 119 label 3 is the GD-enhancing tumor, 120 label 1 is the necrotic and non-enhancing tumor core. 121 122 The possible classes are TC (Tumor core), WT (Whole tumor), 123 and ET (Enhancing tumor). 124 """ 125 126 def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: 127 d = dict(data) 128 for key in self.keys: 129 result = [] 130 # Merge label 1, 2, 3, 4 to construct WT 131 result.append(torch.logical_or( 132 torch.logical_or(torch.logical_or(d[key] == 1, d[key] == 2), d[key] == 3), 133 d[key] == 4 134 )) 135 # Merge label 1, 2, 3 to construct TC 136 result.append(torch.logical_or( 137 torch.logical_or(d[key] == 2, d[key] == 3), 138 d[key] == 1 139 )) 140 # Merge label 1, 2 to construct NET 141 result.append(torch.logical_or(d[key] == 1, d[key] == 2)) 142 # Label 1 is ET 143 result.append(d[key] == 1) 144 d[key] = torch.stack(result, axis=0).float() 145 return d
Convert labels to multi channels based on new BraTS2023 classes: label 2 is the peritumoral edema, label 3 is the GD-enhancing tumor, label 1 is the necrotic and non-enhancing tumor core.
The possible classes are TC (Tumor core), WT (Whole tumor), and ET (Enhancing tumor).
148def get_loader(args: argparse.Namespace) -> DataLoader: 149 """ 150 Load datasets for training, validation, and testing from JSON files. 151 152 Args: 153 args (argparse.Namespace): Parsed command-line arguments. 154 155 Returns: 156 DataLoader: DataLoader for the validation dataset. 157 """ 158 data_root = Path(args.datadir) 159 # data_root = Path(args.data_dir) 160 channel_order = ['-t1n.nii.gz', '-t1c.nii.gz', '-t2w.nii.gz', '-t2f.nii.gz'] 161 img_paths = [f"{data_root.name}{c}" for c in channel_order] 162 163 # val_data = json_data['validation'] 164 val_data = [{'image': img_paths}] 165 # val_data = load_json(args.jsonlist)['validation'] 166 167 # Add data root to JSON file lists 168 for i in range(len(val_data)): 169 val_data[i]['label'] = "" 170 for j in range(len(val_data[i]['image'])): 171 val_data[i]['image'][j] = str(data_root / val_data[i]['image'][j]) 172 173 val_transform = Compose( 174 [ 175 LoadImaged(keys=["image"], image_only=False), 176 EnsureChannelFirstd(keys=["image"]), 177 NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), 178 ToTensord(keys=["image"]), 179 ] 180 ) 181 val_ds = CacheDataset( 182 data=val_data, 183 transform=val_transform, 184 cache_rate=args.cacherate, 185 num_workers=args.workers 186 ) 187 val_loader = DataLoader( 188 val_ds, 189 batch_size=1, 190 shuffle=False, 191 num_workers=args.workers, 192 pin_memory=False 193 ) 194 195 return val_loader
Load datasets for training, validation, and testing from JSON files.
Args: args (argparse.Namespace): Parsed command-line arguments.
Returns: DataLoader: DataLoader for the validation dataset.
198def main() -> None: 199 """ 200 Main function to perform Swin UNETR segmentation inference. 201 202 This function parses command-line arguments, sets up logging, 203 loads the validation data, initializes the model, performs inference 204 on the validation dataset, and saves the segmentation results. 205 """ 206 time0 = time.time() 207 args = parser.parse_args() 208 output_directory = Path(args.exp_path) 209 if not output_directory.exists(): 210 output_directory.mkdir(parents=True) 211 212 # Configure logging 213 logging.basicConfig( 214 filename=output_directory / 'infer.log', 215 filemode='w', 216 level=logging.INFO, 217 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 218 datefmt='%Y-%m-%d %H:%M:%S' 219 ) 220 221 # Load validation data 222 val_loader = get_loader(args) 223 pretrained_dir = args.pretrained_dir 224 model_name = args.pretrained_model_name 225 pretrained_pth = Path(pretrained_dir) / model_name 226 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 227 228 # Initialize the SwinUNETR model 229 model = SwinUNETR( 230 img_size=(args.roi_x, args.roi_y, args.roi_z), 231 in_channels=args.in_channels, 232 out_channels=args.out_channels, 233 feature_size=args.feature_size, 234 drop_rate=0.0, 235 attn_drop_rate=0.0, 236 dropout_path_rate=0.0, 237 use_checkpoint=args.use_checkpoint 238 ) 239 240 # Load pretrained model weights 241 model_dict = torch.load(pretrained_pth, map_location=device)['model'] 242 model.load_state_dict(model_dict) 243 model.eval() 244 model.to(device) 245 246 # Set up the inference function with sliding window 247 model_inferer_test = partial( 248 sliding_window_inference, 249 roi_size=[args.roi_x, args.roi_y, args.roi_z], 250 sw_batch_size=1, 251 predictor=model, 252 overlap=args.infer_overlap, 253 ) 254 255 # Set up activation function 256 post_trans = Activations(sigmoid=not args.pred_label, softmax=args.pred_label) 257 258 with torch.no_grad(): 259 for i, batch in enumerate(val_loader): 260 image = batch["image"].cuda() 261 affine = batch['image_meta_dict']['original_affine'][0].numpy() 262 filepath = Path(batch['image_meta_dict']['filename_or_obj'][0]) 263 img_name = filepath.name.split('.nii.gz')[0] 264 265 # Perform inference 266 output_pred = model_inferer_test(image) 267 logging.info(f"Inference on case {img_name}") 268 logging.info(f"Label-wise: {args.pred_label}") 269 logging.info(f"Image shape: {image.shape}") 270 logging.info(f"Prediction shape: {output_pred.shape}") 271 272 # Apply activation and convert to NumPy 273 prob = [post_trans(i) for i in decollate_batch(output_pred)] 274 prob_np = prob[0].detach().cpu().numpy() 275 logging.info(f"Probmap shape: {prob_np.shape}") 276 np.savez(output_directory / f"{img_name}.npz", probabilities=prob_np) 277 278 # Save integer masks based on prediction 279 if args.pred_label: 280 seg_out = np.argmax(prob_np, axis=0) 281 else: 282 seg = (prob_np > 0.5).astype(np.int8) 283 seg_out = np.zeros_like(seg[0]) 284 seg_out = np.where(seg[0] == 1, 4, 0) 285 seg_out = np.where((seg[1] == 1) & (seg_out == 4), 3, seg_out) 286 seg_out = np.where((seg[2] == 1) & (seg_out == 3), 2, seg_out) 287 seg_out = np.where((seg[3] == 1) & (seg_out == 2), 1, seg_out) 288 # seg_out[seg[3] == 1] = 4 289 # seg_out[seg[0] == 1] = 1 290 # seg_out[seg[2] == 1] = 3 291 292 # Save the segmentation result as a NIfTI file 293 nib.save( 294 nib.Nifti1Image(seg_out.astype(np.int8), affine), 295 output_directory / f"{img_name}.nii.gz" 296 ) 297 298 logging.info(f"Seg shape: {seg_out.shape}") 299 300 logging.info(f"Finished inference! {int(time.time() - time0)} s")
Main function to perform Swin UNETR segmentation inference.
This function parses command-line arguments, sets up logging, loads the validation data, initializes the model, performs inference on the validation dataset, and saves the segmentation results.