NestedFormer: Nested Modality-Aware Transformer for Brain Tumor Segmentation

Zhaohu Xing 1Medical College of Tianjin University, Tianjin, China
, 1 1xingzhaohu@tju.edu.cn
   Lequan Yu 2The University of Hong Kong, Hong Kong, China 2    Liang Wan(🖂) 1Medical College of Tianjin University, Tianjin, China
, 1 1xingzhaohu@tju.edu.cn
   Tong Han 3Brain Medical Center of Tianjin University, Huanhu Hospital, Tianjin, China 3    Lei Zhu 4The Hong Kong University of Science and Technology (Guangzhou), Guangzhou, China 45The Hong Kong University of Science and Technology, Hong Kong, China  5
1email: lwan@tju.edu.cn
Abstract

Multi-modal MR imaging is routinely used in clinical practice to diagnose and investigate brain tumors by providing rich complementary information. Previous multi-modal MRI segmentation methods usually perform modal fusion by concatenating multi-modal MRIs at an early/middle stage of the network, which hardly explores non-linear dependencies between modalities. In this work, we propose a novel Nested Modality-Aware Transformer (NestedFormer) to explicitly explore the intra-modality and inter-modality relationships of multi-modal MRIs for brain tumor segmentation. Built on the transformer-based multi-encoder and single-decoder structure, we perform nested multi-modal fusion for high-level representations of different modalities and apply modality-sensitive gating (MSG) at lower scales for more effective skip connections. Specifically, the multi-modal fusion is conducted in our proposed Nested Modality-aware Feature Aggregation (NMaFA) module, which enhances long-term dependencies within individual modalities via a tri-orientated spatial-attention transformer, and further complements key contextual information among modalities via a cross-modality attention transformer. Extensive experiments on BraTS2020 benchmark and a private meningiomas segmentation (MeniSeg) dataset show that the NestedFormer clearly outperforms the state-of-the-arts. The code is available at https://github.com/920232796/NestedFormer.

Keywords:
Multi-modal MRI Brain Tumor Segmentation Nested Modality-Aware Feature Aggregation Modality-Sensitive Gating
\floatsetup

[table]capposition=top \newfloatcommandcapbtabboxtable[][\FBwidth]

1 Introduction

Brain tumor is one of the most common cancers in the world [Bray2018Cancers], in which gliomas are the most common malignant brain tumors with different levels of aggressiveness and meningiomas are the most prevalent primary intracranial tumors in adults [Ostrom2020BrainTumor]. Multi-modal magnetic resonance imaging (MRI) is routinely used in the clinic by providing rich complementary information for analyzing brain tumors. Specifically, for gliomas, the commonly used MRI sequences are T1-weighted (T1), post-contrast T1-weighted (T1Gd), T2-weighted (T2) and T2 Fluid Attenuation Inversion Recovery (T2-FLAIR) images; see Fig. 1(a), each with varying roles in distinguishing tumor, peritumoral edema and tumor core  [menze2014multimodal, bakas2017advancing, bakas2018identifying]. For meningiomas, they have different characteristic appearances on T1Gd [li2019presurgical] and contrast-enhanced T2-FLAIR (shorted for FLAIR-C) MRI images; see Fig. 1(b). Thus, automatic segmentation of brain tumor structures from multi-modal MRIs is important for clinical diagnosis and treatment planning.

Figure 1: Multi-modal MRIs for (a) Gliomas; and (b) Meningiomas.

In recent years, convolutional neural networks (CNNs) have achieved promising successes in brain tumor segmentation. The main stream models are built upon the encoder-decoder architecture [ronneberger2015u] with skip connections, including S3D-UNet [chen2018s3d], SegResNet [myronenko20183d], HPU-Net [kong2018hybrid], etc. Recent works [wang2021transbts, zhang2021DualTrans, hatamizadeh2022unetr] also explore transformer [vaswani2017attention] to model long-range dependencies within images. For instance, TransBTS [wang2021transbts] utilizes 3D-CNN to extract local spatial features, and applies transformer to model global dependencies on high-level features. UNETR [hatamizadeh2022unetr] uses the ViT transformer as the encoder to learn contextual information, which is merged with the CNN-based decoder via skip connections at multiple resolutions. However, the transformer in these methods is used to enhance the encode path without specific design for multi-modal fusion.

