Data loaders

In PiNN, the dataset is represented with the TensorFlow Dataset class. Several dataset loaders are implemented in PiNN to load data from common formats. Starting from v1.0, PiNN provides a canonical data loader pinn.io.load_ds that handles dataset with different formats, see below for the API documentation and available datasets.

TFRecord

The tfrecord format is a serialized format for efficient data reading in TensorFlow. PiNN can save datasets in the TFRecord dataset. When PiNN writes the dataset, it creates a .yml file records the data structure of the dataset, and a .tfr file holds the data. For example:

from glob import glob
from pinn.io import load_ds, write_tfrecord
from pinn.io import write_tfrecord
filelist = glob('/home/yunqi/datasets/QM9/dsgdb9nsd/*.xyz')
dataset = load_ds(filelist, fmt='qm9', splits={'train':8, 'test':2})['train']
write_tfrecord('train.yml', train_set)
train_ds = load_ds('train.yml')

We advise you to convert your dataset into the TFRecord format for training. The advantage of using this format is that it allows for the storage of preprocessed data and batched dataset.

Splitting the dataset

It is a common practice to split the dataset into subsets for validation in machine learning tasks. PiNN dataset loaders support a split option to do this. The split can be a dictionary specifying the subsets and their relative ratios. The dataset loader will return a dictionary of datasets with corresponding ratios. For example:

from pinn.io import load_ds
dataset = load_ds(files, fmt='qm9', splits={'train':8, 'test':2})
train = dataset['train']
test = dataset['test']

Here train and test will be tf.data.Dataset objects which to be consumed by our models. The loaders also accepts a seed parameter for the split to be reproducible, and the default seed is 0.

Batching the dataset

Most TensorFlow operations (caching, repeating, shuffling) can be directly applied to the dataset. However, to handle datasets with different numbers of atoms in each structure, which is often the case, we use a special sparse_batch operation to create mini-batches of the data in a sparse form. For example:

from pinn.io import sparse_batch
dataset = load_ds(fileanme)
batched = dataset.apply(sparse_batch(100))

Custom format

To be able to shuffle and split the dataset, PiNN require the dataset to be represented as a list of data. In the simplest case, the dataset could be a list of structure files, each contains one structure and label (or a sample). PiNN provides a list_loader decorator which turns a function reading a single sample into a function that transform a list of samples into a dataset. For example:

from pinn.io import list_loader

@list_loader()
def load_file_list(filename):
    # read a single file here
    coord = ...
    elems = ...
    e_data = ...
    datum = {'coord': coord, 'elems':elems, 'e_data': e_data}
    return datum

An example notebook on preparing a custom dataset is here.

Available formats

Format Loader Description
tfr load_tfrecord See TFRecord
runner load_runner Loader for datasets in the RuNNer foramt
ase load_ase Load the files with the ase.io.iead function
qm9 load_qm9 A xyz-like file format used in the QM91 dataset
ani load_ani HD5-based format used in the ANI-12 dataset
cp2k load_cp2k Loader for CP2K output (experimental)
deepmd-kit load_deepmd Loader for deepmk-kit input

API documentation

pinn.io.load_ds

This loader tries to guess the format when dataset is a string:

  • load_tfrecoard if it ends with '.yml'
  • load_runner if it ends with '.data'
  • try to load it with load_ase

If the fmt is specified, the loader will use a corresponsing dataset loader.

Parameters:

Name Type Description Default
dataset Dataset

dataset a file or input for a loader according to fmt

required
fmt str

dataset format, see avialable formats.

'auto'
splits dict

key-val pairs specifying the ratio of subsets

None
shuffle bool

shuffle the dataset (only used when splitting)

True
seed int

random seed for shuffling

0
**kwargs dict

extra arguments to loaders

