Commit 23a48916 authored by Axel Kohlmeyer's avatar Axel Kohlmeyer
Browse files

re-factored balance command now works with group and time weights

(cherry picked from commit 3f674e5062aa8533f91da871f36f7bdcd90845db)
parent 34b34d84
Loading
Loading
Loading
Loading
+67 −169
Original line number Diff line number Diff line
@@ -63,12 +63,6 @@ Balance::Balance(LAMMPS *lmp) : Pointers(lmp)
  nimbalance = 0;
  imbalance = NULL;
  weight = NULL;

  ngroup = 0;
  group_id = NULL;
  group_weight = NULL;

  clock_imbalance = NULL;
}

/* ---------------------------------------------------------------------- */
@@ -100,11 +94,6 @@ Balance::~Balance()
  delete [] imbalance;
  delete [] weight;

#if 1
  delete [] group_id;
  delete [] group_weight;
  delete [] clock_imbalance;
#endif
  if (fp) fclose(fp);
}

@@ -238,37 +227,26 @@ void Balance::command(int narg, char **arg)
      Imbalance *imb;
      int nopt = 0;
      if (strcmp(arg[iarg+1],"group") == 0) {
        imb = new ImbalanceGroup;
        nopt = imb->options(lmp,narg-iarg-1,arg+iarg+1);
        imb = new ImbalanceGroup(lmp);
        nopt = imb->options(narg-iarg,arg+iarg+2);
        imbalance[nimbalance] = imb;
      } else if (strcmp(arg[iarg+1],"time") == 0) {
        imb = new ImbalanceTime;
        nopt = imb->options(lmp,narg-iarg-1,arg+iarg+1);
        imb = new ImbalanceTime(lmp);
        nopt = imb->options(narg-iarg,arg+iarg+2);
        imbalance[nimbalance] = imb;
      } else if (strcmp(arg[iarg+1],"neigh") == 0) {
        imb = new ImbalanceNeigh;
        nopt = imb->options(lmp,narg-iarg-1,arg+iarg+1);
        imb = new ImbalanceNeigh(lmp);
        nopt = imb->options(narg-iarg,arg+iarg+2);
        imbalance[nimbalance] = imb;
      } else if (strcmp(arg[iarg+1],"var") == 0) {
        imb = new ImbalanceVar;
        nopt = imb->options(lmp,narg-iarg-1,arg+iarg+1);
        imb = new ImbalanceVar(lmp);
        nopt = imb->options(narg-iarg,arg+iarg+2);
        imbalance[nimbalance] = imb;
      } else {
        error->all(FLERR,"Unknown balance weight method");
      }
      ++nimbalance;
      iarg += 2+nopt;
#if 1
    } else if (strcmp(arg[iarg],"clock") == 0) {
      if (iarg+2 > narg) error->all(FLERR,"Illegal balance command");
      double factor = force->numeric(FLERR,arg[iarg+1]);
      if (factor < 0.0 || factor > 1.0)
        error->all(FLERR,"Illegal balance command");
      imbalance_clock(factor,0.0);
      iarg += 2;
    } else if (strcmp(arg[iarg],"group") == 0) {
      group_setup(narg-iarg-1,arg+iarg+1);
      iarg += 2*ngroup + 2;
#endif
    } else error->all(FLERR,"Illegal balance command");
  }

@@ -321,7 +299,7 @@ void Balance::command(int narg, char **arg)
  comm->exchange();
  if (domain->triclinic) domain->lamda2x(atom->nlocal);

  // compute and apply imbalance weights
  // compute and apply imbalance weights for local atoms
  if (nimbalance > 0) {
    int i;
    const int nlocal = atom->nlocal;
@@ -329,7 +307,7 @@ void Balance::command(int narg, char **arg)
    for (i = 0; i < nlocal; ++i)
      weight[i] = 1.0;
    for (i = 0; i < nimbalance; ++i)
      imbalance[i]->compute(lmp,weight);
      imbalance[i]->compute(weight);
  } else weight = NULL;

  // imbinit = initial imbalance
