{ "cells": [ { "cell_type": "markdown", "id": "08b18d01-1d05-459a-b463-b748618268db", "metadata": {}, "source": [ "# Building a waveform dataset\n", "\n", "For training neural networks, the more training samples the better. With too little training data, one runs the risk of overfitting. Waveforms, however, can be expensive to generate and take up significant storage. Dingo adopts several strategies to mitigate these problems:\n", "\n", "* Dingo partitions parameters into two types---intrinsic and extrinsic---and builds a training set based only on the intrinsic parameters. This consists of waveform polarizations $h_+$ and $h_\\times$. Extrinsic parameters are selected during training, and applied to generate the detector waveforms $h_I$. This augments the training set to provide unlimited samples from the extrinsic parameters.\n", "\n", "* Saved waveforms are compressed using a singular value decomposition. Although this is lossy, waveform mismatches can monitored to ensure that they fall below the intrinsic error in the waveform model. \n", "\n", "\n" ] }, { "cell_type": "markdown", "id": "697a2506-0b8c-40ee-8036-2e2911b22e08", "metadata": {}, "source": [ "## The `WaveformDataset` class\n", "\n", "The `WaveformDataset` is a storage container for waveform polarizations and parameters, which can used to serve samples to a neural network during training:\n", "\n", "```{eval-rst}\n", ".. autoclass:: dingo.gw.dataset.WaveformDataset\n", " :members:\n", " :inherited-members:\n", " :show-inheritance:\n", "```\n", "\n", "`WaveformDataset` subclasses `dingo.core.dataset.DingoDataset` and `torch.utils.data.Dataset`. The former provides generic functionality for saving and loading datasets as HDF5 files and dictionaries, and is used in several components of Dingo. The latter allows the `WaveformDataset` to be used with a PyTorch `DataLoader`. In general, we follow the PyTorch design framework for training, including [Datasets, DataLoaders,](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) and [Transforms](https://pytorch.org/tutorials/beginner/basics/transforms_tutorial.html)." ] }, { "cell_type": "markdown", "id": "e16c9039-fbef-4996-8eac-96cc8f42dd52", "metadata": { "tags": [] }, "source": [ "## Generating a simple dataset\n", "\n", "As described above, the `WaveformDataset` class is just a container, and does not generate the contents itself. Dataset generation is instead carried out using functions in the `dingo.gw.dataset.generate_dataset` module. Although in practice, datasets are likely to be generated from a settings file using the command line interface, here we describe how to generate one interactively.\n", "\n", "A dataset is based on an intrinsic prior and a waveform generator, so we build these as described [here](generating_waveforms.ipynb)." ] }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-06T12:26:13.184578Z", "start_time": "2025-03-06T12:26:13.100578Z" } }, "cell_type": "code", "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\", \"Wswiglal-redir-stdio\")\n", "import lal" ], "id": "c05d01b6b5c311a2", "outputs": [], "execution_count": 1 }, { "cell_type": "code", "id": "a40003ad-e6c4-4def-8935-dedb3e448935", "metadata": { "ExecuteTime": { "end_time": "2025-03-06T12:26:14.282156Z", "start_time": "2025-03-06T12:26:13.197947Z" } }, "source": [ "from dingo.gw.waveform_generator import WaveformGenerator\n", "from bilby.core.prior import PriorDict\n", "from dingo.gw.prior import default_intrinsic_dict\n", "from dingo.gw.domains import FrequencyDomain\n", "\n", "domain = FrequencyDomain(f_min=20.0, f_max=1024.0, delta_f=0.125)\n", "wfg = WaveformGenerator(approximant='IMRPhenomXPHM', domain=domain, f_ref=20.0)\n", "prior = PriorDict(default_intrinsic_dict)" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Setting spin_conversion_phase = None. Using phase parameter for conversion to cartesian spins.\n" ] } ], "execution_count": 2 }, { "cell_type": "markdown", "id": "99729561-9b18-4be7-8a67-3b6d59a69858", "metadata": {}, "source": [ "We can use the following function to generate sets of parameters and associated waveforms:" ] }, { "cell_type": "code", "id": "6ea82e56-dcff-4e84-8105-27fcee0c7566", "metadata": { "ExecuteTime": { "end_time": "2025-03-06T12:26:15.636350Z", "start_time": "2025-03-06T12:26:14.323162Z" } }, "source": [ "from dingo.gw.dataset.generate_dataset import generate_parameters_and_polarizations\n", "\n", "parameters, polarizations = generate_parameters_and_polarizations(wfg,\n", " prior,\n", " num_samples=100,\n", " num_processes=1)" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating dataset of size 100\n" ] } ], "execution_count": 3 }, { "cell_type": "code", "id": "93896ee8-100c-44aa-88d2-7ffab93ac1c3", "metadata": { "ExecuteTime": { "end_time": "2025-03-06T12:26:15.651837Z", "start_time": "2025-03-06T12:26:15.644358Z" } }, "source": [ "parameters" ], "outputs": [ { "data": { "text/plain": [ " mass_ratio chirp_mass luminosity_distance theta_jn phase a_1 \\\n", "0 0.218187 73.845050 1000.0 1.255204 1.966362 0.197980 \n", "1 0.381173 87.704762 1000.0 2.033628 3.888862 0.460440 \n", "2 0.510406 93.479307 1000.0 1.859908 3.469898 0.023533 \n", "3 0.678305 92.145038 1000.0 0.758713 2.841377 0.172021 \n", "4 0.624489 33.540545 1000.0 1.582852 1.577590 0.413280 \n", ".. ... ... ... ... ... ... \n", "95 0.540129 87.451546 1000.0 2.696406 5.270380 0.201667 \n", "96 0.803457 66.013454 1000.0 0.379665 0.175340 0.437341 \n", "97 0.861454 75.908534 1000.0 1.805871 1.334242 0.505140 \n", "98 0.380818 45.702456 1000.0 1.684684 3.820672 0.092019 \n", "99 0.941143 69.169888 1000.0 2.045144 0.209135 0.925224 \n", "\n", " a_2 tilt_1 tilt_2 phi_12 phi_jl geocent_time \n", "0 0.240156 1.972606 1.376228 2.186446 4.752777 0.0 \n", "1 0.692240 1.754236 0.661015 0.790942 5.066653 0.0 \n", "2 0.296818 2.552577 0.359922 2.138755 3.489143 0.0 \n", "3 0.934613 0.359660 2.157047 3.599841 0.860001 0.0 \n", "4 0.964930 1.929234 2.084173 1.543995 5.298489 0.0 \n", ".. ... ... ... ... ... ... \n", "95 0.187635 0.447384 1.944557 0.052446 0.952740 0.0 \n", "96 0.730075 1.475004 2.752046 5.595977 2.047529 0.0 \n", "97 0.566819 0.965326 0.194196 0.807147 2.357237 0.0 \n", "98 0.228797 1.478859 1.849281 5.860794 0.562862 0.0 \n", "99 0.975578 1.644663 1.359320 3.098630 4.976837 0.0 \n", "\n", "[100 rows x 12 columns]" ], "text/html": [ "
| \n", " | mass_ratio | \n", "chirp_mass | \n", "luminosity_distance | \n", "theta_jn | \n", "phase | \n", "a_1 | \n", "a_2 | \n", "tilt_1 | \n", "tilt_2 | \n", "phi_12 | \n", "phi_jl | \n", "geocent_time | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "0.218187 | \n", "73.845050 | \n", "1000.0 | \n", "1.255204 | \n", "1.966362 | \n", "0.197980 | \n", "0.240156 | \n", "1.972606 | \n", "1.376228 | \n", "2.186446 | \n", "4.752777 | \n", "0.0 | \n", "
| 1 | \n", "0.381173 | \n", "87.704762 | \n", "1000.0 | \n", "2.033628 | \n", "3.888862 | \n", "0.460440 | \n", "0.692240 | \n", "1.754236 | \n", "0.661015 | \n", "0.790942 | \n", "5.066653 | \n", "0.0 | \n", "
| 2 | \n", "0.510406 | \n", "93.479307 | \n", "1000.0 | \n", "1.859908 | \n", "3.469898 | \n", "0.023533 | \n", "0.296818 | \n", "2.552577 | \n", "0.359922 | \n", "2.138755 | \n", "3.489143 | \n", "0.0 | \n", "
| 3 | \n", "0.678305 | \n", "92.145038 | \n", "1000.0 | \n", "0.758713 | \n", "2.841377 | \n", "0.172021 | \n", "0.934613 | \n", "0.359660 | \n", "2.157047 | \n", "3.599841 | \n", "0.860001 | \n", "0.0 | \n", "
| 4 | \n", "0.624489 | \n", "33.540545 | \n", "1000.0 | \n", "1.582852 | \n", "1.577590 | \n", "0.413280 | \n", "0.964930 | \n", "1.929234 | \n", "2.084173 | \n", "1.543995 | \n", "5.298489 | \n", "0.0 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 95 | \n", "0.540129 | \n", "87.451546 | \n", "1000.0 | \n", "2.696406 | \n", "5.270380 | \n", "0.201667 | \n", "0.187635 | \n", "0.447384 | \n", "1.944557 | \n", "0.052446 | \n", "0.952740 | \n", "0.0 | \n", "
| 96 | \n", "0.803457 | \n", "66.013454 | \n", "1000.0 | \n", "0.379665 | \n", "0.175340 | \n", "0.437341 | \n", "0.730075 | \n", "1.475004 | \n", "2.752046 | \n", "5.595977 | \n", "2.047529 | \n", "0.0 | \n", "
| 97 | \n", "0.861454 | \n", "75.908534 | \n", "1000.0 | \n", "1.805871 | \n", "1.334242 | \n", "0.505140 | \n", "0.566819 | \n", "0.965326 | \n", "0.194196 | \n", "0.807147 | \n", "2.357237 | \n", "0.0 | \n", "
| 98 | \n", "0.380818 | \n", "45.702456 | \n", "1000.0 | \n", "1.684684 | \n", "3.820672 | \n", "0.092019 | \n", "0.228797 | \n", "1.478859 | \n", "1.849281 | \n", "5.860794 | \n", "0.562862 | \n", "0.0 | \n", "
| 99 | \n", "0.941143 | \n", "69.169888 | \n", "1000.0 | \n", "2.045144 | \n", "0.209135 | \n", "0.925224 | \n", "0.975578 | \n", "1.644663 | \n", "1.359320 | \n", "3.098630 | \n", "4.976837 | \n", "0.0 | \n", "
100 rows × 12 columns
\n", "