{}
Source code in pinn/io/__init__.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def load_ds(dataset, fmt='auto', splits=None, shuffle=True, seed=0, **kwargs):
    """This loader tries to guess the format when dataset is a string:

    - `load_tfrecoard` if it ends with '.yml'
    - `load_runner` if it ends with '.data'
    - try to load it with `load_ase`

    If the `fmt` is specified, the loader will use a corresponsing dataset loader.

    Args:
        dataset (Dataset): dataset a file or input for a loader according to `fmt`
        fmt (str): dataset format, see avialable formats.
        splits (dict): key-val pairs specifying the ratio of subsets
        shuffle (bool): shuffle the dataset (only used when splitting)
        seed (int): random seed for shuffling
        **kwargs (dict): extra arguments to loaders
    """
    loaders = {'tfr':    load_tfrecord,
               'runner': load_runner,
               'ase':    load_ase,
               'qm9':    load_qm9,
               'ani':    load_ani,
               'cp2k':   load_cp2k}
    if fmt=='auto':
        if dataset.endswith('.yml'):
            return load_tfrecord(dataset, splits=splits, shuffle=shuffle, seed=seed)
        if dataset.endswith('.data'):
            return load_runner(dataset, splits=splits, shuffle=shuffle, seed=seed)
        else:
            return load_ase(dataset, splits=splits, shuffle=shuffle, seed=seed)
    else:
        return loaders[fmt](dataset, splits=splits, shuffle=shuffle, seed=seed, **kwargs)

pinn.io.load_tfrecord

Load tfrecord dataset.

Note that the splits given by load_tfrecord should be the same as other dataset loaders. However, the sequence is not guaranteed with shuffle=True.

Parameters:

Name Type Description Default
dataset str

filename of the .yml metadata file to be loaded.

required
splits dict

key-val pairs specifying the ratio of subsets

None
shuffle bool

shuffle the dataset (only used when splitting)

True
seed int

random seed for shuffling

0
Source code in pinn/io/tfr.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def load_tfrecord(dataset, splits=None, shuffle=True, seed=0):
    """Load tfrecord dataset.

    Note that the splits given by `load_tfrecord` should be the same
    as other dataset loaders. However, the sequence is not guaranteed
    with `shuffle=True`.

    Args:
       dataset (str): filename of the .yml metadata file to be loaded.
       splits (dict): key-val pairs specifying the ratio of subsets
       shuffle (bool): shuffle the dataset (only used when splitting)
       seed (int): random seed for shuffling

    """
    import sys, yaml
    import numpy as np
    import tensorflow as tf
    from pinn.io.base import split_list
    from tensorflow.python.lib.io.file_io import FileIO
    # dataset
    with FileIO(dataset, 'r') as f:
        ds_spec = yaml.safe_load(f)
        format_dict = ds_spec['format']

    dtypes = {k: format_dict[k]['dtype'] for k in format_dict.keys()}
    shapes = {k: format_dict[k]['shape'] for k in format_dict.keys()}
    feature_dict = {k: tf.io.FixedLenFeature([], tf.string) for k in dtypes}

    def parser(example):
        return tf.io.parse_single_example(example, feature_dict)
    def converter(tensors):
        tensors = {k: tf.io.parse_tensor(v, dtypes[k])
                   for k, v in tensors.items()}
        [v.set_shape(shapes[k]) for k, v in tensors.items()]
        return tensors
    tfr = '.'.join(dataset.split('.')[:-1]+['tfr'])
    dataset = tf.data.TFRecordDataset(tfr).map(parser).map(converter)
    # tfr splitter
    if splits is None:
        return dataset
    else:
        n_sample = ds_spec['info']['n_sample']
        splits = split_list(np.int64(list(range(n_sample))),
                            splits=splits, shuffle=shuffle, seed=seed)
        splitted = {k: tf.data.Dataset.zip((
            dataset, tf.data.Dataset.range(n_sample))).filter(
                lambda d, i: tf.reduce_any(tf.equal(v,i))).map(
                    lambda d, i: d)
                    for k,v in splits.items()}
        if shuffle:
            splitted = {k:v.shuffle(len(splits[k])) for k,v in splitted.items()}
        return splitted

pinn.io.load_ase

Loads a ASE trajectory

Parameters:

Name Type Description Default
dataset str | trajectory

a filename or trajectory

required
splits dict

key-val pairs specifying the ratio of subsets

None
shuffle bool

shuffle the dataset (only used when splitting)

True
seed int

random seed for shuffling

