Commit 6f28058b authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

Changing input and output formats along with few typos

parent 276a4d72
Loading
Loading
Loading
Loading
+208 −268
Original line number Diff line number Diff line
import numpy as np
import logging
from deepchem.feat import MaterialStructureFeaturizer
from collections import defaultdict
from typing import List, Dict, Tuple, DefaultDict, Any
from deepchem.utils.typing import PymatgenStructure
from deepchem.feat.graph_data import GraphData


class LCNNFeaturizer(MaterialStructureFeaturizer):
  """
  Calculates the 2-D Surface graph features in 6 diffrent permutaions-
  Calculates the 2-D Surface graph features in 6 different permutations-

  Based on the implementation of Lattice Graph Convolution Neural
  Network (LCNN). This method produces the Atom wise features ( One Hot Encoding)
  and Adjacent neighbour in the specified order of permutations. Neighbors are
  determined using a distance metric and Each Permutation of the Neighbors are
  calculated such a manner in which, first element is the node itself followed
  by a randomly selected neighboring site. The next site shares a surface Pt atom
  with the previous site. And they are picked consecutively.

  First, the template of the Primitive cell needs to be defined and then each
  structure(Data Point i.e different configuration of adsorbate atoms) is passed
  for featurization.
  determined by first extracting a site local environment from the primitive cell,
  and perform graph matching and distance matching to find neighbors.
  First, the template of the Primitive cell needs to be defined along with periodic
  boundary conditions and active and specator site details. structure(Data Point
  i.e different configuration of adsorbate atoms) is passed for featurization.

  This particular featurisation produces a regular-graph (equal number of Neighbors)
  along with its permutation in 6 symmetric axis. This transformation can be
  applied when orderering of neighboring of nodes around a site play an important role
  in the propert predictions. Due to consideration of local neighbor environment,
  this current implementation would be fruitful in finding neighbors for calculating
  formation energy of adbsorption tasks where the local. Adsorption turns out to be important
  in many applications such as catalyst and semiconductor design.

  The permuted neighbors are calculated using the Primitive cells i.e periodic cells
  in all the data points are built via lattice transformation of the primitive cell.

  [1] The Primitive Template file must be passed as raw text string or path file.\n
  [2] The datapoint must be passed as raw text string.\n
  `Primitive cell Format:`

  1. Pymatgen structure object with site_properties key value
   - "SiteTypes" mentioning if it is a active site "A1" or spectator
     site "S1".
  2. ns , the number of spectator types elements. For "S1" its 1.
  3. na , the number of active types elements. For "A1" its 1.
  4. aos, the different species of active elements "A1".
  5. pbc, the periodic boundary conditions.

  `Data point Structure Format(Configuration of Atoms):`

  1. Pymatgen structure object with site_properties with following key value.
   - "SiteTypes", mentioning if it is a active site "A1" or spectator
     site "S1".
   - "oss", different occupational sites. For spectator sites make it -1.

  It is highly recommended that cells of data are directly redefined from
  the primitive cell, specifically, the relative coordinates between sites
  are consistent so that the lattice is non-deviated.

  References
  ----------
@@ -32,177 +58,114 @@ class LCNNFeaturizer(MaterialStructureFeaturizer):

  Examples
  --------
  >>> primitive_cell = '''#Primitive Cell
  >>> 2.81852800e+00  0.00000000e+00  0.00000000e+00 T
  >>> -1.40926400e+00  2.44091700e+00  0.00000000e+00 T
  >>> 0.00000000e+00  0.00000000e+00  2.55082550e+01 F
  >>> 1 1
  >>> 1 0 2
  >>> 6
  >>> 0.666670000000  0.333330000000  0.090220999986 S1
  >>> 0.333330000000  0.666670000000  0.180439359180 S1
  >>> 0.000000000000  0.000000000000  0.270657718374 S1
  >>> 0.666670000000  0.333330000000  0.360876077568 S1
  >>> 0.333330000000  0.666670000000  0.451094436762 S1
  >>> 0.000000000000  0.000000000000  0.496569911270 A1
  >>> '''
  >>> structure = '''2.81859800e+00  0.00000000e+00  0.00000000e+00
  >>> -1.40929900e+00  2.44097800e+00  0.00000000e+00
  >>> 0.00000000e+00  0.00000000e+00  2.55082550e+01
  >>> 6
  >>> 0.666670000000  0.333330000000  0.090220999986 S1
  >>> 0.333330000000  0.666670000000  0.180439359180 S1
  >>> 0.000000000000  0.000000000000  0.270657718374 S1
  >>> 0.666670000000  0.333330000000  0.360876077568 S1
  >>> 0.333330000000  0.666670000000  0.451094436762 S1
  >>> 0.000000000000  0.000000000000  0.496569911270 A1 0
  >>> '''
  >>> featuriser = LCNNFeaturizer(np.around(6.00), primitive_cell)
  >>> data = featuriser._featurize(structure)
  >>> print(data.keys())
  dict_keys(['X_Sites', 'X_NSs'])
  >>> PRIMITIVE_CELL = {
  >>>   "lattice": [[2.818528, 0.0, 0.0],
  >>>               [-1.409264, 2.440917, 0.0],
  >>>               [0.0, 0.0, 25.508255]],
  >>>   "coords": [[0.66667, 0.33333, 0.090221],
  >>>              [0.33333, 0.66667, 0.18043936],
  >>>              [0.0, 0.0, 0.27065772],
  >>>              [0.66667, 0.33333, 0.36087608],
  >>>              [0.33333, 0.66667, 0.45109444],
  >>>              [0.0, 0.0, 0.49656991]],
  >>>   "species": ['H', 'H', 'H', 'H', 'H', 'He'],
  >>>   "site_properties": {'SiteTypes': ['S1', 'S1', 'S1', 'S1', 'S1', 'A1']}
  >>> }
  >>> PRIMITIVE_CELL_INF0 = {
  >>>    "cutoff": np.around(6.00),
  >>>    "structure": Structure(**PRIMITIVE_CELL),
  >>>    "aos": ['1', '0', '2'],
  >>>    "pbc": [True, True, False],
  >>>    "ns": 1,
  >>>    "na": 1
  >>> }
  >>> DATA_POINT = {
  >>>   "lattice": [[1.409264, -2.440917, 0.0],
  >>>               [4.227792, 2.440917, 0.0],
  >>>               [0.0, 0.0, 23.17559]],
  >>>   "coords": [[0.0, 0.0, 0.099299],
  >>>              [0.0, 0.33333, 0.198598],
  >>>              [0.5, 0.16667, 0.297897],
  >>>              [0.0, 0.0, 0.397196],
  >>>              [0.0, 0.33333, 0.496495],
  >>>              [0.5, 0.5, 0.099299],
  >>>              [0.5, 0.83333, 0.198598],
  >>>              [0.0, 0.66667, 0.297897],
  >>>              [0.5, 0.5, 0.397196],
  >>>              [0.5, 0.83333, 0.496495],
  >>>              [0.0, 0.66667, 0.54654766],
  >>>              [0.5, 0.16667, 0.54654766]],
  >>>   "species": ['H', 'H', 'H', 'H', 'H', 'H',
  >>>               'H', 'H', 'H', 'H', 'He', 'He'],
  >>>   "site_properties": {
  >>>     "SiteTypes": ['S1', 'S1', 'S1', 'S1', 'S1',
  >>>                   'S1', 'S1', 'S1', 'S1', 'S1',
  >>>                   'A1', 'A1'],
  >>>     "oss": ['-1', '-1', '-1', '-1', '-1', '-1',
  >>>             '-1', '-1', '-1', '-1', '0', '2']
  >>>                   }
  >>> }
  >>> featuriser = LCNNFeaturizer(**PRIMITIVE_CELL_INF0)
  >>> print(type(featuriser._featurize(Structure(**DATA_POINT))))
  <class 'deepchem.feat.graph_data.GraphData'>

  Notes
  -----
  This Class requires pymatgen , networkx , scipy installed.

  `Primitive cell Format:`

  - [comment]
  - [ax][ay][az][pbc]
  - [bx][by][bz][pbc]
  - [cx][cy][cz][pbc]
  - [number of spectator site type][number of active site type]
  - [os1][os2][os3]
  - [number sites]
  - [site1a][site1b][site1c][site type]
  - [site2a][site2b][site2c][site type]

  `Data point Structure Format(Configuration of Atoms):`

  - [ax][ay][az]
  - [bx][by][bz]
  - [cx][cy][cz]
  - [number sites]
  - [site1a][site1b][site1c][site type][occupation state if active site]
  - [site2a][site2b][site2c][site type][occupation state if active site]

  [1] ax,ay, ... are cell basis vector\n
  [2] pbc is either T or F indication of the periodic boundary condition\n
  [3] os# is the name of the possible occupation state (interpretted as string)\n
  [4] site1a,site1b,site1c are the scaled coordinates of site 1\n
  [5] site type can be either S1, S2, ... or A1, A2,... indicating spectator\n

  """

  def __init__(self, cutoff: float, template: str):
  def __init__(self,
               structure: PymatgenStructure,
               aos: List[str],
               pbc: List[bool],
               ns: int = 1,
               na: int = 1,
               cutoff: float = 6.00):
    """
    Parameters
    ----------
    cutoff: float
    structure: : PymatgenStructure
      Pymatgen Structure object of the primitive cell used for calculating
      neighbors from lattice transformations.It also requires site_properties
      attribute with "Sitetypes"(Active or spectator site).
    aos: List[str]
      A list of all the active site species. For the Pt, N, NO configuration
      set it as ['0', '1', '2']
    pbc: List[bool]
      Periodic Boundary Condition
    ns: int (default 1)
      The number of spectator types elements. For "S1" its 1.
    na: int (default 1)
      the number of active types elements. For "A1" its 1.
    cutoff: float (default 6.00)
      Cutoff of radius for getting local environment.Only
      used down to 2 digits.

    template: str
      Template primitive stucture in string format
    """
    self.aos = aos
    self.cutoff = np.around(cutoff, 2)
    self.setup_env = load_primitive_cell(template, cutoff)
    self.setup_env = _load_primitive_cell(structure, aos, pbc, ns, na, cutoff)

  def _featurize(self, structure) -> Dict[str, np.ndarray]:
  def _featurize(self, structure: PymatgenStructure) -> GraphData:
    """
    Parameters
    ----------
    structure: str
      Structure information as raw text data input as a string
    structure: : PymatgenStructure
      Pymatgen Structure object of the surface configuration. It also requires
      site_properties attribute with "Sitetypes"(Active or spectator site) and
      "oss"(Species of Active site from the list of self.aos and "-1" for
      spectator sites).

    Returns
    -------
    dict: Dict[str, np.ndarray]
    graph: GraphData
      Node features, All edges for each node in diffrent permutations
    """
    xSites, xNSs = self.setup_env.read_datum(structure)
    return {"X_Sites": np.array(xSites), "X_NSs": np.array(xNSs)}


