skip to main content


This content will become publicly available on July 1, 2024

Title: Multi-scale Hierarchical Vision Transformer with Cascaded Attention Decoding for Medical Image Segmentation
Transformers have shown great success in medical image segmentation. However, transformers may exhibit a limited generalization ability due to the underlying single-scale selfattention (SA) mechanism. In this paper, we address this issue by introducing a Multiscale hiERarchical vIsion Transformer (MERIT) backbone network, which improves the generalizability of the model by computing SA at multiple scales. We also incorporate an attention-based decoder, namely Cascaded Attention Decoding (CASCADE), for further refinement of the multi-stage features generated by MERIT. Finally, we introduce an effective multi-stage feature mixing loss aggregation (MUTATION) method for better model training via implicit ensembling. Our experiments on two widely used medical image segmentation benchmarks (i.e., Synapse Multi-organ and ACDC) demonstrate the superior performance of MERIT over state-of-the-art methods. Our MERIT architecture and MUTATION loss aggregation can be used with other downstream medical image and semantic segmentation tasks.  more » « less
Award ID(s):
2007284
NSF-PAR ID:
10468128
Author(s) / Creator(s):
Publisher / Repository:
Medical Imaging with Deep Learning (MIDL) conference
Date Published:
Subject(s) / Keyword(s):
["Medical image segmentation, Vision transformer, Multi-scale transformer, Feature-mixing augmentation, Self-attention"]
Format(s):
Medium: X
Sponsoring Org:
National Science Foundation
More Like this
  1. Transformers have shown great promise in medical image segmentation due to their ability to capture long-range dependencies through self-attention. However, they lack the ability to learn the local (contextual) relations among pixels. Previous works try to overcome this problem by embedding convolutional layers either in the encoder or decoder modules of transformers thus ending up sometimes with inconsistent features. To address this issue, we propose a novel attention-based decoder, namely CASCaded Attention DEcoder (CASCADE), which leverages the multiscale features of hierarchical vision transformers. CASCADE consists of i) an attention gate which fuses features with skip connections and ii) a convolutional attention module that enhances the long-range and local context by suppressing background information. We use a multi-stage feature and loss aggregation framework due to their faster convergence and better performance. Our experiments demonstrate that transformers with CASCADE significantly outperform state-of-the-art CNN- and transformer-based approaches, obtaining up to 5.07% and 6.16% improvements in DICE and mIoU scores, respectively. CASCADE opens new ways of designing better attention-based decoders. 
    more » « less
  2. Medical image segmentation has played an important role in medical analysis and widely developed for many clinical applications. Deep learning-based approaches have achieved high performance in semantic segmentation but they are limited to pixel-wise setting and imbalanced classes data problem. In this paper, we tackle those limitations by developing a new deep learning-based model which takes into account both higher feature level i.e. region inside contour, intermediate feature level i.e. offset curves around the contour and lower feature level i.e. contour. Our proposed Offset Curves (OsC) loss consists of three main fitting terms. The first fitting term focuses on pixel-wise level segmentation whereas the second fitting term acts as attention model which pays attention to the area around the boundaries (offset curves). The third terms plays a role as regularization term which takes the length of boundaries into account. We evaluate our proposed OsC loss on both 2D network and 3D network. Two common medical datasets, i.e. retina DRIVE and brain tumor BRATS 2018 datasets are used to benchmark our proposed loss performance. The experiments have shown that our proposed OsC loss function outperforms other mainstream loss functions such as Cross-Entropy, Dice, Focal on the most common segmentation networks Unet, FCN. 
    more » « less
  3. Medical image segmentation is one of the most challenging tasks in medical image analysis and widely developed for many clinical applications. While deep learning-based approaches have achieved impressive performance in semantic segmentation, they are limited to pixel-wise settings with imbalanced-class data problems and weak boundary object segmentation in medical images. In this paper, we tackle those limitations by developing a new two-branch deep network architecture which takes both higher level features and lower level features into account. The first branch extracts higher level feature as region information by a common encoder-decoder network structure such as Unet and FCN, whereas the second branch focuses on lower level features as support information around the boundary and processes in parallel to the first branch. Our key contribution is the second branch named Narrow Band Active Contour (NB-AC) attention model which treats the object contour as a hyperplane and all data inside a narrow band as support information that influences the position and orientation of the hyperplane. Our proposed NB-AC attention model incorporates the contour length with the region energy involving a fixed-width band around the curve or surface. The proposed network loss contains two fitting terms: (i) a high level feature (i.e., region) fitting term from the first branch; (ii) a lower level feature (i.e., contour) fitting term from the second branch including the (ii1) length of the object contour and (ii2) regional energy functional formed by the homogeneity criterion of both the inner band and outer band neighboring the evolving curve or surface. The proposed NB-AC loss can be incorporated into both 2D and 3D deep network architectures. The proposed network has been evaluated on different challenging medical image datasets, including DRIVE, iSeg17, MRBrainS18 and Brats18. The experimental results have shown that the proposed NB-AC loss outperforms other mainstream loss functions: Cross Entropy, Dice, Focal on two common segmentation frameworks Unet and FCN. Our 3D network which is built upon the proposed NB-AC loss and 3DUnet framework achieved state-of-the-art results on multiple volumetric datasets. 
    more » « less
  4. Network pruning is a widely used technique to reduce computation cost and model size for deep neural networks. However, the typical three-stage pipeline (i.e., training, pruning, and retraining (fine-tuning)) significantly increases the overall training time. In this article, we develop a systematic weight-pruning optimization approach based on surrogate Lagrangian relaxation (SLR), which is tailored to overcome difficulties caused by the discrete nature of the weight-pruning problem. We further prove that our method ensures fast convergence of the model compression problem, and the convergence of the SLR is accelerated by using quadratic penalties. Model parameters obtained by SLR during the training phase are much closer to their optimal values as compared to those obtained by other state-of-the-art methods. We evaluate our method on image classification tasks using CIFAR-10 and ImageNet with state-of-the-art multi-layer perceptron based networks such as MLP-Mixer; attention-based networks such as Swin Transformer; and convolutional neural network based models such as VGG-16, ResNet-18, ResNet-50, ResNet-110, and MobileNetV2. We also evaluate object detection and segmentation tasks on COCO, the KITTI benchmark, and the TuSimple lane detection dataset using a variety of models. Experimental results demonstrate that our SLR-based weight-pruning optimization approach achieves a higher compression rate than state-of-the-art methods under the same accuracy requirement and also can achieve higher accuracy under the same compression rate requirement. Under classification tasks, our SLR approach converges to the desired accuracy × faster on both of the datasets. Under object detection and segmentation tasks, SLR also converges 2× faster to the desired accuracy. Further, our SLR achieves high model accuracy even at the hardpruning stage without retraining, which reduces the traditional three-stage pruning into a two-stage process. Given a limited budget of retraining epochs, our approach quickly recovers the model’s accuracy.

     
    more » « less
  5. Vision Transformers (ViTs) are built on the assumption of treating image patches as “visual tokens” and learn patch-to-patch attention. The patch embedding based tokenizer has a semantic gap with respect to its counterpart, the textual tokenizer. The patch-to-patch attention suffers from the quadratic complexity issue, and also makes it non-trivial to explain learned ViTs. To address these issues in ViT, this paper proposes to learn Patch-to-Cluster attention (PaCa) in ViT. Queries in our PaCa-ViT starts with patches, while keys and values are directly based on clustering (with a predefined small number of clusters). The clusters are learned end-to-end, leading to better tokenizers and inducing joint clustering-for-attention and attention-for-clustering for better and interpretable models. The quadratic complexity is relaxed to linear complexity. The proposed PaCa module is used in designing efficient and interpretable ViT backbones and semantic segmentation head networks. In experiments, the proposed methods are tested on ImageNet-1k image classification, MS-COCO object detection and instance segmentation and MIT-ADE20k semantic segmentation. Compared with the prior art, it obtains better performance in all the three benchmarks than the SWin [32] and the PVTs [47], [48] by significant margins in ImageNet-1k and MIT-ADE20k. It is also significantly more efficient than PVT models in MS-COCO and MIT-ADE20k due to the linear complexity. The learned clusters are semantically meaningful. Code and model checkpoints are available at https:/github.com/iVMCL/PaCaViT. 
    more » « less