To utilize the multi-modal information, most of existing methods adopt an early-fusion strategy, in which multi-modal images are concatenated as the network input. However, this strategy can hardly explore non-linear relationships between different modalities. To alleviate this problem, recent works follow a layer-fusion strategy [Dolz2019HyperDense, zhou20213d, zhang2021modality], where the modality-specific features extracted by different encoders are fused in the middle layers of the network and share the same decoder. In HyeprDense-Net [Dolz2019HyperDense], each modality has a separated stream and dense connections are introduced between layers within the same stream and also across different streams. MAML [zhang2021modality] embeds multi-modal images by different modality-specific FCNs and then applies a modality-aware module to regress attention maps in order to fuse the modality-specific features. Nevertheless, these multi-modal fusion methods do not build the long-range spatial dependencies within and cross modalities, so that they cannot fully utilize the complementary information of different modalities.

In this paper, we propose a novel nested modality-aware transformer, called NestedFormer, for effective and robust multi-modal brain tumor segmentation. We first design an effective Global Poolformer to extract discriminative volumetric spatial features, with more emphasis on global dependencies, from different MRI modalities. To better extract the complementary features and enable any number of modalities for fusion, we propose a novel Nested Modality-aware Feature Aggregation (NMaFA) module. It explicitly considers both single-modality spatial coherence and cross-modality coherence, and leverages nested transformers to establish the intra- and inter-modality long-range dependencies, resulting in more effective feature representation. Moreover, we design a computationally efficient Tri-orientated Spatial Attention (TSA) paradigm to accelerate the 3D-spatial-coherence calculation. To improve feature reuse effect in the decoding, a novel modality-sensitive gating (MSG) module is developed to dynamically filter modality-aware low-resolution features for effective skip connections. Extensive experiments on BraTS2020 benchmark and a privately collected meningiomas segmentation dataset (MeniSeg) show that our model clearly ourperforms the state-of-the-art methods.

Figure 2: An overview of the proposed NestedFormer. We design a Nested Modality-aware Feature Aggregation (NMaFA) module to model both the intra- and inter-modality features for multi-modal fusion.

2 Method

Fig. 2 illustrates the overview of the proposed NestedFormer, which consists of three components: 1) multiple encoders to obtain multi-scale representations of different modalities, 2) a NMaFA fusion module to explore correlated features within and between multi-modal high-level embeddings, and 3) a gating strategy to selectively transfer modality-sensitive low-resolution features to the decoder.

2.1 Global Poolformer Encoder

Recent works show that transformer is more conducive to modeling global information than CNNs. To better extract local context information for each modality, we extend the Poolformer [yu2021metaformer] as the modality-specific encoder. As discussed in [yu2021metaformer], replacing the computation-intensive attention module in Transformer with average pooling can achieve superior performance than recent transformer and MLP-like models. Therefore, to enhance global information, we design Global Poolformer Block (GPB), which leverages global pooling instead of average pooling in Poolformer, followed by a fully connected layer. As shown in Fig. 2, given the input feature embedding , a GPB block consists of a learnable global pooling (GP) and a MLP sub-block. The output is computed as,

(1)

where denotes the layer normalization and is the learnable parameter in the FC layer. Our Global Poolformer encoder contains five groups of one feature embedding (FE) layer and two GPB blocks. Each FE layer is a 3D-convolution, while the first FE layer has a convolution patch size of and the rest layers have a patch size of and a stride of 2. The encoders gradually encode each modality image into high-level feature , where are of input spatial resolutions , and depth dimension ; is the number of modal images, the channel dimension and the layer number are set as , .

2.2 Nested Modality-Aware Feature Aggregation

Given high-level features , NMaFA leverages a spatial-attention based transformer and a cross-modality attention based transformer in a nested manner; see Fig. 3. First, transformer utilizes the self-attention scheme to compute the long-range correlation between different patches in the space within each modality. Specifically, is concatenated in the channel dimension to obtain high-level embedding . In this work, each location of is considered as one “patch”. Then a patch embedding layer maps to a token sequence . takes and the position encoding [vaswani2017attention] as the input, and outputs spatially-enhanced feature .

