This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Keep Your Friends Close & Enemies Farther:
Debiasing Contrastive Learning with Spatial Priors in 3D Radiology Images thanks: 978-1-6654-6819-0/22/$31.00 ©2022 IEEE

Yejia Zhang University of Notre Dame
Notre Dame, IN, 46556, USA
yzhang46@nd.edu
   Nishchal Sapkota University of Notre Dame
Notre Dame, IN, 46556, USA
nsapkota@nd.edu
   Pengfei Gu {@IEEEauthorhalign} Yaopeng Peng University of Notre Dame
Notre Dame, IN, 46556, USA
pgu@nd.edu
University of Notre Dame
Notre Dame, IN, 46556, USA
ypeng4@nd.edu
   Hao Zheng University of Notre Dame
Notre Dame, IN, 46556, USA
hzheng3@nd.edu
   Danny Z. Chen University of Notre Dame
Notre Dame, IN, 46556, USA
dchen@nd.edu
Abstract

Understanding of spatial attributes is central to effective 3D radiology image analysis where crop-based learning is the de facto standard. Given an image patch, its core spatial properties (e.g., position & orientation) provide helpful priors on expected object sizes, appearances, and structures through inherent anatomical consistencies. Spatial correspondences, in particular, can effectively gauge semantic similarities between inter-image regions, while their approximate extraction requires no annotations or overbearing computational costs. However, recent 3D contrastive learning approaches either neglect correspondences or fail to maximally capitalize on them. To this end, we propose an extensible 3D contrastive framework (Spade, for Spatial Debiasing) that leverages extracted correspondences to select more effective positive & negative samples for representation learning. Our method learns both globally invariant and locally equivariant representations with downstream segmentation in mind. We also propose separate selection strategies for global & local scopes that tailor to their respective representational requirements. Compared to recent state-of-the-art approaches, Spade shows notable improvements on three downstream segmentation tasks (CT Abdominal Organ, CT Heart, MR Heart).

Index Terms:
Self-supervised Learning, Contrastive Learning, 3D Radiology Image Segmentation

I Introduction

Deep neural networks have the ability to learn useful features with the caveat of requiring large amounts of annotated training data. To reduce the expense of obtaining sample labels, self-supervised contrastive learning in natural scene images [5][16][17][4][6] use instance discrimination with the objective of making latent representations of similar images (i.e., positives or augmented views of the same image) close while contrasting features of distinct images (i.e., negatives or views from other images).

Many differing characteristics of radiology images, however, make direct applications of this idea suboptimal. Instance discrimination is especially problematic due to the existence of recurring anatomical structures in volumes which often produces “negatives” with similar semantics (i.e., false negatives). Presence of false negatives slows convergence, impairs representations, and discards useful semantics [10]. These methods also favor global, invariant features whereas many medical localization tasks benefit from local, equivariant representations. Addressing false negatives, recent methods propose improved selections of positives by utilizing subject IDs [2], patient metadata [14], and spatial correspondences among image slices [20]. These kill two birds with one stone by both removing false negatives and expanding the diversity of positives. For better local representations, [3] also applies instance discrimination on image sub-crops on top of global contrastive learning which also selects positives based on slice positions.

Despite these advances, three important factors conducive to learning representations for 3D radiology image segmentation remain unaddressed. (1) The use of 2D slices discards spatial context and makes pretrained models incompatible with volumetric fine-tuning methods. Recent works affirm this by showing that pretraining [24] and fine-tuning [11] 3D tasks on 3D data outperform 2D equivalents. Plus, using whole slices introduces multiple anatomical structures which may further impair representations by inviting short-cut learning where models over-focus on discriminative regions and neglect remaining objects. (2) The benefits of contrastive learning are not leveraged in decoder outputs where discriminative representations are more important for downstream localization tasks. Although this was partially addressed in [3], the final method only uses local samples from the same image, which does not fully capitalize on spatial priors that facilitate learning of anatomical patterns across images. This brings us to our final point. (3) Both the diversity & quality of positive & negative sample candidates are limited either because samples are constrained to small selections of intra-batch images [3][20][19] or that inter-image positives are precluded altogether [23][22]. Medical images exhibit less appearance variation than natural scene images which accentuates the importance of inter-image sample diversity for positives & negatives.

Refer to caption
Figure 1: The overall workflow of the Spade framework. 𝒯s\mathcal{T}_{s} & 𝒯i\mathcal{T}_{i} represent spatial & intensity transforms, respectively.

In this work, we address these drawbacks with a new 3D self-supervised framework that builds on MoCo [5] (i.e., applying contrastive learning on encoder features with a global memory bank) and MG [24] (i.e., employing reconstruction on decoder outputs) to pretrain a UNet-like model. We utilize 3D crops (ancillary to points made in (1)) since cropping naturally regularizes features by restricting image contents and promotes more equitable foci among anatomical structures. For (2), we propose a local contrastive component with a separate local memory bank that boosts features from the decoder and learns spatially equivariant representations which benefit downstream segmentation tasks more than solely promoting global, invariant features (e.g., [23][22]). Finally, to address (3), we go beyond instance discrimination and leverage spatial priors (i.e., correspondences) in crops across volumes to select higher quality positives and negatives. Here, we emphasize correspondences for three primary reasons: they are free (i.e., require no annotations to obtain), tractable (i.e., involve little computation), and effective (i.e., disclose helpful information for gauging anatomical similarity). Thus, correspondences between images are first computed by aligning them to a template image (details in § II) since suitable templates are accessible (e.g., selecting a representative image within the dataset) and computing alignment is cheap (1.8 seconds per CT volume). With correspondences as a proxy for semantic similarity, we propose multiple global & local sampling strategies to reduce false negatives (i.e., debiasing) and increase the diversity & quality of positives. Our final framework, Spade (for Spatial debiasing), incorporates the best local & global strategies.