def input_reader(
    text: str, template: bool = False
) -> Tuple[np.ndarray, np.ndarray, List[str], List[str], np.ndarray, int, int]:
  """
  Read Input structures in a format which can produce the coordinate dimensions
  and axes. If it is a primitive cell, it returns the lattice cell, coordinates,
  site types and occupation state. Else if it is a data point, it returns
  the additional periodic boundary and type of occupation state.

  Parameters
  ----------
  text : str
    structure as a string
  template: bool(default False)
    Set to true for primitive cell, and false for data point

  Returns
  -------
  cell: np.ndarray
    cell basis vector
  coord: np.ndarray
    scaled coordinates
  st: List[str]
    Sitetype, 'S1' or 'A1'
  oss/aos: List[str]
    possible occupation state
  pbc: np.ndarray
    Periodic Boundary Condition(Valid only for data point)
  ns: int
    (Valid only for data point)
  na: int
    (Valid only for data point)
  """

  s = text.rstrip('\n').split('\n')
  nl, ns, na = 0, 0, 0
  # read comment
  if template:
    datum = False
    nl += 1
  else:
    datum = True

  # load cell and pbc
  cell = np.zeros((3, 3))
  pbc = np.array([True, True, True])
  for i in range(3):
    t = s[nl].split()
    cell[i, :] = [float(i) for i in t[0:3]]
    if not datum and t[3] == 'F':
      pbc[i] = False
    nl += 1
  # read sites if primitive
  if not datum:
    t = s[nl].split()
    ns = int(t[0])
    na = int(t[1])
    nl += 1
    aos = s[nl].split()
    nl += 1
  # read positions
  nS = int(s[nl])
  nl += 1
  coord = np.zeros((nS, 3))
  st = []
  oss = []
  for i in range(nS):
    t = s[nl].split()
    coord[i, :] = [float(i) for i in t[0:3]]
    st.append(t[3])
    if datum and len(t) == 5:
      oss.append(t[4])
    nl += 1

  if datum:
    return cell, coord, st, oss, pbc, ns, na
  else:
    return cell, coord, st, aos, pbc, ns, na
    config_size = xNSs.shape
    v = np.arange(0, len(xSites)).repeat(config_size[2] * config_size[3])
    u = xNSs.flatten()
    graph = GraphData(node_features=xSites, edge_index=np.array([u, v]))
    return graph


