Commit 69854eab authored by Richard Berger's avatar Richard Berger
Browse files

Expose Neighbor lists via library interface

parent 3353bffb
Loading
Loading
Loading
Loading
+68 −0
Original line number Diff line number Diff line
# 3d Lennard-Jones melt

units		lj
atom_style	atomic

lattice		fcc 0.8442
region		box block 0 2 0 2 0 2
create_box	1 box
create_atoms	1 box
mass		1 1.0

velocity	all create 3.0 87287

pair_style	lj/cut 2.5
pair_coeff	1 1 1.0 1.0 2.5

neighbor	0.1 bin

neigh_modify	every 20 delay 0 check no

python         post_force_callback here """
from __future__ import print_function
from lammps import lammps

def post_force_callback(lmp, v):
  try:
    L = lammps(ptr=lmp)
    t = L.extract_global("ntimestep", 0)
    print("### POST_FORCE ###", t)

    #mylist = L.get_neighlist(0)
    mylist = L.find_pair_neighlist("lj/cut", request=0)
    print(mylist)
    nlocal = L.extract_global("nlocal", 0)
    nghost = L.extract_global("nghost", 0)
    ntypes = L.extract_global("ntypes", 0)
    mass = L.numpy.extract_atom_darray("mass", ntypes+1)
    atype = L.numpy.extract_atom_iarray("type", nlocal+nghost)
    x = L.numpy.extract_atom_darray("x", nlocal+nghost, dim=3)
    v = L.numpy.extract_atom_darray("v", nlocal+nghost, dim=3)
    f = L.numpy.extract_atom_darray("f", nlocal+nghost, dim=3)

    for iatom, numneigh, neighs in mylist:
      print("- {}".format(iatom), x[iatom], v[iatom], f[iatom], " : ",  numneigh, "Neighbors")
      for jatom in neighs:
        if jatom < nlocal:
            print("    *  ", jatom, x[jatom], v[jatom], f[jatom])
        else:
            print("    * [GHOST]", jatom, x[jatom], v[jatom], f[jatom])
  except Exception as e:
    print(e)
"""

fix		1 all nve
fix     3 all python/invoke 1 post_force post_force_callback

#dump		id all atom 1 dump.melt

#dump		2 all image 1 image.*.jpg type type &
#		axes yes 0.8 0.02 view 60 -30
#dump_modify	2 pad 3

#dump		3 all movie 1 movie.mpg type type &
#		axes yes 0.8 0.02 view 60 -30
#dump_modify	3 pad 3

thermo		1
run		1
+72 −0
Original line number Diff line number Diff line
@@ -51,6 +51,31 @@ class MPIAbortException(Exception):
  def __str__(self):
    return repr(self.message)

class NeighList:
    def __init__(self, lmp, idx):
        self.lmp = lmp
        self.idx = idx

    def __str__(self):
        return "Neighbor List ({} atoms)".format(self.size)

    def __repr__(self):
        return self.__str__()

    @property
    def size(self):
        return self.lmp.get_neighlist_size(self.idx)

    def get(self, element):
        iatom, numneigh, neighbors = self.lmp.get_neighlist_element_neighbors(self.idx, element)
        return iatom, numneigh, neighbors

    def __iter__(self):
        inum = self.size

        for ii in range(inum):
            yield self.get(ii)

class lammps(object):

  # detect if Python is using version of mpi4py that can pass a communicator
@@ -73,6 +98,7 @@ class lammps(object):

    modpath = dirname(abspath(getsourcefile(lambda:0)))
    self.lib = None
    self.lmp = None

    # if a pointer to a LAMMPS object is handed in,
    # all symbols should already be available