Second, transformer utilizes the cross-attention scheme to further compute the global relation among different modalities to achieve inter-modality fusion. To this end, is concatenated in the spatial dimension to obtain the flatten sequence . Here, denotes the number of dominant tokens learnt via the Token Learner strategy [ryoo2021tokenlearner], which helps to reduce the computational scope especially when the number of tokens increases greatly along with more modalities. After that, both and are fed into to get the modality-enhanced feature embedding .

Also note that our two modules are different from traditional channel-spatial attention networks, which reweigh feature maps channel-wise and spatial-wise. Our NMaFA relies on transformer mechanism and the two transformers are fused in a nested form, rather than serial [khanh2020enhancing] or parallel [mou2019cs] fusion.

Transformer with Tri-orientated Spatial Attention. To improve the computational efficiency of spatial attention for volumetric embeddings, inspired by [ho2019axial, liu2021swin], we leverage axial-wise attention , plane-wise attention , and window-wise attention ; see Fig. 3(b). Concretely, models the long-range relationship among feature tokens along the vertical direction; models the long-range relationship within each slice; uses sliding windows to model the relationship across local 3D-windows. We employ axial and planar learnable absolute position encodings [vaswani2017attention] for and , respectively, and use relative position encoding for window-wise attention  [liu2021swin]. The resultant attention is computed as follows,

(2)

where denotes the embedding tokens with sequence length and embedding dimension after layer normalization, . By this way, the model not only enhances feature extraction of local important regions, but also calculates global feature dependencies with less computation.

Figure 3: NMaFA: Nested Modality-aware Feature Aggregation. (a) The overall architecture. (b) The transformer with tri-orientated spatial attention . (c) The transformer with cross-modality attention .

Transformer with Cross-Modality Attention. By concatenating features in the channel dimension, mainly enhances the dependencies within each modality and yields , although the inter-modality integration also takes place via patch embedding. To explicitly explore relationship among modalities, we concatenate the feature tokens of different modalities along the spatial dimension, yielding ; and then use a cross-attention transformer to enhance the modality dependency information into ; see Fig. 3(c). The input triplet of (Query, Key, Value) to the cross-attention is computed as

(3)

where are the weight matrices, is the dimension of . The cross attention is then formulated as

(4)

The resultant token sequence from fuses and enhances the input features with increasing receptive fields and the cross-modal global relevance.

2.3 Modality-Sensitive Gating

In feature decoding, we first fold the tokens back to a high-level 4D feature map . is progressively processed in a regular bottom-up style with a 3D convolution and upsampling operation to recover a full resolution feature map for segmentation, where is the number of segments. Note that the encoder features are multi-modal. Hence, we design a modality-sensitive gating strategy in skip connection, to filter the encoder features according to the modality importance. To be specific, for the -th layer, an modality importance map is learnt from that is the output of NMaFA, as follows,

(5)

where is a full connection layer, denotes upsampling times, and is the sigmoid function. Denote as element-wise multiplication. Then the filtered encoder feature is formulated as

(6)

3 Experiment

3.1 Implementation Details

Our NestedFormer was implemented in PyTorch1.7.0 on a NVIDIA GTX 3090 GPU. The parameters were initialized via Xavier [glorot2010understanding]. The loss function was a combination of soft dice loss and cross-entropy loss and we adopted the AdamW optimizer [loshchilov2017decoupled] with a weight decay of . The learning rate was empirically set as .We adopted two sequentially and just one . In , the window-size was set as (2, 2, 2) for BraTS2020 and (2, 4, 4) for MeniSeg.

3.2 Datasets and Evaluation Metrics

For evaluation, we use a public brain tumor segmentation dataset BraTS2020 [menze2014multimodal] and a private 3D meningioma segmentation dataset (MeniSeg) collected from Brain Medical Center of Tianjin University, Tianjin Huanhu hospital. Dice score and 95% Hausdorff Distance (HD95) are adopted for quantitative comparison.