class _SiteEnvironment(object):
@@ -212,13 +175,13 @@ class _SiteEnvironment(object):
               sitetypes: List[str],
               env2config: List[int],
               permutations: List[List[int]],
               cutoff: float,
               cutoff: float = 6.00,
               Grtol: float = 0.0,
               Gatol: float = 0.01,
               rtol: float = 0.01,
               atol: float = 0.0,
               tol: float = 0.01,
               grtol: float = 0.01):
               grtol: float = 1e-3):
    """
    Initialize site environment

@@ -235,11 +198,13 @@ class _SiteEnvironment(object):
      number. S indicates a spectator sites and A indicates a active
      sites.
    env2config: List[int]
      A particular permutation of the neighbors around an active
      site. These indexes will be used for lattice transformation.
    permutations : List[List[int]]
      p x n list of list of integer. p is the permutation
      index and n is the number of sites.
    cutoff : float
      cutoff used for pooling neighbors. for aesthetics only
      cutoff used for pooling neighbors.
    Grtol : float (default 0.0)
      relative tolerance in distance for forming an edge in graph
    Gatol : float (default 0.01)
@@ -277,7 +242,7 @@ class _SiteEnvironment(object):
    self.Grtol = Grtol
    self.Gatol = Gatol
    # tolerance for grouping nodes
    self.grtol = 1e-3
    self.grtol = grtol
    # determine minimum distance between sitetypes.
    # This is used to determine the existence of an edge
    dists = squareform(pdist(pos))
@@ -365,7 +330,7 @@ class _SiteEnvironment(object):

    env : Dict[str, Any]
      dictionary that contains information of local environment of a
      site in datum. See _get_SiteEnvironments defintion in the class
      site in datum. See _get_SiteEnvironments definition in the class
      _SiteEnvironments for what this variable should be.

    Returns
@@ -377,6 +342,10 @@ class _SiteEnvironment(object):
      import networkx.algorithms.isomorphism as iso
    except:
      raise ImportError("This class requires networkx to be installed.")
    try:
      from scipy.spatial.transform import Rotation
    except:
      raise ImportError("This class requires scipy to be installed.")
    # construct graph

    G = self._construct_graph(env['pos'], env['sitetypes'])
@@ -384,7 +353,8 @@ class _SiteEnvironment(object):
      s = 'Number of nodes is not equal.\n'
      raise ValueError(s)
    elif len(self.G.edges) != len(G.edges):
      print(len(self.G.edges), len(G.edges))
      logging.warning("Expected the number of edges to be equal",
                      len(self.G.edges), len(G.edges))
      s = 'Number of edges is not equal.\n'
      s += "- Is the data point's cell a redefined lattice of primitive cell?\n"
      s += '- If relaxed structure is used, you may want to check structure or increase Gatol\n'
@@ -405,7 +375,8 @@ class _SiteEnvironment(object):
      xyz = np.zeros((len(self.pos), 3))
      for i in am:
        xyz[i, :] = env['pos'][am[i], :]
      R = self._kabsch(self.pos, xyz)
      rotation, _ = Rotation.align_vectors(self.pos, xyz)
      R = rotation.as_matrix()
      # RMSD
      rmsd.append(
          np.sqrt(
@@ -419,30 +390,6 @@ class _SiteEnvironment(object):
      s += '-Consider increasing neighbor finding tolerance'
      raise ValueError(s)

  def _kabsch(self, P: np.ndarray, Q: np.ndarray) -> np.ndarray:
    """
    Returns rotation matrix to align coordinates using
    Kabsch algorithm.

    Parameters
    ----------
    P: np.ndarray
    Q: np.ndarray

    Returns
    -------
    R: np.ndarray
      Rotation matrix
    """
    C = np.dot(np.transpose(P), Q)
    V, S, W = np.linalg.svd(C)
    d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0
    if d:
      S[-1] = -S[-1]
      V[:, -1] = -V[:, -1]
    R = np.dot(V, W)
    return R


class _SiteEnvironments(object):

@@ -451,7 +398,7 @@ class _SiteEnvironments(object):
               cutoff: float):
    """
    Initialize
    Use Load to intialize this class.
    Use Load to initialize this class.

    Parameters
    ----------