To demonstrate Spade’s efficacy (see § IV), we utilize 3D torso radiology images. More specifically, we pretrain using 623 chest CT & 200 abdominal CT images, and fine-tune on three downstream segmentation tasks (CT Abdominal, CT Heart, MR Heart). We conduct thorough studies on Spade’s components and the proposed global & local sampling strategies to explore how spatial priors facilitate contrastive learning in 3D radiology images. Additionally, we study Spade’s robustness to template choice and show that rough alignment is both cheap and effective for downstream segmentation. Finally, we compare Spade to recent self-supervised methods and report sizable performance improvements over state-of-the-art predictive, reconstructive, and contrastive pretraining methods (e.g., +2.7%, +3.4%, +1.8% dice, resp., on CT Heart).

II Methodology

Before detailing Spade’s workflow & components (see Fig. 1), we describe template selection & contrastive learning.

II-A Preliminaries

Template Selection & Spatial Correspondence. The main objective of aligning images to a chosen template is to establish spatial correspondences through a shared coordinate system in which we can better gauge semantic similarity. Given an unlabeled dataset with NN images, DD={I1,I2,,IN}\{I_{1},I_{2},\ldots,I_{N}\}, we define a template image as ItI_{t} (note that ItDI_{t}\in D is valid) such that the template contains all key anatomical structures present in DD. We then register all the images in DD to ItI_{t} by computing 𝕋\mathbb{T}={T1,T2,,TN}\{T_{1},T_{2},\cdots,T_{N}\}, where Ti:(xi,yi,zi)(xit,yit,zit)T_{i}:(x_{i},y_{i},z_{i})\rightarrow(x^{t}_{i},y^{t}_{i},z^{t}_{i}) is a function that transforms IiI_{i}’s coordinates to ItI_{t}’s. We apply affine registration and optimize with SGD and cross-correlation as the similarity metric.

Thus, given a patch PiP_{i} from image IiI_{i}, we find its corresponding crop in ItI_{t} with PtP_{t}=Ti(Pi)T_{i}(P_{i}). Further, PiP_{i}’s corresponding crop in any other image IjDI_{j}\in D can be extracted via PjP_{j}=Tj1Ti(Pi)T_{j}^{-1}\circ T_{i}(P_{i}). We concisely denote the corresponding patch of PiP_{i} in image IjI_{j} as PijP_{i\rightarrow j}. Note that we do not transform the images, rather, we only compute 𝕋\mathbb{T}.
Contrastive Representation Learning. We denote the L2L_{2}-normalized unit embedding of an anchor patch as vv. Borrowing notation from [3], we define a positive sample as v+Λ+v^{+}\in\Lambda^{+}, where Λ+\Lambda^{+} is the set of positive embedding vectors for anchor vv. Similarly, a negative sample is defined as vΛv^{-}\in\Lambda^{-}, where Λ\Lambda^{-} is the set of negative embedding vectors. The contrastive loss (NCE [16] for short) is then defined as:

NCE=logexp(vv+/τ)exp(vv+/τ)+vΛg/lexp(vv/τ)\vspace{-2mm}{\footnotesize{\mathcal{L}^{NCE}=-\log\frac{\exp({v\cdot v^{+}/\tau)}}{\exp({v\cdot v^{+}/\tau)}+\sum\limits_{v^{-}\in\Lambda^{-}_{g/l}}\exp(v\cdot v^{-}/\tau)}}} (1)

where “NCE\mathcal{L}^{NCE}” is shorthand for the loss NCE(v,v+)\mathcal{L}^{NCE}(v,v^{+}) between an anchor vv and a positive v+v^{+}, τ\tau is a temperature hyperparameter, and gg & ll indicate global & local components respectively. Incorporating all positive samples, the complete contrastive loss (CON for short) is written as:

g/lCON=1|Λg/l+|(v˙,v¨)Λg/l+NCE(v˙,v¨)+NCE(v¨,v˙)\vspace{-2mm}{\footnotesize{\mathcal{L}^{CON}_{g/l}=\frac{1}{|\Lambda^{+}_{g/l}|}\sum\limits_{\forall(\dot{v},\ddot{v})\in\Lambda^{+}_{g/l}}\mathcal{L}^{NCE}(\dot{v},\ddot{v})+\mathcal{L}^{NCE}(\ddot{v},\dot{v})}} (2)

II-B The Spade Framework

Spade pretrains a randomly initialized UNet-like network containing an encoder fθf_{\theta}, decoder gθg_{\theta}, global projection module hθgh_{\theta_{g}}, local projection module hθlh_{\theta_{l}}, and their corresponding momentum counterparts fϵf_{\epsilon}, gϵg_{\epsilon}, hϵgh_{\epsilon_{g}}, and hϵlh_{\epsilon_{l}}, respectively (see Fig. 1). For each batch, an anchor image IaI_{a} and pp pairs of anchor patches {Paj\{P_{a}^{j}, Pak,}P_{a}^{k},\ldots\} are sampled where overlaps of paired patches are at least oo% and P1×Din×Hin×WinP\in\mathbb{R}^{1\times D_{in}\times H_{in}\times W_{in}}. Following MoCo [5], each PP is randomly transformed twice via spatial augmentations 𝒯S\mathcal{T}_{S} (e.g., flipping, rotating) and intensity noising 𝒯I\mathcal{T}_{I} (e.g., those proposed in [24]). All these views are inputted into both the regular & momentum networks.

