"""
PCBA dataset loader.
"""
import os
import deepchem as dc
from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader
from deepchem.data import Dataset
from typing import List, Optional, Tuple, Union

PCBA_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/%s"
PCBA_TASKS = [
    'PCBA-1030', 'PCBA-1379', 'PCBA-1452', 'PCBA-1454', 'PCBA-1457',
    'PCBA-1458', 'PCBA-1460', 'PCBA-1461', 'PCBA-1468', 'PCBA-1469',
    'PCBA-1471', 'PCBA-1479', 'PCBA-1631', 'PCBA-1634', 'PCBA-1688',
    'PCBA-1721', 'PCBA-2100', 'PCBA-2101', 'PCBA-2147', 'PCBA-2242',
    'PCBA-2326', 'PCBA-2451', 'PCBA-2517', 'PCBA-2528', 'PCBA-2546',
    'PCBA-2549', 'PCBA-2551', 'PCBA-2662', 'PCBA-2675', 'PCBA-2676', 'PCBA-411',
    'PCBA-463254', 'PCBA-485281', 'PCBA-485290', 'PCBA-485294', 'PCBA-485297',
    'PCBA-485313', 'PCBA-485314', 'PCBA-485341', 'PCBA-485349', 'PCBA-485353',
    'PCBA-485360', 'PCBA-485364', 'PCBA-485367', 'PCBA-492947', 'PCBA-493208',
    'PCBA-504327', 'PCBA-504332', 'PCBA-504333', 'PCBA-504339', 'PCBA-504444',
    'PCBA-504466', 'PCBA-504467', 'PCBA-504706', 'PCBA-504842', 'PCBA-504845',
    'PCBA-504847', 'PCBA-504891', 'PCBA-540276', 'PCBA-540317', 'PCBA-588342',
    'PCBA-588453', 'PCBA-588456', 'PCBA-588579', 'PCBA-588590', 'PCBA-588591',
    'PCBA-588795', 'PCBA-588855', 'PCBA-602179', 'PCBA-602233', 'PCBA-602310',
    'PCBA-602313', 'PCBA-602332', 'PCBA-624170', 'PCBA-624171', 'PCBA-624173',
    'PCBA-624202', 'PCBA-624246', 'PCBA-624287', 'PCBA-624288', 'PCBA-624291',
    'PCBA-624296', 'PCBA-624297', 'PCBA-624417', 'PCBA-651635', 'PCBA-651644',
    'PCBA-651768', 'PCBA-651965', 'PCBA-652025', 'PCBA-652104', 'PCBA-652105',
    'PCBA-652106', 'PCBA-686970', 'PCBA-686978', 'PCBA-686979', 'PCBA-720504',
    'PCBA-720532', 'PCBA-720542', 'PCBA-720551', 'PCBA-720553', 'PCBA-720579',
    'PCBA-720580', 'PCBA-720707', 'PCBA-720708', 'PCBA-720709', 'PCBA-720711',
    'PCBA-743255', 'PCBA-743266', 'PCBA-875', 'PCBA-881', 'PCBA-883',
    'PCBA-884', 'PCBA-885', 'PCBA-887', 'PCBA-891', 'PCBA-899', 'PCBA-902',
    'PCBA-903', 'PCBA-904', 'PCBA-912', 'PCBA-914', 'PCBA-915', 'PCBA-924',
    'PCBA-925', 'PCBA-926', 'PCBA-927', 'PCBA-938', 'PCBA-995'
]


class _PCBALoader(_MolnetLoader):

  def __init__(self, assay_file_name: str,
               featurizer: Union[dc.feat.Featurizer, str],
               splitter: Union[dc.splits.Splitter, str, None],
               transformer_generators: List[Union[TransformerGenerator, str]],
               tasks: List[str], data_dir: Optional[str],
               save_dir: Optional[str], **kwargs):
    super(_PCBALoader, self).__init__(
        featurizer, splitter, transformer_generators, tasks, data_dir, save_dir)
    self.assay_file_name = assay_file_name

  def create_dataset(self) -> Dataset:
    dataset_file = os.path.join(self.data_dir, self.assay_file_name)
    if not os.path.exists(dataset_file):
      dc.utils.data_utils.download_url(
          url=PCBA_URL % self.assay_file_name, dest_dir=self.data_dir)
    loader = dc.data.CSVLoader(
        tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer)
    return loader.create_dataset(dataset_file)


def load_pcba(
    featurizer: Union[dc.feat.Featurizer, str] = 'ECFP',
    splitter: Union[dc.splits.Splitter, str, None] = 'scaffold',
    transformers: List[Union[TransformerGenerator, str]] = ['balancing'],
    reload: bool = True,
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    **kwargs
) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
  """Load PCBA dataset

  PubChem BioAssay (PCBA) is a database consisting of biological activities of
  small molecules generated by high-throughput screening. We use a subset of
  PCBA, containing 128 bioassays measured over 400 thousand compounds,
  used by previous work to benchmark machine learning methods.

  Random splitting is recommended for this dataset.

  The raw data csv file contains columns below:

  - "mol_id" - PubChem CID of the compound
  - "smiles" - SMILES representation of the molecular structure
  - "PCBA-XXX" - Measured results (Active/Inactive) for bioassays:
        search for the assay ID at
        https://pubchem.ncbi.nlm.nih.gov/search/#collection=bioassays
        for details

  Parameters
  ----------
  featurizer: Featurizer or str
    the featurizer to use for processing the data.  Alternatively you can pass
    one of the names from dc.molnet.featurizers as a shortcut.
  splitter: Splitter or str
    the splitter to use for splitting the data into training, validation, and
    test sets.  Alternatively you can pass one of the names from
    dc.molnet.splitters as a shortcut.  If this is None, all the data
    will be included in a single dataset.
  transformers: list of TransformerGenerators or strings
    the Transformers to apply to the data.  Each one is specified by a
    TransformerGenerator or, as a shortcut, one of the names from
    dc.molnet.transformers.
  reload: bool
    if True, the first call for a particular featurizer and splitter will cache
    the datasets to disk, and subsequent calls will reload the cached datasets.
  data_dir: str
    a directory to save the raw data in
  save_dir: str
    a directory to save the dataset in

  References
  ----------
  .. [1] Wang, Yanli, et al. "PubChem's BioAssay database."
     Nucleic acids research 40.D1 (2011): D400-D412.
  """
  loader = _PCBALoader('pcba.csv.gz', featurizer, splitter, transformers,
                       PCBA_TASKS, data_dir, save_dir, **kwargs)
  return loader.load_dataset('pcba', reload)


# def load_pcba_146(featurizer='ECFP',
#                   split='random',
#                   reload=True,
#                   data_dir=None,
#                   save_dir=None,
#                   **kwargs):
#   return load_pcba_dataset(
#       featurizer=featurizer,
#       split=split,
#       reload=reload,
#       assay_file_name="pcba_146.csv.gz",
#       data_dir=data_dir,
#       save_dir=save_dir,
#       **kwargs)

# def load_pcba_2475(featurizer='ECFP',
#                    split='random',
#                    reload=True,
#                    data_dir=None,
#                    save_dir=None,
#                    **kwargs):
#   return load_pcba_dataset(
#       featurizer=featurizer,
#       split=split,
#       reload=reload,
#       assay_file_name="pcba_2475.csv.gz",
#       data_dir=data_dir,
#       save_dir=save_dir,
#       **kwargs)
