Commit 10082633 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

Fixed all the mypy errors

parent 3685970b
Loading
Loading
Loading
Loading
+68 −48
Original line number Diff line number Diff line
import numpy as np
from deepchem.feat import MaterialStructureFeaturizer
from collections import defaultdict
from typing import List, Dict, Tuple, Iterable, Union, DefaultDict
from typing import List, Dict, Tuple, DefaultDict, Any


class LCNNFeaturizer(MaterialStructureFeaturizer):
@@ -125,7 +125,9 @@ class LCNNFeaturizer(MaterialStructureFeaturizer):
    return {"X_Sites": np.array(xSites), "X_NSs": np.array(xNSs)}


def input_reader(text: str, template: bool = False) -> Iterable[Union[List[str], np.ndarray, List[int], int]]:
def input_reader(
    text: str, template: bool = False
) -> Tuple[np.ndarray, np.ndarray, List[str], List[str], np.ndarray, int, int]:
  """
  Read Input structures in a format which can produce the coordinate dimensions
  and axes. If it is a primitive cell, it returns the lattice cell, coordinates,
@@ -141,11 +143,24 @@ def input_reader(text: str, template: bool = False) -> Iterable[Union[List[str],

  Returns
  -------
  list of local_env : list of local_env class
  cell: np.ndarray
    cell basis vector
  coord: np.ndarray
    scaled coordinates
  st: List[str]
    Sitetype, 'S1' or 'A1'
  oss/aos: List[str]
    possible occupation state
  pbc: np.ndarray
    Periodic Boundary Condition(Valid only for data point)
  ns: int
    (Valid only for data point)
  na: int
    (Valid only for data point)
  """

  s = text.rstrip('\n').split('\n')
  nl = 0
  nl, ns, na = 0, 0, 0
  # read comment
  if template:
    datum = False