Spade contains three pretext tasks (i.e., global contrastive learning, local contrastive learning, and reconstruction). These tasks use global logits zz=f(P)f(P), and local logits ZZ=gf(P)g\circ f(P), where ff may refer to either fθf_{\theta} or fϵf_{\epsilon} for brevity. For global & local contrastive learning, vv is a unit feature embedding obtained from the global projection module vgv_{g}=hg(z)h_{g}(z), vgCgembv_{g}\in\mathbb{R}^{C_{g}^{emb}} or the local projection module vlv_{l}=hl(Z)h_{l}(Z), vlClemb×Hlemb×Wlembv_{l}\in\mathbb{R}^{C_{l}^{emb}\times H_{l}^{emb}\times W_{l}^{emb}}. Note that vgv_{g} is a vector while vlv_{l} has dimensionality Hlemb×WlembH_{l}^{emb}\times W_{l}^{emb} since we aim to preserve spatial equivariance for local representations. After each iteration, global & local embeddings from the momentum network are then enqueued into MgM_{g} (size 𝒬g\mathcal{Q}_{g}) & MlM_{l} (size 𝒬l\mathcal{Q}_{l}), respectively, along with their corresponding positions in the template image.

\scaleto
Global Feature Learning

.7.6pt The objective of global feature learning is to learn high-level semantics from crops that are invariant to viewpoint changes or transformations. Toward this end, we introduce several strategies to improve representations by further boosting the diversity & efficacy of both positive and negative samples. For pithiness, we formulate these from a single view of an anchor patch PaP_{a} and its embedding vav_{a}.

For global strategy 1 (G1), we augment positives with crops of corresponding positions from different images. This is similar in spirit to [3][20], but we compute correspondences in 3D instead of assuming alignment and sample crops instead of slices which are both richer in spatial context & more robust to alignment errors from a single axis. Concretely, ΛG1\Lambda^{-}_{\textit{G1}}=MgM_{g}, ΛG1+\Lambda^{+}_{\textit{G1}}={va,vai1,vai2,,vain+}\{v_{a},v_{a\rightarrow i_{1}},v_{a\rightarrow i_{2}},\ldots,v_{a\rightarrow i_{n^{+}}}\}, where vaiv_{a\rightarrow i}=hgf(Pai)h_{g}\circ f(P_{a\rightarrow i}) denotes the feature embedding of PaiP_{a\rightarrow i}, and n+n^{+} is the number of sampled positives in separate volumes that spatially correspond to the anchor patch PaP_{a}.

For a large embedding queue, quantity has a quality of its own, but naively raising this capacity increases incidences of false negatives. In global strategy 2 (G2), we debias the negative cohort by removing crop embeddings in MgM_{g} that are “close” to the anchor patch. We quantify “close” as the overlap between two patches in the template space. Given two patches PajP_{a}^{j} & PakP_{a}^{k}, we denote their overlap region as patch PajkP_{a}^{j\cap k}. Their overlap ratio or IoU is defined as IoU(Paj,Pak)IoU(P_{a}^{j},P_{a}^{k})=V(Pajk)/[V(Paj)+V(Pak)V(Pajk)]V(P_{a}^{j\cap k})/[V(P_{a}^{j})+V(P_{a}^{k})-V(P_{a}^{j\cap k})], where V()V() returns the volume of a patch. Given an overlap threshold o[0,1]o\in[0,1], we define the cohorts of positives & negatives as ΛG2\Lambda^{-}_{\textit{G2}}={vi|IoU(vit,vat)o,viMg}\{v_{i}\ |\ IoU(v_{i\rightarrow t},v_{a\rightarrow t})\leq o,\forall v_{i}\in M_{g}\} & ΛG2+\Lambda^{+}_{\textit{G2}}={va,vai1,vai2,,vain+}\{v_{a},v_{a\rightarrow i_{1}},v_{a\rightarrow i_{2}},\ldots,v_{a\rightarrow i_{n^{+}}}\}, where IoU(vit,vat)IoU(v_{i\rightarrow t},v_{a\rightarrow t}) is notational shorthand for IoU(Pit,Pat)IoU(P_{i\rightarrow t},P_{a\rightarrow t}).

Although false negatives are alleviated, ΛG1+\Lambda^{+}_{\textit{G1}} and ΛG2+\Lambda^{+}_{\textit{G2}} still suffer from the same drawbacks as [20][3]. Reliance on intra-batch samples greatly limits the variability of positives just as it had for negatives. In light of this, we propose global strategy 3 (G3) which converts the removed negative patches in G2 to positives based on the reasoning that if the semantics of a patch are similar enough to be considered a false negative, then the mutual information present would enrich the representations with added positives. We define ΛG3\Lambda^{-}_{\textit{G3}}={vi|IoU(vit,vat)o,viMg}\{v_{i}\ |\ IoU(v_{i\rightarrow t},v_{a\rightarrow t})\leq o,\forall v_{i}\in M_{g}\} & ΛG3+\Lambda^{+}_{\textit{G3}}={va,vai1,vai2,,vain+}{vi|IoU(vit,vat)>o,viMg}\{v_{a},v_{a\rightarrow i_{1}},v_{a\rightarrow i_{2}},\ldots,v_{a\rightarrow i_{n}^{+}}\}\cup\{v_{i}\ |\ IoU(v_{i\rightarrow t},v_{a\rightarrow t})>o,\forall v_{i}\in M_{g}\}.

TABLE I: Main results with state-of-the-art approaches. Entries are dice scores (%) and their standard deviations (±\pm) across 4 runs.
      Methods CT Abdominal (BCV) CT Heart (MMWHS) MR Heart (MMWHS)
