This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Compact Binary Systems Waveform Generation with Generative Pre-trained Transformer

Ruijun Shi Equal contribution Department of Astronomy, Beijing Normal University, Beijing 100875, China    Yue Zhou Equal contribution Peng Cheng Laboratory, Shenzhen 518055, China    Tianyu Zhao Department of Astronomy, Beijing Normal University, Beijing 100875, China    Zhoujian Cao Department of Astronomy, Beijing Normal University, Beijing 100875, China    Zhixiang Ren Corresponding author renzhx@pcl.ac.cn Peng Cheng Laboratory, Shenzhen 518055, China
(April 2, 2025)
Abstract

Space-based gravitational wave (GW) detection is one of the most anticipated GW detection projects in the next decade, which promises to detect abundant compact binary systems. At present, deep learning methods have not been widely explored for GW waveform generation and extrapolation. To solve the data processing difficulty and the increasing waveform complexity caused by the detector’s response and second-generation time-delay interferometry (TDI 2.0), an interpretable pre-trained large model named CBS-GPT (Compact Binary Systems Waveform Generation with Generative Pre-trained Transformer) is proposed. For compact binary system waveforms, three models were trained to predict the waveforms of massive black hole binaries (MBHB), extreme mass-ratio inspirals (EMRIs), and galactic binaries (GB), achieving prediction accuracies of at most 99%, 91%, and 99%, respectively. The CBS-GPT model exhibits notable generalization and interpretability, with its hidden parameters effectively capturing the intricate information of waveforms, even with the complex instrument response and a wide parameter range. Our research demonstrates the potential of large models in the GW realm, opening up new opportunities and guidance for future researches such as complex waveforms generation, gap completion, and deep learning model design for GW science.

preprint: APS/123-QED

I Introduction

Refer to caption
Figure 1: Overview of CBS-GPT. The CBS-GPT model was trained separately for three kinds of GW sources (MBHB, EMRIs, and GB). The subsequent waveform can be extrapolated after feeding its corresponding preceding waveform into CBS-GPT. Details of data and model description are in Section II.

The first direct detection of a binary black hole merger (GW150914) [1, 2] by the Laser Interferometer Gravitational-Wave Observatory (LIGO) has opened an innovative window to understand the universe, which provides direct evidence for the validity of Einstein’s General Relativity. Gravitational wave (GW) observations will clarify many questions in astrophysics, cosmology, and fundamental physics [3, 4, 5, 6, 7, 8, 9, 10]. So far, the ground-based GW detectors have reported over a hundred compact binary coalesces (CBC) events [11], and recently Pulsar Timing Array (PTA) has also successfully detected sound evidence of the existence of Stochastic GW Background [12, 13, 14, 15]. To gain a deeper understanding and an overall picture of GW cosmology [16], the field of low-frequency GWs needs to be widely covered.

The space-based GW detection avoids terrestrial noise [17] and makes the detection of low-frequency (1040.1Hz10^{-4}-0.1\mathrm{Hz}) GW signals more promising. Spaced-based GW detectors like Laser Interferometer Space Antenna (LISA) [18], Taiji [19, 20] and Tianqin [21] have been planned and are scheduled for the 2030s. In particular, future space-based GW detection is expected to detect a richer variety of GW sources including massive black hole binaries (MBHB), extreme mass-ratio inspirals (EMRIs), and galactic binaries (GB) [18].

GW signals are extremely weak and usually buried in instrumental noise. With the improvement of detector sensitivity and the increasing amount of data, the computational complexity and timeliness demands for detection and parameter estimation are growing, which are challenging problems for traditional methods that based on computing power of central processing unit (CPU). With the rapid developing of graphics processing unit (GPU) computing power, Artificial Intelligence (AI) methods have shed some new light on this issue. Specifically, AI techniques have been successfully applied in various subjects such as GW signal detection [22, 23, 24, 25, 26, 27], parameter estimation [28, 29, 30], signal extraction and noise reduction [31, 32, 33, 34, 35, 36, 37, 38] with promising results. Additionally, the target of space GW detectors is also one type of complex and multi-scale waveforms (such as MBHB, EMRIs, and GB). Some previous studies focused on generating binary black hole (BBH) waveforms. Lee et al. [39] employed a Recurrent Neural Network (RNN) that is capable of generating BBH waveforms during the merging and ringdown phases of non-spinning binary black hole coalescence. Khan et al. [40] demonstrated that a vanilla transformer can learn quasi-circular and non-precessing BBH waveforms. Similarly, Chua et al. [41] used a greedy algorithm to build a reduced basis, enabling the rapid generation of BBH waveforms. Recently, large-scale language models (LLM) based on attention mechanism have shown their tremendous power in computer vision (CV) and natural language processing (NLP) [42, 43, 44]. Some studies indicate that similar architectures can be applied to the GW data analysis [36, 35]. Space-based GW detectors will observe more signals along with complex difficulties such as source confusion, gaps, and glitches [45]. It is critical to provide a set of data processing tools to address these issues. Deep learning holds promise for meeting these challenges.

In contrast to previous studies on AI waveform generation, which had limitations in considering the second-generation time-delay interferometry (TDI 2.0) responses, our paper takes a step further. Moreover, the parameter range of waveforms in prior investigations was relatively narrow. In our paper, we are committed to further investigation on more complex waveforms and train a model to facilitate solving downstream problems. We introduce CBS-GPT (Compact Binary Systems Waveform Generation with Generative Pre-trained Transformer) model, which is an interpretable, transformer-based, and self-supervised large model for prediction of compact binary sources (MBHB, EMRIs, and GB). In CBS-GPT (Figure 1), patching and hybrid embedding mechanisms are proposed for full extraction of waveform features. By utilizing the self-attention mechanism and mean square error loss, CBS-GPT is trained for each GW waveform source. The experiment results illustrate that CBS-GPT can accurately predict the subsequent waveform based on the input waveform. In this study, two models were trained to achieve extrapolation with different input-to-prediction length ratios. In the 20:1 extrapolation, the average overlap between the predicted and target waveforms of MBHB, EMRIs, and GB reaches 0.981, 0.912, and 0.991, respectively. In the 1:1 extrapolation, the average overlaps reached 0.990, 0.807, and 0.992 for MBHB, EMRIs, and GB, respectively. We have also discovered that waveform complexity can significantly influence the model’s prediction performance, and CBS-GPT can match the key frequencies effectively. Finally, through attention map visualization and correlation calculation, we discover that the attention map and its corresponding waveform present similar periodic distribution, which illustrates that CBS-GPT is able to learn waveform features even under the complex instrument response and a wide parameter range.