BraTS2020 Dataset.   The BraTS2020 training dataset contains 369 aligned four-modal MRI data (i.e., T1, T1Gd, T2, T2-FLAIR), with expert segmentation masks (i.e., GD-enhancing tumor, peritumoral edema, and tumor core). Each modality has a volume and is already resampled and co-registered. The segmentation task aims to segment the whole tumor (WT), enhancing tumor (ET), and tumor core (TC) regions. Following the recent work [larrazabal2021orthogonal], we randomly divide the dataset into training (315), validation (17) and test (37).

Meningioma Dataset.   The MeniSeg dataset contains annotated two-modal MRIs (i.e., T1Gd and FLAIR-C) from the meningiomas patients, who had undergone tumor resection between March 2016 and March 2021. MRI scans were performed with four 3.0T MRI scanners (Skyra, Trio, Avanto, Prisma from Siemens). Two radiologists annotated meningioma tumor and edema masks on T1Gd and FLAIR-C MRIs, and the third high-experienced radiologist made examination. Each modality data has a volume of , and is aligned into the same space and sampled to volume sizes of [32, 192, 192] for training. Two-fold cross-validation is conducted for all the compared methods.

 missingmissing

Methods
Param
(M)
FLOPs
(G)
WT TC ET Ave
Dice HD95 Dice HD95 Dice HD95 Dice HD95

 missingmissing

3D-UNet [cciccek20163d] 5.75 1449.59 0.882 5.113 0.830 6.604 0.782 6.715 0.831 6.144
SegResNet [myronenko20183d] 18.79 185.23 0.903 4.578 0.845 5.667 0.796 7.064 0.848 5.763
MAML [zhang2021modality] 5.76 577.65 0.914 4.804 0.854 5.594 0.796 5.221 0.855 5.206
nnUNet [isensee2021nnu] 5.75 1449.59 0.907 6.94 0.848 5.069 0.814 5.851 0.856 5.953

 missingmissing

SwinUNet(2D) [cao2021swin] 27.17 357.49 0.872 6.752 0.809 8.071 0.744 10.644 0.808 8.489
TransBTS [wang2021transbts] 32.99 333 0.910 4.141 0.855 5.894 0.791 5.463 0.852 5.166
UNETR [hatamizadeh2022unetr] 92.58 41.19 0.899 4.314 0.842 5.843 0.788 5.598 0.843 5.251

 missingmissing

NestedFormer 10.48 71.77 0.920 4.567 0.864 5.316 0.800 5.269 0.861 5.051

 missingmissing

Table 1: Quantitative comparison on BraTS 2020 dataset.

3.3 Comparison with SOTA Methods

We compare our network against seven SOTA segmentation methods, including three CNN-based methods (3D-UNet [cciccek20163d], SegResNet [myronenko20183d], MAML [zhang2021modality], nnUNet [isensee2021nnu]), and three transformer-based methods (SwinUNet(2D) [cao2021swin], TransBTS [wang2021transbts], and UNETR [hatamizadeh2022unetr]). For a fair comparison, we utilize the public implementations of compared methods to re-train their networks for generating their best segmentation results. Considering the computation power, all the methods are trained for at most 300 epochs on BraTS2020 and 200 epochs on MeniSeg.

BraTS2020. Table 1 reports the Dice and HD95 scores on three regions (WT, TC, and ET) as well as the averaged scores of all the methods on BraTS2020. Apparently, our NestedFormer achieves the largest Dice score on WT, the largest Dice score on TC, the smallest HD95 scores on TC, and our method also ranks second in Dice score on ET, and second in HD95 score on WT and ET. More importantly, our method has the best quantitative performance with averaging Dice and HD95 scores to be 0.861 and 5.051. It is noted that HD95 is for the distance difference between two sets of points, which is more sensitive than Dice [wang2021transbts]. Hence, Dice is often used as the main metric and HD95 as the reference. We also experimented with two-fold cross-validation for UNETR, TransBTS and our method, while our method outperforms the two methods in WT and TC, and is quite close to the best result in ET. As for model complexity, our model has 10.48M parameters and 71.77G FLOPs which is a moderate size model.