10% 25% 50% 10% 25% 50% 10% 25% 50%
2 5 10 1 3 7 1 3 7
Random Init. 52.72 ±\pm 1.03 69.37 ±\pm 0.79 78.60 ±\pm 0.40 62.49 ±\pm 1.18 85.23 ±\pm 0.68 89.01 ±\pm 0.51 68.45 ±\pm 1.15 75.70 ±\pm 3.27 86.26 ±\pm 0.68
Predictive & Generative Approaches
[24] MG19 54.10 ±\pm 1.45 71.78 ±\pm 0.84 79.39 ±\pm 0.71 64.02 ±\pm 1.59 86.20 ±\pm 1.27 89.08 ±\pm 0.25 69.66 ±\pm 1.47 74.88 ±\pm 3.99 86.01 ±\pm 0.30
[15] Rubik++20 54.99 ±\pm 1.45 69.75 ±\pm 0.56 79.80 ±\pm 0.46 64.67 ±\pm 3.25 86.31 ±\pm 0.74 89.02 ±\pm 0.17 72.49 ±\pm 1.27 78.31 ±\pm 1.42 86.86 ±\pm 0.32
[21] SAR21 53.14 ±\pm 1.68 69.96 ±\pm 0.38 78.31 ±\pm 0.50 65.01 ±\pm 1.49 86.50 ±\pm 0.88 89.18 ±\pm 0.35 73.25 ±\pm 1.45 78.60 ±\pm 1.34 86.72 ±\pm 0.13
[9] TransVW21 55.42 ±\pm 1.86 71.58 ±\pm 1.33 79.66 ±\pm 1.07 65.21 ±\pm 1.27 86.48 ±\pm 1.23 90.17 ±\pm 0.45 71.91 ±\pm 2.08 77.32 ±\pm 1.56 86.30 ±\pm 0.57
Metric Learning Approaches
[18] PGL20 54.87 ±\pm 1.12 70.45 ±\pm 0.98 78.88 ±\pm 0.27 63.68 ±\pm 2.85 86.65 ±\pm 0.77 88.84 ±\pm 0.41 68.61 ±\pm 2.86 74.29 ±\pm 3.55 86.56 ±\pm 0.68
[5] MoCo20 55.64 ±\pm 0.85 71.07 ±\pm 1.07 79.97 ±\pm 0.66 65.03 ±\pm 1.62 85.96 ±\pm 1.21 89.63 ±\pm 0.34 69.44 ±\pm 0.80 77.23 ±\pm 1.60 86.51 ±\pm 0.13
[20] PCL21 56.05 ±\pm 0.64 68.55 ±\pm 0.55 76.11 ±\pm 0.42 66.23 ±\pm 1.25 85.14 ±\pm 1.01 88.25 ±\pm 0.32 73.44 ±\pm 1.08 76.90 ±\pm 1.05 83.50 ±\pm 0.54
[23] PCRL21 56.01 ±\pm 1.39 71.30 ±\pm 0.91 80.23 ±\pm 0.24 65.58 ±\pm 2.41 87.03 ±\pm 0.98 90.02 ±\pm 0.70 74.83 ±\pm 1.91 77.72 ±\pm 1.97 86.79 ±\pm 1.05
Spade (Ours) 57.55 ±\pm 1.22 72.84 ±\pm 0.97 80.03 ±\pm 0.67 68.07 ±\pm 2.09 87.90 ±\pm 0.84 90.29 ±\pm 0.43 74.97 ±\pm 1.62 79.04 ±\pm 1.08 87.22 ±\pm 0.56
\scaleto
Local Feature Learning

.7.6pt Directly extending the proposed global strategies to the voxel level would be ineffective for three reasons. 1) Decoder outputs have higher resolutions, finer localized details, and require spatial equivariance rather than invariance. 2) We incur larger computational burdens when processing 3D feature maps at full resolution with contrastive projection heads which commonly contain more parameters than the backbone [5]. 3) Spatial alignments are not accurate enough to reliably assign local positives and negatives.

Instead of relying on full-resolution features, we use local features from the overlapping regions of crop pairs. Given overlapping crops Paj,PakP_{a}^{j},P_{a}^{k} from image IaI_{a} where IoU(Paj,Pak)oIoU(P_{a}^{j},P_{a}^{k})\geq o , we sample corresponding overlapping crops from other images (e.g., Pbj,PbkP_{b}^{j},P_{b}^{k}). We introduce three improvements to local feature learning. First, only logits in the overlapping regions (e.g., Zajk,ZbjkZ_{a}^{j\cap k},Z_{b}^{j\cap k}) are used for representation learning. This strategy mitigates the deleterious effects of misalignment errors while still promoting local understanding, retaining sample diversity, and limiting computation cost. Second, we reverse the spatial transforms on logits (e.g., ZajkZ^{\prime j\cap k}_{a}=𝒯s1(Zajk)\mathcal{T}_{s}^{-1}(Z_{a}^{j\cap k})) before obtaining the final embeddings (e.g., vajkv^{\prime j\cap k}_{a}=hl(Zajk)h_{l}(Z^{\prime j\cap k}_{a})) so that representations are contrasted in their original orientations and equivariance is maintained. Finally, spatial relations in embeddings are preserved with 1x1 convolutions in hlh_{l} instead of an MLP. Note that hlh_{l} also resizes logits to the local embedding size of vlCemb×Hemb×Wembv_{l}\in\mathbb{R}^{C^{emb}\times H^{emb}\times W^{emb}}.