The rest of this paper is organized as follows. Section II describes data generation and the CBS-GPT model architecture. In Section III, we present our overlap and attention map results, and discuss interpretability outcomes as well as potential applications. Finally, Section IV highlights our findings based on the results.

II Methodology

II.1 Data

Refer to caption
(a) MBHB
Refer to caption
(b) EMRIs
Refer to caption
(c) GB
Figure 2: TDI2.0 response complicates waveforms. To simplify waveform comparison here, all waveforms were standardized to a maximum amplitude of 1. The Δt\Delta t in the figure represents the sampling rate. The effects of different parameters on time and frequency domain are shown on the left and right panels. (a) MBHB waveforms at different MtotM_{tot}. At high frequencies, the TDI response function has a greater impact. The gray line represents the TDI 2.0 transfer function in the frequency domain. (b) EMRIs waveforms at different e0e_{0}. As the eccentricity increases, the EMRIs waveform becomes more and more complex in the frequency domain. (c) GB waveforms at different ff. The GB signal is relatively simple and is a single-frequency signal.

Space-based GW detectors’ targets are GW signals at frequencies of [104,0.1]Hz[10^{-4},0.1]\text{Hz}. We focus on three compact binary sources that are of major interest for LISA: MBHB, EMRIs, and GB. Figure 2 displays data examples. Detailed information of the data generation process is given below.

II.1.1 MBHB

MBHB are one of the space-based GW detector’s main detection targets [18]. In this paper, SEOBNRv4_opt [46] (l=m=2l=m=2 mode) is used to generate the MBHB waveforms. The parameter space of the MBHB dataset is shown in Table 2(a). In Figure 2(a), the TDI 2.0 transfer function significantly affects high-frequency transmissions due to the lower total mass of MBHB. Firstly, we generate MBHB time-series waveforms with a length of 20,000 points with a sampling rate of 5 seconds. We train two models with different input-to-prediction token length ratios. The two scenarios are referred to as 20:1 and 1:1 extrapolation in the subsequent sections. Table 2 summarizes token information for each source. The 20:1 extrapolation involved predicting the subsequent 200 points after merging with an input sequence of the preceding 4000 points. The 1:1 extrapolation predicted the same 200 points after merging, but with an input sequence limited to the preceding 200 points. During the inference phase, the 4,000 valid points (or 200 valid points) before the merge time are fed into CBS-GPT to predict the succeeding 200 points, hence achieving a 20:1 extrapolation (or 1:1 extrapolation) prediction of the MBHB waveforms.

Table 1: Parameters distribution of training dataset and test dataset
Parameter Description Parameter distribution
MtotM_{tot} Total mass of massive black hole binaries m1+m2m_{1}+m_{2} log-Uniform [5.5,7]M[5.5,7]M_{\odot}
qq Mass ratio m2m1\frac{m_{2}}{m_{1}} Uniform [0.1,1][0.1,1]
S1z,S2zS^{z}_{1},S^{z}_{2} Spin parameters of two black holes Uniform [0.99,0.99][-0.99,0.99]
ι\iota, ψ\psi The inclination angle and polarization angle Uniform [0,π][0,\pi]
ϕc\phi_{c} Coalescence phase. Uniform [0,2π][0,2\pi]
λ\lambda Ecliptic longitude Uniform [0,2π][0,2\pi]
β\beta Ecliptic latitude Uniform [0,π][0,\pi]
(a) Parameter space of MBHB dataset
Parameter Description Parameter distribution
MM The mass of MBH Uniform [105107]M[10^{5}-10^{7}]M_{\odot}
mm The mass of stellar-mass compact Fix [10M][10M_{\odot}]
aa Spin parameter of MBH Uniform [103,0.8][10^{-3},0.8]
p0p_{0} Semi-latus rectum Uniform [10,16][10,16]
e0e_{0} Eccentricity Uniform [103,0.4][10^{-3},0.4]
ι0\iota_{0} The cosine of the orbit’s inclination angle from the equatorial plane Uniform [0.98,0.98][-0.98,0.98]
θS\theta_{S}, θK\theta_{K} The polar angles describing the sky location and the orientation of the spin angular momentum vector of the MBH Uniform [103,π][10^{-3},\pi]
θS\theta_{S}, ϕK\phi_{K} The azimuthal angles describing the sky location and the orientation of the spin angular momentum vector of the MBH Uniform [103,2π][10^{-3},2\pi]
Φφ,0,Φθ,0,Φr,0\Phi_{\varphi,0},\Phi_{\theta,0},\Phi_{r,0} The phase of azimuthal, polar, and radial modes Fix [0][0]
(b) Parameter space of EMRIs dataset
Parameter Description Parameter distribution
ff Frequency log-Uniform [4,2][-4,-2] Hz
f˙\dot{f} The derivative of ff Fix [1014][10^{-14}]
AA Amplitude Uniform [1023,1021][10^{-23},10^{-21}]
ι0\iota_{0}, ψ\psi, ϕ0\phi_{0} The inclination angle, polarization angle and initial phase Uniform [0,π][0,\pi]
λ\lambda Ecliptic longitude Uniform [0,2π][0,2\pi]
β\beta Ecliptic latitude Uniform [0,π][0,\pi]
(c) Parameter space of GB dataset
Table 2: The waveform information of different sources.
20:1 extrapolation 1:1 extrapolation
Input tokens 1000 50
Prediction tokens 50 50
MBHB 4 points/token 4 points/token
EMRIs 4 points/token 16 points/token
GB 4 points/token 32 points/token

