Commit a9ca1f5c authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent a3b4a3aa
Loading
Loading
Loading
Loading
+14 −5
Original line number Diff line number Diff line
@@ -53,10 +53,12 @@ class WeaveModel(KerasModel):

  def __init__(self,
               n_tasks: int,
               n_atom_feat: int = 75,
               n_pair_feat: int = 14,
               n_hidden: int = 50,
               n_graph_feat: int = 128,
               n_atom_feat: OneOrMany[int] = 75,
               n_pair_feat: OneOrMany[int] = 14,
               n_hidden: OneOrMany[int] = 50,
               n_graph_feat: OneOrMany[int] = 128,
               n_weave: int = 2,
               fully_connected_layer_sizes: List[int] = [2000, 1000],
               mode: str = "classification",
               n_classes: int = 2,
               batch_size: int = 100,
@@ -74,6 +76,8 @@ class WeaveModel(KerasModel):
      Number of units(convolution depths) in corresponding hidden layer
    n_graph_feat: int, optional
      Number of output features for each molecule(graph)
    n_weave: int, optional
      The number of weave layers in this model.
    mode: str
      Either "classification" or "regression" for type of model.
    n_classes: int
@@ -81,7 +85,12 @@ class WeaveModel(KerasModel):
    """
    if mode not in ['classification', 'regression']:
      raise ValueError("mode must be either 'classification' or 'regression'")
    self.n_tasks = n_tasks

    if not isinstance(n_atom_feat, collections.Sequence):
      n_atom_feat = [n_atom_feat] * n_weave
    if not isinstance(n_pair_feat, collections.Sequence):
      n_pair_feat = [n_pair_feat] * n_weave

    self.n_atom_feat = n_atom_feat
    self.n_pair_feat = n_pair_feat
    self.n_hidden = n_hidden