In local strategy 1 (L1), we sample positives from the overlapping regions of the same crop pair and treat all vlMlv_{l}\in M_{l} as negatives. Thus, we have ΛL1\Lambda^{-}_{\textit{L1}}=MlM_{l}, ΛL1+\Lambda^{+}_{\textit{L1}}={vajk,vakj}\{v^{\prime j\cap k}_{a},v^{\prime k\cap j}_{a}\}. Note that vajkv_{a}^{j\cap k} is the overlapping region of vajv_{a}^{j} & vakv_{a}^{k} within vajv_{a}^{j}, while vakjv_{a}^{k\cap j} is the same corresponding area except in vakv_{a}^{k}. Local strategy 2 (L2) debiases negatives that are above an overlap threshold (the same oo as previously defined). More specificially, ΛL2\Lambda^{-}_{\textit{L2}}={vi|IoU(vit,vat)o,viMl}\{v_{i}\ |\ IoU(v_{i\rightarrow t},v_{a\rightarrow t})\leq o,\forall v_{i}\in M_{l}\}, ΛL2+\Lambda^{+}_{\textit{L2}}={vajk,vakj}\{v^{\prime j\cap k}_{a},v^{\prime k\cap j}_{a}\}. Local strategy 3 (L3) selects other corresponding overlapped regions as positives in addition to applying L2: ΛL3\Lambda^{-}_{\textit{L3}}={vi|IoU(vit,vat)o,viMl}\{v_{i}\ |\ IoU(v_{i\rightarrow t},v_{a\rightarrow t})\leq o,\forall v_{i}\in M_{l}\}, ΛL3+\Lambda^{+}_{\textit{L3}}={vajk,vakj,vbjk,vbkj,}\{v^{\prime j\cap k}_{a},v^{\prime k\cap j}_{a},v^{\prime j\cap k}_{b},v^{\prime k\cap j}_{b},\ldots\}. Finally, local strategy 4 (L4) adopts the idea introduced in G3 that utilizes debiased negatives as positives: ΛL4\Lambda^{-}_{\textit{L4}}={vi|IoU(vit,vat)o,viMl}\{v_{i}\ |\ IoU(v_{i\rightarrow t},v_{a\rightarrow t})\leq o,\forall v_{i}\in M_{l}\}, ΛL4+\Lambda^{+}_{\textit{L4}}={vajk,vakj,vbjk,vbkj,}{vi|IoU(vit,vat)>o,viMl}\{v^{\prime j\cap k}_{a},v^{\prime k\cap j}_{a},v^{\prime j\cap k}_{b},v^{\prime k\cap j}_{b},\ldots\}\cup\{v_{i}\ |\ IoU(v_{i\rightarrow t},v_{a\rightarrow t})>o,\forall v_{i}\in M_{l}\}.

\scaleto
Reconstruction

.7.6pt MG [24] elegantly fits into our pipeline by co-opting the data transformations used for contrasting views and complementing local representations via dense reconstruction. Without bells & whistles, we directly reconstruct the spatially transformed version of the original crop (i.e. 𝒯s(P)\mathcal{T}_{s}(P)). The reconstruction loss is defined as:

rMSE=[𝒯s(P)σ(gθfθ𝒯i𝒯s(P))]2{\small{\mathcal{L}^{MSE}_{r}=\left[\mathcal{T}_{s}(P)-\sigma(g_{\theta}\circ f_{\theta}\circ\mathcal{T}_{i}\circ\mathcal{T}_{s}(P))\right]^{2}}} (3)
\scaleto
Overall Loss & Training Details

.7.6pt Our overall loss Spade\mathcal{L}_{Spade} is:

Spade=λgCON+(1λ)lCON+λrrMSE{\small{\mathcal{L}_{Spade}=\lambda\mathcal{L}^{CON}_{g}+(1-\lambda)\mathcal{L}^{CON}_{l}+\lambda_{r}\mathcal{L}^{MSE}_{r}}} (4)

where λ\lambda and λr\lambda_{r} are loss-weighing terms for contrastive learning and reconstruction, respectively. For gCON\mathcal{L}_{g}^{CON} & lCON\mathcal{L}_{l}^{CON}, Λ/+\Lambda^{-/+} is populated based on the global and local strategies selected. Note that local sets Λl\Lambda_{l} contain embeddings with spatial dimensions rather than vectors, so dot-products in Eq. 1 are spatially equivariant. After weights are updated in an optimization step, we adjust the momentum parameters, ϵ\epsilon, from the regular model parameters, θ\theta, via ϵβϵ+(1β)θ\epsilon\leftarrow\beta\epsilon+(1-\beta)\theta.

To compute 𝕋\mathbb{T} for volume alignment, we first crop out the background by thresholding at 350-350 Hounsfield Units, and downsample all images by a factor of 2. Affine registration is performed using the negative normalized cross correlation metric and tri-linear interpolation. Optimization is run until convergence (minimum 50 iterations) with a 0.5 learning rate.

For network components, ff is a 3D Res2Net-50 (scale=4, stride=[1,2,2] in the first downsampling layer) [8], while gg is a lightweight decoder. The global projection module hgh_{g} follows [5] with average pooling to size 1x1x1, flattening, linear projection to 2048 channels, ReLU, and linear projection to embeddings of size Cgemb=128C_{g}^{emb}=128. The local projection module resizes features to 1x3x3 via average pooling, applies 1x1x1 convolution with 1024 filters, activates with ReLU, and outputs embeddings after a 1x1x1 convolution with 64 filters (vl64×3×3,Clembv_{l}\in\mathbb{R}^{64\times 3\times 3},C_{l}^{emb}=64,Hlemb64,H_{l}^{emb}=3,Wlemb3,W_{l}^{emb}=33; note the depth dimension is squeezed). In contrastive pretraining, we use 𝒬g\mathcal{Q}_{g}=1600016000, 𝒬l\mathcal{Q}_{l}=10001000, oo=0.20.2, pp=22, n+n^{+}=44, and β\beta=0.990.99. For losses, we set τ\tau=0.20.2, λ\lambda=0.50.5, and λr\lambda_{r}=1010.

III Experiments

III-A Pretraining Data and Details

\scaleto
Data Description

.7.6pt We demonstrate the efficacy our approach with torso CTs from two datasets as pretraining data since they present diverse anatomical structures, diseases, diagnostic tasks, and imaging settings.

AMOS [12] (CT Abdominal Multi-Organ Segmentation) provides 240 training and 120 test abdominal CTs from diverse sources. We remove all “abnormal images” with anisotropic spacing or sizing along the right/left and anterior/inferior axes. Pretraining uses the remaining 200 training scans and is validated on 100 testing scans. Note that although masks are given for 15 organs, we do not use them in any capacity.

