Commit 38c2207a authored by Bharat123rox's avatar Bharat123rox
Browse files

Refactor code in rl/ folder

parent 3f42da48
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -55,8 +55,11 @@ class Environment(object):
    if state_dtype is None:
      # Assume all arrays are float32.
      import numpy
      import collections
      if isinstance(state_shape[0], collections.Sequence):
      try:
        from collections.abc import Sequence as SequenceCollection
      except:
        from collections import Sequence as SequenceCollection
      if isinstance(state_shape[0], SequenceCollection):
        self._state_dtype = [numpy.float32] * len(state_shape)
      else:
        self._state_dtype = numpy.float32
+5 −3
Original line number Diff line number Diff line
"""Advantage Actor-Critic (A2C) algorithm for reinforcement learning."""
import time
import collections

try:
  from collections.abc import Sequence as SequenceCollection
except:
  from collections import Sequence as SequenceCollection
import numpy as np
import tensorflow as tf

@@ -171,7 +173,7 @@ class A2C(object):
    self.value_weight = value_weight
    self.entropy_weight = entropy_weight
    self.use_hindsight = use_hindsight
    self._state_is_list = isinstance(env.state_shape[0], collections.Sequence)
    self._state_is_list = isinstance(env.state_shape[0], SequenceCollection)
    if optimizer is None:
      self._optimizer = Adam(learning_rate=0.001, beta1=0.9, beta2=0.999)
    else:
+5 −2
Original line number Diff line number Diff line
"""Proximal Policy Optimization (PPO) algorithm for reinforcement learning."""
import copy
import time
import collections
try:
  from collections.abc import Sequence as SequenceCollection
except:
  from collections import Sequence as SequenceCollection
from multiprocessing.dummy import Pool

import numpy as np
@@ -149,7 +152,7 @@ class PPO(object):
    self.value_weight = value_weight
    self.entropy_weight = entropy_weight
    self.use_hindsight = use_hindsight
    self._state_is_list = isinstance(env.state_shape[0], collections.Sequence)
    self._state_is_list = isinstance(env.state_shape[0], SequenceCollection)
    if optimizer is None:
      self._optimizer = Adam(learning_rate=0.001, beta1=0.9, beta2=0.999)
    else: