<?xml-model href='http://www.tei-c.org/release/xml/tei/custom/schema/relaxng/tei_all.rng' schematypens='http://relaxng.org/ns/structure/1.0'?><TEI xmlns="http://www.tei-c.org/ns/1.0">
	<teiHeader>
		<fileDesc>
			<titleStmt><title level='a'>Overcoming Distribution Shifts in Plug-and-Play Methods with Test- Time Training</title></titleStmt>
			<publicationStmt>
				<publisher>IEEE</publisher>
				<date>12/10/2023</date>
			</publicationStmt>
			<sourceDesc>
				<bibl> 
					<idno type="par_id">10504926</idno>
					<idno type="doi">10.1109/CAMSAP58249.2023.10403502</idno>
					<title level='j'>IEEE International Workshop on Computational Advances in Multi-Sensor Adaptive Processing</title>
<idno></idno>
<biblScope unit="volume"></biblScope>
<biblScope unit="issue"></biblScope>					

					<author>Edward P. Chandler</author><author>Shirin Shoushtari</author><author>Jiaming Liu</author><author>M. Salman Asif</author><author>Ulugbek S. Kamilov</author>
				</bibl>
			</sourceDesc>
		</fileDesc>
		<profileDesc>
			<abstract><ab><![CDATA[Plug-and-Play Priors (PnP) is a well-known class of methods for solving inverse problems in computational imaging. PnP methods combine physical forward models with learned prior models specified as image denoisers. A common issue with the learned models is that of a performance drop when there is a distribution shift between the training and testing data. Test-time training (TTT) was recently proposed as a general strategy for improving the performance of learned models when training and testing data come from different distributions. In this paper, we propose PnP-Ttt as a new method for overcoming distribution shifts in PnP. PnP-TTT uses deep equilibrium learning (DEQ) for optimizing a self-supervised loss at the fixed points of PnP iterations. PnP-TTT can be directly applied on a single test sample to improve the generalization of PnP. We show through simulations that given a sufficient number of measurements, PnP-TTT enables the use of image priors trained on natural images for image reconstruction in magnetic resonance imaging (MRI).]]></ab></abstract>
		</profileDesc>
	</teiHeader>
	<text><body xmlns="http://www.tei-c.org/ns/1.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:xlink="http://www.w3.org/1999/xlink">
