Temporal Clustering with External Memory Network for Disease Progression Modeling
Abstract
Disease progression modeling (DPM) involves using mathematical frameworks to quantitatively measure the severity of how certain disease progresses. DPM is useful in many ways such as predicting health state, categorizing disease stages, and assessing patients’ disease trajectory, etc. Recently, with the wider availability of electronic health records (EHR) and the broad application of data-driven machine learning methods, DPM has attracted much attention yet remains two major challenges: (i) Due to the existence of irregularity, heterogeneity, and long-term dependency in EHRs, most existing DPM methods might not be able to provide comprehensive patient representations. (ii) Lots of records in EHRs might be irrelevant to the target disease. Most existing models learn to automatically focus on the relevant information instead of explicitly capture the target-relevant events, which might make the learned model suboptimal. To address these two issues, we propose Temporal Clustering with External Memory Network (TC-EMNet) for DPM that groups patients with similar trajectories to form disease clusters/stages. TC-EMNet uses a variational autoencoder (VAE) to capture internal complexity from the input data and utilizes an external memory work to capture long-term distance information, both of which are helpful for producing comprehensive patient health states. Last but not least, the k-means algorithm is adopted to cluster the extracted comprehensive patient representation to capture disease progression. Experiments on two real-world datasets show that our model demonstrates competitive clustering performance against state-of-the-art methods and is able to identify clinically meaningful clusters. The visualization of the patient representations shows that the proposed model can generate better patient health states than the baselines.
Index Terms:
disease progression modeling, deep learning, temporal clusteringI Introduction
With the recent development of deep learning and the accumulation of electronic health records (EHR), also known as time-series data, there has been an increasing effort in clustering EHR data in order to discover meaningful patterns throughout longitudinal health information. Moreover, chronic diseases, such as Parkinson’s disease (PD) and Alzheimer’s disease (AD), can have various outcomes even with a limited number of patients. Such diseases are heterogeneous in nature and often evolve at unique patterns that trigger distinct responses to therapeutic interventions based upon different conditions [1]. Thus, it has become crucial to develop a disease progression modeling (DPM) system to capture certain progression patterns, provide early detection to critical situations, and yield clinically helpful information to improve the quality of care.
Traditionally, DPM or disease clustering/staging is developed by domain experts with extensive clinical experience, in which disease stages are defined separately and based solely on the values of one or a few biomarkers [2, 3]. Nevertheless, developing a DPM system requires long-term observation and human labor, and the result is often based on known biomarkers and acknowledged covariants, which makes it difficult to develop a DPM system for disease with limited knowledge on biomarkers that have not been well-studied. In recent years, the rapid growth of data-driven machine learning methods has motivated a great effort in developing DPM models. There are two main approaches when it comes to DPM: 1) The problem is formed as a risk prediction task with label information based on patient representation that is extracted from the last layer of the model. [4, 5, 6, 7, 8]. 2) The problem is formed as a traditional unsupervised, patient clustering/subtyping problem where the model is trained to separate the patient into multiple groups [9, 10, 11]. Leveraging disease outcomes during the training process can prevent the model from forming heterogeneous clusters. However, for certain diseases, diagnosis labels are often unavailable at each patient visit due to limited knowledge of the disease. Moreover, deep learning models that are designed for supervised tasks may not perform well when training in an unsupervised fashion. Therefore, there is a need for developing a DPM framework that can handle both situations with respect to the availability of training labels. However, most developed deep learning models for disease progression modeling suffers from the following limitations:
-
•
Irregularity and heterogeneity: Many diseases are heterogeneous in nature and EHR data often has high internal complexity. Due to the complexity of effectively encoding various health conditions into patient representation, accurate DPM still remains a challenging problem.
-
•
Long-term Dependency: RNNs are long known to suffer from modeling long-term dependency since it tends to forget earlier information when the input sequence is long. Disease progression modeling, especially for chronic disease, requires long-term observation of the patient in order to provide a comprehensive view for decision making.
-
•
Target Awareness: Most rnn-based methods derive patient representations directly from the hidden states of the model. Such an approach neglect the contribution of target-relevant information. In fact, real-world clinical decisions made by doctors are often based upon past diagnoses as well.
To address these challenges, we propose Temporal Clustering with External Memory Network (TC-EMNet) for disease progression modeling via both supervised and unsupervised settings. TC-EMNet leverages a variation autoencoder framework and a memory network to deal with data irregularity and long-term dependency problems of RNNs respectively. At each time step, TC-EMNet takes EHR medical records as input and encodes the input feature using a recurrent neural network to get hidden representations. Then TC-EMNet samples from the hidden state to form a latent representation. Meanwhile, the hidden state is stored in a global-level memory network, which in turn outputs a memory representation based on current memory cells. The memory representation is then concatenated with the current latent representation to form the patient representation at the current time step. When the training label is available, the model also employs a patient-level memory network to process label information up to most recent visit and outputs target-aware memory representation. We combine memory representations from global-level and patient-level memory networks using a calibration process. TC-EMNet is trained with reconstruction objective under unsupervised setting and prediction objective under supervised setting.
In this paper, our contributions are four fold:
-
•
We propose a novel deep learning framework, namely TC-EMNet for disease progression modeling under both supervised and unsupervised settings.
-
•
TC-EMNet uses a combined recurrent neural network and variational auto-encoder (VAE) architecture to capture the irregularity in data and heterogeneity nature of the disease.
-
•
Under superviesd setting, TC-EMNet employs dual memory network architecture to leverage both hidden representations from the input data and clinical diagnosis to produce accurate patient representations.
-
•
Experiments on two world datasets show that TC-EMNet yields competitive clustering performance over state-of-the-art methods and is able to find clinically interpretable disease clusters/stages.
The remainder of the paper is organized as follows. Section II briefly reviews existing works related to DPM, temporal clustering, and VAE. Section III describes the technical details of the proposed model (TC-EMNet). Section IV and V present experimental results and discussions. Finally, Section VI concludes the paper.