LUNA [13] (CT Lung Nodule Analysis) consists of 888 chest CTs across 10 folds. The first 7 folds (623 images) and the remaining 3 folds (265 images) are used for pretraining and validation, respectively.

\scaleto
Implementation Details

.7.6pt To preprocess all 823 pretraining & 365 validation CTs, all images are resampled to 2×\times0.7×\times0.7 (mm) superior, anterior, right spacings, respectively (the median spacing in all pretraining datasets). Intensities are clipped from 1000-1000 to 10001000, and normalized between 0 & 11. Crops are obtained with scaling factor [0.5, 2], but ultimately resized to [32, 64, 64]; crops with mainly air are discarded. We train using SGD (l.r. 0.00750.0075, mom. 0.90.9) with cosine annealing & batch size 24 for 500 epochs on a single V100 with PyTorch.

TABLE II: Global feature learning approaches. Pretrained on 5.2M AMOS & LUNA crops, finetuned on 10% of BCV labels (3 volumes), and evaluated on the test set.
# Feature Init. 𝒬g\mathcal{Q}_{g} Dice (%)
No Spatial Priors
1 Random - 52.72
2 \topinset\square\square0.1pt-0.1pt PGL - 54.87 \scaleto(+2.2)6.5pt
3 \topinset\square\square0.1pt-0.1pt MoCo 4k 55.11 \scaleto(+2.4)6.5pt
Global Debiasing Strategies
4 \topinset\square\square2pt-2pt G1 (n+=4n^{+}=4) 4k 55.49 \scaleto(+2.8)6.5pt
5 \topinset\square\square2pt-2pt G2 (n+=4,o=0.0n^{+}=4,o=0.0) 4k 55.24 \scaleto(+2.5)6.5pt
6 \topinset\square\square2pt-2pt G2 (n+=4,o=0.2n^{+}=4,o=0.2) 4k 56.37 \scaleto(+3.7)6.5pt
7 \topinset\square\square2pt-2pt G2 (n+=4,o=0.4n^{+}=4,o=0.4) 4k 56.17 \scaleto(+3.5)6.5pt
8 \topinset\square\square2pt-2pt G3 (n+=4,o=0.2n^{+}=4,o=0.2) 4k 56.58 (+3.9)
Effect of Queue Size & Debiasing
9 \topinset\square\square0.1pt-0.1pt MoCo 1k 55.64 \scaleto(+2.9)6.5pt
10 \topinset\square\square0.1pt-0.1pt MoCo 4k 55.11 \scaleto(+2.4)6.5pt
11 \topinset\square\square0.1pt-0.1pt MoCo 16k 54.98 \scaleto(+2.3)6.5pt
12 \topinset\square\square2pt-2pt G3 (n+=4,o=0.2n^{+}=4,o=0.2) 1k 56.26 \scaleto(+3.5)6.5pt
13 \topinset\square\square2pt-2pt G3 (n+=4,o=0.2n^{+}=4,o=0.2) 4k 56.58 \scaleto(+3.9)6.5pt
14 \topinset\square\square2pt-2pt G3 (n+=4,o=0.2n^{+}=4,o=0.2) 16k 56.83 (+4.1)

III-B Fine-tuning Data and Details

\scaleto
Data Description

.7.6pt To gauge the quality of learned features, we specifically selected three tasks covering different torso regions, organs, modalities, and scales that were separate from the pretraining data but contain formerly seen structures. For all datasets, we employ a 7:1:2 (train, validation, test) split and fine-tune using 10%, 25%, and 50% of training images (if the number of training volumes isn’t evenly divisible, we apply the floor function). Following [23][3][15], we use the class-averaged dice score for all segmentation evaluations.

BCV [1] (CT Beyond the Cranial Vault) contains 30 abdominal CTs with 13 anatomical annotations. The dataset is split into 21 training, 3 validation, and 6 testing images.

MMWHS [25] (CT & MR Multi-Modality Whole Heart Segmentation) presents 20 labeled CT and 20 labeled MR images with annotations for seven cardiac structures. We treat each modality as its own downstream task and evaluate them independently. For both CT & MR subsets, we split the labeled images into 14 training, 2 validation, and 4 testing.

\scaleto
Implementation Details

.7.6pt For preprocessing, we clip values between 1000-1000 & 10001000 for CTs and Z-normalize for MRs. We apply nnU-Net [11] spatial preprocessing and squish intensities between 0 & 11 to be consistent with pretraining. Regarding data, we sample crops (ensuring 50% have foreground) with scaling factor [0.75, 1.25] and resize to [32, 96, 96]. Crops are augmented with mirroring, blurring, intensity scaling, and gamma. We train using AdamW (l.r. 0.001, w.d. 0.01) with cosine annealing & batch size 8 for \approx50k iterations on a single Titan-Xp with PyTorch (run time \approx16 hours). We find higher weight decay to benefit, especially with limited labels.

III-C Baselines

We select state-of-the-art medical pretraining baselines like PCRL [23] TransVW [9], Rubik++ [15], PCL [20], and others. We choose PCL over [3] for its proposed improvements regarding slices near partition borders. We make comparisons as fair as possible with PCL (a 2D method) by matching the number of parameters in 2D Res2Net with our 3D model via width increases. FG [7] is omitted since global contrast coupled with reconstruction is explored by PCRL [23] and by us in § IV-A. We exclude HSSL [22] for similar reasons in addition to the fact that it only uses 2D slices.