@@ -137,6 +163,21 @@ class lammps(object):
      [c_void_p,c_char_p,c_int,c_int,c_int,POINTER(c_int),c_void_p]
    self.lib.lammps_scatter_atoms_subset.restype = None

    self.lib.lammps_find_pair_neighlist.argtypes = [c_void_p, c_char_p, c_int, c_int]
    self.lib.lammps_find_pair_neighlist.restype  = c_int

    self.lib.lammps_find_fix_neighlist.argtypes = [c_void_p, c_char_p, c_int]
    self.lib.lammps_find_fix_neighlist.restype  = c_int

    self.lib.lammps_find_compute_neighlist.argtypes = [c_void_p, c_char_p, c_int]
    self.lib.lammps_find_compute_neighlist.restype  = c_int

    self.lib.lammps_neighlist_num_elements.argtypes = [c_void_p, c_int]
    self.lib.lammps_neighlist_num_elements.restype  = c_int

    self.lib.lammps_neighlist_element_neighbors.argtypes = [c_void_p, c_int, c_int, POINTER(c_int), POINTER(c_int), POINTER(POINTER(c_int))]
    self.lib.lammps_neighlist_element_neighbors.restype  = None

    # if no ptr provided, create an instance of LAMMPS
    #   don't know how to pass an MPI communicator from PyPar
    #   but we can pass an MPI communicator from mpi4py v2.0.0 and later
@@ -651,6 +692,37 @@ class lammps(object):

    self.lib.lammps_set_fix_external_callback(self.lmp, fix_name.encode(), cFunc, cCaller)

  def get_neighlist(self, idx):
    if idx < 0:
        return None
    return NeighList(self, idx)

  def find_pair_neighlist(self, style, nsub=0, request=0):
    style = style.encode()
    idx = self.lib.lammps_find_pair_neighlist(self.lmp, style, nsub, request)
    return self.get_neighlist(idx)

  def find_fix_neighlist(self, fixid, request=0):
    fixid = fixid.encode()
    idx = self.lib.lammps_find_fix_neighlist(self.lmp, fixid, request)
    return self.get_neighlist(idx)

  def find_compute_neighlist(self, computeid, request):
    computeid = computeid.encode()
    idx = self.lib.lammps_find_compute_neighlist(self.lmp, computeid, request)
    return self.get_neighlist(idx)

  def get_neighlist_size(self, idx):
    return self.lib.lammps_neighlist_num_elements(self.lmp, idx)

  def get_neighlist_element_neighbors(self, idx, element):
    c_iatom = c_int()
    c_numneigh = c_int()
    c_neighbors = POINTER(c_int)()
    self.lib.lammps_neighlist_element_neighbors(self.lmp, idx, element, byref(c_iatom), byref(c_numneigh), byref(c_neighbors))
    neighbors = self.numpy.iarray(c_int, c_neighbors, c_numneigh.value, 1)
    return c_iatom.value, c_numneigh.value, neighbors

# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
+133 −0
Original line number Diff line number Diff line
@@ -38,6 +38,9 @@
#include "force.h"
#include "info.h"
#include "fix_external.h"
#include "neighbor.h"
#include "neigh_list.h"
#include "neigh_request.h"

#if defined(LAMMPS_EXCEPTIONS)
#include "exceptions.h"
@@ -1727,3 +1730,133 @@ int lammps_get_last_error_message(void *ptr, char * buffer, int buffer_size) {
}

#endif

/* ----------------------------------------------------------------------
   Find neighbor list index for pair style
------------------------------------------------------------------------- */
int lammps_find_pair_neighlist(void* ptr, char * style, int nsub, int request) {
  LAMMPS *  lmp = (LAMMPS *) ptr;
  Pair* pair = lmp->force->pair_match(style, 1, nsub);

  if (pair != NULL) {
    // find neigh list
    for (int i = 0; i < lmp->neighbor->nlist; i++) {
      NeighList * list = lmp->neighbor->lists[i];
      if (list->requestor_type != NeighList::PAIR || pair != list->requestor) continue;

      if (list->index == request) {
          return i;
      }
    }
  }
  return -1;
}

