Commit 66b4c9b8 authored by Axel Kohlmeyer's avatar Axel Kohlmeyer
Browse files

implement modified version of balance and fix balance according to steve's...

implement modified version of balance and fix balance according to steve's suggestions and requirements

(cherry picked from commit 5a81288329e1f7d34f86c7e6d79082ba8b074516)
parent 85f58624
Loading
Loading
Loading
Loading
+78 −32
Original line number Diff line number Diff line
@@ -53,6 +53,10 @@ Balance::Balance(LAMMPS *lmp) : Pointers(lmp)

  fp = NULL;
  firststep = 1;

  ngroup = 0;
  group_id = NULL;
  group_weight = NULL;
}

/* ---------------------------------------------------------------------- */
@@ -80,6 +84,9 @@ Balance::~Balance()

  delete rcb;

  delete [] group_id;
  delete [] group_weight;

  if (fp) fclose(fp);
}

@@ -202,6 +209,9 @@ void Balance::command(int narg, char **arg)
        if (fp == NULL) error->one(FLERR,"Cannot open balance output file");
      }
      iarg += 2;
    } else if (strcmp(arg[iarg],"group") == 0) {
      group_setup(narg-iarg-1,arg+iarg+1);
      iarg += 2*ngroup + 2;
    } else error->all(FLERR,"Illegal balance command");
  }