TABLE III: Local feature learning strategies & Template choice. Pretrained on 5.2M AMOS & LUNA crops, finetuned on 10% of BCV labels (3 volumes), and evaluated on the test set.
Feature Init. 𝒬l\mathcal{Q}_{l} Dice (%)
Local Debiasing Strategies
1 Oursglobal{}_{\text{global}} (G3, n+=4,o=0.2,𝒬g=16n^{+}=4,o=0.2,\mathcal{Q}_{g}=16k) - 56.83
2 Oursglobal{}_{\text{global}} + Reconstruction - 56.94 \scaleto(+0.1)6.5pt
3 Oursglobal{}_{\text{global}} + L1 (n+=4,o=0.2n^{+}=4,o=0.2) 1k 57.18 \scaleto(+0.4)6.5pt
4 Oursglobal{}_{\text{global}} + L2 (n+=4,o=0.2n^{+}=4,o=0.2) 1k 57.41 \scaleto(+0.6)6.5pt
5 Oursglobal{}_{\text{global}} + L3 (n+=4,o=0.2n^{+}=4,o=0.2) 1k 56.54 \scaleto(-0.3)6.5pt
6 Oursglobal{}_{\text{global}} + L4 (n+=4,o=0.2n^{+}=4,o=0.2) 1k 56.13 \scaleto(-0.7)6.5pt
7 Oursproposed{}_{\text{proposed}} (G3 + L2 + Reconstruction) 1k 57.55 \scaleto(+0.7)6.5pt
Template Selection
8 Oursproposed{}_{\text{proposed}} 1k 57.55
9 Oursproposed{}_{\text{proposed}} (AMOS Template) 1k 57.66
10 Oursproposed{}_{\text{proposed}} (BCV Template) 1k 57.26
Average 57.49

IV Results and Discussion

IV-A Component Studies

\scaleto
Global Feature Learning

.7.6pt We first investigate the effectiveness of our global sampling strategies (see Table II; note \topinset\square\square0.1pt-0.1pt  &  \topinset\square\square2pt-2pt  indicate same-image positives & inter-image positives, respectively). For fair assessment, we train each method with 5.2M sampled patches and make queue sizes comparable for MoCo. After pretraining, a randomly-initialized decoder is attached to the encoder & the entire network is finetuned.

First, we note that incorporating negatives, even without debiasing, improves over positive-only approaches (see rows 2 & 3). Also, enriching positives with spatially corresponding patches (G1, row 4) adds additional benefits. Debiasing negatives (rows 5-8), however, yields the largest increases when the overlap threshold oo is properly selected. A threshold that’s too low (row 5) may cause the undesirable removal of hard negatives (performing even worse than no debiasing in row 4). On the other hand, setting oo too high (row 7) prevents removal of damaging false negatives. Useful to contrastive learning, oo can modulate the amount of mutual information that’s distilled between samples (e.g., low oo amounts to negatives with low mutual information with the anchor). This enables pretraining flexibility in the face of wide ranges of possible tasks & data.

Next, we study the interactions between queue size and debiasing. In contrastive learning without debiasing (rows 9-11), we observe declines in performance as queue size increases. This is may be attributed to the increasing number of false negatives; surprisingly, the decline is limited possibly because the vast majority of queue entries are valid negatives. In contrast, our debiasing strategies (rows 12-14) empower larger queue sizes where performance steadily improves. This affirms the principle that appropriate selections of both positives & negatives facilitate contrastive representation learning.

\scaleto
Local Feature Learning & Template Choice

.7.6pt

In Table III, we explore the efficacy of local approaches by jointly pretraining using our best global strategy (G3 in row 1) and the indicated local strategy (note: the same oo & n+n^{+} are adopted from the global strategy). To evaluate, we use the same precedure as experiments in Table II, except a single layer segmentation head (1x1x1 conv.) is appended to end of the decoder for pixel classification.

We first compare our equivariant contrastive method (L1, row 3) with a traditional reconstruction approach (row 2) and find the contrastive approach to be superior. However, we integrate reconstruction since we observe performance improvements, and also qualitatively see acceleration in pretraining convergence & speculate that it expedites initialization of sensible features which shortens feature warmup for contrastive tasks. Next, we find that local strategies (rows 3-6) behave in stark contrast to global ones in that diversifying positives (rows 5-6) impairs features. This is probably from rough correspondences yielding erroneous positives & negatives, thus, blunting both cohorts. So, our final proposed method utilizes G3, L2, and reconstruction. Comparing the training times of Spade to PCRL [23], our method takes 25.4 hours for 5.2 million patches while PCRL requires 28 hours.

To study the sensitivity of our approach to different templates, we select 3 full-torso CTs: two from within the pretraining set (in AMOS), and one outside of pretraining data (in BCV). The only selection criteria is that the entire torso is covered. From Table III, we conclude that although there’s a slight reduction in performance for the BCV template, Spade performs reasonably well for all three. This affirms the notion that approximate spatial correspondences are sufficiently effective for semantic comparisons.

IV-B Comparison with State-of-the-Arts

From our main experiments in Table I, we make the following observations. 1) Spade generally outperforms recent state-of-the-art medical pretraining methods. This supports our assumption that spatial priors are effective for predicting semantic similarity and transferring this knowledge into useful representations. This also shows that simple sampling strategies can beat fairly complex pipelines like [23] which uses multiple regularization approaches like mixup, attention modules, and transformation embeddings. 2) Our method benefits performance the most when annotations are more limited. This is a desirable property for pretraining since its primary objective is to improve representations when annotations are most scarce due to acquisition costs. However, this is a double-edged sword in that when there’s more labels (e.g. MMWHS 50% labels), the cost of pretraining may not out-weight the benefit. 3) We affirm that effective negative samples are central to representation learning in 3D radiology images. Our studies on sampling strategies (see § IV-A, § IV-B) indicate that debiasing improves downstream performance more than enriching the diversity of positives. The task of making two patches with comparable semantics similar may be an easier task (courtesy of the curse of dimensionality) than contrasting similar patches. More studies regarding metric learning in medical images are needed to further explore this observation.

V Conclusion

In this work, we present a contrastive learning framework, Spade, that leverages spatial correspondences between 3D radiology images to enrich positive pairs and debias false negatives. Our studies indicate that Spade adds little additional computation, is highly effective in learning representations for downstream segmentation tasks, and hopefully spurs additional work in discovering more effective priors for contrastive learning with the abundant medical data we have at our disposal.