MeniSeg. In Table 3, we list Dice and HD95 scores of our network and compared methods on tumor and edema regions on the MeniSeg dataset, as well as the average metrics. Among all the compared methods, MAML has the largest Dice score of 0.819 at the tumor segmentation, while UNETR has the largest Dice score of 0.693 at the edema segmentation, and average Dice score of 0.755. In comparison, our method has a 1.5% Dice improvement in meningioma tumor, 0.2% Dice improvement in edema, and 1.0% average Dice improvement. Regarding HD95, our method achieves the 4th smallest score of 2.647 on the tumor segmentation, and the smallest score of 6.173 on the edema segmentation.

{floatrow}
\ttabbox
Table 2: Quantitative comparison on MeniSeg dataset.
Methods Tumor Edema Ave
Dice HD95 Dice HD95 Dice HD95

 missingmissing

3D-UNet [cciccek20163d] 0.799 5.099 0.676 9.655 0.737 7.377
SegResNet [myronenko20183d] 0.813 2.970 0.665 10.438 0.739 6.704
MAML [zhang2021modality] 0.819 2.112 0.682 9.158 0.750 5.635

 missingmissing

SwinUNet(2D) [cao2021swin] 0.807 1.817 0.679 7.986 0.743 4.901
TransBTS [wang2021transbts] 0.809 1.742 0.679 6.388 0.744 4.065
UNETR [hatamizadeh2022unetr] 0.818 3.279 0.693 7.837 0.755 5.813

 missingmissing

NestedFormer 0.834 2.647 0.695 6.173 0.765 4.410

 missingmissing

{floatrow}\ttabbox
Table 3: Ablation study for different modules on MeniSeg.

 missingmissing

Methods Encoder Fusion Dice
CNN PB GPB MSG Tmuor Edema Ave

 missingmissing

baseline 1 0.805 0.675 0.74
baseline 2 0.810 0.679 0.75
baseline 3 0.816 0.688 0.752
baseline 4 0.825 0.699 0.762
baseline 5 0.823 0.697 0.76
NestedFormer 0.834 0.695 0.765

 missingmissing

Figure 4: The visual comparison results on BraTS2020 and MeniSeg dataset.

Visual Comparisons on BraTS2020 and MeniSeg. Fig. 4 visually compares the segmentation results predicted by our network and SOTA methods on BraTS2020 and MeniSeg. From these visualization results, we can find that our method can more accurately segment brain tumor and peritumoral edema regions than all the compared methods. The reason behind is that our method is able to better fuse multi-modal MRIs by explicitly exploring the intra-modality and the inter-modality relationships among multiple modalities.

3.4 Ablation study

We conduct ablation studies on the MeniSeg dataset to evaluate the contributions of main modules in our method; see Table 3. We not only compared the effects of three different encoder backbones based on CNN, PB, and GP, but also verified the effect of our proposed fusion modules. Among them, uses multiple U-Net encoders to extract features of different modal images, and performs feature fusion by concatenation. - uses multiple GPB encoders to extract features, and conducts skip connection via simple convolution, w/o and (see Fig. 3), respectively. replaces GPB block with the original PoolFormer block (PB) in the encoder, using the proposed NMaFA module (including and ) as well as MSG. It can be observed clearly that compared with , using the NMaFA module enhances the extraction of long-distance dependency information and effectively improves the segmentation results, while GPB outperforms PB by considering global information. Moreover, the MSG module is added to increase the feature reuse capability of skip connections, which further improves the segmentation effect, achieving the best average segmentation Dice (0.765) on the MeniSeg dataset.

4 Conclusion

We propose a novel multi-modal segmentation framework, dubbed as NestedFormer. This architecture extracts the features of modalities by using multiple Global Poolformer Encoders. Then, the high-level features are effectively fused by the NMaFA module, and the low-level features are selected by the modality-sensitive gate (MSG) module. Through these proposed modules, the network effectively extracts and hierarchically fuses features from different modalities. The effectiveness of our proposed NestedFormer is validated on BraTS2020 and MeniSeg datasets. Our framework are modality-agnostic and can be extended to other multimodal medical data. In the future work, we will explore more efficient feature fusion on low-levels to further improve the segmentation performance.

Acknowledgments. This work was supported by the grant from Tianjin Natural Science Foundation (Grant No. 20JCYBJC00960) and HKU Seed Fund for Basic Research (Project No. 202111159073).

References