skip to main content


Title: Fine-Tuning can Distort Pretrained Features and Underperform Out-of-Distribution
When transferring a pretrained model to a downstream task, two popular methods are full fine-tuning (updating all the model parameters) and linear probing (updating only the last linear layer—the “head”). It is well known that fine-tuning leads to better accuracy in-distribution (ID). However, in this paper, we find that fine-tuning can achieve worse accuracy than linear probing out-of-distribution (OOD) when the pretrained features are good and the distribution shift is large. On 10 distribution shift datasets (Breeds-Living17, Breeds-Entity30, DomainNet, CIFAR → STL, CIFAR10.1, FMoW, ImageNetV2, ImageNet-R, ImageNet-A, ImageNet-Sketch), fine-tuning obtains on average 2% higher accuracy ID but 7% lower accuracy OOD than linear probing. We show theoretically that this tradeoff between ID and OOD accuracy arises even in a simple setting: fine-tuning overparameterized two-layer linear networks. We prove that the OOD error of fine-tuning is high when we initialize with a fixed or random head—this is because while fine-tuning learns the head, the lower layers of the neural network change simultaneously and distort the pretrained features. Our analysis suggests that the easy two-step strategy of linear probing then full fine-tuning (LP-FT), sometimes used as a fine-tuning heuristic, combines the benefits of both fine-tuning and linear probing. Empirically, LP-FT outperforms both fine-tuning and linear probing on the above datasets (1% better ID, 10% better OOD than full fine-tuning).  more » « less
Award ID(s):
2343611
NSF-PAR ID:
10472125
Author(s) / Creator(s):
; ; ; ;
Publisher / Repository:
ICLR (Oral) 2022 arXiv:2202.10054
Date Published:
Subject(s) / Keyword(s):
["Machine Learning (cs.LG)","Computer Vision and Pattern Recognition (cs.CV)"]
Format(s):
Medium: X
Location:
ICLR (Oral) 2022 arXiv:2202.10054
Sponsoring Org:
National Science Foundation
More Like this
  1. When transferring a pretrained model to a downstream task, two popular methods are full fine-tuning (updating all the model parameters) and linear probing (updating only the last linear layer -- the "head"). It is well known that fine-tuning leads to better accuracy in-distribution (ID). However, in this paper, we find that fine-tuning can achieve worse accuracy than linear probing out-of-distribution (OOD) when the pretrained features are good and the distribution shift is large. On 10 distribution shift datasets (Breeds-Living17, Breeds-Entity30, DomainNet, CIFAR → STL, CIFAR10.1, FMoW, ImageNetV2, ImageNet-R, ImageNet-A, ImageNet-Sketch), fine-tuning obtains on average 2% higher accuracy ID but 7% lower accuracy OOD than linear probing. We show theoretically that this tradeoff between ID and OOD accuracy arises even in a simple setting: fine-tuning overparameterized two-layer linear networks. We prove that the OOD error of fine-tuning is high when we initialize with a fixed or random head -- this is because while fine-tuning learns the head, the lower layers of the neural network change simultaneously and distort the pretrained features. Our analysis suggests that the easy two-step strategy of linear probing then full fine-tuning (LP-FT), sometimes used as a fine-tuning heuristic, combines the benefits of both fine-tuning and linear probing. Empirically, LP-FT outperforms both fine-tuning and linear probing on the above datasets (1% better ID, 10% better OOD than full fine-tuning). 
    more » « less
  2. Pretrained Transformers achieve remarkable performance when training and test data are from the same distribution. However, in real-world scenarios, the model often faces out-of-distribution (OOD) instances that can cause severe semantic shift problems at inference time. Therefore, in practice, a reliable model should identify such instances, and then either reject them during inference or pass them over to models that handle another distribution. In this paper, we develop an unsupervised OOD detection method, in which only the in-distribution (ID) data are used in training. We propose to fine-tune the Transformers with a contrastive loss, which improves the compactness of representations, such that OOD instances can be better differentiated from ID ones. These OOD instances can then be accurately detected using the Mahalanobis distance in the model’s penultimate layer. We experiment with comprehensive settings and achieve near-perfect OOD detection performance, outperforming baselines drastically. We further investigate the rationales behind the improvement, finding that more compact representations through margin-based contrastive learning bring the improvement. We release our code to the community for future research. 
    more » « less
  3. Abstract

    Insect pests cause significant damage to food production, so early detection and efficient mitigation strategies are crucial. There is a continual shift toward machine learning (ML)‐based approaches for automating agricultural pest detection. Although supervised learning has achieved remarkable progress in this regard, it is impeded by the need for significant expert involvement in labeling the data used for model training. This makes real‐world applications tedious and oftentimes infeasible. Recently, self‐supervised learning (SSL) approaches have provided a viable alternative to training ML models with minimal annotations. Here, we present an SSL approach to classify 22 insect pests. The framework was assessed on raw and segmented field‐captured images using three different SSL methods, Nearest Neighbor Contrastive Learning of Visual Representations (NNCLR), Bootstrap Your Own Latent, and Barlow Twins. SSL pre‐training was done on ResNet‐18 and ResNet‐50 models using all three SSL methods on the original RGB images and foreground segmented images. The performance of SSL pre‐training methods was evaluated using linear probing of SSL representations and end‐to‐end fine‐tuning approaches. The SSL‐pre‐trained convolutional neural network models were able to perform annotation‐efficient classification. NNCLR was the best performing SSL method for both linear and full model fine‐tuning. With just 5% annotated images, transfer learning with ImageNet initialization obtained 74% accuracy, whereas NNCLR achieved an improved classification accuracy of 79% for end‐to‐end fine‐tuning. Models created using SSL pre‐training consistently performed better, especially under very low annotation, and were robust to object class imbalances. These approaches help overcome annotation bottlenecks and are resource efficient.

     
    more » « less
  4. null (Ed.)
    We investigate the extent to which individual attention heads in pretrained transformer language models, such as BERT and RoBERTa, implicitly capture syntactic dependency relations. We employ two methods---taking the maximum attention weight and computing the maximum spanning tree---to extract implicit dependency relations from the attention weights of each layer/head, and compare them to the ground-truth Universal Dependency (UD) trees. We show that, for some UD relation types, there exist heads that can recover the dependency type significantly better than baselines on parsed English text, suggesting that some self-attention heads act as a proxy for syntactic structure. We also analyze BERT fine-tuned on two datasets---the syntax-oriented CoLA and the semantics-oriented MNLI---to investigate whether fine-tuning affects the patterns of their self-attention, but we do not observe substantial differences in the overall dependency relations extracted using our methods. Our results suggest that these models have some specialist attention heads that track individual dependency types, but no generalist head that performs holistic parsing significantly better than a trivial baseline, and that analyzing attention weights directly may not reveal much of the syntactic knowledge that BERT-style models are known to learn. 
    more » « less
  5. Background

    Deep learning (DL)‐based automatic segmentation models can expedite manual segmentation yet require resource‐intensive fine‐tuning before deployment on new datasets. The generalizability of DL methods to new datasets without fine‐tuning is not well characterized.

    Purpose

    Evaluate the generalizability of DL‐based models by deploying pretrained models on independent datasets varying by MR scanner, acquisition parameters, and subject population.

    Study Type

    Retrospective based on prospectively acquired data.

    Population

    Overall test dataset: 59 subjects (26 females); Study 1: 5 healthy subjects (zero females), Study 2: 8 healthy subjects (eight females), Study 3: 10 subjects with osteoarthritis (eight females), Study 4: 36 subjects with various knee pathology (10 females).

    Field Strength/Sequence

    A 3‐T, quantitative double‐echo steady state (qDESS).

    Assessment

    Four annotators manually segmented knee cartilage. Each reader segmented one of four qDESS datasets in the test dataset. Two DL models, one trained on qDESS data and another on Osteoarthritis Initiative (OAI)‐DESS data, were assessed. Manual and automatic segmentations were compared by quantifying variations in segmentation accuracy, volume, and T2 relaxation times for superficial and deep cartilage.

    Statistical Tests

    Dice similarity coefficient (DSC) for segmentation accuracy. Lin's concordance correlation coefficient (CCC), Wilcoxon rank‐sum tests, root‐mean‐squared error‐coefficient‐of‐variation to quantify manual vs. automatic T2 and volume variations. Bland–Altman plots for manual vs. automatic T2 agreement. APvalue < 0.05 was considered statistically significant.

    Results

    DSCs for the qDESS‐trained model, 0.79–0.93, were higher than those for the OAI‐DESS‐trained model, 0.59–0.79. T2 and volume CCCs for the qDESS‐trained model, 0.75–0.98 and 0.47–0.95, were higher than respective CCCs for the OAI‐DESS‐trained model, 0.35–0.90 and 0.13–0.84. Bland–Altman 95% limits of agreement for superficial and deep cartilage T2 were lower for the qDESS‐trained model, ±2.4 msec and ±4.0 msec, than the OAI‐DESS‐trained model, ±4.4 msec and ±5.2 msec.

    Data Conclusion

    The qDESS‐trained model may generalize well to independent qDESS datasets regardless of MR scanner, acquisition parameters, and subject population.

    Evidence Level

    1

    Technical Efficacy

    Stage 1

     
    more » « less