0
Source code in pinn/io/ase.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def load_ase(dataset, splits=None, shuffle=True, seed=0):
    """
    Loads a ASE trajectory

    Args:
        dataset (str|ase.io.trajectory): a filename or trajectory
        splits (dict): key-val pairs specifying the ratio of subsets
        shuffle (bool): shuffle the dataset (only used when splitting)
        seed (int): random seed for shuffling
    """
    from ase.io import read

    if isinstance(dataset, str):
        dataset = read(dataset, index=':')

    ds_spec = _ase_spec(dataset[0])
    @list_loader(ds_spec=ds_spec)
    def _ase_loader(atoms):
        datum = {
            'elems': atoms.numbers,
            'coord': atoms.positions,
        }
        if 'cell' in ds_spec:
            datum['cell'] = atoms.cell[:]

        if 'e_data' in ds_spec:
            datum['e_data'] = atoms.get_potential_energy()

        if 'f_data' in ds_spec:
            datum['f_data'] = atoms.get_forces()

        if 'q_data' in ds_spec:
            datum['q_data'] = atoms.get_charges()

        if 'd_data' in ds_spec:
            datum['d_data'] = atoms.get_dipole_moment()
        return datum

    return _ase_loader(dataset, splits=splits, shuffle=shuffle, seed=seed)

pinn.io.load_runner

Loads runner formatted trajectory Args: flist (str): one or a list of runner formatted trajectory(s) splits (dict): key-val pairs specifying the ratio of subsets shuffle (bool): shuffle the dataset (only used when splitting) seed (int): random seed for shuffling

Source code in pinn/io/runner.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def load_runner(flist, splits=None, shuffle=True, seed=0):
    """
    Loads runner formatted trajectory
    Args:
        flist (str): one or a list of runner formatted trajectory(s)
        splits (dict): key-val pairs specifying the ratio of subsets
        shuffle (bool): shuffle the dataset (only used when splitting)
        seed (int): random seed for shuffling
    """
    if isinstance(flist, str):
        flist = [flist]
    frame_list = []
    for fname in flist:
        frame_list += _gen_frame_list(fname)
    return _frame_loader(frame_list, splits=splits, shuffle=shuffle, seed=seed)

pinn.io.load_qm9

Loads the QM9 dataset

QM9 provides a variety of labels, but typically we are only training on one target, e.g. U0. A label_map option is offered to choose the output dataset structure, by default, it only takes "U0" and maps that to "e_data", i.e. label_map={'e_data': 'U0'}.

Other available labels are

['tag', 'index', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo',
 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']

Desciptions of those tags can be found in QM9's description file.

Parameters:

Name Type Description Default
flist list

list of QM9-formatted data files.

required
splits dict

key-val pairs specifying the ratio of subsets

None
shuffle bool

shuffle the dataset (only used when splitting)

True
seed int

random seed for shuffling

0
label_map dict

dictionary mapping labels to output datasets

{'e_data': 'U0'}
Source code in pinn/io/qm9.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def load_qm9(flist, label_map={'e_data': 'U0'}, splits=None, shuffle=True, seed=0):
    """Loads the QM9 dataset

    QM9 provides a variety of labels, but typically we are only
    training on one target, e.g. U0. A ``label_map`` option is
    offered to choose the output dataset structure, by default, it
    only takes "U0" and maps that to "e_data",
    i.e. `label_map={'e_data': 'U0'}`.

    Other available labels are

    ```python
    ['tag', 'index', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo',
     'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']
    ```

    Desciptions of those tags can be found in QM9's description file.

    Args:
        flist (list): list of QM9-formatted data files.
        splits (dict): key-val pairs specifying the ratio of subsets
        shuffle (bool): shuffle the dataset (only used when splitting)
        seed (int): random seed for shuffling
        label_map (dict): dictionary mapping labels to output datasets
    """
    import numpy as np
    import tensorflow as tf
    from pinn.io.base import list_loader
    from ase.data import atomic_numbers

    _labels = ['tag', 'index', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap',
               'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']
    _label_ind = {k: i for i, k in enumerate(_labels)}

    @list_loader(ds_spec=_qm9_spec(label_map))
    def _qm9_loader(fname):
        with open(fname) as f:
            lines = f.readlines()
        elems = [atomic_numbers[l.split()[0]] for l in lines[2:-3]]
        coord = [[i.replace('*^', 'E') for i in l.split()[1:4]]
                 for l in lines[2:-3]]
        elems = np.array(elems, np.int32)
        coord = np.array(coord, float)
        data = {'elems': elems, 'coord': coord}
        for k, v in label_map.items():
            data[k] = float(lines[1].split()[_label_ind[v]])
        return data
    return _qm9_loader(flist, splits=splits, shuffle=shuffle, seed=seed)

pinn.io.load_ani

Loads the ANI-1 dataset

Parameters:

Name Type Description Default
filelist list

filenames of ANI-1 h5 files.

required
split dict

key-val pairs specifying the ratio of subsets

False
shuffle bool

shuffle the dataset (only used when splitting)

True
seed int

random seed for shuffling

0
cycle_length int

number of parallel threads to read h5 file

4
Source code in pinn/io/ani.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def load_ani(filelist, split=False, shuffle=True, seed=0, cycle_length=4):
    """Loads the ANI-1 dataset

    Args:
        filelist (list): filenames of ANI-1 h5 files.
        split (dict): key-val pairs specifying the ratio of subsets
        shuffle (bool): shuffle the dataset (only used when splitting)
        seed (int): random seed for shuffling
        cycle_length (int): number of parallel threads to read h5 file
    """
    import h5py
    import numpy as np
    import tensorflow as tf
    from pinn.io.base import split_list
    ds_spec = {
        'elems': {'dtype':  tf.int32,   'shape': [None]},
        'coord': {'dtype':  tf.keras.backend.floatx(), 'shape': [None, 3]},
        'e_data': {'dtype': tf.keras.backend.floatx(), 'shape': []}}
    # Load the list of samples
    sample_list = []
    for fname in filelist:
        store = h5py.File(fname)
        k1 = list(store.keys())[0]
        samples = store[k1]
        for k2 in samples.keys():
            sample_list.append((fname, '{}/{}'.format(k1, k2)))
    # Generate dataset from sample list

    def generator_fn(samplelist): return tf.data.Dataset.from_generator(
            lambda: _ani_generator(samplelist), output_signature=ds_spec).interleave(
            lambda x: tf.data.Dataset.from_tensor_slices(x),
            cycle_length=cycle_length)
    # Generate nested dataset
    subsets = split_list(sample_list, split=split, shuffle=shuffle, seed=0)
    splitted = map_nested(generator_fn, subsets)
    return splitted

pinn.io.load_cp2k

This is a experimental loader for CP2K data

It takes data from different sources, the CP2K output file and dat files, which will be specified in the files dictionary. A list of "keys" is used to specify the data to read and where it is read from.

key data source provides
force files['out'] f_data
energy files['out'] e_data
stress files['out'] coord, elems
cell_dat files['cell_dat'] cell

Parameters:

Name Type Description Default
files dict

input files

required
keys list

data to read

required
splits dict

key-val pairs specifying the ratio of subsets

None
shuffle bool

shuffle the dataset (only used when splitting)

True
seed int

random seed for shuffling

0
Source code in pinn/io/cp2k.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def load_cp2k(files, keys, splits=None, shuffle=True, seed=0):
    """This is a experimental loader for CP2K data

    It takes data from different sources, the CP2K output file and dat files,
    which will be specified in the files dictionary. A list of "keys" is used to
    specify the data to read and where it is read from.

    | key        | data source         | provides         |
    |------------|---------------------|------------------|
    | `force`    | `files['out']`      | `f_data`         |
    | `energy`   | `files['out']`      | `e_data`         |
    | `stress`   | `files['out']`      | `coord`, `elems` |
    | `cell_dat` | `files['cell_dat']` | `cell`           |

    Args:
        files (dict): input files
        keys (list): data to read
        splits (dict): key-val pairs specifying the ratio of subsets
        shuffle (bool): shuffle the dataset (only used when splitting)
        seed (int): random seed for shuffling
    """
    from pinn.io import list_loader
    ds_spec = {}
    for key in keys:
        for name in provides[key]:
            ds_spec.update({name:formats[name]})

    all_list = _gen_list(files, keys)

    @list_loader(ds_spec=ds_spec)
    def _frame_loader(i):
        results = {}
        for k,v in all_list.items():
            results.update(loaders[k](v[i]))
        return results

    return _frame_loader(list(range(len(all_list['coord']))),
                         splits=splits, shuffle=shuffle, seed=seed)

pinn.io.load_deepmd

This is loader for deepmd input data. It takes a dict of key and file path or a directory path which contains the data files. If type_map is provided, it will be used to convert the type id to atomic numbers.

key data source provides
coord path/coord.raw coord
force path/force.raw f_data
energy path/energy.raw e_data
virial path/virial.raw s_data
box path/box.raw cell
type path/type.raw elems

Parameters:

Name Type Description Default
files dict | Path | str

input files

required
type_map dict | Path | str

mapping of type id to atomic number

None
pbc bool

flag of periodic boundary condition

True
shuffle bool

shuffle the dataset (only used when splitting)

True
seed int

random seed for shuffling

0
Source code in pinn/io/deepmd.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def load_deepmd(fdict_or_fpath, type_map=None, pbc=True, shuffle=True, seed=0):

    """This is loader for deepmd input data. It takes a dict of key and file path or a directory path which contains the data files. If `type_map` is provided, it will be used to convert the type id to atomic numbers.

    | key        | data source         | provides         |
    |------------|---------------------|------------------|
    | `coord`    | `path/coord.raw`      | `coord`  |
    | `force`    | `path/force.raw`      | `f_data`      |
    | `energy`   | `path/energy.raw`      | `e_data`      |
    | `virial`   | `path/virial.raw`      | `s_data`      |
    | `box`     | `path/box.raw` | `cell`           |
    | `type`    | `path/type.raw` | `elems`           |

    Args:
        files (dict | Path | str): input files
        type_map (dict | Path | str): mapping of type id to atomic number
        pbc (bool): flag of periodic boundary condition
        shuffle (bool): shuffle the dataset (only used when splitting)
        seed (int): random seed for shuffling
    """
    if isinstance(fdict_or_fpath, (Path, str)):
        fdict = {}
        for key in ['coord', 'force', 'energy', 'virial', 'box', 'elems']:
            fdict[key] = Path(fdict_or_fpath) / f'{key}.raw'
    else:
        assert all([key in fdict_or_fpath for key in ['coord', 'force', 'energy', 'virial', 'box', 'elems']])
        fdict = fdict_or_fpath

    from ase.data import chemical_symbols
    import numpy as np
    from pinn.io import list_loader

    coord = np.loadtxt(fdict['coord'])
    force = np.loadtxt(fdict['force'])
    energy = np.loadtxt(fdict['energy'])
    # stress = np.loadtxt(fdict['virial'])
    cell = np.loadtxt(fdict['cell'])
    elem = np.loadtxt(fdict['elems'], dtype=int)

    if type_map is not None:
        if isinstance(type_map, (bool)):
            type_map_path = Path(fdict_or_fpath) / 'type_map.raw'

        elif isinstance(type_map, (Path, str)):
            type_map_path = Path(type_map)
    if type_map_path.exists():
        with open(type_map_path, 'r') as f:
            # assume type.raw is incremental integers starting from 0
            type_map = {chemical_symbols.index(line.strip()): i for i, line in enumerate(f)}

    for k, v in type_map.items():
        elem[elem == v] = k

    data = []
    # DeePMD .raw files use units of Å and eV. [https://docs.deepmodeling.com/projects/deepmd/en/latest/data/system.html]
    for i in range(len(coord)):
        data.append({
            'coord': coord[i],
            'f_data': force[i],
            'e_data': energy[i],
            # 's_data': stress[i],
            'cell': cell[i],
            'elems': elem
        })

    ds_spec = {
    'elems':  {'dtype':  'int32','shape': [None]},
    'cell':   {'dtype': 'float', 'shape': [3, 3]},
    'coord':  {'dtype':  'float','shape': [None, 3]},
    'e_data': {'dtype': 'float', 'shape': []},
    'f_data': {'dtype': 'float', 'shape': [None, 3]},
    # 's_data': {'dtype': 'float', 'shape': [3, 3]},
}

    @list_loader(ds_spec=ds_spec, pbc=pbc)
    def _frame_loader(datum):
        return datum

    return _frame_loader(data, shuffle=shuffle, seed=seed)

  1. 1 R. Ramakrishnan, P.O. Dral, M. Rupp, and O.A. von Lilienfeld, “Quantum chemistry structures and properties of 134 kilo molecules,” Sci. Data 1, 140022 (2014). 

  2. 1 J.S. Smith, O. Isayev, and A.E. Roitberg, “ANI-1, A data set of 20 million calculated off-equilibrium conformations for organic molecules,” Sci. Data 4, 170193 (2017). 

« Previous
Next »