Commit 15e2fdbf authored by pvskand's avatar pvskand
Browse files

removed unet_model and added architecture in the constructor itself

parent b2c82de7
Loading
Loading
Loading
Loading
+14 −17
Original line number Diff line number Diff line
@@ -13,22 +13,20 @@ from deepchem.models.tensorgraph.layers import Conv2D, MaxPool2D, Conv2DTranspos
from deepchem.models import TensorGraph

class UNet(TensorGraph):
# """
# U-Net architecture implementation.
# Parameters
# ----------
# img_rows : int
#  number of rows of the image.
# img_cols: int
#  number of columns of the image
# """
    def __init__(self,img_rows=512, img_cols=512, **kwargs):
    """
        U-Net architecture implementation.
        Parameters
        ----------
        img_rows : int
         number of rows of the image.
        img_cols: int
         number of columns of the image
    """
    def __init__(self,img_rows=512, img_cols=512, model=dc.models.TensorGraph(), **kwargs):
        super(UNet, self).__init__(use_queue=False, **kwargs)
        self.img_cols = img_cols
        self.img_rows = img_rows

    def unet_model(self):
        model = dc.models.TensorGraph()
        self.model = dc.models.TensorGraph()

        input = Feature(shape=(None, self.img_rows, self.img_cols))

@@ -74,4 +72,3 @@ class UNet(TensorGraph):
        conv10 = Conv2D(num_outputs=1, kernel_size=1, activation='sigmoid', in_layers=[conv9])

        model.add_output(conv10)
        return model