<div xmlns="http://www.tei-c.org/ns/1.0"><head n="1.">INTRODUCTION</head><p>Many computational imaging problems can be formulated as inverse problems, where the goal is to recover an unknown image from a set of noisy measurements. It is common to solve inverse problems by integrating the measurement model characterizing the response of the imaging instrument with a regularizer infusing prior knowledge on the unknown image. There has been considerable recent interest in using deep learning (DL) for designing data-driven image priors <ref type="bibr">[1,</ref><ref type="bibr">2,</ref><ref type="bibr">3]</ref>. DL methods eliminate the need for explicit prior modeling by learning a mapping from measurements to target images using convolutional neural networks (CNN).</p><p>Model-based DL (MBDL) is an extension to traditional DL that integrates the image prior defined through a CNN with the knowledge of the measurement models. For example, plug-and-play priors (PnP) is a well-known MBDL approach that uses pre-trained image denoiser as priors <ref type="bibr">[4,</ref><ref type="bibr">5,</ref><ref type="bibr">3]</ref>. Other MBDL widely-used MBDL approaches include deep unfolding (DU) and deep equilibrium (DEQ) learning, both of which rely on the integration of the measurement model during the training of the image prior <ref type="bibr">[6,</ref><ref type="bibr">7,</ref><ref type="bibr">8,</ref><ref type="bibr">9,</ref><ref type="bibr">10]</ref>. While both DU and DEQ interpret iterations of image reconstruction as neural network layers, the memory complexity of DEQ is independent of the number of unfolded iterations.</p><p>Much of the existing research on MBDL has focused on the scenarios where the statistical distribution of the training data matches that of the testing data. While this strategy has led to significant theoretical and algorithmic innovations, it does not address the issue of the performance gap due to data distribution shifts. For example, image priors trained with a specific distribution in PnP, performs poorly on samples from different distributions <ref type="bibr">[11]</ref>. Thus, distribution shifts limit the applicability of priors pre-trained for one class to another one. Domain adaptation refers to a class of DL techniques for improving the performance of a learned model on a target task containing insufficient annotated data by using the knowledge learned by the model from another related task with adequate labeled data <ref type="bibr">[12,</ref><ref type="bibr">13]</ref>. Test-time training (TTT) was recently proposed as a domain adaptation strategy based on self-supervised optimization of the learned model utilizing only test-time measurements <ref type="bibr">[14]</ref>. The TTT strategy was also recently used in the context of imaging inverse problems to address domain shifts in end-to-end image reconstruction with DL for accelerated magnetic resonance imaging (MRI) <ref type="bibr">[15]</ref>.</p><p>In this paper, we investigate TTT in the context of PnP methods. We propose PnP-TTT as a method for overcoming the performance gap in PnP due to data distribution shifts. PnP-TTT uses DEQ to update the weights of the CNN prior in PnP at test-time. The DEQ learning in PnP-TTT is used to minimize a self-supervised loss at the fixed points of PnP iterations for one test sample. We also present numerical results showing that DEQ training in PnP-TTT can significantly boost the performance of the shifted priors. We evaluate the proposed method on image reconstruction for compressed sensing MRI (CS-MRI), where we recover MRI images from subsampled Fourier measurements. Our results show that given enough measurements, PnP-TTT can close the gap due to distribution shift between test and training data. It is worth mentioning that our method can also be applied to other tasks and different variants of PnP, highlighting its broader applicability for inverse problems in computational imaging.</p></div>
<div xmlns="http://www.tei-c.org/ns/1.0"><head n="2.">BACKGROUND 2.1. Inverse Problems</head><p>We consider the problem of recovering an image x 2 C n from its noisy measurement y = Ax + e, where A 2 C m&#8677;n is the measurement operator and e 2 C m is additive white Gaussian noise (AWGN). We can formulate the problem as a regularized optimization problem</p><p>x &#8676; = arg min</p><p>where g is the data-fidelity term used to ensure the consistency of the solution with the measurement and h is the regularization term that infuses prior knowledge. For example, the least-squares loss is a widely-used data-fidelity term g(x) = 1 2 ky Axk 2 2 and total variation (TV) is commonly used as the regularizer <ref type="bibr">[16]</ref>.</p></div>
<div xmlns="http://www.tei-c.org/ns/1.0"><head n="2.2.">Plug-and-Play Priors</head><p>PnP framework includes a family of methods that incorporate the measurement model with CNN denoisers to solve inverse problems <ref type="bibr">[3]</ref>. PnP methods can be interpreted as a fixed-point iteration of some high-dimensional operator where the CNN takes the role of the prior. For example, the proximal gradient method (PGM) variant of PnP can be expressed as</p><p>where D &#10003; is the denoiser, g is the data-fidelity term, rg is the gradient of g, I is the identity mapping, and &gt; 0 is the step-size. The PnP method in ( <ref type="formula">2</ref>) is commonly refered to as PnP-PGM.</p></div>
<div xmlns="http://www.tei-c.org/ns/1.0"><head n="2.3.">Deep Equilibrium Models</head><p>DEQ is a recent approach for training MBDL architectures in a memory-efficient way <ref type="bibr">[9]</ref>. DEQ uses implicit differentiation for training possibly infinite-depth networks by backpropagating through the fixed points of an operator.</p><p>For the operator defined in eq. ( <ref type="formula">2</ref>), the output is implicitly expressed as</p><p>where T &#10003; is the operator parameterized by &#10003;, and x is the fixed point acquired using fixed point iterations in the </p><p>forward pass of DEQ. The connection of DEQ and PnP has inspired end-to-end training of CNN denoisers as model dependant priors in many imaging problems such as MRI <ref type="bibr">[9]</ref> and computed tomography (CT) <ref type="bibr">[10]</ref>. The prior D &#10003; in DEQ is trained by minimizing the loss between the fixed points from eq. ( <ref type="formula">3</ref>) and the ground truth</p><p>Implicit differentiation of the fixed points yields the gradient of the loss with respect to &#10003; in the backward pass as</p><p>where I is the identity mapping and `is the loss.</p></div>
<div xmlns="http://www.tei-c.org/ns/1.0"><head n="2.4.">Test-Time Training</head><p>Current PnP methods are built on the premise that the prior represents the same distribution as that of the desired solution. However, it is common to observe distribution shifts between training and testing data. In some scenarios, there are insufficient training samples to train a DL network as the prior, hence, alternative priors trained on a shifted distribution are used with suboptimal reconstruction performance. TTT has been proposed to reduce the performance gap due to distribution shift in various tasks <ref type="bibr">[14]</ref>. The key idea of TTT is to update the shifted model's weight at test-time by minimizing a self-supervised loss</p><p>where D &#10003; is the neural network and y is a test sample. Depending on the selection of `sup , TTT has shown improved performance in many imaging tasks. For example, it can be used to improve the MRI reconstruction using DL models trained in an end-to-end matter on shifted distributions <ref type="bibr">[15]</ref>. In this scenario, the self-supervised loss proposed is</p><p>where A is the measurement model, A &#8224; is the Hermitian transpose, and y is the test-time measurement. Note that as opposed to (4), TTT in <ref type="bibr">(7)</ref> does not need ground truth reconstruction to compute `sup and one can use other loss functions rather than the normalized `1 <ref type="bibr">[15]</ref>.</p></div>
<div xmlns="http://www.tei-c.org/ns/1.0"><head n="2.5.">Our contribution</head><p>We propose PnP-TTT as a novel approach for enhancing the performance of image reconstruction for PnP methods Fig. <ref type="figure">1</ref>: Evaluation of PnP-TTT for different sampling ratios in accelerated MRI. The leftmost chart displays the best PSNR performance achieved by PnP-TTT vs. sampling ratios. The remaining charts show PSNR at each TTT iteration. Note that the best performance is above the lower baseline for all the sampling ratios; however, TTT eventually overfits to the test-time measurement, reducing performance. Additionally, note that at larger sampling ratios, the performance of PnP-TTT prior can surpass that of the matched prior due to the DEQ training. under distribution shifts. Our approach involves domain adaptation for a shifted pre-trained image prior through TTT using DEQ to close the distribution gap, which only requires a test single measurement. Our results show that PnP-TTT can improve the performance significantly given sufficient measurement for shifted priors with minimal computational cost.</p></div>
<div xmlns="http://www.tei-c.org/ns/1.0"><head n="3.">METHOD</head><p>We now present our method for domain adaption of image priors in PnP. We consider the PnP-PGM algorithm in eq. ( <ref type="formula">2</ref>) and run it until its convergence. In practice, we find that about 100 iterations of PnP-PGM are sufficient in our configuration. We can update the weights of the image prior on a test measurement by minimizing the following self-supervised loss</p><p>where x is the fixed-point of PnP-PGM defined in eq. ( <ref type="formula">3</ref>) and T &#10003; is the operator defined in <ref type="bibr">(2)</ref>. We use the DEQ to compute the gradient of `sup at test-time using implicit differentiation. We follow a method similar to <ref type="bibr">[17,</ref><ref type="bibr">18]</ref> to train the image priors using the DnCNN architecture, with batch normalization layers replaced with spectral normalization to control the Lipschitz constant of the denoisers. DnCNN is trained as a denoiser for AWGN level = 5. During the training stage we do not use DU or DEQ so that the learned prior model is purely an image denoiser. We use 400 CBSD to train natural prior on grayscale images of size 180 &#8677; 180 <ref type="bibr">[19]</ref>. We train MRI priors on MRI brain images of size 256 &#8677; 256 <ref type="bibr">[20]</ref>.</p><p>For test-time training, we initialize PnP-PGM with x 0 = 0 and 100 iterations in the forward pass of DEQ, using the trained denoising prior. We use Nesterov acceleration <ref type="bibr">[21]</ref>, and set stepsize = 1. We use 100 iterations and Anderson acceleration in the backward pass of DEQ <ref type="bibr">[22]</ref>. We allow TTT to run for 50 iterations, using SGD to update the parameters &#10003; with a step size of 1 &#8677; 10 5 . At inference, using the adapted prior, we again run for 100 iterations with Nesterov acceleration in PnP-PGM with step size = 1. Note that once all 50 TTT iterations are performed for a particular measurement, &#10003; is reset to the non-domain-adapted weights.</p><p>Since the goal of PnP-TTT is to overcome the performance gap from a distribution shift between train-and test-time, performing TTT on as many measurements are available at test-time may be beneficial. Future experiments could examine if there is any performance improvement when using multiple measurements instead of only one for PnP-TTT.</p><p>The measurement model for a single-coil, accelerated MRI with radial Fourier sampling can be modeled as A = M F , where M is the diagonal sampling matrix and F is the Fourier transform. We investigate five different sampling ratios (m/n) in the experiments. For the experiments reported here, we consider a noiseless scenario; however, we expect similar performance of PnP-TTT under moderate amounts of noise.  </p></div>
<div xmlns="http://www.tei-c.org/ns/1.0"><head n="4.">RESULTS</head><p>We test our proposed method by reconstructing ten brain MRI images selected from the test dataset of <ref type="bibr">[20]</ref> with mismatched DnCNN prior trained on natural images. Due to distribution shift, natural priors demonstrate suboptimal performance for the MRI task. To establish a performance baseline, we compare the result of the proposed method with those obtained by mismatched natural prior and matched MRI prior. Specifically, we consider the performance achieved by natural prior as the lower baseline, and that achieved by the MRI prior as the upper baseline. Our proposed PnP-TTT seeks to enhance the performance of a mismatched natural prior so as to approach that of a matched MRI prior.</p><p>Table <ref type="table">1</ref> reports the best results achieved for five CS ratios: 10, 20, 30, 40, and 50. It can be seen that PnP-TTT can close the performance gap for CS ratios of 20 and more, while for a CS ratio of 10, it can make an improvement compared to the lower baseline (mismatched natural prior). The reconstruction quality is quantified using peak signal-tonoise ratio (PSNR) in dB and the structural similarity index measure (SSIM). </p></div>
<div xmlns="http://www.tei-c.org/ns/1.0"><head n="5.">CONCLUSION</head><p>We present PnP-TTT as a novel framework for closing the performance gap that arises due to mismatched priors in imaging inverse problems. PnP-TTT achieves this by adapting the mismatched priors during the testing phase by using DEQ training to update the weights of the mismatched priors. One of the main advantage of PnP-TTT is that one can use mismatched priors on a shifted distribution without the need to do additional training. Instead, the prior can simply be adapted to the test-time measurements. Our results show that PnP-TTT can significantly enhance the performance, achieving performance comparable to that of using a matched prior during inference. Furthermore, this work demonstrates that priors from different tasks can be used interchangeably in scenarios with shifted distribution without the loss of performance.</p></div><note xmlns="http://www.tei-c.org/ns/1.0" place="foot" n="2023" xml:id="foot_0"><p>IEEE 9th International Workshop on Computational Advances in Multi-Sensor Adaptive Processing (CAMSAP)</p></note>
			<note xmlns="http://www.tei-c.org/ns/1.0" place="foot" xml:id="foot_1"><p>Authorized licensed use limited to: WASHINGTON UNIVERSITY LIBRARIES. Downloaded on May 06,2024 at 04:35:05 UTC from IEEE Xplore. Restrictions apply.</p></note>
			<note xmlns="http://www.tei-c.org/ns/1.0" place="foot" n="2" xml:id="foot_2"><p>2 ,(8)</p></note>
		</body>
		</text>
</TEI>
