The FIRST Plug-and-Play Method for Deformable Image Registration
1Department of Electrical & System Engineering, Washington University in St. Louis, St. Louis, MO, USA 2Department of Computer Science & Engineering, Washington University in St. Louis, St. Louis, MO, USA 3Program of Imaging Science, Washington University in St. Louis, St. Louis, MO, USA 4Mallinckrodt Institute of Radiology, Washington University in St. Louis, St. Louis, MO, USA 5Department of Biomedical Engineering, Washington University in St. Louis, St. Louis, MO, USA 6Department of Neurology, Washington University in St. Louis, St. Louis, MO, USA
Junhao Hu and Weijie Gan contributed equally to this project.
Deformable image registration (DIR) is an active research topic in biomedical imaging. There is a growing interest in developing DIR methods based on deep learning (DL). A traditional DL approach to DIR is based on training a convolutional neural network (CNN) to estimate the registration field between two input images. While conceptually simple, this approach comes with a limitation that it exclusively relies on a pre-trained CNN without explicitly enforcing fidelity between the registered image and the reference. We present plug-and-play image registration network (PIRATE) as a new DIR method that addresses this issue by integrating an explicit data-fidelity penalty and a CNN prior. PIRATE pre-trains a CNN denoiser on the registration field and "plugs" it into an iterative method as a regularizer. We additionally present PIRATE+ that fine-tunes the CNN prior in PIRATE using deep equilibrium models (DEQ). PIRATE+ interprets the fixed-point iteration of PIRATE as a network with effectively infinite layers and then trains the resulting network end-to-end, enabling it to learn more task-specific information and boosting its performance. Our numerical results on OASIS and CANDI datasets show that our methods achieve state-of-the-art performance on DIR.
Video 1: Evolution of the registration field accross PIRARE+ iterations with the corresponding warped images.
Figure 1: Illustration of the PIRATE and PIRATE+ pipelines. PIRATE updates the registration field using the penalty function that measures the similarity between the warped image and the fixed image, as well as a pre-trained CNN denoiser used as a regularizer. The DEQ update in PIRATE+ enables to fine-tune the CNN by calculating the gradients using implicitly differentiation through the fixed point of the forward iteration. As described in this paper, the DEQ update of PIRATE+ is computed using the weighted loss consisting of similarity loss, smoothness loss, and Jacobian loss.
Figure 2: The visual results of warped images (top) and correlated warped segmentation masks (bottom) from PIRATE+ and selected benchmarks on the OASIS dataset. The regions of interest are highlighted within a red box. The DSC for each image is displayed in the bottom row. Note that the result of PIRATE+ is more consistent with the fixed image with fewer artifacts compared with other baselines.
Figure 3: The visual results for the warped grid (top) and negative JD (yellow points in bottom) of PIRATE+ and its five ablated variants on CANDI dataset. We denote P as the penalty function, R as the denoiser regularizer, S as the smoothness regularizer, and D as DEQ. PIRATE is formulated as P+R+S, and PIRATE+ is formulated as P+D+S. The gradient loss is shown in the top row, and the ratio of negative JD is shown in the bottom row. Note that PIRATE+'s architecture is optimal, since it significantly reduces the negative JD and provides smoother registration field.