II.1.2 EMRIs

EMRIs are a kind of black hole binary system with a mass ratio of m/M104107m/M\simeq 10^{-4}-10^{-7} and massive black holes (MBH) that have a mass range of M105107MM\simeq 10^{5}-10^{7}M_{\odot}. EMRIs waveforms are able to encapsulate the properties of space-time near a massive black hole. EMRIs are among the primary detection targets for the space-based GW detectors, possessing the potential to unveil new physical phenomena [47, 48, 49]. We employ FastEMRIsWaveforms (FEW) package [50] to generate EMRIs waveforms with a sampling rate of 5s. The EMRIs signals with a duration of 1 year are randomly sliced into five waveform segments containing 4,200 points for 20:1 extrapolation (or 1600 points for 1:1 extrapolation). For continuous GWs, the random slice can simulate variations in the phase and amplitude domain of the same signal, enhancing the model’s generalization capability. The parameter space of the EMRIs dataset is shown in Table 2(b). The complexity of the EMRIs waveform is visible in Figure 2(b). As the eccentricity increases, there is a corresponding increase in its complexity, which becomes particularly prominent in the frequency domain.

II.1.3 Galactic binary

Within the Milky Way galaxy, a substantial population of binary white dwarf systems exists, posing foreground noise challenges for space-based GW detectors. We use the following GB model to generate GB waveforms [51]:

h+src(t)\displaystyle h_{+}^{\text{src}}{(t)} =𝒜(1+cos2ι)cosΦ(t),\displaystyle=\mathcal{A}(1+\cos^{2}\iota)\cos\Phi(t), (1)
h×src(t)\displaystyle h_{\times}^{\text{sr}\mathrm{c}}(t) =2𝒜sinιsinΦ(t),\displaystyle=2\mathcal{A}\sin\iota\sin\Phi(t),
Φ(t)\displaystyle\Phi(t) =ϕ0+2πf0t+πf˙0t2+π3f¨0t3,\displaystyle=\phi_{0}+2\pi f_{0}t+\pi\dot{f}_{0}t^{2}+\frac{\pi}{3}\ddot{f}_{0}t^{3},
f¨0\displaystyle\ddot{f}_{0} =113f˙02f0.\displaystyle=\frac{11}{3}\frac{\dot{f}_{0}^{2}}{f_{0}}.

Similar to EMRIs, GB waveforms are generated with a duration of 1 year and a sampling rate of 1/15 Hz. Five slices of 4,200 points for 20:1 extrapolation (or 3200 points for 1:1 extrapolation) is randomly truncated for training and inference. The parameter space of the GB dataset is shown in Table 2(c).

II.1.4 Detector response and TDI 2.0

After generating the waveform, we project it into the LISA detector [52]. For LISA, the signals will be processed with TDI combination to suppress the overpowering laser noise. The response of space-based GW detectors is more intricate compared to ground-based detectors, accounting for factors such as satellite orbits and arm-length delays. The strain induced on link 12 is:

H12(t)=h+SSB(t)×ξ+(𝒖^,𝒗^,𝒏^12)\displaystyle H_{12}(t)=h^{\rm{SSB}}_{+}(t)\times\xi_{+}(\bm{\hat{u}},\bm{\hat{v}},\bm{\hat{n}}_{12}) (2)
+h×SSB(t)×ξ×(𝒖^,𝒗^,𝒏^12).\displaystyle+h^{\rm{SSB}}_{\times}(t)\times\xi_{\times}(\bm{\hat{u}},\bm{\hat{v}},\bm{\hat{n}}_{12}).

The ξ+,×\xi_{+,\times} refers to the antenna pattern:

ξ+(𝒖^,𝒗^,𝒏^12)\displaystyle\xi_{+}({\bm{\hat{u}}},\bm{\hat{v}},\bm{\hat{n}}_{12}) =(𝒖^𝒏^12)2(𝒗^𝒏^12)2,\displaystyle=\left(\bm{\hat{u}}\cdot\bm{\hat{n}}_{12}\right)^{2}-\left(\bm{\hat{v}}\cdot\bm{\hat{n}}_{12}\right)^{2}, (3)
ξ×(𝒖^,𝒗^,𝒏^12)\displaystyle\xi_{\times}(\bm{\hat{u}},\bm{\hat{v}},\bm{\hat{n}}_{12}) =2(𝒖^𝒏^12)(𝒗^𝒏^12),\displaystyle=2(\bm{\hat{u}}\cdot\bm{\hat{n}}_{12})(\bm{\hat{v}}\cdot\bm{\hat{n}}_{12}),

where 𝒏^12\bm{\hat{n}}_{12} is the link unit vector, 𝒖^\bm{\hat{u}} and 𝒗^\bm{\hat{v}} represent polarization vectors defined as the opposite direction of the polar and azimuthal angles in the Solar System Barycenter (SSB) frame respectively. Due to the longer arm lengths of space-based GW detectors, the influence of arm length needs to be taken into consideration. The time of transmission from spacecraft 2 is denoted as t2t_{2}, and after propagating over the arm length distance to reach spacecraft 1, the reception time is t1t_{1},

t1t2+L12C12c0L12H(x(λ),t(λ))𝑑λ,t_{1}\approx t_{2}+\frac{L_{12}}{C}-\frac{1}{2c}\int_{0}^{L_{12}}H(x(\lambda),t(\lambda))d\lambda, (4)

where L12L_{12} represents the arm length of the detector. The variable λ\lambda describes the path of the photon. and we approximate t1t_{1} to the first order as t(λ)t2+λ/ct2+L12/c.t({\lambda})\approx t_{2}+\lambda/c\approx t_{2}+L_{12}/c. Due to the slow motion of the space-based GW detector, the frequency shift is given by

