Unverified Commit 9de39778 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2466 from PascalIversen/master

Add img_size parameter to ChemCeption
parents 2e4b30fc 3ca99af3
Loading
Loading
Loading
Loading
+16 −15
Original line number Diff line number Diff line
@@ -8,18 +8,16 @@ __license__ = "MIT"

import numpy as np
import tensorflow as tf
import os
import sys
import logging

from typing import Dict
from deepchem.data.datasets import pad_batch
from deepchem.models import KerasModel, layers
from deepchem.models import KerasModel
from deepchem.models.losses import L2Loss, SoftmaxCrossEntropy, SigmoidCrossEntropy
from deepchem.metrics import to_one_hot
from deepchem.models import chemnet_layers
from tensorflow.keras.layers import Input, Dense, Reshape, Softmax, Activation
from tensorflow.keras.layers import Dropout, Conv1D, Concatenate, Lambda, GRU, LSTM, Bidirectional
from tensorflow.keras.layers import Conv2D, ReLU, Add, GlobalAveragePooling2D
from tensorflow.keras.layers import Conv1D, GRU, LSTM, Bidirectional
from tensorflow.keras.layers import GlobalAveragePooling2D

DEFAULT_INCEPTION_BLOCKS = {"A": 3, "B": 3, "C": 3}

@@ -217,19 +215,22 @@ class ChemCeption(KerasModel):
  """

  def __init__(self,
               img_spec="std",
               base_filters=16,
               inception_blocks=DEFAULT_INCEPTION_BLOCKS,
               n_tasks=10,
               n_classes=2,
               augment=False,
               mode="regression",
               img_spec: str = "std",
               img_size: int = 80,
               base_filters: int = 16,
               inception_blocks: Dict = DEFAULT_INCEPTION_BLOCKS,
               n_tasks: int = 10,
               n_classes: int = 2,
               augment: bool = False,
               mode: str = "regression",
               **kwargs):
    """
    Parameters
    ----------
    img_spec: str, default std
        Image specification used
    img_size: int, default 80
        Image size used
    base_filters: int, default 16
        Base filters used for the different inception and reduction layers
    inception_blocks: dict,
@@ -244,9 +245,9 @@ class ChemCeption(KerasModel):
        Whether the model is used for regression or classification
    """
    if img_spec == "engd":
      self.input_shape = (80, 80, 4)
      self.input_shape = (img_size, img_size, 4)
    elif img_spec == "std":
      self.input_shape = (80, 80, 1)
      self.input_shape = (img_size, img_size, 1)
    self.base_filters = base_filters
    self.inception_blocks = inception_blocks
    self.n_tasks = n_tasks