Source code for ccs_fit.scripts.ccs_fetch

import math
import sys
import json
import itertools as it
from collections import OrderedDict, defaultdict
from matplotlib.pyplot import pink
import numpy as np
from ase import Atoms
from ase import io
import ase.db as db
from ase.cell import Cell
from sympy import true
from tqdm import tqdm
import itertools
import random
import os


[docs]def pair_dist(atoms, R_c, ch1, ch2, counter): """ This function returns pairwise distances between two types of atoms within a certain cuttoff Input ----- R_c : float Cut off distance(6. Å) ch1 : str Atom species 1 ch2 : str Atoms species 2 Returns ------- A list of distances """ try: cell = atoms.get_cell() n_repeat = R_c * np.linalg.norm(np.linalg.inv(cell), axis=0) n_repeat = np.ceil(n_repeat).astype(int) offsets = [ *itertools.product(*[np.arange(-n, n + 1) for n in n_repeat])] except: cell = Cell([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) offsets = [[0, 0, 0]] mask1 = [atom == ch1 for atom in atoms.get_chemical_symbols()] mask2 = [atom == ch2 for atom in atoms.get_chemical_symbols()] pos1 = atoms[mask1].positions index1 = np.arange(0, len(atoms))[mask1] atoms_2 = atoms[mask2] Natoms_2 = len(atoms_2) pos2 = [] for offset in offsets: pos2.append((atoms_2.positions + offset @ cell)) pos2 = np.array(pos2) pos2 = np.reshape(pos2, (-1, 3)) r_distance = [] forces = OrderedDict() for p1, id in zip(pos1, index1): tmp = pos2 - p1 norm_dist = np.linalg.norm(tmp, axis=1) dist_mask = norm_dist < R_c r_distance.extend(norm_dist[dist_mask].tolist()) forces["F" + str(counter) + "_" + str(id) ] = np.asarray(tmp[dist_mask]).tolist() if ch1 == ch2: r_distance.sort() r_distance = r_distance[::2] return r_distance, forces
[docs]def ccs_fetch( mode=None, DFT_DB=None, R_c=6.0, Ns='all', DFTB_DB=None, charge_dict=None, include_forces=False, verbose=False): """ Function to read files and output structures.json Input ----- mode : string To what target the spline should be fitted. Options are [CCS,CCS+Q,DFTB,PRUNED_CCS] DFT_DB : string optional: target database R_c : float optional: Distance cut-off. Defaults to 6.0. Returns ------- structures.json : JSON file Collection of structures in .json format. Example ------- To be added. """ DFT_DB = db.connect(DFT_DB) if mode == "CCS": REF_DB = DFT_DB if mode == "CCS+Q": from pymatgen.core import Lattice, Structure from pymatgen.analysis import ewald REF_DB = DFT_DB if mode == "DFTB": REF_DB = db.connect(DFTB_DB) if mode == "PRUNED_CCS": REF_DB = db.connect(DFTB_DB) if Ns == 'all': Ns = -1 # CONVERT TO INTEGER INPUT FORMAT if Ns > 0: mask = [a <= Ns-1 for a in range(len(REF_DB))] random.shuffle(mask) else: mask = len(REF_DB) * [True] species = [] counter = -1 c = OrderedDict() d = OrderedDict() cf = OrderedDict() for row in tqdm(REF_DB.select(), total=len(DFT_DB), desc=" Fetching data", colour="#008080"): counter = counter + 1 if mask[counter]: struct = row.toatoms() ce = OrderedDict() if include_forces: FREF = row.forces EREF = row.energy ce["energy_dft"] = EREF if mode == "DFTB": key = str(row.id) EDFT = DFT_DB.get("id=" + str(row.id)).energy if include_forces: FDFT = DFT_DB.get("id=" + str(row.id)).forces ce["energy_dft"] = EDFT ce["energy_dftb"] = EREF dict_species = defaultdict(int) struct.charges = [] for elem in struct.get_chemical_symbols(): dict_species[elem] += 1 if mode == "CCS+Q": struct.charges.append(charge_dict[elem]) dict_species = {key: value for key, value in sorted(dict_species.items())} atom_pair = it.combinations_with_replacement( dict_species.keys(), 2) if mode == "CCS+Q": lattice = Lattice(struct.get_cell()) coords = struct.get_scaled_positions() ew_struct = Structure( lattice, struct.get_chemical_symbols(), coords, site_properties={"charge": struct.charges}, ) Ew = ewald.EwaldSummation(ew_struct, compute_forces=True) ES_energy = Ew.total_energy ES_forces = Ew.forces ce["ewald"] = ES_energy if include_forces: for i in range(len(struct)): if mode == "CCS": cf["F" + str(counter) + "_" + str(i) ] = {"force_dft": list(FREF[i, :])} if mode == "DFTB": cf["F" + str(counter) + "_" + str(i) ] = {"force_dft": list(FDFT[i, :]), "force_dftb": list(FREF[i, :])} if mode == "CCS+Q": cf["F" + str(counter) + "_" + str(i) ] = {"force_dft": list(FREF[i, :]), "force_ewald": list(ES_forces[i, :])} ce["atoms"] = dict_species for (x, y) in atom_pair: pair_distances, forces = pair_dist(struct, R_c, x, y, counter) ce[str(x) + "-" + str(y)] = pair_distances for i in range(len(struct)): try: cf["F" + str(counter) + "_" + str(i)][ str(x) + "-" + str(y) ] = forces["F" + str(counter) + "_" + str(i)] except: pass # FORCES SHOULD BE DOUBLE COUNTED! if (x != y): pair_distances, forces = pair_dist( struct, R_c, y, x, counter) for i in range(len(struct)): try: cf["F" + str(counter) + "_" + str(i)][ str(x) + "-" + str(y) ] = forces["F" + str(counter) + "_" + str(i)] except: pass d["S" + str(counter + 1)] = ce st = OrderedDict() st["energies"] = d if include_forces: st['forces'] = cf with open("structures.json", "w") as f: json.dump(st, f, indent=8)
[docs]def main(): import argparse parser = argparse.ArgumentParser(description='CCS fetching tool') parser.add_argument("-m", "--mode", type=str, metavar="", default='CCS', help="Mode. Available options: CCS, CCS+Q, DFTB") parser.add_argument("-d", "--DFT_DB", type=str, metavar="", default='DFT.db', help="Name of DFT reference database") parser.add_argument("-dd", "--DFTB_DB", type=str, metavar="", default=None, help="Name of DFTB reference database") parser.add_argument("-r", "--R_c", type=float, metavar="", default=6.0, help="Cut-off radius") parser.add_argument("-n", "--Ns", type=int, metavar="", default=-1, help="Number of structures to include") parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") parser.add_argument("-chg", "--charge_dict", type=json.loads, metavar="", help="Specify atomic charges in json format, e.g.: \n \'{ \"Zn\" : 2.0 , \"O\" : -2.0 }\' ") parser.add_argument("-f", "--include_forces", action="store_true", help='Include forces.') args = parser.parse_args() ccs_fetch(**vars(args)) try: size = os.get_terminal_size() c = size.columns txt = "-"*c print("") print(txt) import art txt = art.text2art('CCS:Fetch') print(txt) except: pass try: size = os.get_terminal_size() c = size.columns txt = "-"*c print(txt) print("") except: pass
if __name__ == "__main__": main()