y12(t1)\displaystyle y_{12}(t_{1})\approx 12(1𝒌^𝒏^12(t1))[H12(t1L12(t1)c\displaystyle\begin{aligned} \frac{1}{2\Big{(}1-\hat{\bm{k}}\cdot\hat{\bm{n}}_{12}(t_{1})\Big{)}}\Big{[}H_{12}\left(t_{1}-\frac{L_{12}(t_{1})}{c}\right.\end{aligned} (5)
𝒌^𝒙2(t1)c)H12(t1𝒌^𝒙1(t1)c)],\displaystyle\left.\left.-\frac{\hat{\bm{k}}\cdot\bm{x}_{2}(t_{1})}{c}\right)-H_{12}\left(t_{1}-\frac{\hat{\bm{k}}\cdot\bm{x}_{1}(t_{1})}{c}\right)\right],

where k^\hat{k} represents the propagation vector of the wave source.

Space-based GW detectors have unequal arm lengths, which results in significant laser frequency noise. To mitigate this issue, TDI techniques are commonly employed to suppress laser frequency noise [53, 54]. The first and second generation Michelson combinations, X1 and X2, are defined by [53],

X1=\displaystyle X_{1}= y13+𝑫13y31+𝑫131y12+𝑫1312y21\displaystyle y_{13}+\bm{D}_{13}y_{31}+\bm{D}_{131}y_{12}+\bm{D}_{1312}y_{21} (6)
[y12+𝑫12y21+𝑫121y13+𝑫1213y31],\displaystyle-\left[y_{12}+\bm{D}_{12}y_{21}+\bm{D}_{121}y_{13}+\bm{D}_{1213}y_{31}\right],
X2=\displaystyle X_{2}= X1+𝑫13121y12+𝑫131212y21+𝑫1312121y1\displaystyle X_{1}+\bm{D}_{13121}y_{12}+\bm{D}_{131212}y_{21}+\bm{D}_{1312121}y_{1} (7)
+𝑫13121213y31[𝑫12131y13+𝑫121313y31\displaystyle+\bm{D}_{13121213}y_{31}-[\bm{D}_{12131}y_{13}+\bm{D}_{121313}y_{31}
+𝑫1213131y12+𝑫12131312y21],\displaystyle\left.+\bm{D}_{1213131}y_{12}+\bm{D}_{12131312}y_{21}\right],

where the delay operators are defined by,

𝑫i1,i2,,inx(t)=x(tk=1n1Likik+1(t)).\bm{D}_{i_{1},i_{2},...,i_{n}}x(t)=x\Bigg{(}t-\sum_{k=1}^{n-1}L_{i_{k}i_{k+1}}(t)\Bigg{)}. (8)

The detector response and TDI 2.0 response of GW are calculated using Fastlisaresponse [52]. TDI 2.0 generates three channels X, Y, and Z. The variables Y and Z may be produced via cyclic permutation of the indices in Eq. 7. A more detailed derivation can be found in Section IV of the reference [52]. By combining X, Y, and Z, three independent channels A, E, and T are obtained,

A\displaystyle A =(ZX)/2,\displaystyle=(Z-X)/\sqrt{2}, (9)
E\displaystyle E =(X2Y+Z)/6,\displaystyle=(X-2Y+Z)/\sqrt{6},
T\displaystyle T =(X+Y+Z)/3.\displaystyle=(X+Y+Z)/\sqrt{3}.

The incorporation of response functions and TDI 2.0 combination introduces increased complexity to the waveform, especially in the high-frequency part. As depicted in Figure 2, MBHB waveforms exhibit significant differences at various parameter values.

II.2 CBS-GPT Model

Transformers [55] are a class of deep learning models that have exhibited excellent performance in various tasks, such as NLP [43] and CV [44]. We incorporate the masked self-attention mechanism and feed-forward neural network to build our CBS-GPT model.

Patching. Firstly, the input waveform is preprocessed by standardization, which facilitates the model in capturing waveform information more effectively:

I=standard(s)=smean(s)std(s)I=\text{standard}(s)=\frac{s-\text{mean}(s)}{\text{std}(s)} (10)

where s={si|i[0,N)}s=\{s_{i}|i\in[0,N)\} represents the input waveform, μ=mean(s)=1Ni=1Nsi\mu=\text{mean}(s)=\frac{1}{N}\sum_{i=1}^{N}s_{i} and std(s)=i=1N(siμ)2N\text{std}(s)=\sqrt{\frac{\sum_{i=1}^{N}(s_{i}-\mu)^{2}}{N}} represent the mean and standard deviation of the waveform, respectively. The standardization centers the original data to a mean of 0 and a standard deviation of 1, which makes features have equal weight in various analyses and is more suitable for machine learning algorithms that are sensitive to feature scales. Then, I={xi|i[0,N)}I=\{x_{i}|i\in[0,N)\} is divided into non-overlapping patches, and we refer to each patch as a "token" here. In our 20:1 extrapolation experiment for example, we have an input waveform with N=4200N=4200 sampling points, which is segmented into num=1,050num=1,050 tokens, and each token contains 44 points. Each token is treated as a vector, after patching, the standardized waveform II is processed into the input matrix I[1050,4]I^{{}^{\prime}}\in\mathbb{R}^{[1050,4]}.

Hybrid Embedding. The hybrid embedding module is utilized in our model, because each token contains richer physical information and cannot be tokenized by simple tokenizers as in NLP. As Figure 1 shows, it is combined with a token embedding layer and a positional embedding layer (Eq. 11). The token embedding layer performs linear projection to achieve dimension-matching with following encoder blocks, which meanwhile preserves the entire information of the input waveform. The positional embedding is also a linear layer that encodes positional relationships between tokens, which is rather important in improving prediction accuracy [55].

Ee=IWe\displaystyle E_{e}=I{{}^{\prime}}\ W_{e} (11)
Ep=𝐈numWp\displaystyle E_{p}=\mathbf{I}_{num}\ W_{p}
Ehybrid=Ee+Ep,\displaystyle E_{hybrid}=E_{e}+E_{p},