II Related Work
II-A Disease Progression Modeling
Disease progression modeling (DPM) plays a very important role in the healthcare domain, especially for chronic diseases such as Parkinson’s Disease (PD) and Alzheimer’s Disease (AD). A well-performed disease progression modeling system can not only provide early detection or diagnosis but also discover clinically meaningful patterns for certain groups of trajectories. Most probabilistic models for DPM are based on the hidden markov model (HMM). For example, [12] derived a deep probabilistic model based on sequence-to-sequence architecture to model progression dynamic on UK Cystic Fibrosis Registry. [9] introduced a continuous-time Markov process to learn a discrete representation of each progression state. Moreover, deep learning methods have also been developed for disease progression modeling. [13] proposed a CNN-based model to jointly learn features from MR images combined with demographic information to predict Alzheimer’s Disease progression patterns. [14] designed a prediction framework using generative models to forecast the distribution of patients’ outcomes. DPM can be regarded as a classification problem, where diagnosis labels are leveraged in favor of model training. On the other hand, DPM can also be seen from an unsupervised perspective where the goal is to discover potential disease states or patient subtypes throughout patients’ medical history [15]. However, DPM still remains a challenge due to the high complexity of data introduced by irregular progression patterns for certain chronic diseases.
II-B Temporal clustering
Temporal clustering, widely known as time-series clustering, is a data-driven method to cluster patients into subgroups based on time-series observation. Temporal clustering can be considered a challenging task often because of the high dimensionality of the dataset and multiple time steps for each data sample. Recent advances have been focused on leveraging the latent representation learned by recurrent neural network (RNN) for temporal clustering, which was motivated by the success of RNN modeling time-series data. Moreover, due to the emerging availability of electronic health records (EHR) that introduced large-scale and normalized context for individual patients, the deep learning approach become capable of learning more comprehensive patterns and achieving better performance on several critical tasks. [16] introduces a time-aware mechanism to long short term memory cells to capture progression patterns with irregular time-interval. [4] proposed an actor-critic algorithm for predictive clustering where, instead of defining a similarity measure for clustering, a cluster embedding is trained to represent each disease stage. [17] proposed an auto-encoder to reconstruct relevant features for sepsis with attention and showed that the proposed model can identify interpretable patient subtypes. Nevertheless, there is only limited literature that focuses on DPM using temporal clustering techniques.
II-C Variational Autoencoder
Variational autoencoder (VAE) is a type of generative model that can handle complicate distributions. VAEs are effective against modeling complex data structures and are widely adopted to solve many real-world problems range from image generation to anomaly detection [18, 19, 20]. It has also several successful applications with healthcare data [21]. [22] proposed to use VAE to impute missing values for electronic health data with uncertainty-aware attention. Experiments on real-world datasets show that VAE can capture the complexity of EHR distribution. [14] leveraged the VAE framework to forecast disease states for Parkinson’s Disease (PD) and Alzheimer’s Disease (AD). Nonetheless, the latent representation learned from VAE can be drawn from unrealistic distribution if trained without any constraints.
III Methodology
III-A Problem Definition
Let and be the random variables for input feature space and label space accordingly. Here we focus on a clustering problem, where we are given a population of time-series data consisted of paired sequences of observations for patients. denotes the time stamps for each patients at which the observations are made.
We aim to identify clusters for time-series data, each corresponding to a disease stage. Each cluster consists of homogeneous data samples, represented by the centroids based on certain similarity measures.
III-B Method
This section presents our proposed framework. Here we discuss disease progression modeling under both supervised and unsupervised settings, where our proposed question requires estimating the underlying distribution of all possible disease stages. Such a DPM framework can help the doctors identify meaningful characteristics in both times when a disease has certain diagnosis labels but possible underlying disease stages and when a disease has no well-defined diagnosis labels.
The framework consists of three components: the encoder, the memory network, and the clustering network. For each patient, a recurrent neural network is deployed to encode the patient’s information. The memory network controls the overall long-term information at each timestamp. Specifically, when a hidden representation is generated based on current and previous observation at timestamp , the hidden state is read by the memory network and updates the memory storage. Next, a latent variable is drawn from the prior distribution conditioned on the hidden state that is generated from the memory network. Then, we either yield prediction outcomes or reconstruct the current observation accordingly. We take the hidden presentation from the last layer of the model for clustering.
III-B1 Encoder Network
The encoder network takes the current observation and the hidden state from the previous timestamp and yields the hidden representation that can interact with the external network. Specifically, a LSTM cell is adopted to generate and update the hidden state:
(1) |
where is the current observation at timestamp and is the hidden state from previous step. At each timestamp, the encoder network maps a sequence of time-series input to a hidden representation , where is the subspace of latent representation. The hidden representation will be interacting with the external memory network to form an accurate representation.
III-B2 Memory Network
Long-term information plays an important role in disease progression modeling, since, in the context of chronic disease, the health conditions from the past will affect the current disease stages of the patient. In addition, historical information should be stored in an efficient way such that it can provide useful guidance towards the patient’s current health state at different timestamps. To this end, we propose an external memory network to capture long-term information throughout the progression modeling process. Our proposed memory network is closely related to [23], which has several successful applications in the field of natural language processing. Similarly, we define memory slots to represent historical information that can be extracted and retrieved at any given timestamp. At each timestamp, the hidden state from the encoder network is recorded and written into the memory cells. By pushing through a series of observations, the memory network will process continuous representations for each individual visit so that a more comprehensive review of the patient can be utilized during the clustering/staging process.
Memory Reading
We denote a clinical sequences record , where t stands for index or timestamps of the given record. In memory network, after receiving a hidden representation from the encoder network, the network will produce an external representation based on reading weight of the memory slots. Specifically, can be expressed as:
(2) | ||||
where denotes the number of memory slots, is the memory representation with hidden size . is the strength vector that can be learned through the reading operation and is the cosine similarity measure. Memory reading operation is built upon the idea that not all records in the sequence contribute equally to the current health state of the patient. Hence, the weights are computed using the softmax function based on the cosine similarity of the current hidden states and all the previous memories.

