Commit 3dac80f3 authored by Bharat123rox's avatar Bharat123rox
Browse files

Refactor code in contrib/rl/ folder

parent 4efb180c
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -5,7 +5,10 @@ from deepchem.models.optimizers import Adam
from deepchem.models.tensorgraph.layers import Feature, Weights, Label, Layer
import numpy as np
import tensorflow as tf
import collections
try:
  from collections.abc import Sequence as SequenceCollection
except:
  from collections import Sequence as SequenceCollection
import copy
import time

@@ -109,7 +112,7 @@ class MCTS(object):
    self.n_search_episodes = n_search_episodes
    self.discount_factor = discount_factor
    self.value_weight = value_weight
    self._state_is_list = isinstance(env.state_shape[0], collections.Sequence)
    self._state_is_list = isinstance(env.state_shape[0], SequenceCollection)
    if optimizer is None:
      self._optimizer = Adam(learning_rate=0.001, beta1=0.9, beta2=0.999)
    else: