Skip to content

Writing a PiNN model

Below is a simplified version of the potential model as a template to implement new models. See also the API Documentation of the helper functions.

from pinn.models.base import export_model, get_train_op, MetricsCollector

@export_model
def simple_potential_model(features, labels, mode, params):
    """Model function for neural network potentials"""
    network = pinn.get_network(params['network'])
    model_params = default_params.copy()
    model_params.update(params['model_params'])

    features = network.preprocess(features)
    connect_dist_grad(features)
    pred = network(features)

    if mode == tf.estimator.ModeKeys.TRAIN:
        metrics = make_metrics(features, pred, model_params, mode)
        train_op = get_train_op(params['optimizer'],
                                metrics.LOSS, metrics.ERROR, network)
        return tf.estimator.EstimatorSpec(mode, loss=metrics.LOSS, train_op=train_op)
    if mode == tf.estimator.ModeKeys.EVAL:
        metrics = make_metrics(features, pred, model_params, mode)
        return tf.estimator.EstimatorSpec(mode, loss=metrics.LOSS,
                                          eval_metric_ops=metrics.METRICS)
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {'energy': pred}
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

def make_metrics(features, pred, params, mode):
    metrics = MetricsCollector(mode)
    e_pred = pred
    e_data = features['e_data']
    e_mask = tf.abs(e_data) > params['max_energy'] if params['max_energy'] else None
    e_weight = params['e_loss_multiplier']
    e_weight *= features['e_weight'] if params['use_e_weight'] else 1
    metrics.add_error('E', e_data, e_pred, mask=e_mask, weight=e_weight,
                      use_error=(not params['use_e_per_atom']))
    return metrics

In the above code, the optimizer is defined by a model function, a more detailed introduction to Estimators and model functions can be found in the TensorFlow 1 documentation.

The MetricsCollector object is a helper object in PiNN to handle different forms of errors. It helps to apply customized weights to errors, filter them and keep appropriate logs during the training and evaluation phases, see the API documentation for more details.

Note

Models created in this way will not be accessible from the PiNN CLI and pinn.get_model (they must be created by a model_fn(params) call in a python script). A mechanism to include custom models might be available in a future version.