@@ -435,35 +413,39 @@ void Balance::command(int narg, char **arg)
    error->all(FLERR,str);
  }

  // recompute and apply imbalance weights for local atoms

  if (nimbalance > 0) {
    int i;
    const int nlocal = atom->nlocal;
    delete[] weight;
    weight = new double[nlocal];
    for (i = 0; i < nlocal; ++i)
      weight[i] = 1.0;
    for (i = 0; i < nimbalance; ++i)
      imbalance[i]->compute(weight);
  } else weight = NULL;

  // imbfinal = final imbalance based on final (weighted) nlocal

  int maxfinal;
  double imbfinal = imbalance_nlocal(maxfinal);

  if (me == 0) {
    double stop_time = MPI_Wtime();
    if (screen) {
      fprintf(screen,"  rebalancing time: %g seconds\n",
              MPI_Wtime()-start_time);
      fprintf(screen,"  rebalancing time: %g seconds\n",stop_time-start_time);
      fprintf(screen,"  iteration count = %d\n",niter);
      if (ngroup > 0) {
        fprintf(screen,"  group weights:");
        for (int i=0; i < ngroup; ++i)
          fprintf(screen," %s=%g", group->names[group_id[i]],group_weight[i]);
        fprintf(screen,"\n");
      }
      for (int i = 0; i < nimbalance; ++i) imbalance[i]->info(screen);
      fprintf(screen,"  initial/final max load/proc = %d %d\n",
              maxinit,maxfinal);
      fprintf(screen,"  initial/final imbalance factor = %g %g\n",
              imbinit,imbfinal);
    }
    if (logfile) {
      fprintf(logfile,"  rebalancing time: %g seconds\n",stop_time-start_time);
      fprintf(logfile,"  iteration count = %d\n",niter);
      if (ngroup > 0) {
        fprintf(logfile,"  group weights:");
        for (int i=0; i < ngroup; ++i)
          fprintf(logfile," %s=%g", group->names[group_id[i]],group_weight[i]);
        fprintf(logfile,"\n");
      }
      for (int i = 0; i < nimbalance; ++i) imbalance[i]->info(logfile);
      fprintf(logfile,"  initial/final max load/proc = %d %d\n",
              maxinit,maxfinal);
      fprintf(logfile,"  initial/final imbalance factor = %g %g\n",
@@ -505,69 +487,6 @@ void Balance::command(int narg, char **arg)
  }
}

 /* ----------------------------------------------------------------------
    compute the computational load associated with an atom
    i = atom index
    return cost = product of group weights for this atom.
------------------------------------------------------------------------- */

double Balance::getcost(int i)
{
   double cost = 1.0;
   for (int j = 0; j < ngroup; ++j) {
     if (atom->mask[i] & group->bitmask[group_id[j]])
       cost *= group_weight[j];
   }
   return cost;
}

/* ----------------------------------------------------------------------
   calculate imbalance based on timers for Pair+Bond+Kspace+Neighbor time.
------------------------------------------------------------------------- */

double Balance::imbalance_clock(double factor, double last_cost)
{

  // Compute the cost function of based on relevant timers
  if (timer->has_normal()) {
    if (!clock_imbalance) clock_imbalance = new double[nprocs+1];

    double cost = -last_cost;
    cost += timer->get_wall(Timer::PAIR);
    cost += timer->get_wall(Timer::NEIGH);
    cost += timer->get_wall(Timer::BOND);
    cost += timer->get_wall(Timer::KSPACE);

    double *clock_cost = new double[nprocs+1];
    for (int i = 0; i <= nprocs; ++i) clock_imbalance[i] = clock_cost[i] = 0.0;
    clock_cost[me] = cost;
    clock_cost[nprocs] = cost;
    MPI_Allreduce(clock_cost,clock_imbalance,nprocs+1,MPI_DOUBLE,MPI_SUM,world);

    const double avg_cost = clock_imbalance[nprocs]/nprocs;
    if (avg_cost > 0.0) {
      for (int i = 0; i < nprocs; ++i)
        clock_imbalance[i] = (1.0-factor) + factor*clock_imbalance[i]/avg_cost;
    } else {
      for (int i = 0; i < nprocs; ++i)
        clock_imbalance[i] = 1.0;
    }

#if BALANCE_DEBUG
    if (me == 0) {
      fprintf(stderr,"Clock imbalance using factor %g\n",factor);
      for (int i = 0; i < nprocs; ++i)
        fprintf(stderr," % 2d: %4.2f",i,clock_imbalance[i]);
      fputs("\n",stderr);
    }
#endif

    delete [] clock_cost;
    return cost + last_cost;
  }
  return last_cost;
}

/* ----------------------------------------------------------------------
   calculate imbalance based on (weighted) local atom counts
   return max = max atom per proc
@@ -579,10 +498,12 @@ double Balance::imbalance_nlocal(int &maxcost)
  // Compute the cost function of local atoms

  double cost = 0.0;
  for (int i=0; i < atom->nlocal; ++i) {
    cost += getcost(i);
  if (weight == NULL) {
    cost = atom->nlocal;
  } else {
    for (int i=0; i < atom->nlocal; ++i)
      cost += weight[i];
  }
  if (clock_imbalance) cost *= clock_imbalance[me];

  int intcost = (int)cost;
  int sumcost = maxcost = 0;
@@ -622,21 +543,23 @@ double Balance::imbalance_splits(int &max)
  int nlocal = atom->nlocal;
  int ix,iy,iz;

  if (weight) {
    for (int i = 0; i < nlocal; i++) {
      ix = binary(x[i][0],nx,xsplit);
      iy = binary(x[i][1],ny,ysplit);
      iz = binary(x[i][2],nz,zsplit);

    proccost[iz*nx*ny + iy*nx + ix] += getcost(i);
      proccost[iz*nx*ny + iy*nx + ix] += weight[i];
    }
  } else {
    for (int i = 0; i < nlocal; i++) {
      ix = binary(x[i][0],nx,xsplit);
      iy = binary(x[i][1],ny,ysplit);
      iz = binary(x[i][2],nz,zsplit);
      proccost[iz*nx*ny + iy*nx + ix] += 1.0;
    }

  for (int i = 0; i < nprocs; i++) {
    if (clock_imbalance)
      proccount[i] = static_cast<int>(proccost[i]*clock_imbalance[i]);
    else
      proccount[i] = static_cast<int>(proccost[i]);
  }

  for (int i = 0; i < nprocs; i++) proccount[i] = (int)(proccost[i]);
  MPI_Allreduce(proccount,allproccount,nprocs,MPI_INT,MPI_SUM,world);
  bigint sum = 0;
  max = 0;
@@ -698,20 +621,10 @@ int *Balance::bisection(int sortflag)

  // invoke RCB
  // then invert() to create list of proc assignements for my atoms
  // Use specified weightings for each atom rather than atom count

#if 1
  double factor = 1.0;
  if (clock_imbalance) factor = clock_imbalance[me];
  // Use compute weights for each atom, if available

  double *weights = new double[nlocal];
  for (int i = 0; i < nlocal; i++)
    weights[i] = getcost(i)*factor;
#endif

  rcb->compute(dim,atom->nlocal,atom->x,weights,shrinklo,shrinkhi);
  rcb->compute(dim,atom->nlocal,atom->x,weight,shrinklo,shrinkhi);
  rcb->invert(sortflag);
  delete[] weights;

  // reset RCB lo/hi bounding box to full simulation box as needed

@@ -820,28 +733,6 @@ void Balance::shift_setup(char *str, int nitermax_in, double thresh_in)
  rho = 1;
}

/* ----------------------------------------------------------------------
   setup group based load balance operations
   called from balance->command() and fix balance
------------------------------------------------------------------------- */
int Balance::group_setup(int narg, char **arg)
{
  if (narg < 3) error->all(FLERR,"Illegal balance command");

  ngroup = force->inumeric(FLERR,arg[0]);
  if (ngroup < 1) error->all(FLERR,"Illegal balance command");
  if (2*ngroup+1 > narg) error->all(FLERR,"Illegal balance command");

  group_id = new int[ngroup];
  group_weight = new double[ngroup];
  for (int i = 0; i < ngroup; ++i) {
    group_id[i] = group->find(arg[2*i+1]);
    if (group_id[i] < 0) error->all(FLERR,"Unknown group in balance command");
    group_weight[i] = force->numeric(FLERR,arg[2*i+2]);
  }
  return ngroup;
}

/* ----------------------------------------------------------------------
   load balance by changing xyz split proc boundaries in Comm
   called one time from input script command or many times from fix balance
@@ -886,10 +777,13 @@ int Balance::shift()
    tally(bdim[idim],np,split);

    double cost = 0.0;
    for (i=0; i < atom->nlocal; i++)
      cost += getcost(i);
    if (weight == NULL) {
      cost = atom->nlocal;
    } else {
      for (int i=0; i < atom->nlocal; ++i)
        cost += weight[i];
    }

    if (clock_imbalance) cost *= clock_imbalance[me];
    int intcost = (int)cost;
    int totalcost;
    MPI_Allreduce(&intcost,&totalcost,1,MPI_INT,MPI_SUM,world);
@@ -1033,12 +927,16 @@ void Balance::tally(int dim, int n, double *split)
  int nlocal = atom->nlocal;
  int index;

  double factor = 1.0;
  if (clock_imbalance) factor = clock_imbalance[me];

  if (weight) {
    for (int i = 0; i < nlocal; i++) {
      index = binary(x[i][dim],n,split);
    onecost[index] += getcost(i)*factor;
      onecost[index] += weight[i];
    }
  } else {
    for (int i = 0; i < nlocal; i++) {
      index = binary(x[i][dim],n,split);
      onecost[index] += 1.0;
    }
  }

  for (int i = 0; i < n; i++) onecount[i] = static_cast<bigint>(onecost[i]);
+0 −11
Original line number Diff line number Diff line
@@ -32,12 +32,10 @@ class Balance : protected Pointers {
  Balance(class LAMMPS *);
  ~Balance();
  void command(int, char **);
  int group_setup(int, char **);
  void shift_setup(char *, int, double);
  int shift();
  int *bisection(int sortflag = 0);
  double imbalance_nlocal(int &);
  double imbalance_clock(double, double);
  void dumpout(bigint, FILE *);

 private:
@@ -71,14 +69,6 @@ class Balance : protected Pointers {
  class Imbalance **imbalance; // list of imbalance compute classes
  double *weight;            // per (local) atom weight factor or NULL

#if 1
  int    ngroup;             // number of groups weights
  int    *group_id;          // group ids for weights
  double *group_weight;      // weights of groups

  double *clock_imbalance;   // computed wall clock imbalance, NULL if not available
#endif

  int outflag;               // for output of balance results to file
  FILE *fp;
  int firststep;
@@ -88,7 +78,6 @@ class Balance : protected Pointers {
  void tally(int, int, double *);
  int adjust(int, double *);
  int binary(double, int, double *);
  double getcost(int);
#ifdef BALANCE_DEBUG
  void debug_shift_output(int, int, int, double *);
#endif
+4 −1
Original line number Diff line number Diff line
@@ -97,9 +97,11 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
      if (clock_factor < 0.0 || clock_factor > 1.0)
        error->all(FLERR,"Illegal fix balance command");
      iarg += 2;
#if 0
    } else if (strcmp(arg[iarg],"group") == 0) {
      int ngroup = balance->group_setup(narg-iarg-1,arg+iarg+1);
      iarg += 2 + 2*ngroup;
#endif
    } else error->all(FLERR,"Illegal fix balance command");
  }

@@ -229,10 +231,11 @@ void FixBalance::pre_exchange()
  if (domain->triclinic) domain->lamda2x(atom->nlocal);

  // return if imbalance < threshhold

#if 0
  if (clock_factor > 0.0)
    last_clock = balance->imbalance_clock(clock_factor,last_clock);
  imbnow = balance->imbalance_nlocal(maxperproc);
#endif
  if (imbnow <= thresh) {
    if (nevery) next_reneighbor = (update->ntimestep/nevery)*nevery + nevery;
    return;
+18 −7
Original line number Diff line number Diff line
@@ -14,25 +14,36 @@
#ifndef LMP_IMBALANCE_H
#define LMP_IMBALANCE_H

#include <stdio.h>

namespace LAMMPS_NS {
 class LAMMPS;

class Imbalance {
 public:
  Imbalance() {};
  Imbalance(LAMMPS *lmp) : _lmp(lmp) {};
  virtual ~Imbalance() {};

  // disallow copy constructor and assignment operator
  // disallow default and copy constructor, assignment operator
 private:
  Imbalance() {};
  Imbalance(const Imbalance &) {};
  Imbalance &operator=(const Imbalance &) {return *this;};

  // required member functions
  // internal use only data members
 protected:
  LAMMPS *_lmp;

  // public API
 public:
  // parse options. return number of arguments consumed.
  virtual int options(LAMMPS *lmp, int narg, char **arg) = 0;
  // compute and apply weigh factors to local atom array
  virtual void compute(LAMMPS *lmp, double *weights) = 0;
  // parse options. return number of arguments consumed. (required)
  virtual int options(int narg, char **arg) = 0;
  // reinitialize internal data (needed for fix balance) (optional)
  virtual void init() {};
  // compute and apply weight factors to local atom array (required)
  virtual void compute(double *weights) = 0;
  // print information about the state of this imbalance compute (required)
  virtual void info(FILE *fp) = 0;
};

}
+26 −10
Original line number Diff line number Diff line
@@ -21,11 +21,11 @@

using namespace LAMMPS_NS;

int ImbalanceGroup::options(LAMMPS *lmp, int narg, char **arg)
int ImbalanceGroup::options(int narg, char **arg)
{
  Error *error = lmp->error;
  Force *force = lmp->force;
  Group *group = lmp->group;
  Error *error = _lmp->error;
  Force *force = _lmp->force;
  Group *group = _lmp->group;

  if (narg < 3) error->all(FLERR,"Illegal balance weight command");

@@ -41,14 +41,16 @@ int ImbalanceGroup::options(LAMMPS *lmp, int narg, char **arg)
      error->all(FLERR,"Unknown group in balance weight command");
    _factor[i] = force->numeric(FLERR,arg[2*i+2]);
  }
  return _num;
  return 2*_num+1;
}

void ImbalanceGroup::compute(LAMMPS *lmp, double *weight)
/* -------------------------------------------------------------------- */

void ImbalanceGroup::compute(double *weight)
{
  const int * const mask = lmp->atom->mask;
  const int * const bitmask = lmp->group->bitmask;
  const int nlocal = lmp->atom->nlocal;
  const int * const mask = _lmp->atom->mask;
  const int * const bitmask = _lmp->group->bitmask;
  const int nlocal = _lmp->atom->nlocal;

  if (_num == 0) return;

@@ -62,3 +64,17 @@ void ImbalanceGroup::compute(LAMMPS *lmp, double *weight)
    weight[i] = iweight;
  }
}

/* -------------------------------------------------------------------- */

void ImbalanceGroup::info(FILE *fp)
{
  if (_num > 0) {
    const char * const * const names = _lmp->group->names;

    fprintf(fp,"  group weights:");
    for (int i = 0; i < _num; ++i)
      fprintf(fp," %s=%g",names[_id[i]],_factor[i]);
    fputs("\n",fp);
  }
}
Loading