@@ -365,6 +375,12 @@ void Balance::command(int narg, char **arg)
  if (me == 0) {
    if (screen) {
      fprintf(screen,"  iteration count = %d\n",niter);
      if (ngroup > 0) {
        fprintf(screen,"  group weights:");
        for (int i=0; i < ngroup; ++i)
          fprintf(screen," %s=%g", group->names[group_id[i]],group_weight[i]);
        fprintf(screen,"\n");
      }
      fprintf(screen,"  initial/final max load/proc = %d %d\n",
              maxinit,maxfinal);
      fprintf(screen,"  initial/final imbalance factor = %g %g\n",
@@ -372,6 +388,12 @@ void Balance::command(int narg, char **arg)
    }
    if (logfile) {
      fprintf(logfile,"  iteration count = %d\n",niter);
      if (ngroup > 0) {
        fprintf(logfile,"  group weights:");
        for (int i=0; i < ngroup; ++i)
          fprintf(logfile," %s=%g", group->names[group_id[i]],group_weight[i]);
        fprintf(logfile,"\n");
      }
      fprintf(logfile,"  initial/final max load/proc = %d %d\n",
              maxinit,maxfinal);
      fprintf(logfile,"  initial/final imbalance factor = %g %g\n",
@@ -414,17 +436,17 @@ void Balance::command(int narg, char **arg)
}

 /* ----------------------------------------------------------------------
   compute the load associated with an atom
    compute the computational load associated with an atom
    i = atom index
   return cost = product of load factors of groups which include the atom
    return cost = product of group weights for this atom.
------------------------------------------------------------------------- */

double Balance::getcost(int i)
{
   double cost = 1.0;
   for (int j=0; j < MAX_GROUP; j++)
   {
      if (atom->mask[i] & group->bitmask[j]) cost *= group->load_factor[j];
   for (int j = 0; j < ngroup; ++j) {
     if (atom->mask[i] & group->bitmask[group_id[j]])
       cost *= group_weight[j];
   }
   return cost;
}
@@ -435,24 +457,24 @@ double Balance::getcost(int i)
   return imbalance factor = max atom per proc / ave atom per proc
------------------------------------------------------------------------- */

double Balance::imbalance_nlocal(int &max)
double Balance::imbalance_nlocal(int &maxcost)
{
  // Compute the cost function of local atoms

  double cost = 0.0;
  for (int i=0; i < atom->nlocal; i++)
  {
  for (int i=0; i < atom->nlocal; ++i) {
    cost += getcost(i);
  }

  double imbalance = 1.0;
  int intcost = (int)cost;
  int sumcost;

  int sum;
  MPI_Allreduce(&intcost,&maxcost,1,MPI_INT,MPI_MAX,world);
  MPI_Allreduce(&intcost,&sumcost,1,MPI_INT,MPI_SUM,world);
  
  MPI_Allreduce(&intcost,&max,1,MPI_INT,MPI_MAX,world);
  MPI_Allreduce(&intcost,&sum,1,MPI_INT,MPI_SUM,world);
  double imbalance = 1.0;
  if (max) imbalance = max / (1.0 * sum / nprocs);
  if (maxcost && sumcost > 0)
    imbalance = maxcost / (static_cast<double>(sumcost)/nprocs);
  return imbalance;
}

@@ -474,8 +496,8 @@ double Balance::imbalance_splits(int &max)
  int ny = comm->procgrid[1];
  int nz = comm->procgrid[2];

  double proccountd [nprocs];
  for (int i = 0; i < nprocs; i++) proccountd[i] = 0.0;
  double *proccost = new double[nprocs];
  for (int i = 0; i < nprocs; i++) proccost[i] = 0.0;

  double **x = atom->x;
  int nlocal = atom->nlocal;
@@ -486,20 +508,24 @@ double Balance::imbalance_splits(int &max)
    iy = binary(x[i][1],ny,ysplit);
    iz = binary(x[i][2],nz,zsplit);

    proccountd[iz*nx*ny + iy*nx + ix] += getcost(i);
    proccost[iz*nx*ny + iy*nx + ix] += getcost(i);
  }

  for (int i = 0; i < nprocs; i++) proccount[i] = (bigint)proccountd[i];
  for (int i = 0; i < nprocs; i++)
    proccount[i] = static_cast<int>(proccost[i]);

  MPI_Allreduce(proccount,allproccount,nprocs,MPI_INT,MPI_SUM,world);
  int sum = 0;
  bigint sum = 0;
  max = 0;
  for (int i = 0; i < nprocs; i++){
    max = MAX(max,allproccount[i]);
    sum += allproccount[i];
  }
  double imbalance = 1.0;
  if (max) imbalance = max / (1.0 * sum / nprocs);
  if (max && sum > 0)
    imbalance = max / (static_cast<double>(sum) / nprocs);

  delete [] proccost;
  return imbalance;
}

@@ -668,6 +694,28 @@ void Balance::shift_setup(char *str, int nitermax_in, double thresh_in)
  rho = 1;
}

/* ----------------------------------------------------------------------
   setup group based load balance operations
   called from balance->command() and fix balance
------------------------------------------------------------------------- */
int Balance::group_setup(int narg, char **arg)
{
  if (narg < 3) error->all(FLERR,"Illegal balance command");

  ngroup = force->inumeric(FLERR,arg[0]);
  if (ngroup < 1) error->all(FLERR,"Illegal balance command");
  if (2*ngroup+1 > narg) error->all(FLERR,"Illegal balance command");

  group_id = new int[ngroup];
  group_weight = new double[ngroup];
  for (int i = 0; i < ngroup; ++i) {
    group_id[i] = group->find(arg[2*i+1]);
    if (group_id[i] < 0) error->all(FLERR,"Unknown group in balance command");
    group_weight[i] = force->numeric(FLERR,arg[2*i+2]);
  }
  return ngroup;
}

/* ----------------------------------------------------------------------
   load balance by changing xyz split proc boundaries in Comm
   called one time from input script command or many times from fix balance
@@ -712,14 +760,11 @@ int Balance::shift()
    tally(bdim[idim],np,split);

    double cost = 0.0;

    for (i=0; i < atom->nlocal; i++)
    {
      cost += getcost(i);
    }

    int intcost = (int)cost;
    int totalcost;

    MPI_Allreduce(&intcost,&totalcost,1,MPI_INT,MPI_SUM,world);

    // target[i] = desired sum at split I
@@ -854,8 +899,8 @@ int Balance::shift()

void Balance::tally(int dim, int n, double *split)
{
  double onecountd[n];
  for (int i = 0; i < n; i++) onecountd[i] = 0.0;
  double *onecost = new double[n];
  for (int i = 0; i < n; i++) onecost[i] = 0.0;

  double **x = atom->x;
  int nlocal = atom->nlocal;
@@ -864,16 +909,17 @@ void Balance::tally(int dim, int n, double *split)

  for (int i = 0; i < nlocal; i++) {
    index = binary(x[i][dim],n,split);
    onecountd[index] += getcost(i);
    onecost[index] += getcost(i);
  }

  for (int i = 0; i < n; i++) onecount[i] = (bigint)onecountd[i];

  for (int i = 0; i < n; i++) onecount[i] = static_cast<bigint>(onecost[i]);
  MPI_Allreduce(onecount,count,n,MPI_LMP_BIGINT,MPI_SUM,world);

  sum[0] = 0;
  for (int i = 1; i < n+1; i++)
    sum[i] = sum[i-1] + count[i-1];

  delete [] onecost;
}

/* ----------------------------------------------------------------------
+5 −0
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@ class Balance : protected Pointers {
  Balance(class LAMMPS *);
  ~Balance();
  void command(int, char **);
  int group_setup(int, char **);
  void shift_setup(char *, int, double);
  int shift();
  int *bisection(int sortflag = 0);
@@ -65,6 +66,10 @@ class Balance : protected Pointers {
  int *proccount;            // particle count per processor
  int *allproccount;

  int    ngroup;             // number of groups weights
  int    *group_id;          // group ids for weights
  double *group_weight;      // weights of groups

  int outflag;               // for output of balance results to file
  FILE *fp;
  int firststep;
+8 −3
Original line number Diff line number Diff line
@@ -73,6 +73,10 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
    iarg++;
  }

  // create instance of Balance class. required for processing group flags.

  balance = new Balance(lmp);

  // optional args

  outflag = 0;
@@ -85,6 +89,9 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
      outflag = 1;
      outarg = iarg+1;
      iarg += 2;
    } else if (strcmp(arg[iarg],"group") == 0) {
      int ngroup = balance->group_setup(narg-iarg-1,arg+iarg+1);
      iarg += 2 + 2*ngroup;
    } else error->all(FLERR,"Illegal fix balance command");
  }

@@ -106,10 +113,8 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
  if (lbstyle == BISECTION && comm->style == 0)
    error->all(FLERR,"Fix balance rcb cannot be used with comm_style brick");

  // create instance of Balance class
  // if SHIFT, initialize it with params
  // if SHIFT, initialize balance class with params

  balance = new Balance(lmp);
  if (lbstyle == SHIFT) balance->shift_setup(bstr,nitermax,thresh);

  // create instance of Irregular class