Commit 7f20afe1 authored by Axel Kohlmeyer's avatar Axel Kohlmeyer
Browse files

convert from using fix property/atom to using fix store

(cherry picked from commit 280aef55d2b7d8dbe4878247075f8c5aad94b73e)
parent 7e0dc7a7
Loading
Loading
Loading
Loading
+34 −60
Original line number Diff line number Diff line
@@ -36,6 +36,8 @@
#include "imbalance_neigh.h"
#include "imbalance_var.h"

#include "fix_store.h"

using namespace LAMMPS_NS;

enum{XYZ,SHIFT,BISECTION};
@@ -43,8 +45,6 @@ enum{NONE,UNIFORM,USER};
enum{X,Y,Z};
enum{LAYOUT_UNIFORM,LAYOUT_NONUNIFORM,LAYOUT_TILED};    // several files

const char * const Balance::bal_id = (const char * const) "BALANCE_WEIGHTS";

/* ---------------------------------------------------------------------- */

Balance::Balance(LAMMPS *lmp) : Pointers(lmp)
@@ -65,6 +65,7 @@ Balance::Balance(LAMMPS *lmp) : Pointers(lmp)

  nimbalance = 0;
  imbalance = NULL;
  imb_fix = NULL;
}

/* ---------------------------------------------------------------------- */
@@ -94,8 +95,8 @@ Balance::~Balance()
  for (int i; i < nimbalance; ++i)
    delete imbalance[i];
  delete [] imbalance;
  int ifix = modify->find_fix(bal_id);
  if (ifix >= 0) modify->delete_fix(bal_id);
  if (imb_fix) modify->delete_fix(imb_fix->id);
  imb_fix = NULL;

  if (fp) fclose(fp);
}
@@ -216,6 +217,7 @@ void Balance::command(int narg, char **arg)
  if (nimbalance) imbalance = new Imbalance*[nimbalance];

  nimbalance = outflag = 0;
  imb_fix = NULL;
  while (iarg < narg) {
    if (strcmp(arg[iarg],"out") == 0) {
      if (iarg+2 > narg) error->all(FLERR,"Illegal balance command");
@@ -303,29 +305,22 @@ void Balance::command(int narg, char **arg)
  if (domain->triclinic) domain->lamda2x(atom->nlocal);

  // compute and apply imbalance weights for local atoms
  int iweight = -1;
  if (nimbalance > 0) {
    int dflag = 0;
    iweight = atom->find_custom(bal_id,dflag);

    // add fix property/atom, for storing weights with atoms, if needed.
    if (iweight < 0 || dflag != 1) {
      char *fixargs[4];
      char *val_id = new char[strlen(bal_id)+3];
      strcpy(val_id,"d_");
      strcat(val_id,bal_id);
    char *fixargs[6];

      fixargs[0] = (char *)bal_id;
    fixargs[0] = (char *) "IMBALANCE_WEIGHTS";
    fixargs[1] = (char *) "all";
      fixargs[2] = (char *)"property/atom";
      fixargs[3] = val_id;
    fixargs[2] = (char *) "STORE";
    fixargs[3] = (char *) "peratom";
    fixargs[4] = (char *) "1";
    fixargs[5] = (char *) "1";

      modify->add_fix(4,fixargs);
      iweight = atom->find_custom(bal_id,dflag);
      delete[] val_id;
    }
    modify->add_fix(6,fixargs);
    imb_fix = (FixStore *) modify->fix[modify->nfix-1];

    double * const weight = atom->dvector[iweight];
    double * const weight = imb_fix->vstore;
    for (int i = 0; i < atom->nlocal; ++i)
      weight[i] = 1.0;
    for (int n = 0; n < nimbalance; ++n)
@@ -435,19 +430,6 @@ void Balance::command(int narg, char **arg)
    error->all(FLERR,str);
  }

  // recompute and apply imbalance weights for local atoms

  if (nimbalance > 0) {
    int i;
    const int nlocal = atom->nlocal;
    double * const weight = atom->dvector[iweight];

    for (i = 0; i < nlocal; ++i)
      weight[i] = 1.0;
    for (i = 0; i < nimbalance; ++i)
      imbalance[i]->compute(weight);
  }

  // imbfinal = final imbalance based on final (weighted) nlocal

  int maxfinal;
@@ -518,22 +500,20 @@ void Balance::command(int narg, char **arg)
double Balance::imbalance_nlocal(int &maxcost)
{
  // Compute the cost function of local atoms
  const int nlocal = atom->nlocal;
  int intcost, sumcost;
  intcost = sumcost = maxcost = 0;

  if (imb_fix) {
    const double * const weight = imb_fix->vstore;
    double cost = 0.0;
  int dflag = 0;
  int iweight = atom->find_custom(bal_id,dflag);

  if (iweight < 0 || dflag != 1) {
    cost = atom->nlocal;
  } else {
    const double * const weight = atom->dvector[iweight];
    const int nlocal = atom->nlocal;
    for (int i=0; i < nlocal; ++i)
      cost += weight[i];
  }

  int intcost = (int)cost;
  int sumcost = maxcost = 0;
    intcost = (int)cost;
  } else {
    intcost = nlocal;
  }

  MPI_Allreduce(&intcost,&maxcost,1,MPI_INT,MPI_MAX,world);
  MPI_Allreduce(&intcost,&sumcost,1,MPI_INT,MPI_SUM,world);
@@ -648,10 +628,7 @@ int *Balance::bisection(int sortflag)

  // Use pre-computed weights for each atom, if available

  int dflag = 0;
  int iweight = atom->find_custom(bal_id,dflag);
  double * const weight =
    (iweight < 0 || dflag != 1) ? NULL : atom->dvector[iweight];
  double * const weight = (imb_fix) ? imb_fix->vstore : NULL;

  // invoke RCB
  // then invert() to create list of proc assignements for my atoms
@@ -806,24 +783,21 @@ int Balance::shift()

    // intial count and sum

    int dflag = 0;
    int iweight = atom->find_custom(bal_id,dflag);
    const double * const weight =
      (iweight < 0 || dflag != 1) ? NULL : atom->dvector[iweight];
    double cost = 0.0;
    const double * const weight = (imb_fix) ? imb_fix->vstore : NULL;
    int intcost, totalcost;

    np = procgrid[bdim[idim]];
    tally(bdim[idim],np,split,weight);

    if (weight) {
      double cost = 0.0;
      for (int i=0; i < atom->nlocal; ++i)
        cost += weight[i];
      intcost = (int) cost;
    } else {
      cost = atom->nlocal;
      intcost = atom->nlocal;
    }

    int intcost = (int)cost;
    int totalcost;
    MPI_Allreduce(&intcost,&totalcost,1,MPI_INT,MPI_SUM,world);

    // target[i] = desired sum at split I
+6 −2
Original line number Diff line number Diff line
@@ -26,8 +26,9 @@ CommandStyle(balance,Balance)
namespace LAMMPS_NS {

class Balance : protected Pointers {
  friend class FixBalance;

 public:
  class RCB *rcb;

  Balance(class LAMMPS *);
  ~Balance();
@@ -38,7 +39,9 @@ class Balance : protected Pointers {
  double imbalance_nlocal(int &);
  void dumpout(bigint, FILE *);

  static const char * const bal_id; // name of custom atom property for weights
 protected:
  class RCB *rcb;
  void set_imb_fix(class FixStore *fix) { imb_fix = fix; };

 private:
  int me,nprocs;
@@ -69,6 +72,7 @@ class Balance : protected Pointers {

  int nimbalance;              // number of imbalance weight computes
  class Imbalance **imbalance; // list of imbalance compute classes
  class FixStore *imb_fix;     // fix for storing per-atom weights

  int outflag;                 // for output of balance results to file
  FILE *fp;
+42 −42
Original line number Diff line number Diff line
@@ -32,6 +32,8 @@
#include "imbalance_neigh.h"
#include "imbalance_var.h"

#include "fix_store.h"

using namespace LAMMPS_NS;
using namespace FixConst;

@@ -97,7 +99,7 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
  int outarg = 0;
  fp = NULL;
  nimbalance = 0;
  imb_id = NULL;
  imb_fix = NULL;

  while (iarg < narg) {
    if (strcmp(arg[iarg],"out") == 0) {
@@ -179,15 +181,19 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :

FixBalance::~FixBalance()
{
  if (fp) fclose(fp);
  delete balance;
  delete irregular;

  for (int i = 0; i < nimbalance; ++i)
    delete imbalance[i];
  delete[] imbalance;
  if (imb_id) modify->delete_fix(imb_id);
  delete [] imb_id;

  if (imb_fix && (modify->nfix > 0)) {
    modify->delete_fix(imb_fix->id);
    imb_fix = NULL;
    balance->set_imb_fix(NULL);
  }

  delete balance;
  delete irregular;
  if (fp) fclose(fp);
}

/* ---------------------------------------------------------------------- */
@@ -202,38 +208,38 @@ int FixBalance::setmask()

/* ---------------------------------------------------------------------- */

void FixBalance::init()
void FixBalance::post_constructor()
{
  if (force->kspace) kspace_flag = 1;
  else kspace_flag = 0;

  // add per atom weight property, if weighted balancing is requested

  if (nimbalance > 0) {
    int dflag = 0;
    int iweight = atom->find_custom(Balance::bal_id,dflag);

    if (iweight < 0 || dflag != 1) {
      char *fixargs[4];
      
      imb_id = new char[strlen(this->id)+strlen(Balance::bal_id)+2];
      char *val_id = new char[strlen(Balance::bal_id)+3];
    char *fixargs[6];
    char *imb_id = new char[strlen(this->id)+19];

    strcpy(imb_id,this->id);
      strcat(imb_id,"_");
      strcat(imb_id,Balance::bal_id);
      strcpy(val_id,"d_");
      strcat(val_id,Balance::bal_id);
    strcat(imb_id,"_IMBALANCE_WEIGHTS");

    fixargs[0] = imb_id;
    fixargs[1] = (char *) "all";
      fixargs[2] = (char *)"property/atom";
      fixargs[3] = val_id;
    fixargs[2] = (char *) "STORE";
    fixargs[3] = (char *) "peratom";
    fixargs[4] = (char *) "1";
    fixargs[5] = (char *) "1";

      modify->add_fix(4,fixargs);
      delete[] val_id;
    modify->add_fix(6,fixargs);
    imb_fix = (FixStore *) modify->fix[modify->nfix-1];
    balance->set_imb_fix(imb_fix);

    delete[] imb_id;
  }
}

/* ---------------------------------------------------------------------- */

void FixBalance::init()
{
  if (force->kspace) kspace_flag = 1;
  else kspace_flag = 0;
}

/* ---------------------------------------------------------------------- */
@@ -263,11 +269,8 @@ void FixBalance::setup_pre_exchange()

  // compute and apply imbalance weights for local atoms

  int iweight = -1;
  if (nimbalance > 0) {
    int dflag = 0;
    iweight = atom->find_custom(Balance::bal_id,dflag);
    double * const weight = atom->dvector[iweight];
    double * const weight = imb_fix->vstore;
    for (int i = 0; i < atom->nlocal; ++i)
      weight[i] = 1.0;
    for (int n = 0; n < nimbalance; ++n)
@@ -304,11 +307,8 @@ void FixBalance::pre_exchange()

  // compute and apply imbalance weights for local atoms

  int iweight = -1;
  if (nimbalance > 0) {
    int dflag = 0;
    iweight = atom->find_custom(Balance::bal_id,dflag);
    double * const weight = atom->dvector[iweight];
    double * const weight = imb_fix->vstore;
    for (int i = 0; i < atom->nlocal; ++i)
      weight[i] = 1.0;
    for (int n = 0; n < nimbalance; ++n)
+3 −2
Original line number Diff line number Diff line
@@ -35,6 +35,7 @@ class FixBalance : public Fix {
  void setup_pre_exchange();
  void pre_exchange();
  void pre_neighbor();
  void post_constructor();
  double compute_scalar();
  double compute_vector(int);
  double memory_usage();
@@ -55,7 +56,7 @@ class FixBalance : public Fix {

  int nimbalance;               // number of imbalance weight computes
  class Imbalance **imbalance;  // list of imbalance compute classes
  char *imb_id;                 // id of property/atom fix for storing weights
  class FixStore *imb_fix;      // fix for storing per-atom weights

  class Balance *balance;
  class Irregular *irregular;