where We[4,dmodel]W_{e}\in\mathbb{R}^{[4,d_{model}]}, dmodel=2048d_{model}=2048 and Wp[num,dmodel]W_{p}\in\mathbb{R}^{[num,d_{model}]} are both learnable parameters, and 𝐈num\mathbf{I}_{num} represents an identity matrix with shape [num,num][num,num].

Encoder block. The encoder contains nblock=36n_{block}=36 blocks. Each block mainly consists of an attention module and a feed-forward neural network. As for the attention module, masked multi-heads self-attention (MMHSA) is adopted in our work, which enables information to be projected into matrices in different ways, thereby enhancing the expressive capacity of the model. The computation process of the attention module is as follows:

Qji=WjiQxj\displaystyle Q_{ji}=W_{ji}^{Q}x_{j} (12)
Kji=WjiKxj\displaystyle K_{ji}=W_{ji}^{K}x_{j}
Vji=WjiVxj\displaystyle V_{ji}=W_{ji}^{V}x_{j}
headji(Qji,Kji,Vji)=softmax(QjiKjiTmaskd)Vji,head_{ji}(Q_{ji},K_{ji},V_{ji})=\text{softmax}\left(\frac{Q_{ji}K_{ji}^{T}\cdot\mathrm{mask}}{\sqrt{d}}\right)V_{ji}, (13)
MMHSAj(Q,K,V)=Concat(head1,,headH)WjE,\mathrm{MMHSA}_{j}(Q,K,V)=\mathrm{Concat}(head_{1},...,head_{H})W_{j}^{E}\>, (14)
Hj=LayerNorm(MMHSA(Q,K,V))+xj,\mathrm{H_{j}^{\prime}}=\mathrm{LayerNorm}(\mathrm{MMHSA}(Q,K,V))+x_{j}\>, (15)

In each encoder block, there is H=dmodel/64=32H=d_{model}/64=32 heads. WjiQ,WjiK,WjiVW_{ji}^{Q},W_{ji}^{K},W_{ji}^{V} represent learnable query, key, and value parameters of ii-th attention head and jj-th encoder block, respectively, and maskmask is a lower triangular standard matrix.

