Unverified Commit 3ca99af3 authored by Pascal's avatar Pascal Committed by GitHub
Browse files

add img_size parameter to ChemCeption model

parent e422b9b8
Loading
Loading
Loading
Loading
+16 −15
Original line number Diff line number Diff line
@@ -8,18 +8,16 @@ __license__ = "MIT"

import numpy as np
import tensorflow as tf
import os
import sys
import logging

from typing import Dict
from deepchem.data.datasets import pad_batch
from deepchem.models import KerasModel, layers
from deepchem.models import KerasModel
from deepchem.models.losses import L2Loss, SoftmaxCrossEntropy, SigmoidCrossEntropy
from deepchem.metrics import to_one_hot
from deepchem.models import chemnet_layers
from tensorflow.keras.layers import Input, Dense, Reshape, Softmax, Activation
from tensorflow.keras.layers import Dropout, Conv1D, Concatenate, Lambda, GRU, LSTM, Bidirectional
from tensorflow.keras.layers import Conv2D, ReLU, Add, GlobalAveragePooling2D
from tensorflow.keras.layers import Conv1D, GRU, LSTM, Bidirectional
from tensorflow.keras.layers import GlobalAveragePooling2D

DEFAULT_INCEPTION_BLOCKS = {"A": 3, "B": 3, "C": 3}

@@ -217,19 +215,22 @@ class ChemCeption(KerasModel):
  """

  def __init__(self,
               img_spec="std",
               base_filters=16,
               inception_blocks=DEFAULT_INCEPTION_BLOCKS,
               n_tasks=10,
               n_classes=2,
               augment=False,
               mode="regression",
               img_spec: str = "std",
               img_size: int = 80,
               base_filters: int = 16,
               inception_blocks: Dict = DEFAULT_INCEPTION_BLOCKS,
               n_tasks: int = 10,
               n_classes: int = 2,
               augment: bool = False,
               mode: str = "regression",
               **kwargs):
    """
    Parameters
    ----------
    img_spec: str, default std
        Image specification used
    img_size: int, default 80
        Image size used
    base_filters: int, default 16
        Base filters used for the different inception and reduction layers
    inception_blocks: dict,
@@ -244,9 +245,9 @@ class ChemCeption(KerasModel):
        Whether the model is used for regression or classification
    """
    if img_spec == "engd":
      self.input_shape = (80, 80, 4)
      self.input_shape = (img_size, img_size, 4)
    elif img_spec == "std":
      self.input_shape = (80, 80, 1)
      self.input_shape = (img_size, img_size, 1)
    self.base_filters = base_filters
    self.inception_blocks = inception_blocks
    self.n_tasks = n_tasks