Commit 64a57200 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Cleanup

parent 46648977
Loading
Loading
Loading
Loading
+13 −17
Original line number Diff line number Diff line
@@ -66,23 +66,19 @@ class MetaLearner(object):
class MAML(object):
  """Implements the Model-Agnostic Meta-Learning algorithm for low data learning.

  The algorithm is described in Finn et al., "Model-Agnostic
  Meta-Learning for Fast Adaptation of Deep Networks"
  (https://arxiv.org/abs/1703.03400).  It is used for training
  models that can perform a variety of tasks, depending on what
  data they are trained on.  It assumes you have training data
  for many tasks, but only a small amount for each one.  It
  performs "meta-learning" by looping over tasks and trying to
  minimize the loss on each one *after* one or a few steps of
  gradient descent.  That is, it does not try to create a model
  that can directly solve the tasks, but rather tries to create
  a model that is very easy to train.

  To use this class, create a subclass of MetaLearner that
  encapsulates the model and data for your learning problem.
  Pass it to a MAML object and call fit().  You can then use
  train_on_current_task() to fine tune the model for a
  particular task.
  The algorithm is described in Finn et al., "Model-Agnostic Meta-Learning for Fast
  Adaptation of Deep Networks" (https://arxiv.org/abs/1703.03400).  It is used for
  training models that can perform a variety of tasks, depending on what data they
  are trained on.  It assumes you have training data for many tasks, but only a small
  amount for each one.  It performs "meta-learning" by looping over tasks and trying
  to minimize the loss on each one *after* one or a few steps of gradient descent.
  That is, it does not try to create a model that can directly solve the tasks, but
  rather tries to create a model that is very easy to train.

  To use this class, create a subclass of MetaLearner that encapsulates the model
  and data for your learning problem.  Pass it to a MAML object and call fit().
  You can then use train_on_current_task() to fine tune the model for a particular
  task.
  """

  def __init__(self,