xj={Ehybrid,j=0y,j10<j<nblock,x_{j}=\begin{cases}E_{hybrid},&j=0\\ y{{}^{\prime}}_{j-1},&0<j<n_{block},\end{cases} (16)

where xjx_{j} represents the hybrid embedding or the output of the previous encoder block. The feed-forward network (FFN) is composed of two dense layers and is connected to each attention module.

We employ the residual connection (Eq. 18), which is helpful to alleviate the gradient-vanishing problem.

Inter(Hj)=GeLU(HjWj1+bj1)Wj2+bj2,\mathrm{Inter}(H_{j}^{\prime})=\mathrm{GeLU}(H_{j}^{\prime}W_{j1}+b_{j1})W_{j2}+b_{j2}, (17)
yj=FFN(Hj)=LayerNorm(Inter(Hj))+Hj,y_{j}=\mathrm{FFN}(H_{j}^{\prime})=\mathrm{LayerNorm}(\mathrm{Inter}(H_{j}^{\prime}))+H_{j}^{\prime}, (18)

where Wj1,bj1,Wj2,bj2W_{j1},b_{j1},W_{j2},b_{j2} are both learnable parameters and GeLUGeLU is an activation function. Finally, the output of the last encoder block is inversely projected to the same shape of II{{}^{\prime}}.

y=ynblockWiT,y[1050,4].y^{\prime}=y_{n_{block}}W_{i}^{T},\ y^{\prime}\in\mathbb{R}^{[1050,4]}. (19)

Loss Function. Next-token-prediction error is adopted to train CBS-GPT, which means that the predicted token ymy{{}^{\prime}}_{m} is designed to match the input token (Im+1I{{}^{\prime}}_{m+1}) at position m+1m+1. Hence, only num1num-1 tokens are taken into account when calculating the training loss. Specifically, the mean squared error (MSE) loss is used to measure the difference between the predictions:

=1(num1)×4m=0num2t=03||ym,tI|m+1,t|2\mathcal{L}=\frac{1}{(num-1)\times 4}\sum_{m=0}^{num-2}\sum_{t=0}^{3}||y{{}^{\prime}}_{m,t}-I{{}^{\prime}}_{m+1,t}||^{2} (20)

II.3 Training and Inference

During training, the Adam [56] optimizer with β1=0.9\beta_{1}=0.9, β2=0.999\beta_{2}=0.999 is used, and the initial learning rate is 2e-4. There are 1.6 millions waveforms in the training dataset of each model, and the parameter of each waveform is randomly selected from its correponding parameter space. After passing through the LISA response, each waveform is divided into three TDI channels (A, E, and T). In this study, the E channel is selected to train the model. The model was trained on two NVIDIA V100 GPUs for approximately 30 hours. During inference, for each signal source, 10,000 waveforms are generated to test CBS-GPT’s performance. For each waveform, the initial input contains 1,000/50 valid tokens and 50 masked tokens that are masked with zero, whose corresponding value in the mask matrix also equals zero, which guarantees that no attention is paid to the to-be-extrapolated token. In the first step, the 1,001-st/51-st token is predicted and replaces the previous 1,001-st/51-st token, and so forth, 50 successive tokens are predicted based on 1,000/50 valid input tokens.

III Results and discussion

During inference, overlap is defined to evaluate the extrapolation accuracy of the predicted waveform. Overlap is calculated between the target waveform and the predicted waveform generated by CBS-GPT as stated in Eq. 21. The overlap 𝒪\mathcal{O} ranges between [0,1][0,1], with values closer to 1 indicating that the predicted waveform is more similar to the target waveform.

𝒪(ht,hp)=maxtc(h^t|h^p[tc])1/2,\mathcal{O}({h}_{t},{h}_{p})=\max_{t_{c}}\left(\hat{h}_{t}|\hat{h}_{p}[t_{c}]\right)^{1/2}, (21)

with

(h|s)\displaystyle(h|s) =2fminfmaxh~(f)s~(f)+h~(f)s~(f)Sn(f)𝑑f,\displaystyle=2\int_{f_{\min}}^{f_{\max}}\frac{\tilde{h}^{*}(f)\tilde{s}(f)+\tilde{h}(f)\tilde{s}^{*}(f)}{S_{n}(f)}df, (22)
h^\displaystyle\hat{h} =h(h|h)\displaystyle=\frac{h}{\sqrt{(h|h)}}

where tct_{c} represents time-shifted, and we set Sn(f)=1S_{n}(f)=1.

Overall, in the context of 20:1 extrapolation tasks targeting MBHB, GB, and EMRIs signals, CBS-GPT has demonstrated remarkable efficacy, with over 50% of the overlaps exceeding 0.99. Figure 3 and 4 showcase the prediction performance of each waveform under varying parameter conditions, revealing that the CBS-GPT model can learn waveform features with a wide range of parameters. Figure 5 and 6 demonstrate the generalization and potent interpretability of CBS-GPT.

III.1 Results of MBHB

The results of MBHB overlap are shown in Table 4(a). The CBS-GPT model is sensitive to total mass, mass ratio, and spin parameters. Here we use χeff\chi_{\text{eff}} to represent the spin parameter [57]:

χeff=S1z1+q+qS2z1+q.\chi_{\text{eff}}=\frac{S_{1}^{z}}{1+q}+\frac{qS_{2}^{z}}{1+q}. (23)

20:1 extrapolation. The overlap distribution and waveform examples are shown in Figure 3(a) and Figure 5(a), with mean and median overlaps equal 0.981 and 0.992, respectively. The overlap results reveal that CBS-GPT can forecast the waveform of the merge-ringdown phase based on the inspiral phase characteristics. CBS-GPT exhibits optimal inference performance when the total mass is approximately 106.5M10^{6.5}M_{\odot} as shown in Figure 3(b). This phenomenon has also been observed in other signal sources. The overlap is lower for for waveforms with low total mass and high effective spin χeff\chi_{\text{eff}}. Comparing low and high-mass situations to those involving intermediate masses, the performance of mid-frequency band prediction is the best. Since TDI 2.0 transfer functions in the high-frequency part are more complex [58, 59], the waveform is also more complex. Consequently, the model’s performance experiences a slight decrease. But even under such less ideal circumstances, CBS-GPT can still successfully recover a significant portion of the signals.

1:1 extrapolation. We find that the model pays little attention to the early-stage waveform and mainly concentrates on the late-stage inspiral waveform when forecasting the merging waveform of an MBHB (detailed explanation is in Section III.3). This demonstrates the marginal contribution of early-stage inspiral waveforms to subsequent waveforms generation. Hence we retrained a model, whose input only contains 200 points before merge time and predicted the subsequent 200 points, thus achieving a 1:1 extrapolation. The average and median overlap achieved 0.990 and 0.996, respectively. The results are slightly better than the previous 20:1 extrapolation, which validates our former conclusion. In Table 4(a), we observe a noticeable improvement in overlap for cases with masses greater than 106M10^{6}M_{\odot}, which illustrates that shorter input waveforms allow the model’s attention to be more focused, leading to improved inference performance. In Figure 5(b), we showcase the predictive performance of CBS-GPT in the 1:1 extrapolation scenario.

Generalization ability refers to the performance of a model when applied to data that has not seen before. To evaluate the generalization capability of CBS-GPT, we selected MBHB signals with mass ratios ranging from 1:10 to 1:100 in the 1:1 extrapolation model. Figure 5(g) showcases the waveform examples of generalization ability. The average overlap achieved 0.970, with more than half of the overlaps surpassing 0.993, which demonstrated the strong generalization ability of our method. The model’s performance on generalization experiment also illustrates its ability to learn the essence of the data.

Refer to caption
(a) MBHB overlap distribution.
Refer to caption
(b) MBHB: 20:1 extrapolation.
Refer to caption
(c) MBHB: 1:1 extrapolation
Refer to caption
(d) MBHB generalization: 1:1 extrapolation
Figure 3: The overlap distribution of MBHB is shown in (a). (b, c, d) portray the heat maps of MtotM_{tot} and χeff\chi_{\mathrm{eff}} parameters, which have the greatest impact on overlap. A darker color corresponds a higher overlap value.
Table 3: The overlap results
MBHB 20:1 1:1 generalization
i. mean 0.981 0.990 0.970
ii. median 0.992 0.996 0.993
iii. mass <106M<10^{6}M_{\odot} 0.980 0.979 0.938
iv. mass 106M\geq 10^{6}M_{\odot} 0.982 0.995 0.986
(a) The overlap results of MBHB. Group iii. and iv. correspond to the mean overlap values for Mtot<106MM_{tot}<10^{6}M_{\odot} and Mtot106MM_{tot}\geq 10^{6}M_{\odot} respectively.
EMRIs 20:1 1:1
i.mean 0.912 0.807
ii. median 0.997 0.910
iii. e0<0.1e_{0}<0.1 0.962 0.905
iv. e00.1e_{0}\geq 0.1 0.896 0.778
(b) The overlap results of EMRIs. Group iii. and iv. correspond to the mean overlap values for e0<0.1e_{0}<0.1 and e00.1e_{0}\geq 0.1 respectively.
GB 20:1 1:1
i. mean 0.991 0.992
ii. median 0.996 0.994
iii. f<103f<10^{-3} 0.987 0.990
iv. f103f\geq 10^{-3} 0.995 0.993
(c) The overlap results of GB. Group iii. and iv. correspond to the mean overlap values for f<103f<10^{-3} and f103f\geq 10^{-3} respectively.

III.2 Results of Continous Waveform: EMRIs and GB

The overlap distributions of the EMRIs and GB are shown in Figure 4 and their mean and median values are displayed in Table 4(b) and Table 4(c). Examples of predicted EMRIs and GB waveforms are shown in Figure 5(c)-5(f).

20:1 extrapolation. Regarding GB, its mean and median overlap both exceed 0.99. The mean and median overlap of EMRIs are equal to 0.912 and 0.997, respectively. While the mean overlap of EMRIs is slightly lower, its median overlap aligns with that observed in MBHB and GB waveforms.

Specifically, the overlap distribution of EMRIs significantly influenced by the mass parameters and eccentricity parameters. As depicted in Table LABEL:tab:emri_overlap, when e0e_{0} is less than 0.10.1, the majority of overlaps remain below 0.9. As the eccentricity increases, the waveform features become more complex in waveform amplitude. Therefore, when the eccentricity is higher, the corresponding overlap tends to decrease.

In contrast to MBHB and EMRIs signals, the GB signal presents a comparatively straightforward, single-frequency waveform. As for GB, the frequency parameter has the greatest impact on the waveform. When the frequency is larger than 103.510^{-3.5}Hz, the overlap is basically higher than 0.9. The result of GB signals demonstrates the model’s sensitivity over frequency, with the distinct preference for learning the characteristics associated with intermediate frequency signals.

1:1 extrapolation. In this scenario, the mean and median overlaps for EMRIs were found to be 0.807 and 0.910, while for GB, the mean and median overlaps were 0.992 and 0.994 respectively.

The performance impact was negligible for GB, but there was a significant decrease in EMRIs waveforms. This can be attributed to the larger eccentricity and wider range of scales exhibited by EMRIs, as well as their continuous periodic transitions. Due to the high complexity of EMRI waveforms, shorter waveforms fail to capture the waveform features. Therefore, in the case of complex waveforms, CBS-GPT requires longer input waveforms to learn more distinctive features.

Refer to caption
(a) EMRIs overlap distribution.
Refer to caption
(b) EMRIs: 20:1 extrapolation.
Refer to caption
(c) EMRIs: 1:1 extrapolation
Refer to caption
(d) GB overlap distribution.
Refer to caption
(e) GB: 20:1 extrapolation.
Refer to caption
(f) GB: 1:1 extrapolation
Figure 4: The overlap distributions of EMRIs and GB are shown in (a, d). (b, c) portray the heat maps of e0e_{0} and MM parameters, which have the greatest impact on overlap of EMRIs. Similarly, (e, f) portray the heat maps of frequency parameter ff, which have the greatest impact on overlap of GB. A darker color corresponds a higher overlap value.
Refer to caption
(a) MBHB 20:1 extrapolation
Refer to caption
(b) MBHB 1:1 extrapolation
Refer to caption
(c) EMRIs 20:1 extrapolation
Refer to caption
(d) EMRIs 1:1 extrapolation
Refer to caption
(e) GB 20:1 extrapolation
Refer to caption
(f) GB 1:1 extrapolation
Refer to caption
(g) MBHB Generalization 1:1 extrapolation.
Figure 5: CBS-GPT prediction results. (a, b) MBHB results. (c, d) EMRIs results. (e, f) GB results. (g) Generalization results of MBHB waveform with 1/q1/q\approx 10, 40, 70, and 100, respectively. We set the predicted starting point at time zero. The blue line represents the conjunction of the last part of the input waveform and target label, the orange line is the predicted waveform, and the gray line is the difference between the predicted and target waveform. The inset figure in each subfigures represents the anticipated and target waveforms in the frequency domain, as well as the differences between them.

III.3 Interpretability

The attention map (Eq. 24) allows us to understand the extrapolation process and attention mechanism while forecasting waveforms, making it easier to gain insight into how CBS-GPT interpret GW data.

A=1Hi=1Hsoftmax(QjiKjiTmaskd),A=\frac{1}{H}\sum_{i=1}^{H}\text{softmax}(\frac{Q_{ji}K_{ji}^{T}\cdot\mathrm{mask}}{\sqrt{d}})\>, (24)

where HH represents all attention heads of the last encoder block. In Figure 6, the vertical axis represents the model input waveform, and the horizontal axis represents the predicted waveform.

When predicting continuous gravitational waveforms (EMRIs and GB), the attention maps (Figure 6(d) - 6(i)) exhibit grid-like patterns that are closely related to the phase of the waveforms, with the scale of the grid expanding as the frequency decreases. In order to measure the similarity between the attention map and the input waveform, we introduce the correlation coefficient (with details described in Appendix A). Overall, the average correlation coefficient of continuous waveform exceeds 0.8, which demonstrates that the model can accurately match the waveform’s frequency and phase information. This mode assists CBS-GPT in successfully extrapolating waveforms.

As showcased in Figure 6(a) - 6(c), during the prediction of the merge-ringdown phase of MBHB waveforms, attention primarily focuses on near-diagonal elements. In contrast to continuous GW signals, the amplitude of MBHB reaches zero after the merge-ringdown, and the main focus of attention mechanism lies in the merging stage and the stage after the merge, with relatively less attention payed to inspiral phase.

Refer to caption
(a) MBHB 20:1 extrapolation, Mtot=105.5MM_{tot}=10^{5.5}M_{\odot}
Refer to caption
(b) MBHB 20:1 extrapolation, Mtot=106.25MM_{tot}=10^{6.25}M_{\odot}
Refer to caption
(c) MBHB 1:1 extrapolation, Mtot=106.25MM_{tot}=10^{6.25}M_{\odot}
Refer to caption
(d) EMRIs 20:1 extrapolation: e0=0.01e_{0}=0.01
Refer to caption
(e) EMRIs 20:1 extrapolation: e0=0.3e_{0}=0.3
Refer to caption
(f) EMRIs 1:1 extrapolation: e0=0.3e_{0}=0.3
Refer to caption
(g) GB 20:1 extrapolation: f=103f=10^{-3}
Refer to caption
(h) GB 20:1 extrapolation: f=103.5f=10^{-3.5}Hz
Refer to caption
(i) GB 1:1 extrapolation: f=103.52f=10^{-3.52}
Figure 6: Attention maps of the last encoder layer. For a clear presentation, only part of the attention map is displayed. The blue lines on the left and bottom panels represent the input waveforms, whose 1,001-st (or 101-st) to 1,050-th (or 50-th) tokens are padded with zero value during inference, and the orange line represents the waveform predicted by CBS-GPT. The term ’Similarity’ in the title of each figure denotes the correlation coefficient between the waveform and the attention map.

III.4 Potential Applications

Complex waveforms generation. Currently, waveform generation for high mass ratio binary black holes remains a challenging problem because of high computational cost. Our approach can partially alleviate this problem since CBS-GPT that trained on low mass ratio waveforms with relatively low computational cost can be applied to high mass ratio waveform generation. This generalization characteristic, as shown in Figure 5(g), demonstrates that the model can learn intrinsic features and can be applied to waveform extrapolation of a broader parameter space. By incorporating simulations based on numerical relativity, we may build a waveform template bank by extrapolating more complex and computation-intensive waveforms. For burst wave sources such as MBHB, waveform generation time of CBS-GPT for a single waveform is less than 100ms on a single NVIDIA V100 GPU. With the rapid development of GPU computing power, CBS-GPT presents the potential for high-speed template waveform generation.

Gap imputation. In space-based GW detectors, the presence of data gaps due to data transmission, satellite attitude adjustments, and unidentified glitches can significantly impact the precision of waveform parameter estimation. Our waveform extrapolation method is promising to accomplish the task of waveform imputation, and by integrating with successive denoising models [33, 37, 34, 35, 36, 37, 38, 60], parameter estimation accuracy can be further enhanced [61].

Model Design Guidance. We established a more convenient method for visualizing and quantifying attention maps, offering guidance for transformer-based models design in the GW research realm. Our results also demonstrate that attention mechanism can be leveraged to establish more robust deep learning models that are specifically tailored for GW astronomy.

IV Conclusion

In this paper, we introduce the CBS-GPT model, consisting of hybrid embedding and encoder blocks. The CBS-GPT is applied to predict GW waveforms after the TDI 2.0 response. We investigated two scenarios of different extrapolation ratios between input and predicted waveform length. Different models are trained for MBHB, EMRIs, and GB. In the 20:1 and 1:1 extrapolation scenarios, the average overlaps between the predicted waveform and the target waveform of MBHB, EMRIs, and GB reach 0.981, 0.912, 0.991, and 0.990, 0.807, 0.991, respectively. EMRIs exhibited poorer performance in the 1:1 extrapolation due to their complex waveform patterns and rich amplitude variations caused by eccentricity. We also proved the strong generalization of CBS-GPT on MBHB waveforms.

Moreover, we introduced a correlation coefficient and found that the correlation between hidden parameters of CBS-GPT and waveform was relatively high, which indicated that the model could learn waveform’s phase information extremely well. Overall, our results show that CBS-GPT has the ability to comprehend detailed waveform properties and make predictions over varied frequencies. We are confident that in the future, large AI models such as CBS-GPT can be applied to GW data processing tasks including complex waveforms generation and gap imputation.

Acknowledgements.
This research was supported by the Peng Cheng Laboratory and Peng Cheng Cloud-Brain. This work was also supported in part by the National Key Research and Development Program of China Grant No. 2021YFC2203001 and in part by the NSFC (No. 11920101003 and No. 12021003). Z.C. was supported by the “Interdisciplinary Research Funds of Beijing Normal University" and CAS Project for Young Scientists in Basic Research YSBR-006.

Appendix A Correlation coefficient between waveform and hidden parameters

To evaluate the correlation between the attention map’s grid-like pattern and the waveform, we introduce the correlation coefficient between the waveform and hidden parameters (or attention map). This coefficient assesses the level of correlation and demonstrates the attention map’s ability to capture phase information. Firstly, we compute the mean value of each token of the patched waveform to get the sequence MM. Subsequently, the outer product of MM is computed, resulting in the auto-correlation matrix. As the attention map AA (Eq. 24) is processed by masking and normalization, we do a similar adjustment to the auto-correlation matrix:

Rmask=Mask(MMmin(MM)),\displaystyle R_{\text{mask}}=\text{Mask}(M\otimes M-\text{min}(M\otimes M)), (25)
RNorm=RowNorm(Rmask),\displaystyle R_{\text{Norm}}=\text{RowNorm}\left(R_{\text{mask}}\right),

where RowNorm()\text{RowNorm}(\cdot) denotes the normalization of each row of the matrix and Mask()\text{Mask}(\cdot) is consistent with the mask method of Section II.3.

To assess the correlation between the two matrices, we calculate the Pearson correlation coefficient between the flattened attention map AA and flattened RNormR_{\text{Norm}}:

ρA,RNorm\displaystyle\rho_{A,R_{\text{Norm}}} =ρ{Flatten(A),Flatten(RNorm)}\displaystyle=\rho\left\{\text{Flatten}(A),\text{Flatten}(R_{\text{Norm}})\right\} (26)
=n1nAFiRFi1nAFi1nRFin1nAFi2(1nAFi)2n1nRFi2(1nRFi)2\displaystyle=\frac{n\sum_{1}^{n}{A_{F}}_{i}{R_{F}}_{i}-\sum_{1}^{n}{A_{F}}_{i}\sum_{1}^{n}{R_{F}}_{i}}{\sqrt{n\sum_{1}^{n}{A_{F}}_{i}^{2}-(\sum_{1}^{n}{A_{F}}_{i})^{2}}\sqrt{n\sum_{1}^{n}{R_{F}}_{i}^{2}-(\sum_{1}^{n}{R_{F}}_{i})^{2}}}

where Flatten()\text{Flatten}(\cdot) denotes flattening the matrix into one dimension, AFA_{F} and RFR_{F} represent the flattened vector of AA and RNormR_{\text{Norm}} respectively, and nn represents the length after flattening. Finally, ρA,RNorm\rho_{A,R_{\text{Norm}}} is defined as the correlation coefficient between waveform and hidden parameters.

References