Commit 3685970b authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

Fixing mypy errors

parent 12bd97b7
Loading
Loading
Loading
Loading
+176 −175
Original line number Diff line number Diff line
import numpy as np
from deepchem.feat import MaterialStructureFeaturizer
from collections import defaultdict
from typing import List, Dict, Tuple, Iterable, Union, Any
from typing import List, Dict, Tuple, Iterable, Union, DefaultDict


class LCNNFeaturizer(MaterialStructureFeaturizer):
@@ -107,7 +107,7 @@ class LCNNFeaturizer(MaterialStructureFeaturizer):
      Template primitive stucture in string format
    """
    self.cutoff = np.around(cutoff, 2)
    self.setup_env = _SiteEnvironments.load_primitive_cell(template, cutoff)
    self.setup_env = load_primitive_cell(template, cutoff)

  def _featurize(self, structure) -> Dict[str, np.ndarray]:
    """
@@ -125,7 +125,7 @@ class LCNNFeaturizer(MaterialStructureFeaturizer):
    return {"X_Sites": np.array(xSites), "X_NSs": np.array(xNSs)}


def input_reader(text: str, template: bool = False) -> Iterable[list]:
def input_reader(text: str, template: bool = False) -> Iterable[Union[List[str], np.ndarray, List[int], int]]:
  """
  Read Input structures in a format which can produce the coordinate dimensions
  and axes. If it is a primitive cell, it returns the lattice cell, coordinates,
@@ -193,10 +193,10 @@ def input_reader(text: str, template: bool = False) -> Iterable[list]:
class _SiteEnvironment(object):

  def __init__(self,
               pos: Union[list, np.ndarray],
               sitetypes: list,
               env2config: Union[list, np.ndarray],
               permutations: list,
               pos: List[np.ndarray],
               sitetypes: List[str],
               env2config: Union[List[int], np.ndarray],
               permutations: List[List[int]],
               cutoff: float,
               Grtol: float = 0.0,
               Gatol: float = 0.01,
@@ -249,7 +249,7 @@ class _SiteEnvironment(object):
    self.pos = pos
    self.sitetypes = sitetypes
    self.activesiteidx = [i for i, s in enumerate(self.sitetypes) if 'A' in s]
    self.formula = defaultdict(int)
    self.formula: DefaultDict[str, int] = defaultdict(int)
    for s in sitetypes:
      self.formula[s] += 1
    self.permutations = permutations
@@ -285,7 +285,7 @@ class _SiteEnvironment(object):
    self._nm = iso.categorical_node_match('n', '')
    self._em = iso.numerical_edge_match('d', 0, rtol, 0)

  def _construct_graph(self, pos: list, sitetypes: str):
  def _construct_graph(self, pos: List[np.ndarray], sitetypes: List[str]):
    """
    Returns local environment graph using networkx and
    tolerance specified.
@@ -333,7 +333,9 @@ class _SiteEnvironment(object):
        n += 1
    return G

  def get_mapping(self, env: Dict[str, list]) -> Union[Dict[int, int], None]:
  def get_mapping(self,
                  env: Dict[str, Union[List[int], List[str], np.ndarray]]
                  ) -> Dict[int, int]:
    """
    Returns mapping of sites from input to this object

@@ -361,6 +363,7 @@ class _SiteEnvironment(object):
    except:
      raise ImportError("This class requires networkx to be installed.")
    # construct graph

    G = self._construct_graph(env['pos'], env['sitetypes'])
    if len(self.G.nodes) != len(G.nodes):
      s = 'Number of nodes is not equal.\n'
@@ -401,7 +404,7 @@ class _SiteEnvironment(object):
      s += '-Consider increasing neighbor finding tolerance'
      raise ValueError(s)

  def _kabsch(self, P: np.ndarray, Q: np.ndarray) -> np.ndarray:
  def _kabsch(self, P: List[np.ndarray], Q: List[np.ndarray]) -> List[np.ndarray]:
    """
    Returns rotation matrix to align coordinates using
    Kabsch algorithm.
@@ -429,7 +432,7 @@ class _SiteEnvironment(object):
class _SiteEnvironments(object):

  def __init__(self, site_envs: List[_SiteEnvironment], ns: int, na: int,
               aos: List[str], eigen_tol: float, pbc: List[str], cutoff: float):
               aos: List[str], eigen_tol: float, pbc: List[bool], cutoff: float):
    """
    Initialize
    Use Load to intialize this class.
@@ -453,7 +456,7 @@ class _SiteEnvironments(object):
      Cutoff radius in angstrom for pooling sites to construct local environment
    """
    self.site_envs = site_envs
    self.unique_site_types = [env.sitetypes[0] for env in self.site_envs]
    self.unique_site_types: List[str] = [env.sitetypes[0] for env in self.site_envs]
    self.ns = ns
    self.na = na
    self.aos = aos