@@ -462,7 +409,7 @@ class _SiteEnvironments(object):
    na : int
      number of active sites types
    aos : List[str]
      Avilable occupational states for active sites
      Available occupational states for active sites
      string should be the name of the occupancy. (consistent with the input data)
    eigen_tol : float
      tolerance for eigenanalysis of point group analysis in pymatgen.
@@ -482,15 +429,18 @@ class _SiteEnvironments(object):
    self.pbc = pbc
    self.cutoff = cutoff

  def read_datum(self, text: str, cutoff_factor: float = 1.1
                ) -> Tuple[List[float], List[List[int]]]:
  def read_datum(self, struct,
                 cutoff_factor: float = 1.1) -> Tuple[np.ndarray, np.ndarray]:
    """
    Load structure data and return neighbor information

    Parameters
    ----------
    text : str
      raw string of the structure
    struct: : PymatgenStructure
      Pymatgen Structure object of the surface configuration. It also requires
      site_properties attribute with "Sitetypes"(Active or spectator site) and
      "oss"(Species of Active site from the list of self.aos and "-1" for
      spectator sites).
    cutoff_factor : float
      this is extra buffer factor multiplied to cutoff to
      ensure pooling all relevant sites.
@@ -500,9 +450,11 @@ class _SiteEnvironments(object):
    XSites : List[float]
      One hot encoding features of the site.
    XNSs : List[List[int]]
      Neighbors calculated in diffrent permutations
      Neighbors calculated in different permutations
    """
    cell, coord, st, oss, _, _, _ = input_reader(text)
    oss = [
        species for species in struct.site_properties["oss"] if species != '-1'
    ]
    # Construct one hot encoding
    XSites = np.zeros((len(oss), len(self.aos)))
    for i, o in enumerate(oss):