Memory Writing
Memory writing stores latent representation into memory slots. We use a fixed number of slots to denote the overall memory size. The dimension of the continuous space for each memory slot is and we use to denote the dimension of hidden representation . The hidden state is non-linearly projected into the memory space using a matrix A, , where is the new input memory representation. Memory writing aims to filter out non-related information and stores only personalized information based on the current hidden state. Mathematically, memory writing can be expressed as:
(3) |
where and are gated vectors that control the information flow between the previous and current memory vector.
III-B3 Clustering Network
After obtaining the representation of the observation through the encoding network, i.e the prior network, and updating the memory cell at current timestamp , we follow the traditional framework of variational autoencoder (VAE) [24] to compute the mean and standard deviation vectors through the posterior network. We assume that the output is a Gaussian distribution. The computation process can be expressed as:
(4) | ||||
where is the hidden state and is the observation at timestamp . is posterior functions described by feed-forward neural networks. We then draw samples from the posterior Gaussian distribution using the reparameterization trick:
(5) |
where , and is the latent representation. indicates element-wise multiplication. The reparameterization trick allows the gradient to backpropagate through the sampling process. Lastly, depends on the availability of diagnosis labels, the clustering network will be trained on two different objectives. When diagnosis label is used, the clustering network is directly trained to predict the label information:
(6) |
where is a feed-forward network that outputs probabilities of each label. When diagnosis label is not available, we trained the framework to reconstruct the observation from the latent variable conditioned on the memory state , denote as:
(7) |
where is the reconstructed input, is a feed forward network and is the concatenation. During cluster phase, we use euclidean distance-based k-means algorithm on the latent variable .
III-B4 Dual Memory Network Architecture
Under a clinical setting, doctors often provide diagnosis labels based on patients’ current and past medical events. Such information can be target health conditions or a diagnosis. Under supervised setting when the label is available during training, we further utilize a patient-level memory network to capture diagnosis information during each visit. Compared to global-level memory network, patient-level memory network at current memory slot can only access diagnosis up to previous timestamp, namely, . patient-level memory network only reads and writes diagnosis information which later is combined with a global-memory network for clustering. We propose a calibration process to integrate representations from two memory networks, as follows:
(8) | ||||
where and is the global-level and patient-level memory network respectively. This memory calibration process can be regarded as a point-wise attention mechanism.
III-C Objective Function and Optimization
Here, we present our training objectives and optimization process. As mentioned in previous sections, the entire network can be trained from end to end using maximum likelihood estimation (MLE). To solve the intractable marginalization for the latent variable , we use the variational lower bound parameterized by to approximate the true distribution, which we assume to be Gaussian. After the memory work reading and writing, We use the latent variable at timestamp to identify the disease stages. Here we restrict the latent variable to be a multivariate Gaussian distribution, which enforces the same for the posterior. We learn the generative parameter using maximum likelihood estimation (MLE):
(9) | |||
However, the marginalization of is intractable for complicated functions (for instance neural networks). Thus, we need to derive a variational lower bound (i.e. variational Bayesian method) to approximate the logarithm of the marginal probability of the observation, which is as follows:
(10) | ||||
where the inequality can be obtained using Jensen’s inequality and the variational lower bound involves the probability that are parameterized by , which eventually approximate the intractable true posterior distribution . Since health-related data is often associated with high-dimensional and general more complicated distribution, we introduce the latent variable to capture the internal stochasticity from the data. We can train the entire clustering network end-to-end using stochastic optimization techniques. After obtaining the variational lower bound, the optimization follows the KL divergence that is the difference of log-likelihood and the variational lower bound:
(11) | ||||
where and represents the model parameter and proxy posterior accordingly. The equation holds if the distribution of is equal to the true distribution. When diagnosis label is used during training, we use the cross-entropy loss to directly predict the outcome from the combined latent representation denoted as:
(12) | ||||
When the model is trained in a unsupervised manner, the overall objective function combined with the reconstruction loss becomes:
(13) | ||||
where we use the mean square error (MSE) for reconstruction loss and is a hyperparameter to prevent VAE from KL vanishing problem. We adopt a linear annealing schedule for based on training steps denoted as:
(14) |
where is a threshold value. Last but not least, we use the k-means algorithm [25] on the patient representation to perform clustering.
w/o label | with label | |||||
Model | Purity | NMI | RI | Purity | NMI | RI |
RNN | 0.67990.00 | 0.14150.01 | 0.14060.02 | 0.85320.00 | 0.40200.01 | 0.38050.01 |
Bi-LSTM | 0.68100.02 | 0.15400.02 | 0.15590.02 | 0.86740.00 | 0.40920.01 | 0.40420.02 |
RETAIN | 0.69030.02 | 0.17870.01 | 0.16710.01 | 0.71440.02 | 0.25720.01 | 0.18380.03 |
Dipole | 0.68390.00 | 0.17070.01 | 0.14520.00 | 0.89040.01 | 0.46740.01 | 0.47760.02 |
StageNet | 0.69430.01 | 0.20020.01 | 0.17910.01 | 0.85130.01 | 0.40450.03 | 0.37440.01 |
AC-TPC | - | - | - | 0.82140.03 | 0.33620.07 | 0.38270.09 |
VAE | 0.66510.02 | 0.10230.02 | 0.11170.02 | 0.64950.04 | 0.17180.05 | 0.10420.04 |
Memory Network | 0.68870.02 | 0.13920.01 | 0.15840.02 | 0.82620.01 | 0.36030.01 | 0.35380.02 |
0.70400.01 | 0.19670.02 | 0.18910.02 | 0.89040.00 | 0.46790.01 | 0.48890.01 | |
- | - | - | 0.91260.01 | 0.47890.01 | 0.49230.02 |
w/o label | with label | |||||
Model | Purity | NMI | RI | Purity | NMI | RI |
RNN | 0.72210.00 | 0.30890.01 | 0.31200.01 | 0.76400.02 | 0.42220.04 | 0.36630.03 |
Bi-LSTM | 0.72640.00 | 0.31700.00 | 0.29760.01 | 0.76740.03 | 0.44560.05 | 0.35750.05 |
RETAIN | 0.52410.02 | 0.11880.01 | 0.06190.01 | 0.75100.01 | 0.40720.03 | 0.33610.01 |
Dipole | 0.72330.00 | 0.32000.00 | 0.31530.00 | 0.80330.01 | 0.49470.01 | 0.44760.02 |
StageNet | 0.72520.01 | 0.33050.00 | 0.32340.01 | 0.78390.01 | 0.47000.03 | 0.38400.01 |
AC-TPC | - | - | - | 0.81510.01 | 0.49840.03 | 0.51290.01 |
VAE | 0.71610.00 | 0.35760.01 | 0.31530.00 | 0.79420.01 | 0.44520.00 | 0.37820.01 |
Memory Network | 0.69960.01 | 0.28090.01 | 0.2581 0.02 | 0.76890.01 | 0.44820.01 | 0.45970.01 |
0.74520.00 | 0.37730.00 | 0.37420.01 | 0.82560.00 | 0.50530.00 | 0.48230.01 | |
- | - | - | 0.83390.00 | 0.50350.00 | 0.49930.01 |
IV Experiments
We evaluated our proposed model on two real-world datasets, Alzheimer’s Disease Neuroimaging Initiative (ADNI) and Parkinson’s Progression Markers Initiative (PPMI) dataset. All dataset can be accessed on IDA website111https://ida.loni.usc.edu/. The code can be found on GitHub222https://github.com/Ericzhang1/TC-EMNet.git.
IV-A Datasets
IV-A1 ADNI Dataset
Alzheimer’s disease (AD) is a chronic neurodegenerative disease that is often related to behavior and cognitive impairment. ADNI is a longitudinal study that aims to explore early detection and tracking of AD based on imaging, biomarkers, and genetic data collected throughout the process [26]. The dataset consists of a total of 11651 visits over 1346 patients with 6 months intervals. For each patient, 21 variables are collected and processed, including 16 time-varying features (brain function, cognitive tests) and 5 static features (background, demographics). 3 diagnose labels are assigned by doctors at each visit for the patient, including control normal (CN), Mild Cognitive Impairment (MCI), and AD, which indicates the severity of how AD symptoms have progressed on each patient.
IV-A2 PPMI Dataset
Parkinson’s Progression Markers Initiative (PPMI) is a longitudinal study aiming to evaluate patients’ progression on Parkinson’s disease (PD) based on biomarkers [27]. The dataset consists of a total of 13685 visits over 2145 patients with irregular time intervals. For each patient, 79 features based on motor and non-motor symptoms are collected, including cognitive assessment, lab tests, demographic information, and biospecimens. Since the dataset does not provide a diagnosis label per visit for each patient, we use Hoehn and Yahr (HY) scores as labels for our evaluation. HY scores, ranges from 0 to 5, indicate the severity of patients’ symptoms of Parkinson’s disease. We use the mean and last occurrence carried forward method to impute missing values.
IV-B Baselines
We compare our proposed model to several state-of-the-art methods, ranged from vanilla RNNs to multi-layer attention models. Since here we consider disease progression modeling under both supervised and unsupervised settings, we adjusted the architecture of the baseline models to fit the objective accordingly. For baselines that cannot be modified interchangeably, we did not collect the result under the corresponding setting. For all experiments, we use k-means clustering on the hidden representations from the last layer to report the clustering performance.
-
•
RNN [28]: A single RNN cell with an additional layer of feed-forward neural network. The model is trained with cross-entropy loss and reconstruction objective accordingly.
-
•
Bi-LSTM [29]: Similar to RNN model, a Bi-directional LSTM is used with a reconstruction objective, the model takes both directions of the sequence data into account and is showed to capture richer information compare to single direction.
-
•
RETAIN [30]: An interpretable deep learning model that is based on recurrent neural network and reverse time attention mechanism. The RETAIN model learns the importance of hospital records through attention weights. We modify the last layer of RETAIN and train the model based on the prediction and reconstruction objective.
-
•
Dipole [31]: A interpretable bidirectional recurrent neural network that employs attention mechanism to leverage both past and future visits. We use concatenation-based attention mechanism for testing and, similar to RETAIN, we adjust the last layer of the model accordingly.
-
•
StageNet [5]: A recent risk prediction model that learned to extract disease progression patterns during training and leveraged modified LSTM cell with an attention mechanism. The progression pattern at each timestamp is re-calibrated accordingly using a convolution network.
-
•
AC-TPC [4]: A recent deep predictive clustering network that consists of an encoder network, selector, and a predictor. The model is first initialized using a prediction objective and then optimized to train a cluster embedding using the actor-critic algorithm. This method cannot be trained without label information.
-
•
VAE [32]: A vanilla variational autoencoder model using a LSTM cell as encoder and trained with prediction and variation objective respectively. Note that this baseline method can be served as an ablation example against our proposed method.
-
•
Memory Network: A vanilla global-level memory network with reading and writing mechanism described previously. The network reads and writes the EHR sequence directly and the k-means algorithm is applied directly to the hidden memory representation.
-
•
: Unsupervised version of TC-EMNet. When the training label is not available, only a global-level memory network is used to produce memory representation. We also train the model for the prediction task and set it as an ablation example against supervised version of TC-EMNet.
-
•
: Supervised version of TC-EMNet. When the training label is available, a patient-level memory network is used to combine with a global-level memory network to produce target-aware memory representations.
Hyperparameter | Range |
---|---|
hidden size | |
latent variable size | |
x | |
learning rate | |
batch size |