/* ----------------------------------------------------------------------
   Find neighbor list index for compute with given fix ID
   The request ID identifies which request it is in case of there are
   multiple neighbor lists for this fix
------------------------------------------------------------------------- */
int lammps_find_fix_neighlist(void* ptr, char * id, int request) {
  LAMMPS *  lmp = (LAMMPS *) ptr;
  Fix* fix = NULL;
  const int nfix = lmp->modify->nfix;

  // find fix with name
  for (int ifix = 0; ifix < nfix; ifix++) {
    if (strcmp(lmp->modify->fix[ifix]->id, id) == 0) {
        fix = lmp->modify->fix[ifix];
        break;
    }
  }

  if (fix != NULL) {
    // find neigh list
    for (int i = 0; i < lmp->neighbor->nlist; i++) {
      NeighList * list = lmp->neighbor->lists[i];
      if (list->requestor_type != NeighList::FIX || fix != list->requestor) continue;

      if (list->index == request) {
          return i;
      }
    }
  }
  return -1;
}

/* ----------------------------------------------------------------------
   Find neighbor list index for compute with given compute ID
   The request ID identifies which request it is in case of there are
   multiple neighbor lists for this compute
------------------------------------------------------------------------- */

int lammps_find_compute_neighlist(void* ptr, char * id, int request) {
  LAMMPS *  lmp = (LAMMPS *) ptr;
  Compute* compute = NULL;
  const int ncompute = lmp->modify->ncompute;

  // find compute with name
  for (int icompute = 0; icompute < ncompute; icompute++) {
    if (strcmp(lmp->modify->compute[icompute]->id, id) == 0) {
        compute = lmp->modify->compute[icompute];
        break;
    }
  }

  if (compute == NULL) {
    // find neigh list
    for (int i = 0; i < lmp->neighbor->nlist; i++) {
      NeighList * list = lmp->neighbor->lists[i];
      if (list->requestor_type != NeighList::COMPUTE || compute != list->requestor) continue;

      if (list->index == request) {
          return i;
      }
    }
  }
  return -1;
}

/* ----------------------------------------------------------------------
   Return the number of entries in the neighbor list with given index
------------------------------------------------------------------------- */

int lammps_neighlist_num_elements(void * ptr, int idx) {
  LAMMPS *  lmp = (LAMMPS *) ptr;
  Neighbor * neighbor = lmp->neighbor;

  if(idx < 0 || idx >= neighbor->nlist) {
    return -1;
  }

  NeighList * list = neighbor->lists[idx];
  return list->inum;
}

/* ----------------------------------------------------------------------
   Return atom index, number of neighbors and neighbor array for neighbor
   list entry
------------------------------------------------------------------------- */

void lammps_neighlist_element_neighbors(void * ptr, int idx, int element, int * iatom, int * numneigh, int ** neighbors) {
  LAMMPS *  lmp = (LAMMPS *) ptr;
  Neighbor * neighbor = lmp->neighbor;
  *iatom = -1;
  *numneigh = 0;
  *neighbors = NULL;

  if(idx < 0 || idx >= neighbor->nlist) {
    return;
  }

  NeighList * list = neighbor->lists[idx];

  if(element < 0 || element >= list->inum) {
    return;
  }

  int i = list->ilist[element];
  *iatom     = i;
  *numneigh  = list->numneigh[i];
  *neighbors = list->firstneigh[i];
}
+6 −0
Original line number Diff line number Diff line
@@ -75,6 +75,12 @@ int lammps_config_has_jpeg_support();
int lammps_config_has_ffmpeg_support();
int lammps_config_has_exceptions();

int lammps_find_pair_neighlist(void* ptr, char * style, int nsub, int request);
int lammps_find_fix_neighlist(void* ptr, char * id, int request);
int lammps_find_compute_neighlist(void* ptr, char * id, int request);
int lammps_neighlist_num_elements(void* ptr, int idx);
void lammps_neighlist_element_neighbors(void * ptr, int idx, int element, int * iatom, int * numneigh, int ** neighbors);

// lammps_create_atoms() takes tagint and imageint as args
// ifdef insures they are compatible with rest of LAMMPS
// caller must match to how LAMMPS library is built
+3 −0
Original line number Diff line number Diff line
@@ -85,6 +85,9 @@ NeighList::NeighList(LAMMPS *lmp) : Pointers(lmp)
  // USER-DPD package

  np = NULL;

  requestor = NULL;
  requestor_type = NeighList::NONE;
}

/* ---------------------------------------------------------------------- */
Loading