@@ -510,16 +462,14 @@ class _SiteEnvironments(object):
    # get mapping between all site index to active site index
    alltoactive = {}
    n = 0
    for i, s in enumerate(st):
    for i, s in enumerate(struct.site_properties["SiteTypes"]):
      if 'A' in s:
        alltoactive[i] = n
        n += 1
    # Get Neighbors
    # Read Data
    site_envs = _get_SiteEnvironments(
        coord,
        cell,
        st,
        struct,
        self.cutoff * cutoff_factor,
        self.pbc,
        get_permutations=False,
@@ -543,14 +493,14 @@ class _SiteEnvironments(object):
      # map it to active sites
      nni_perm = np.vectorize(alltoactive.__getitem__)(nni_perm)
      XNSs[i].append(nni_perm.tolist())
    return XSites.tolist(), XNSs
    return np.array(XSites), np.array(XNSs)

  @classmethod
  def _truncate(cls, env_ref: _SiteEnvironment,
                env: Dict[str, Any]) -> Dict[str, Any]:
    """
    When cutoff_factor is used, it will pool more site than cutoff factor specifies.
    This will rule out nonrelevant sites by distance.
    When cutoff_factor is used, it will pool more site than cutoff
    factor specifies. This will rule out non-relevant sites by distance.

    Parameters
    ----------
@@ -586,7 +536,12 @@ class _SiteEnvironments(object):
    return env


def load_primitive_cell(path: str, cutoff: float,
def _load_primitive_cell(struct: PymatgenStructure,
                         aos: List[str],
                         pbc: List[bool],
                         ns: int,
                         na: int,
                         cutoff: float,
                         eigen_tol: float = 1e-5) -> _SiteEnvironments:
  """
  This loads the primitive cell, along with all the permutations
@@ -595,11 +550,22 @@ def load_primitive_cell(path: str, cutoff: float,

  Parameters
  ----------
  path : str
    Primitive Cell as a raw string
  cutoff : float.
    cutoff distance in angstrom for collecting local
    environment.
  struct: PymatgenStructure
    Pymatgen Structure object of the primitive cell used for calculating
    neighbors from lattice transformations.It also requires site_properties
    attribute with "Sitetypes"(Active or spectator site).
  aos: List[str]
    A list of all the active site species. For the Pt, N, NO configuration
    set it as ['0', '1', '2'].
  pbc: List[bool]
    Periodic Boundary Condition
  ns: int (default 1)
    The number of spectator types elements. For "S1" its 1.
  na: int (default 1)
    The number of active types elements. For "A1" its 1.
  cutoff: float (default 6.00)
    Cutoff of radius for getting local environment.Only
    used down to 2 digits.
  eigen_tol : float (default)
    tolerance for eigenanalysis of point group analysis in
    pymatgen.
@@ -609,9 +575,8 @@ def load_primitive_cell(path: str, cutoff: float,
  SiteEnvironments: _SiteEnvironments
    Instance of the _SiteEnvironments object
  """
  cell, coord, st, aos, pbc, ns, na = input_reader(path, template=True)
  site_envs = _get_SiteEnvironments(
      coord, cell, st, cutoff, pbc, True, eigen_tol=eigen_tol)
      struct, cutoff, pbc, True, eigen_tol=eigen_tol)
  site_envs_format = [
      _SiteEnvironment(e['pos'], e['sitetypes'], e['env2config'],
                       e['permutations'], cutoff) for e in site_envs
@@ -624,47 +589,43 @@ def load_primitive_cell(path: str, cutoff: float,
                           cutoff)


def _get_SiteEnvironments(coord: np.ndarray,
                          cell: np.ndarray,
                          SiteTypes: List[str],
def _get_SiteEnvironments(struct: PymatgenStructure,
                          cutoff: float,
                          pbc: np.ndarray,
                          PBC: List[bool],
                          get_permutations: bool = True,
                          eigen_tol: float = 1e-5) -> List[Dict[str, Any]]:
  """
  Used to extract information about both primitve cells and data points.
  Extract local environments from primitive cell. Using the two diffrent types
  Used to extract information about both primitive cells and data points.
  Extract local environments from Structure object by calculating neighbors
  based on gaussian distance. For primitive cell, Different permutations of the
  neighbors are calculated and will be later will mapped for data point in the
  _SiteEnvironment.get_mapping() function.
  site types ,

  Parameters
  ----------
  coord : np.ndarray
    n x 3 list or numpy array of scaled positions. n is the number
    of atom.
  cell : np.ndarray
    3 x 3 list or numpy array
  SiteTypes : List[str]
    n list of string. String must be S or A followed by a
    number. S indicates a spectator sites and A indicates a active
    sites.
  struct: PymatgenStructure
    Pymatgen Structure object of the primitive cell used for calculating
    neighbors from lattice transformations.It also requires site_properties
    attribute with "Sitetypes"(Active or spectator site).
  cutoff : float
    cutoff distance in angstrom for collecting local
    environment.
  pbc : np.ndarray
    Periodic boundary condition
  get_permutations : bool (default True)
    Whether to find permutatated neighbor list or not.
    Whether to find permuted neighbor list or not.
  eigen_tol : float (default 1e-5)
    Tolerance for eigenanalysis of point group analysis in
    pymatgen.

  Returns
  ------
  site_envs : List[_SiteEnvironment]
  site_envs : List[Dict[str, Any]]
    list of local_env class
  """
  try:
    from pymatgen import Element, Structure, Molecule, Lattice
    from pymatgen import Molecule
    from pymatgen.symmetry.analyzer import PointGroupAnalyzer
  except:
    raise ImportError("This class requires pymatgen to be installed.")
@@ -673,41 +634,20 @@ def _get_SiteEnvironments(coord: np.ndarray,
    from scipy.spatial.distance import cdist
  except:
    raise ImportError("This class requires scipy to be installed.")

  assert isinstance(coord, (list, np.ndarray))
  assert isinstance(cell, (list, np.ndarray))
  assert len(coord) == len(SiteTypes)

  coord = np.mod(coord, 1)
  pbc = np.array(pbc)

  # Available pymatgne functions are very limited when DummySpecie is
  # involved. This may be perhaps fixed in the future. Until then, we
  # simply bypass this by mapping site to an element
  # Find available atomic number to map site to it
  availableAN = [i + 1 for i in reversed(range(0, 118))]

  # Organize Symbols and record mapping
  symbols = []
  site_idxs = []
  SiteSymMap = {}  # mapping
  SymSiteMap = {}
  for i, SiteType in enumerate(SiteTypes):
    if SiteType not in SiteSymMap:
      symbol = Element.from_Z(availableAN.pop())
      SiteSymMap[SiteType] = symbol
      SymSiteMap[symbol] = SiteType

    else:
      symbol = SiteSymMap[SiteType]
    symbols.append(symbol)
    if 'A' in SiteType:
      site_idxs.append(i)

  # Find neighbors and permutations using pymatgen
  lattice = Lattice(cell)
  structure = Structure(lattice, symbols, coord)
  pbc = np.array(PBC)
  structure = struct
  neighbors = structure.get_all_neighbors(cutoff, include_index=True)
  symbols = structure.species
  site_idxs = [
      i for i, sitetype in enumerate(structure.site_properties['SiteTypes'])
      if sitetype == 'A1'
  ]
  site_sym_map = {}
  sym_site_map = {}
  for i, new_ele in enumerate(structure.species):
    sym_site_map[new_ele] = structure.site_properties['SiteTypes'][i]
    site_sym_map[structure.site_properties['SiteTypes'][i]] = new_ele

  site_envs = []
  for site_idx in site_idxs:
    local_env_sym = [symbols[site_idx]]
@@ -736,7 +676,7 @@ def _get_SiteEnvironments(coord: np.ndarray,

    site_env = {
        'pos': local_env_xyz,
        'sitetypes': [SymSiteMap[s] for s in local_env_sym],
        'sitetypes': [sym_site_map[s] for s in local_env_sym],
        'env2config': local_env_sitemap,
        'permutations': perm,
        'dist': local_env_dist
+486 −0

File added.

Preview size limit exceeded, changes collapsed.

+17 −9

File changed.

Preview size limit exceeded, changes collapsed.

+30 −29

File changed.

Preview size limit exceeded, changes collapsed.