References

  • [1] Multi-atlas labeling beyond the cranial vault, 2015. Accessed Jan 2021 at https://www.synapse.org/#!Synapse:syn3193805/wiki/89480.
  • [2] S. Azizi, B. Mustafa, F. Ryan, Z. Beaver, J. Freyberg, J. Deaton, A. Loh, A. Karthikesalingam, S. Kornblith, T. Chen, et al. Big self-supervised models advance medical image classification. In ICCV, pages 3478–3488, 2021.
  • [3] K. Chaitanya, E. Erdil, N. Karani, and E. Konukoglu. Contrastive learning of global and local features for medical image segmentation with limited annotations. NIPS, 33, 2020.
  • [4] T. Chen, S. Kornblith, M. Norouzi, and G. Hinton. A simple framework for contrastive learning of visual representations. In ICML, pages 1597–1607. PMLR, 2020.
  • [5] X. Chen, H. Fan, R.B. Girshick, and K. He. Improved baselines with momentum contrastive learning. ArXiv, abs/2003.04297, 2020.
  • [6] C.Y. Chuang, J. Robinson, Y.C. Lin, A. Torralba, and S. Jegelka. Debiased contrastive learning. NIPS, 33:8765–8775, 2020.
  • [7] J. Dippel, S. Vogler, and J. Höhne. Towards fine-grained visual representations by combining contrastive learning with image reconstruction and attention-weighted pooling. ArXiv:2104.04323, 2021.
  • [8] S.H. Gao, M.M. Cheng, K. Zhao, X.Y. Zhang, M.H. Yang, and P. Torr. Res2net: A new multi-scale backbone architecture. IEEE TPAMI, 43(2):652–662, 2019.
  • [9] F. Haghighi, M.R.H. Taher, Z. Zhou, M.B. Gotway, and J. Liang. Transferable visual words: Exploiting the semantics of anatomical patterns for self-supervised learning. IEEE TMI, 40:2857–2868, 2021.
  • [10] T. Huynh, S. Kornblith, M. R Walter, M. Maire, and M. Khademi. Boosting contrastive self-supervised learning with false negative cancellation. In IEEE/CVF WACV, pages 2785–2795, 2022.
  • [11] F. Isensee, P.F. Jaeger, S.AA Kohl, J. Petersen, and K. H Maier-Hein. nnu-net: a self-configuring method for deep learning-based biomedical image segmentation. Nature Methods, 18(2):203–211, 2021.
  • [12] Y. Ji, H. Bai, J. Yang, C. Ge, Y. Zhu, R. Zhang, Z. Li, L. Zhang, W. Ma, X. Wan, et al. Amos: A large-scale abdominal multi-organ benchmark for versatile medical image segmentation. ArXiv:2206.08023, 2022.
  • [13] A.A.A. Setio, F. Ciompi, G.J.S. Litjens, P.K. Gerke, C. Jacobs, S.J. Riel, M. M.W. Wille, M. Naqibullah, C.I. Sánchez, and B. Ginneken. Pulmonary nodule detection in ct images: False positive reduction using multi-view convolutional networks. IEEE TMI, 35:1160–1169, 2016.
  • [14] H. Sowrirajan, J. Yang, A. Ng, and P. Rajpurkar. Moco pretraining improves representation and transferability of chest x-ray models. In MIDL, 2021.
  • [15] X. Tao, Y. Li, W. Zhou, K. Ma, and Y. Zheng. Revisiting rubik’s cube: self-supervised learning with volume-wise transformation for 3d medical image segmentation. In MICCAI, pages 238–248. Springer, 2020.
  • [16] A. van den Oord, Y. Li, and O. Vinyals. Representation learning with contrastive predictive coding. ArXiv, abs/1807.03748, 2018.
  • [17] Z. Wu, Y. Xiong, S.X. Yu, and D. Lin. Unsupervised feature learning via non-parametric instance discrimination. CVPR, pages 3733–3742, 2018.
  • [18] Y. Xie, J. Zhang, Z. Liao, Y. Xia, and C. Shen. Pgl: Prior-guided local self-supervised learning for 3d medical image segmentation. ArXiv, abs/2011.12640, 2020.
  • [19] C. You, Y. Zhou, R. Zhao, L. Staib, and J.S. Duncan. Simcvd: Simple contrastive voxel-wise representation distillation for semi-supervised medical image segmentation. IEEE TMI, 2022.
  • [20] D. Zeng, Y. Wu, X. Hu, X. Xu, H. Yuan, M. Huang, J. Zhuang, J. Hu, and Y. Shi. Positional contrastive learning for volumetric medical image segmentation. In MICCAI, 2021.
  • [21] X. Zhang, S. Feng, Y. Zhou, Y. Zhang, and Y. Wang. Sar: Scale-aware restoration learning for 3d tumor segmentation. In MICCAI, 2021.
  • [22] H. Zheng, J. Han, H. Wang, L. Yang, Z. Zhao, C. Wang, and D.Z. Chen. Hierarchical self-supervised learning for medical image segmentation based on multi-domain data aggregation. In MICCAI, 2021.
  • [23] H.Y. Zhou, C.K. Lu, S. Yang, X. Han, and Y. Yu. Preservational learning improves self-supervised medical image models by reconstructing diverse contexts. ICCV, pages 3479–3489, 2021.
  • [24] Z. Zhou, V. Sodha, M.M.R. Siddiquee, R. Feng, N. Tajbakhsh, M.B. Gotway, and J. Liang. Models genesis: Generic autodidactic models for 3d medical image analysis. MICCAI, 11767:384–393, 2019.
  • [25] X. Zhuang. Challenges and methodologies of fully automatic whole heart segmentation: a review. Journal of healthcare engineering, 4(3):371–407, 2013.