Commit 796c846c authored by nd-02110114's avatar nd-02110114
Browse files

🐛 fix load material dataset function

parent e525f10e
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -44,11 +44,11 @@ class SineCoulombMatrix(MaterialStructureFeaturizer):
  This class requires matminer and Pymatgen to be installed.
  """

  def __init__(self, max_atoms: int, flatten: bool = True):
  def __init__(self, max_atoms: int = 100, flatten: bool = True):
    """
    Parameters
    ----------
    max_atoms: int
    max_atoms: int (deafult 100)
      Maximum number of atoms for any crystal in the dataset. Used to
      pad the Coulomb matrix.
    flatten: bool (default True)
@@ -86,8 +86,8 @@ class SineCoulombMatrix(MaterialStructureFeaturizer):

    if self.flatten:
      eigs, _ = np.linalg.eig(sine_mat)
      zeros = np.zeros((1, self.max_atoms))
      zeros[:len(eigs)] = eigs
      zeros = np.zeros(self.max_atoms)
      zeros[:len(eigs[0])] = eigs[0]
      features = zeros
    else:
      features = pad_array(sine_mat, self.max_atoms)
+2 −1
Original line number Diff line number Diff line
@@ -63,10 +63,11 @@ class TestMaterialFeaturizers(unittest.TestCase):
    Test SCM featurizer.
    """

    featurizer = SineCoulombMatrix(max_atoms=1)
    featurizer = SineCoulombMatrix(max_atoms=3)
    features = featurizer.featurize([self.struct_dict])

    assert len(features) == 1
    assert features.shape == (1, 3)
    assert np.isclose(features[0], 1244, atol=.5)

  def test_cgcnn_featurizer(self):
+17 −16
Original line number Diff line number Diff line
@@ -3,19 +3,18 @@ Experimental bandgaps for inorganic crystals.
"""
import os
import logging

import deepchem
from deepchem.feat import Featurizer, MaterialStructureFeaturizer, MaterialCompositionFeaturizer
from deepchem.trans import Transformer
from deepchem.feat import MaterialCompositionFeaturizer
from deepchem.splits.splitters import Splitter
from deepchem.molnet.defaults import get_defaults

from typing import List, Tuple, Dict, Optional, Union, Any, Type
from typing import List, Tuple, Dict, Optional, Any

logger = logging.getLogger(__name__)

# TODO: Change URLs
DEFAULT_DIR = deepchem.utils.get_data_dir()
BANDGAP_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/expt_gap.tar.gz'
BANDGAP_URL = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/expt_gap.tar.gz'

# dict of accepted featurizers for this dataset
# modify the returned dicts for your dataset
@@ -106,9 +105,10 @@ def load_bandgap(

  References
  ----------
  .. [1] Zhuo, Y. et al. "Predicting the Band Gaps of Inorganic Solids by Machine Learning." J. Phys. Chem. Lett. (2018) DOI: 10.1021/acs.jpclett.8b00124.

  .. [2] Dunn, A. et al. "Benchmarking Materials Property Prediction Methods: The Matbench Test Set and Automatminer Reference Algorithm." https://arxiv.org/abs/2005.00707 (2020)
  .. [1] Zhuo, Y. et al. "Predicting the Band Gaps of Inorganic Solids by Machine Learning."
     J. Phys. Chem. Lett. (2018) DOI: 10.1021/acs.jpclett.8b00124.
  .. [2] Dunn, A. et al. "Benchmarking Materials Property Prediction Methods: The Matbench Test Set
     and Automatminer Reference Algorithm." https://arxiv.org/abs/2005.00707 (2020)

  Examples
  --------
@@ -159,12 +159,13 @@ def load_bandgap(

  # Load .tar.gz file
  if featurizer.__class__.__name__ in supported_featurizers:
    dataset_file = os.path.join(data_dir, 'expt_gap.tar.gz')
    deepchem.utils.untargz_file(dataset_file, dest_dir=data_dir)
    dataset_file = os.path.join(data_dir, 'expt_gap.json')

    if not os.path.exists(dataset_file):
      targz_file = os.path.join(data_dir, 'expt_gap.tar.gz')
      if not os.path.exists(targz_file):
        deepchem.utils.download_url(url=BANDGAP_URL, dest_dir=data_dir)

      deepchem.utils.untargz_file(
          os.path.join(data_dir, 'expt_gap.tar.gz'), data_dir)

+14 −13
Original line number Diff line number Diff line
@@ -4,18 +4,16 @@ Perovskite crystal structures and formation energies.
import os
import logging
import deepchem
from deepchem.feat import Featurizer, MaterialStructureFeaturizer, MaterialCompositionFeaturizer
from deepchem.trans import Transformer
from deepchem.feat import MaterialStructureFeaturizer
from deepchem.splits.splitters import Splitter
from deepchem.molnet.defaults import get_defaults

from typing import List, Tuple, Dict, Optional, Union, Any, Type, Callable
from typing import List, Tuple, Dict, Optional, Any

logger = logging.getLogger(__name__)

# TODO: Change URLs
DEFAULT_DIR = deepchem.utils.get_data_dir()
PEROVSKITE_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/perovskite.tar.gz'
PEROVSKITE_URL = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/perovskite.tar.gz'

# dict of accepted featurizers for this dataset
# modify the returned dicts for your dataset
@@ -104,9 +102,11 @@ def load_perovskite(

  References
  ----------
  .. [1] Castelli, I. et al. "New cubic perovskites for one- and two-photon water splitting using the computational materials repository." Energy Environ. Sci., (2012), 5, 9034-9043 DOI: 10.1039/C2EE22341D.

  .. [2] Dunn, A. et al. "Benchmarking Materials Property Prediction Methods: The Matbench Test Set and Automatminer Reference Algorithm." https://arxiv.org/abs/2005.00707 (2020)
  .. [1] Castelli, I. et al. "New cubic perovskites for one- and two-photon water splitting
     using the computational materials repository." Energy Environ. Sci., (2012), 5,
     9034-9043 DOI: 10.1039/C2EE22341D.
  .. [2] Dunn, A. et al. "Benchmarking Materials Property Prediction Methods:
     The Matbench Test Set and Automatminer Reference Algorithm." https://arxiv.org/abs/2005.00707 (2020)

  Examples
  --------
@@ -157,12 +157,13 @@ def load_perovskite(

  # Load .tar.gz file
  if featurizer.__class__.__name__ in supported_featurizers:
    dataset_file = os.path.join(data_dir, 'perovskite.tar.gz')
    deepchem.utils.untargz_file(dataset_file, dest_dir=data_dir)
    dataset_file = os.path.join(data_dir, 'perovskite.json')

    if not os.path.exists(dataset_file):
      targz_file = os.path.join(data_dir, 'perovskite.tar.gz')
      if not os.path.exists(targz_file):
        deepchem.utils.download_url(url=PEROVSKITE_URL, dest_dir=data_dir)

      deepchem.utils.untargz_file(
          os.path.join(data_dir, 'perovskite.tar.gz'), data_dir)

+0 −3
Original line number Diff line number Diff line
@@ -3,10 +3,7 @@ Tests for bandgap loader.
"""

import os
import tempfile
import shutil
import numpy as np
import deepchem as dc
from deepchem.molnet import load_bandgap


Loading