@@ -495,7 +498,7 @@ class _SiteEnvironments(object):
        n += 1
    # Get Neighbors
    # Read Data
    site_envs = self._get_SiteEnvironments(
    site_envs = _get_SiteEnvironments(
        coord,
        cell,
        st,
@@ -503,16 +506,16 @@ class _SiteEnvironments(object):
        self.pbc,
        get_permutations=False,
        eigen_tol=self.eigen_tol)
    XNSs = [[] for _ in range(len(self.site_envs))]
    XNSs: List[list] = [[] for _ in range(len(self.site_envs))]
    for env in site_envs:
      i = self.unique_site_types.index(env['sitetypes'][0])
      env = self._truncate(self.site_envs[i], env)
      new_env = self._truncate(self.site_envs[i], env)

      # get map between two environment
      mapping = self.site_envs[i].get_mapping(env)
      mapping = self.site_envs[i].get_mapping(new_env)
      # align input to the primitive cell (reference)
      aligned_idx = [
          env['env2config'][mapping[i]] for i in range(len(env['env2config']))
          new_env['env2config'][mapping[i]] for i in range(len(new_env['env2config']))
      ]
      # apply permutations
      nni_perm = np.take(aligned_idx, self.site_envs[i].permutations)
@@ -525,7 +528,7 @@ class _SiteEnvironments(object):

  @classmethod
  def _truncate(cls, env_ref: _SiteEnvironment,
                env: _SiteEnvironment) -> Dict[str, np.ndarray]:
                env: Dict[str, Union[List[Union[int, float, str]], np.ndarray]]) -> Dict[str, Union[List[int], List[str], np.ndarray]]:
    """
    When cutoff_factor is used, it will pool more site than cutoff factor specifies.
    This will rule out nonrelevant sites by distance.
@@ -537,7 +540,7 @@ class _SiteEnvironments(object):

    Returns
    -------
    env: Dict[str, np.ndarray]
    env: Dict[str, Union[list, np.ndarray]]
    """
    # Extract the right number of sites by distance
    dists = defaultdict(list)
@@ -557,15 +560,14 @@ class _SiteEnvironments(object):
        for i in range(len(env['sitetypes']))
        if i in siteidx
    ]
    env['env2config'] = [env['env2config'][i] for i in siteidx]
    env['env2config']: List[int] = [env['env2config'][i] for i in siteidx]
    del env['dist']
    return env

  @classmethod
  def load_primitive_cell(cls,
                          path: str,

def load_primitive_cell(path: str,
                        cutoff: float,
                          eigen_tol: float = 1e-5) -> Any:
                        eigen_tol: float = 1e-5) -> _SiteEnvironments:
  """
  This loads the primitive cell, along with all the permutations
  required for creating a neighbor. This produces the site environments of
@@ -588,7 +590,7 @@ class _SiteEnvironments(object):
    Instance of the _SiteEnvironments object
  """
  cell, pbc, coord, st, ns, na, aos = input_reader(path, template=True)
    site_envs = cls._get_SiteEnvironments(
  site_envs = _get_SiteEnvironments(
      coord, cell, st, cutoff, pbc, True, eigen_tol=eigen_tol)
  site_envs = [
      _SiteEnvironment(e['pos'], e['sitetypes'], e['env2config'],
@@ -598,17 +600,17 @@ class _SiteEnvironments(object):
  ust = [env.sitetypes[0] for env in site_envs]
  usi = np.unique(ust, return_index=True)[1]
  site_envs = [site_envs[i] for i in usi]
    return cls(site_envs, ns, na, aos, eigen_tol, pbc, cutoff)
  return _SiteEnvironments(site_envs, ns, na, aos, eigen_tol, pbc, cutoff)

  @classmethod
  def _get_SiteEnvironments(cls,
                            coord: Union[list, np.ndarray],
                            cell: np.ndarray,

def _get_SiteEnvironments(coord: Union[List[np.ndarray], np.ndarray],
                          cell: Union[List[np.ndarray], np.ndarray],
                          SiteTypes: List[str],
                          cutoff: float,
                            pbc: List[str],
                          pbc: List[bool],
                          get_permutations: bool = True,
                            eigen_tol: float = 1e-5) -> List[_SiteEnvironment]:
                          eigen_tol: float = 1e-5) -> List[Dict[str, Union[List[Union[int, str, float]], np.ndarray]]]:

  """
  Used to extract information about both primitve cells and data points.
  Extract local environments from primitive cell. Using the two diffrent types
@@ -644,7 +646,6 @@ class _SiteEnvironments(object):
  try:
    from pymatgen import Element, Structure, Molecule, Lattice
    from pymatgen.symmetry.analyzer import PointGroupAnalyzer

  except:
    raise ImportError("This class requires pymatgen to be installed.")