Commit 8d38db07 authored by Axel Kohlmeyer's avatar Axel Kohlmeyer
Browse files

convert weight array from class member to local pointer to temporary storage

(cherry picked from commit ecbbdc2e7ff025d6ad1f6b808a13b2521ec825a8)
parent 4114bafc
Loading
Loading
Loading
Loading
+17 −17
Original line number Diff line number Diff line
@@ -62,7 +62,6 @@ Balance::Balance(LAMMPS *lmp) : Pointers(lmp)

  nimbalance = 0;
  imbalance = NULL;
  weight = NULL;
}

/* ---------------------------------------------------------------------- */
@@ -92,7 +91,6 @@ Balance::~Balance()
  for (int i; i < nimbalance; ++i)
    delete imbalance[i];
  delete [] imbalance;
  delete [] weight;

  if (fp) fclose(fp);
}
@@ -210,7 +208,7 @@ void Balance::command(int narg, char **arg)
  nimbalance = 0;
  for (int i=iarg; i < narg; ++i)
    if (strcmp(arg[iarg],"weight") == 0) ++nimbalance;
  imbalance = new Imbalance*[nimbalance];
  if (nimbalance) imbalance = new Imbalance*[nimbalance];

  nimbalance = outflag = 0;
  while (iarg < narg) {
@@ -300,6 +298,7 @@ void Balance::command(int narg, char **arg)
  if (domain->triclinic) domain->lamda2x(atom->nlocal);

  // compute and apply imbalance weights for local atoms
  double *weight = NULL;
  if (nimbalance > 0) {
    int i;
    const int nlocal = atom->nlocal;
@@ -308,12 +307,12 @@ void Balance::command(int narg, char **arg)
      weight[i] = 1.0;
    for (i = 0; i < nimbalance; ++i)
      imbalance[i]->compute(weight);
  } else weight = NULL;
  }

  // imbinit = initial imbalance

  int maxinit;
  double imbinit = imbalance_nlocal(maxinit);
  double imbinit = imbalance_nlocal(maxinit,weight);

  // no load-balance if imbalance doesn't exceed threshhold
  // unless switching from tiled to non tiled layout, then force rebalance
@@ -372,14 +371,14 @@ void Balance::command(int narg, char **arg)
  if (style == SHIFT) {
    comm->layout = LAYOUT_NONUNIFORM;
    shift_setup_static(bstr);
    niter = shift();
    niter = shift(weight);
  }

  // style BISECTION = recursive coordinate bisectioning

  if (style == BISECTION) {
    comm->layout = LAYOUT_TILED;
    bisection(1);
    bisection(weight,1);
  }

  // reset proc sub-domains
@@ -424,12 +423,13 @@ void Balance::command(int narg, char **arg)
      weight[i] = 1.0;
    for (i = 0; i < nimbalance; ++i)
      imbalance[i]->compute(weight);
  } else weight = NULL;
  }

  // imbfinal = final imbalance based on final (weighted) nlocal

  int maxfinal;
  double imbfinal = imbalance_nlocal(maxfinal);
  double imbfinal = imbalance_nlocal(maxfinal,weight);
  delete[] weight;

  if (me == 0) {
    double stop_time = MPI_Wtime();
@@ -493,7 +493,7 @@ void Balance::command(int narg, char **arg)
   return imbalance factor = max atom per proc / ave atom per proc
------------------------------------------------------------------------- */

double Balance::imbalance_nlocal(int &maxcost)
double Balance::imbalance_nlocal(int &maxcost, double *weight)
{
  // Compute the cost function of local atoms

@@ -526,7 +526,7 @@ double Balance::imbalance_nlocal(int &maxcost)
   return imbalance factor = max atom per proc / ave atom per proc
------------------------------------------------------------------------- */

double Balance::imbalance_splits(int &max)
double Balance::imbalance_splits(int &max, double *weight)
{
  double *xsplit = comm->xsplit;
  double *ysplit = comm->ysplit;
@@ -580,7 +580,7 @@ double Balance::imbalance_splits(int &max)
   sortflag = flag for sorting order of received messages by proc ID
------------------------------------------------------------------------- */

int *Balance::bisection(int sortflag)
int *Balance::bisection(double *weight, int sortflag)
{
  if (!rcb) rcb = new RCB(lmp);

@@ -739,7 +739,7 @@ void Balance::shift_setup(char *str, int nitermax_in, double thresh_in)
   return niter = iteration count
------------------------------------------------------------------------- */

int Balance::shift()
int Balance::shift(double *weight)
{
  int i,j,k,m,np,max;
  double *split;
@@ -774,7 +774,7 @@ int Balance::shift()
    // intial count and sum

    np = procgrid[bdim[idim]];
    tally(bdim[idim],np,split);
    tally(bdim[idim],np,split,weight);

    double cost = 0.0;
    if (weight == NULL) {
@@ -827,7 +827,7 @@ int Balance::shift()
    int change = 1;
    for (m = 0; m < nitermax; m++) {
      change = adjust(np,split);
      tally(bdim[idim],np,split);
      tally(bdim[idim],np,split,weight);
      niter++;

#ifdef BALANCE_DEBUG
@@ -898,7 +898,7 @@ int Balance::shift()
    // stop at this point in bstr if imbalance factor < threshhold
    // this is a true 3d test of particle count per processor

    double imbfactor = imbalance_splits(max);
    double imbfactor = imbalance_splits(max,weight);
    if (imbfactor <= stopthresh) break;
  }

@@ -918,7 +918,7 @@ int Balance::shift()
   use binary search to find which slice each atom is in
------------------------------------------------------------------------- */

void Balance::tally(int dim, int n, double *split)
void Balance::tally(int dim, int n, double *split, double *weight)
{
  double *onecost = new double[n];
  for (int i = 0; i < n; i++) onecost[i] = 0.0;
+21 −22
Original line number Diff line number Diff line
@@ -33,9 +33,9 @@ class Balance : protected Pointers {
  ~Balance();
  void command(int, char **);
  void shift_setup(char *, int, double);
  int shift();
  int *bisection(int sortflag = 0);
  double imbalance_nlocal(int &);
  int shift(double *);
  int *bisection(double *, int sortflag = 0);
  double imbalance_nlocal(int &, double *);
  void dumpout(bigint, FILE *);

 private:
@@ -62,20 +62,19 @@ class Balance : protected Pointers {
  int rho;                      // 0 for geometric recursion
                                // 1 for density weighted recursion

  int *proccount;            // particle count per processor
  int *proccount;               // (weighted) particle count per processor
  int *allproccount;

  int nimbalance;
  int nimbalance;               // number of imbalance weight computes
  class Imbalance **imbalance;  // list of imbalance compute classes
  double *weight;            // per (local) atom weight factor or NULL

  int outflag;                  // for output of balance results to file
  FILE *fp;
  int firststep;

  double imbalance_splits(int &);
  double imbalance_splits(int &, double *);
  void shift_setup_static(char *);
  void tally(int, int, double *);
  void tally(int, int, double *, double *);
  int adjust(int, double *);
  int binary(double, int, double *);
#ifdef BALANCE_DEBUG