ADNI Dataset | ||||
---|---|---|---|---|
Features | ||||
Cluster I | RAVLT_learning | Ventricles | WholeBrain | ICV |
RAVLT_perc_forgetting | RAVLT_forgetting | ADAS13 | RAVLT_immediate | |
Cluster II | ICV | RAVLT_perc_forgetting | ADAS13 | Ventricles |
serial | RAVLT_immediate | CDRSB | ||
Cluster III | RAVLT_perc_forgetting | serial | ICV | RAVLT_learning |
Entorhinal | Hippocampus | Ventricles | WholeBrain | |
PPMI Dataset | ||||
Features | ||||
Cluster I | Global Spontaneity of Movement | Speech | Anxious Mood | Arising from Chair |
Right leg | Getting Out of Bed | Pronation-Supination (left) | ||
Cluster II | Posture | Rest tremor amplitude | Dopamine | Rigidity |
Saliva + Drooling | Anxious Mood | Global Spontaneity of Movement | ||
Cluster III | Postural Stability | Cognitive Impairment | Rest Tremor Amplitude | Pronation-Supination (left) |
Dopamine | Standing | Rigidity | ||
Cluster IV | Pronation-Supination (left) | Standing | Postural Stability | Chewing |
Cognitive Impairment | Dopamine | Right Hand | ||
Cluster V | Dopamine | Cognitive Impairment | Hallucinations | Chewing |
Dressing | Pronation-Supination (left) | Arising from Chair | ||
Cluster VI | Rigidity | Serial | Rigidity | Standing |
Apathy | Constipation Problems | Cognitive Impairment | Dopamine |
Model | # of trainable parameters |
---|---|
Dipole | 279k |
StageNet | 283k |
AC-TPC | 143k |
163k | |
174k |
IV-C Model Training and Implementation Details
As mentioned previously, our proposed network is continuous and differentiable. We can train the network using stochastic optimization techniques. All neural networks in the proposed network are feed-forward networks. We implemented our solution using Pytorch [33] and trained the model on a single Nvidia Volta V100 GPU with 16GB memory. We adopt gradient accumulation when dealing with out-of-memory problems. We select hyperparameters through random search as shown in table III. For our model, we set both hidden size and latent variable size to be 128. We adopt Adam optimizer with a learning rate of . The model is trained with batch size for epochs. is set to . We split the dataset into training, validation, and testing set with a ratio of and report the performance of fold cross-validation for both datasets. A detailed description of the optimization process of our proposed framework can be found in Algorithm 1. The average running time of our proposed framework on both datasets is about 2 hours for cross-validation. For the implementation of other baseline methods, we implement RNN and Bi-lstm methods with Pytorch. We adopt implementations from Pyhealth [34] for RETAIN, Dipole, and StageNet. And we adopt implementation from [4] for AC-TPC. All baseline methods share the same hyperparameter searching space.
IV-D Evaluation Metrics
To evaluate the clustering performance of our model, we use purity score (purity), normalized mutual information (NMI) [35], and adjusted rand index (ARI) [36]. Purity score is ranged between to , indicating the extent to which a cluster is consist of single class. NMI ( to ) represents the mutual information between each clusters with being perfect clustering. ARI derives from the Rand index and measures the percentage of the correct cluster assignment. Mathematically, the metrics can be expressed as follows:
(15) | ||||
where is the total number of samples, and denotes the cluster assignment and true label respectively, is the mutual information function and is the entropy, and RI are the expectation value and Rand index accordingly.
V Results
V-A Clustering Performance
A quantitative comparison of the clustering performance on ADNI and PPMI dataset is shown in table I and table II respectively. We set the cluster assignments to the number of class/diagnosis for each dataset, i.e. for ADNI (diagnosis label) and (NHY score) for PPMI. We want the model to identify the individual disease stages both when there is only limited knowledge known to a certain disease, i.e class/diagnosis is not available and when diagnosis label is available, and thus provide insightful and interpretable information to help discover corresponding treatment to individual treatment. We compare our proposed method with the aforementioned baselines in terms of clustering performance. It is clear that our method has demonstrated competitive performance against all baseline methods across all evaluation metrics for both datasets. We note that it is generally difficult to identify clusters without the presence of label information as indicated by low NMI and RI scores. However, TC-EMNet outperforms baseline by a large margin in terms of NMI and RI scores when clustering with label. Training under supervised setting yields significantly better clustering performance compared to training under unsupervised setting. This is due to fact that the correlation between diagnosis and input features is encoded into each hidden representation. Although AC-TPC has better performance in terms of RI on the PPMI dataset. The method relies on pre-training the model with over epochs, which could result in the model memorizing the input data. Both Dipole and StageNet have comparable performance. However, it is worth mentioning that both models have leveraged attention over multi-layer RNNs, which introduces additional complexity to the model. A detailed comparison between the trainable parameters is shown in table V. Furthermore, we find that when training with label information, RI score can be negatively impacted compared to training without labels. Such phenomena are observed for multiple baseline methods. One explanation could be directly leveraging label information overwhelms the training process since labels possess strong prediction power compared to input features, making the model more biased towards dominated class when dealing with imbalanced datasets; thus, RI may drop as there are more false positives and false negatives. It also can be observed that leveraging external memory effectively captures long-term information and the TC-EMNet is capable of learning complexity from the input data. The patient-level memory network constructively binds with the global-level memory network to produce more comprehensive memory representations.