@@ -185,17 +200,17 @@ def input_reader(text: str, template: bool = False) -> Iterable[Union[List[str],
    nl += 1

  if datum:
    return cell, coord, st, oss
    return cell, coord, st, oss, pbc, ns, na
  else:
    return cell, pbc, coord, st, ns, na, aos
    return cell, coord, st, aos, pbc, ns, na


class _SiteEnvironment(object):

  def __init__(self,
               pos: List[np.ndarray],
               pos: np.ndarray,
               sitetypes: List[str],
               env2config: Union[List[int], np.ndarray],
               env2config: List[int],
               permutations: List[List[int]],
               cutoff: float,
               Grtol: float = 0.0,
@@ -212,14 +227,15 @@ class _SiteEnvironment(object):

    Parameters
    ----------
    pos : Union[list, np.ndarray]
    pos : np.ndarray
      n x 3 list or numpy array of (non-scaled) positions. n is the
      number of atom.
    sitetypes : list
    sitetypes : List[str]
      n list of string. String must be S or A followed by a
      number. S indicates a spectator sites and A indicates a active
      sites.
    permutations : list
    env2config: List[int]
    permutations : List[List[int]]
      p x n list of list of integer. p is the permutation
      index and n is the number of sites.
    cutoff : float
@@ -285,16 +301,17 @@ class _SiteEnvironment(object):
    self._nm = iso.categorical_node_match('n', '')
    self._em = iso.numerical_edge_match('d', 0, rtol, 0)

  def _construct_graph(self, pos: List[np.ndarray], sitetypes: List[str]):
  def _construct_graph(self, pos: np.ndarray, sitetypes: List[str]):
    """
    Returns local environment graph using networkx and
    tolerance specified.

    Parameters
    ----------
    pos: list
    pos: np.ndarray
      ns x 3. coordinates of positions. ns is the number of sites.
      sitetypes: ns. sitetype for each site
    sitetypes: List[str]

    Returns
    ------
@@ -333,9 +350,7 @@ class _SiteEnvironment(object):
        n += 1
    return G

  def get_mapping(self,
                  env: Dict[str, Union[List[int], List[str], np.ndarray]]
                  ) -> Dict[int, int]:
  def get_mapping(self, env: Dict[str, Any]) -> Dict[int, int]:
    """
    Returns mapping of sites from input to this object

@@ -348,15 +363,15 @@ class _SiteEnvironment(object):
    Parameters
    ----------

    env : Dict[str, list]
    env : Dict[str, Any]
      dictionary that contains information of local environment of a
      site in datum. See _get_SiteEnvironments defintion in the class
      _SiteEnvironments for what this variable should be.

    Returns
    -------
    dict : Union[Dict[int, int], None]
      Atom mapping. None if there is no mapping
    dict : Dict[int, int]
      Atom mapping from Primitive cell to data point.
    """
    try:
      import networkx.algorithms.isomorphism as iso
@@ -404,7 +419,7 @@ class _SiteEnvironment(object):
      s += '-Consider increasing neighbor finding tolerance'
      raise ValueError(s)

  def _kabsch(self, P: List[np.ndarray], Q: List[np.ndarray]) -> List[np.ndarray]:
  def _kabsch(self, P: np.ndarray, Q: np.ndarray) -> np.ndarray:
    """
    Returns rotation matrix to align coordinates using
    Kabsch algorithm.
@@ -432,7 +447,8 @@ class _SiteEnvironment(object):
class _SiteEnvironments(object):

  def __init__(self, site_envs: List[_SiteEnvironment], ns: int, na: int,
               aos: List[str], eigen_tol: float, pbc: List[bool], cutoff: float):
               aos: List[str], eigen_tol: float, pbc: np.ndarray,
               cutoff: float):
    """
    Initialize
    Use Load to intialize this class.
@@ -451,12 +467,14 @@ class _SiteEnvironments(object):
    eigen_tol : float
      tolerance for eigenanalysis of point group analysis in pymatgen.
    pbc : List[str]
      periodic boundary condition.
      Boolean array, periodic boundary condition.
    cutoff : float
      Cutoff radius in angstrom for pooling sites to construct local environment
    """
    self.site_envs = site_envs
    self.unique_site_types: List[str] = [env.sitetypes[0] for env in self.site_envs]
    self.unique_site_types: List[str] = [
        env.sitetypes[0] for env in self.site_envs
    ]
    self.ns = ns
    self.na = na
    self.aos = aos
@@ -464,8 +482,8 @@ class _SiteEnvironments(object):
    self.pbc = pbc
    self.cutoff = cutoff

  def read_datum(self, text: str,
                 cutoff_factor: float = 1.1) -> Tuple[List[float], List[list]]:
  def read_datum(self, text: str, cutoff_factor: float = 1.1
                ) -> Tuple[List[float], List[List[int]]]:
    """
    Load structure data and return neighbor information

@@ -481,10 +499,10 @@ class _SiteEnvironments(object):
    ------
    XSites : List[float]
      One hot encoding features of the site.
    XNSs : List[list]
    XNSs : List[List[int]]
      Neighbors calculated in diffrent permutations
    """
    cell, coord, st, oss = input_reader(text)
    cell, coord, st, oss, _, _, _ = input_reader(text)
    # Construct one hot encoding
    XSites = np.zeros((len(oss), len(self.aos)))
    for i, o in enumerate(oss):
@@ -515,7 +533,8 @@ class _SiteEnvironments(object):
      mapping = self.site_envs[i].get_mapping(new_env)
      # align input to the primitive cell (reference)
      aligned_idx = [
          new_env['env2config'][mapping[i]] for i in range(len(new_env['env2config']))
          new_env['env2config'][mapping[i]]
          for i in range(len(new_env['env2config']))
      ]
      # apply permutations
      nni_perm = np.take(aligned_idx, self.site_envs[i].permutations)
@@ -528,7 +547,7 @@ class _SiteEnvironments(object):

  @classmethod
  def _truncate(cls, env_ref: _SiteEnvironment,
                env: Dict[str, Union[List[Union[int, float, str]], np.ndarray]]) -> Dict[str, Union[List[int], List[str], np.ndarray]]:
                env: Dict[str, Any]) -> Dict[str, Any]:
    """
    When cutoff_factor is used, it will pool more site than cutoff factor specifies.
    This will rule out nonrelevant sites by distance.
@@ -536,7 +555,9 @@ class _SiteEnvironments(object):
    Parameters
    ----------
    env_ref: _SiteEnvironment
    env: _SiteEnvironment
      Site information of the primitive cell
    env: Dict[str, Any]
      Site information of the data point

    Returns
    -------
@@ -560,13 +581,12 @@ class _SiteEnvironments(object):
        for i in range(len(env['sitetypes']))
        if i in siteidx
    ]
    env['env2config']: List[int] = [env['env2config'][i] for i in siteidx]
    env['env2config'] = [env['env2config'][i] for i in siteidx]
    del env['dist']
    return env


def load_primitive_cell(path: str,
                        cutoff: float,
def load_primitive_cell(path: str, cutoff: float,
                        eigen_tol: float = 1e-5) -> _SiteEnvironments:
  """
  This loads the primitive cell, along with all the permutations
@@ -580,7 +600,7 @@ def load_primitive_cell(path: str,
  cutoff : float.
    cutoff distance in angstrom for collecting local
    environment.
  eigen_tol : float
  eigen_tol : float (default)
    tolerance for eigenanalysis of point group analysis in
    pymatgen.

@@ -589,28 +609,28 @@ def load_primitive_cell(path: str,
  SiteEnvironments: _SiteEnvironments
    Instance of the _SiteEnvironments object
  """
  cell, pbc, coord, st, ns, na, aos = input_reader(path, template=True)
  cell, coord, st, aos, pbc, ns, na = input_reader(path, template=True)
  site_envs = _get_SiteEnvironments(
      coord, cell, st, cutoff, pbc, True, eigen_tol=eigen_tol)
  site_envs = [
  site_envs_format = [
      _SiteEnvironment(e['pos'], e['sitetypes'], e['env2config'],
                       e['permutations'], cutoff) for e in site_envs
  ]

  ust = [env.sitetypes[0] for env in site_envs]
  ust = [env.sitetypes[0] for env in site_envs_format]
  usi = np.unique(ust, return_index=True)[1]
  site_envs = [site_envs[i] for i in usi]
  return _SiteEnvironments(site_envs, ns, na, aos, eigen_tol, pbc, cutoff)
  site_envs_format = [site_envs_format[i] for i in usi]
  return _SiteEnvironments(site_envs_format, ns, na, aos, eigen_tol, pbc,
                           cutoff)


def _get_SiteEnvironments(coord: Union[List[np.ndarray], np.ndarray],
                          cell: Union[List[np.ndarray], np.ndarray],
def _get_SiteEnvironments(coord: np.ndarray,
                          cell: np.ndarray,
                          SiteTypes: List[str],
                          cutoff: float,
                          pbc: List[bool],
                          pbc: np.ndarray,
                          get_permutations: bool = True,
                          eigen_tol: float = 1e-5) -> List[Dict[str, Union[List[Union[int, str, float]], np.ndarray]]]:

                          eigen_tol: float = 1e-5) -> List[Dict[str, Any]]:
  """
  Used to extract information about both primitve cells and data points.
  Extract local environments from primitive cell. Using the two diffrent types
@@ -618,7 +638,7 @@ def _get_SiteEnvironments(coord: Union[List[np.ndarray], np.ndarray],

  Parameters
  ----------
  coord : Union[list, np.ndarray]
  coord : np.ndarray
    n x 3 list or numpy array of scaled positions. n is the number
    of atom.
  cell : np.ndarray
@@ -630,11 +650,11 @@ def _get_SiteEnvironments(coord: Union[List[np.ndarray], np.ndarray],
  cutoff : float
    cutoff distance in angstrom for collecting local
    environment.
  pbc : list[str]
  pbc : np.ndarray
    Periodic boundary condition
  get_permutations : bool (default True)
    Whether to find permutatated neighbor list or not.
  eigen_tol : float
  eigen_tol : float (default 1e-5)
    Tolerance for eigenanalysis of point group analysis in
    pymatgen.

+0 −72
Original line number Diff line number Diff line
primitive_cell = """#primitive strucutre
2.81852800e+00  0.00000000e+00  0.00000000e+00 T
-1.40926400e+00  2.44091700e+00  0.00000000e+00 T
0.00000000e+00  0.00000000e+00  2.55082550e+01 F
1 1
1 0 2
6
0.666670000000  0.333330000000  0.090220999986 S1
0.333330000000  0.666670000000  0.180439359180 S1
0.000000000000  0.000000000000  0.270657718374 S1
0.666670000000  0.333330000000  0.360876077568 S1
0.333330000000  0.666670000000  0.451094436762 S1
0.000000000000  0.000000000000  0.496569911270 A1
"""

structure = """-4.22779200e+00 -2.44091700e+00  0.00000000e+00
 2.81852800e+00 -4.88183400e+00  0.00000000e+00
 0.00000000e+00  0.00000000e+00  2.31755900e+01
24
 0.333330000000  0.500000000000  0.099298999982 S1
 0.000000000000  0.500000000000  0.198598000008 S1
 0.166670000000  0.250000000000  0.297897000033 S1
 0.333330000000  0.500000000000  0.397196000059 S1
 0.000000000000  0.500000000000  0.496495000084 S1
 0.333330000000  0.000000000000  0.099298999982 S1
 0.000000000000  0.000000000000  0.198598000008 S1
 0.166670000000  0.750000000000  0.297897000033 S1
 0.333330000000  0.000000000000  0.397196000059 S1
 0.000000000000  0.000000000000  0.496495000084 S1
 0.833330000000  0.250000000000  0.099298999982 S1
 0.500000000000  0.250000000000  0.198598000008 S1
 0.666670000000  0.000000000000  0.297897000033 S1
 0.833330000000  0.250000000000  0.397196000059 S1
 0.500000000000  0.250000000000  0.496495000084 S1
 0.833330000000  0.750000000000  0.099298999982 S1
 0.500000000000  0.750000000000  0.198598000008 S1
 0.666670000000  0.500000000000  0.297897000033 S1
 0.833330000000  0.750000000000  0.397196000059 S1
 0.500000000000  0.750000000000  0.496495000084 S1
 0.666670000000  0.500000000000  0.546547663253 A1 0
 0.666670000000  0.000000000000  0.546547663253 A1 0
 0.166670000000  0.750000000000  0.546547663253 A1 2
 0.166670000000  0.250000000000  0.546547663253 A1 1
"""
check_feature = [[0., 1., 0.], [0., 1., 0.], [0., 0., 1.], [1., 0., 0.]]

check_edges = [
    [[[0, 2, 3, 1, 1, 2, 3, 1, 0, 2, 3, 1, 0, 1, 2, 0, 0, 3,
       1], [0, 3, 2, 1, 1, 3, 2, 1, 0, 3, 2, 1, 0, 1, 3, 0, 0, 2,
            1], [0, 1, 2, 0, 1, 2, 0, 3, 1, 0, 3, 1, 2, 0, 3, 1, 3, 1,
                 2], [0, 1, 3, 0, 1, 3, 0, 2, 1, 0, 2, 1, 3, 0, 2, 1, 2, 1, 3],
      [0, 3, 0, 1, 0, 1, 3, 2, 1, 2, 1, 0, 2, 1, 0, 1, 3, 2,
       3], [0, 2, 0, 1, 0, 1, 2, 3, 1, 3, 1, 0, 3, 1, 0, 1, 2, 3, 2]],
     [[1, 3, 2, 0, 0, 3, 2, 0, 1, 3, 2, 0, 1, 0, 3, 1, 1, 2,
       0], [1, 2, 3, 0, 0, 2, 3, 0, 1, 2, 3, 0, 1, 0, 2, 1, 1, 3, 0],
      [1, 0, 3, 1, 0, 3, 1, 2, 0, 1, 2, 0, 3, 1, 2, 0, 2, 0,
       3], [1, 0, 2, 1, 0, 2, 1, 3, 0, 1, 3, 0, 2, 1, 3, 0, 3, 0,
            2], [1, 2, 1, 0, 1, 0, 2, 3, 0, 3, 0, 1, 3, 0, 1, 0, 2, 3, 2],
      [1, 3, 1, 0, 1, 0, 3, 2, 0, 2, 0, 1, 2, 0, 1, 0, 3, 2, 3]],
     [[2, 1, 0, 3, 3, 1, 0, 3, 2, 1, 0, 3, 2, 3, 1, 2, 2, 0,
       3], [2, 0, 1, 3, 3, 0, 1, 3, 2, 0, 1, 3, 2, 3, 0, 2, 2, 1, 3],
      [2, 3, 1, 2, 3, 1, 2, 0, 3, 2, 0, 3, 1, 2, 0, 3, 0, 3,
       1], [2, 3, 0, 2, 3, 0, 2, 1, 3, 2, 1, 3, 0, 2, 1, 3, 1, 3,
            0], [2, 0, 2, 3, 2, 3, 0, 1, 3, 1, 3, 2, 1, 3, 2, 3, 0, 1, 0],
      [2, 1, 2, 3, 2, 3, 1, 0, 3, 0, 3, 2, 0, 3, 2, 3, 1, 0,
       1]], [[3, 0, 1, 2, 2, 0, 1, 2, 3, 0, 1, 2, 3, 2, 0, 3, 3, 1, 2], [
           3, 1, 0, 2, 2, 1, 0, 2, 3, 1, 0, 2, 3, 2, 1, 3, 3, 0, 2
       ], [3, 2, 0, 3, 2, 0, 3, 1, 2, 3, 1, 2, 0, 3, 1, 2, 1, 2,
           0], [3, 2, 1, 3, 2, 1, 3, 0, 2, 3, 0, 2, 1, 3, 0, 2, 0, 2, 1],
             [3, 1, 3, 2, 3, 2, 1, 0, 2, 0, 2, 3, 0, 2, 3, 2, 1, 0,
              1], [3, 0, 3, 2, 3, 2, 0, 1, 2, 1, 2, 3, 1, 2, 3, 2, 0, 1, 0]]]
]
+46 −0
Original line number Diff line number Diff line
{
  "primitive_cell": "#primitive strucutre\n2.81852800e+00  0.00000000e+00  0.00000000e+00 T\n-1.40926400e+00  2.44091700e+00  0.00000000e+00 T\n0.00000000e+00  0.00000000e+00  2.55082550e+01 F\n1 1\n1 0 2\n6\n0.666670000000  0.333330000000  0.090220999986 S1\n0.333330000000  0.666670000000  0.180439359180 S1\n0.000000000000  0.000000000000  0.270657718374 S1\n0.666670000000  0.333330000000  0.360876077568 S1\n0.333330000000  0.666670000000  0.451094436762 S1\n0.000000000000  0.000000000000  0.496569911270 A1\n",
  "structure": "-4.22779200e+00 -2.44091700e+00  0.00000000e+00\n 2.81852800e+00 -4.88183400e+00  0.00000000e+00\n 0.00000000e+00  0.00000000e+00  2.31755900e+01\n24\n 0.333330000000  0.500000000000  0.099298999982 S1\n 0.000000000000  0.500000000000  0.198598000008 S1\n 0.166670000000  0.250000000000  0.297897000033 S1\n 0.333330000000  0.500000000000  0.397196000059 S1\n 0.000000000000  0.500000000000  0.496495000084 S1\n 0.333330000000  0.000000000000  0.099298999982 S1\n 0.000000000000  0.000000000000  0.198598000008 S1\n 0.166670000000  0.750000000000  0.297897000033 S1\n 0.333330000000  0.000000000000  0.397196000059 S1\n 0.000000000000  0.000000000000  0.496495000084 S1\n 0.833330000000  0.250000000000  0.099298999982 S1\n 0.500000000000  0.250000000000  0.198598000008 S1\n 0.666670000000  0.000000000000  0.297897000033 S1\n 0.833330000000  0.250000000000  0.397196000059 S1\n 0.500000000000  0.250000000000  0.496495000084 S1\n 0.833330000000  0.750000000000  0.099298999982 S1\n 0.500000000000  0.750000000000  0.198598000008 S1\n 0.666670000000  0.500000000000  0.297897000033 S1\n 0.833330000000  0.750000000000  0.397196000059 S1\n 0.500000000000  0.750000000000  0.496495000084 S1\n 0.666670000000  0.500000000000  0.546547663253 A1 0\n 0.666670000000  0.000000000000  0.546547663253 A1 0\n 0.166670000000  0.750000000000  0.546547663253 A1 2\n 0.166670000000  0.250000000000  0.546547663253 A1 1\n",
  "node_feature": [
    [0.0, 1.0, 0.0],
    [0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0],
    [1.0, 0.0, 0.0]
  ],
  "edges": [
    [
      [
        [0, 2, 3, 1, 1, 2, 3, 1, 0, 2, 3, 1, 0, 1, 2, 0, 0, 3, 1],
        [0, 3, 2, 1, 1, 3, 2, 1, 0, 3, 2, 1, 0, 1, 3, 0, 0, 2, 1],
        [0, 1, 2, 0, 1, 2, 0, 3, 1, 0, 3, 1, 2, 0, 3, 1, 3, 1, 2],
        [0, 1, 3, 0, 1, 3, 0, 2, 1, 0, 2, 1, 3, 0, 2, 1, 2, 1, 3],
        [0, 3, 0, 1, 0, 1, 3, 2, 1, 2, 1, 0, 2, 1, 0, 1, 3, 2, 3],
        [0, 2, 0, 1, 0, 1, 2, 3, 1, 3, 1, 0, 3, 1, 0, 1, 2, 3, 2]
      ],
      [
        [1, 3, 2, 0, 0, 3, 2, 0, 1, 3, 2, 0, 1, 0, 3, 1, 1, 2, 0],
        [1, 2, 3, 0, 0, 2, 3, 0, 1, 2, 3, 0, 1, 0, 2, 1, 1, 3, 0],
        [1, 0, 3, 1, 0, 3, 1, 2, 0, 1, 2, 0, 3, 1, 2, 0, 2, 0, 3],
        [1, 0, 2, 1, 0, 2, 1, 3, 0, 1, 3, 0, 2, 1, 3, 0, 3, 0, 2],
        [1, 2, 1, 0, 1, 0, 2, 3, 0, 3, 0, 1, 3, 0, 1, 0, 2, 3, 2],
        [1, 3, 1, 0, 1, 0, 3, 2, 0, 2, 0, 1, 2, 0, 1, 0, 3, 2, 3]
      ],
      [
        [2, 1, 0, 3, 3, 1, 0, 3, 2, 1, 0, 3, 2, 3, 1, 2, 2, 0, 3],
        [2, 0, 1, 3, 3, 0, 1, 3, 2, 0, 1, 3, 2, 3, 0, 2, 2, 1, 3],
        [2, 3, 1, 2, 3, 1, 2, 0, 3, 2, 0, 3, 1, 2, 0, 3, 0, 3, 1],
        [2, 3, 0, 2, 3, 0, 2, 1, 3, 2, 1, 3, 0, 2, 1, 3, 1, 3, 0],
        [2, 0, 2, 3, 2, 3, 0, 1, 3, 1, 3, 2, 1, 3, 2, 3, 0, 1, 0],
        [2, 1, 2, 3, 2, 3, 1, 0, 3, 0, 3, 2, 0, 3, 2, 3, 1, 0, 1]
      ],
      [
        [3, 0, 1, 2, 2, 0, 1, 2, 3, 0, 1, 2, 3, 2, 0, 3, 3, 1, 2],
        [3, 1, 0, 2, 2, 1, 0, 2, 3, 1, 0, 2, 3, 2, 1, 3, 3, 0, 2],
        [3, 2, 0, 3, 2, 0, 3, 1, 2, 3, 1, 2, 0, 3, 1, 2, 1, 2, 0],
        [3, 2, 1, 3, 2, 1, 3, 0, 2, 3, 0, 2, 1, 3, 0, 2, 0, 2, 1],
        [3, 1, 3, 2, 3, 2, 1, 0, 2, 0, 2, 3, 0, 2, 3, 2, 1, 0, 1],
        [3, 0, 3, 2, 3, 2, 0, 1, 2, 1, 2, 3, 1, 2, 3, 2, 0, 1, 0]
      ]
    ]
  ]
}
+13 −7
Original line number Diff line number Diff line
import os
import json
import numpy as np
from deepchem.feat.material_featurizers.lcnn_featurizer import LCNNFeaturizer
from data.lcnn_test_data import primitive_cell, structure, check_edges, check_feature


def test_LCNNFeaturizer():
  featuriser = LCNNFeaturizer(np.around(6.00), primitive_cell)
  data = featuriser._featurize(structure)
  assert np.all(data['X_Sites'] == np.array(check_feature))
  assert np.all(data['X_NSs'] == np.array(check_edges))
  current_dir = os.path.dirname(os.path.realpath(__file__))
  strucutre_file = os.path.join(current_dir,
                                'platinum_absorption_strucutre.json')
  with open(strucutre_file, 'r') as f:
    test_data = json.load(f)
    featuriser = LCNNFeaturizer(np.around(6.00), test_data["primitive_cell"])
    data = featuriser._featurize(test_data["structure"])
    assert np.all(data['X_Sites'] == np.array(test_data["node_feature"]))
    assert np.all(data['X_NSs'] == test_data["edges"])
    assert data['X_Sites'].shape == (4, 3)
    assert data['X_NSs'].shape == (1, 4, 6, 19)