V-B Disease Stage
In order to interpret the disease stages and progression patterns found by TC-EMNet. We first selected three baseline models that have comparable performance against TC-EMNet and visualized the hidden representations in space using PCA [37]. The results are shown in Fig 3. We observed that in general most methods can produce distinct clusters for the ADNI dataset. However, for PPMI dataset, most baseline methods failed at producing effective clusters, whereas TC-EMNet produces distinct clustering results. This shows that TC-EMNet is able to constructively model long-term information between each visit in order to find effective representations. Next, we compute feature importance for every cluster based on the weights from the last layer of the network. The results are shown in table IV. It can be observed that for both datasets, each cluster is determined by a diverse range of features, which means it is easier to identify each patients’ progression patterns through observation. We also compute the centroid values for each cluster and plot the distribution in Fig 4, 5 for ADNI and PPMI datasets respectively. For ADNI dataset, our proposed model has determined significant features such as RAVLT_learning, RAVLT_perc_forgetting, ICV, ventricles. Rey’s Auditory Verbal Learning Test (RAVLT) scores are helpful in testing episodic memories and are very important indicators in identifying a patient’s progression in Alzheimer’s disease [38]. In particular, the learning test (RAVLT_learning) and percent forgetting test (RAVLT_perc_forgetting) are highly correlated and thus become crucial biomarkers for early detection in AD. It can be observed in Fig 4 that three clusters produced by our model have wide distribution for RAVLT testing values, which suggests three different patient subtypes. As for PPMI dataset, our model has found that the dopamine dysregulation syndrome (Dopamine) is a significant feature in identifying clusters. Studies have discovered that under clinical settings early characterization of Dopamine can aid the treatment for motor and non-motor complications for Parkinson’s disease [39]. There are also studies that showed that cognitive impairment (Cognitive impairment) is a strong indicator for Parkinson’s disease. Difference in cognitive impairment scores can reflect advanced progression in PD [40].
VI Conclusion
In this paper, we propose TC-EMNet for disease progression modeling on time-series data. TC-EMNet leverages VAE to model data irregularity and an external memory network to capture long-term dependency. We developed TC-EMNet to perform patient clustering/subtyping under both supervised and unsupervised settings. Under supervised setting, TC-EMNet leverages a dual memory network architecture to extract target-aware information from diagnosis to compute patient representations. Throughout the experiment on two real-world datasets, we showed that our model outperforms state-of-the-art methods and is able to identify interpretable disease stages that are clinically meaningful. TC-EMNet yields competitive clustering performance with limited complexity. In the real-world clinical setting, we hope that our model could help physicians identify patients’ progression patterns and discover potential disease stages to gain more understanding about chronic and other heterogeneous diseases.
Acknowledgment
This paper was funded in part by the National Science Foundation under award number CBET-2037398.
References
- [1] A. A. Kehagia, R. A. Barker, and T. W. Robbins, “Neuropsychological and clinical heterogeneity of cognitive impairment and dementia in patients with parkinson’s disease,” The Lancet Neurology, vol. 9, no. 12, pp. 1200–1213, 2010.
- [2] M. Ferrer, J. Alonso, J. Morera, R. M. Marrades, A. Khalaf, M. C. Aguar, V. Plaza, L. Prieto, and J. M. Anto, “Chronic obstructive pulmonary disease stage and health-related quality of life,” Annals of internal Medicine, vol. 127, no. 12, pp. 1072–1079, 1997.
- [3] S. Auer and B. Reisberg, “The gds/fast staging system,” International Psychogeriatrics, vol. 9, no. S1, pp. 167–171, 1997.
- [4] C. Lee and M. Van Der Schaar, “Temporal phenotyping using deep predictive clustering of disease progression,” in International Conference on Machine Learning. PMLR, 2020, pp. 5767–5777.
- [5] J. Gao, C. Xiao, Y. Wang, W. Tang, L. M. Glass, and J. Sun, “Stagenet: Stage-aware neural networks for health risk prediction,” in Proceedings of The Web Conference 2020, 2020, pp. 530–540.
- [6] T. Ma, C. Xiao, and F. Wang, “Health-atm: A deep architecture for multifaceted patient health record representation and risk prediction,” in Proceedings of the 2018 SIAM International Conference on Data Mining. SIAM, 2018, pp. 261–269.
- [7] Z. Sun, S. Ghosh, Y. Li, Y. Cheng, A. Mohan, C. Sampaio, and J. Hu, “A probabilistic disease progression modeling approach and its application to integrated huntington’s disease observational data,” JAMIA open, vol. 2, no. 1, pp. 123–130, 2019.
- [8] X. Zhang, J. Chou, J. Liang, C. Xiao, Y. Zhao, H. Sarva, C. Henchcliffe, and F. Wang, “Data-driven subtyping of parkinson’s disease using longitudinal clinical records: a cohort study,” Scientific reports, vol. 9, no. 1, pp. 1–12, 2019.
- [9] X. Wang, D. Sontag, and F. Wang, “Unsupervised learning of disease progression models,” in Proceedings of the 20th ACM SIGKDD international conference on Knowledge discovery and data mining, 2014, pp. 85–94.
- [10] V. Fortuin, M. Hüser, F. Locatello, H. Strathmann, and G. Rätsch, “Som-vae: Interpretable discrete representation learning on time series,” arXiv preprint arXiv:1806.02199, 2018.
- [11] L. Mou, P. Zhao, H. Xie, and Y. Chen, “T-lstm: A long short-term memory neural network enhanced by temporal information for traffic flow prediction,” Ieee Access, vol. 7, pp. 98 053–98 060, 2019.
- [12] A. M. Alaa and M. van der Schaar, “Attentive state-space modeling of disease progression,” in Advances in Neural Information Processing Systems, 2019, pp. 11 338–11 348.
- [13] M. Liu, J. Zhang, E. Adeli, and D. Shen, “Joint classification and regression via deep multi-task multi-channel learning for alzheimer’s disease diagnosis,” IEEE Transactions on Biomedical Engineering, vol. 66, no. 5, pp. 1195–1206, 2018.
- [14] X. Teng, S. Pei, and Y.-R. Lin, “Stocast: Stochastic disease forecasting with progression uncertainty,” IEEE Journal of Biomedical and Health Informatics, vol. 25, no. 3, pp. 850–861, 2020.
- [15] J. M. Dennis, B. M. Shields, W. E. Henley, A. G. Jones, and A. T. Hattersley, “Disease progression and treatment response in data-driven subgroups of type 2 diabetes compared with models based on simple clinical features: an analysis using clinical trial data,” The lancet Diabetes & endocrinology, vol. 7, no. 6, pp. 442–451, 2019.
- [16] I. M. Baytas, C. Xiao, X. Zhang, F. Wang, A. K. Jain, and J. Zhou, “Patient subtyping via time-aware lstm networks,” in Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining, 2017, pp. 65–74.
- [17] C. Yin, R. Liu, D. Zhang, and P. Zhang, “Identifying sepsis subphenotypes via time-aware multi-modal auto-encoder,” in Proceedings of the 26th ACM SIGKDD international conference on knowledge discovery & data mining, 2020, pp. 862–872.
- [18] C. P. Burgess, I. Higgins, A. Pal, L. Matthey, N. Watters, G. Desjardins, and A. Lerchner, “Understanding disentangling in -vae,” arXiv preprint arXiv:1804.03599, 2018.
- [19] I. Higgins, L. Matthey, A. Pal, C. Burgess, X. Glorot, M. Botvinick, S. Mohamed, and A. Lerchner, “beta-vae: Learning basic visual concepts with a constrained variational framework,” 2016.
- [20] J.-T. Kuo and K.-T. Chien, “Variational recurrent neural networks for speech separation.” INTERSPEECH, 2017.
- [21] B. Shickel, P. J. Tighe, A. Bihorac, and P. Rashidi, “Deep ehr: a survey of recent advances in deep learning techniques for electronic health record (ehr) analysis,” IEEE journal of biomedical and health informatics, vol. 22, no. 5, pp. 1589–1604, 2017.
- [22] E. Jun, A. W. Mulyadi, and H.-I. Suk, “Stochastic imputation and uncertainty-aware attention to ehr for mortality prediction,” in 2019 International Joint Conference on Neural Networks (IJCNN). IEEE, 2019, pp. 1–7.
- [23] S. Sukhbaatar, A. Szlam, J. Weston, and R. Fergus, “End-to-end memory networks,” arXiv preprint arXiv:1503.08895, 2015.
- [24] M. Lopez-Martin, B. Carro, A. Sanchez-Esguevillas, and J. Lloret, “Conditional variational autoencoder for prediction and feature recovery applied to intrusion detection in iot,” Sensors, vol. 17, no. 9, p. 1967, 2017.
- [25] K. Wagstaff, C. Cardie, S. Rogers, S. Schrödl et al., “Constrained k-means clustering with background knowledge,” in Icml, vol. 1, 2001, pp. 577–584.
- [26] C. R. Jack Jr, M. A. Bernstein, N. C. Fox, P. Thompson, G. Alexander, D. Harvey, B. Borowski, P. J. Britson, J. L. Whitwell, C. Ward et al., “The alzheimer’s disease neuroimaging initiative (adni): Mri methods,” Journal of Magnetic Resonance Imaging: An Official Journal of the International Society for Magnetic Resonance in Medicine, vol. 27, no. 4, pp. 685–691, 2008.
- [27] K. Marek, D. Jennings, S. Lasch, A. Siderowf, C. Tanner, T. Simuni, C. Coffey, K. Kieburtz, E. Flagg, S. Chowdhury et al., “The parkinson progression marker initiative (ppmi),” Progress in neurobiology, vol. 95, no. 4, pp. 629–635, 2011.
- [28] T. Mikolov, M. Karafiát, L. Burget, J. Cernockỳ, and S. Khudanpur, “Recurrent neural network based language model.” in Interspeech, vol. 2, no. 3. Makuhari, 2010, pp. 1045–1048.
- [29] Z. Huang, W. Xu, and K. Yu, “Bidirectional lstm-crf models for sequence tagging,” arXiv preprint arXiv:1508.01991, 2015.
- [30] E. Choi, M. T. Bahadori, J. A. Kulas, A. Schuetz, W. F. Stewart, and J. Sun, “Retain: An interpretable predictive model for healthcare using reverse time attention mechanism,” arXiv preprint arXiv:1608.05745, 2016.
- [31] F. Ma, R. Chitta, J. Zhou, Q. You, T. Sun, and J. Gao, “Dipole: Diagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks,” in Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining, 2017, pp. 1903–1911.
- [32] M. J. Kusner, B. Paige, and J. M. Hernández-Lobato, “Grammar variational autoencoder,” in International Conference on Machine Learning. PMLR, 2017, pp. 1945–1954.
- [33] A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga et al., “Pytorch: An imperative style, high-performance deep learning library,” Advances in neural information processing systems, vol. 32, pp. 8026–8037, 2019.
- [34] Y. Zhao, Z. Qiao, C. Xiao, L. Glass, and J. Sun, “Pyhealth: A python library for health predictive models,” arXiv preprint arXiv:2101.04209, 2021.
- [35] N. X. Vinh, J. Epps, and J. Bailey, “Information theoretic measures for clusterings comparison: Variants, properties, normalization and correction for chance,” The Journal of Machine Learning Research, vol. 11, pp. 2837–2854, 2010.
- [36] L. Hubert and P. Arabie, “Comparing partitions,” Journal of classification, vol. 2, no. 1, pp. 193–218, 1985.
- [37] A. M. Martinez and A. C. Kak, “Pca versus lda,” IEEE transactions on pattern analysis and machine intelligence, vol. 23, no. 2, pp. 228–233, 2001.
- [38] E. Moradi, I. Hallikainen, T. Hänninen, J. Tohka, A. D. N. Initiative et al., “Rey’s auditory verbal learning test scores can be predicted from whole brain mri in alzheimer’s disease,” NeuroImage: Clinical, vol. 13, pp. 415–427, 2017.
- [39] A. H. Evans and A. J. Lees, “Dopamine dysregulation syndrome in parkinson’s disease,” Current opinion in neurology, vol. 17, no. 4, pp. 393–398, 2004.
- [40] D. Verbaan, J. Marinus, M. Visser, S. M. van Rooden, A. M. Stiggelbout, H. A. Middelkoop, and J. J. van Hilten, “Cognitive impairment in parkinson’s disease,” Journal of Neurology, Neurosurgery & Psychiatry, vol. 78, no. 11, pp. 1182–1187, 2007.