Commit 2b42428d authored by Richard Berger's avatar Richard Berger
Browse files

Extend lib interface to set fix external callback

This allows creating a callback in Python and attaching it to
a fix external instance.
parent 8cfdf4fa
Loading
Loading
Loading
Loading
+36 −0
Original line number Diff line number Diff line
# this example requires the LAMMPS Python package (lammps.py) to be installed
# and LAMMPS to be loadable as shared library in LD_LIBRARY_PATH

import lammps

def callback(caller, ntimestep, nlocal, tag, x, fext):
    """
    This callback receives a caller object that was setup when registering the callback

    In addition to timestep and number of local atoms, the tag and x arrays are passed as
    NumPy arrays. The fext array is a force array allocated for fix external, which
    can be used to apply forces to all atoms. Simply update the value in the array,
    it will be directly written into the LAMMPS C arrays
    """
    print("Data passed by caller (optional)", caller)
    print("Timestep:", ntimestep)
    print("Number of Atoms:", nlocal)
    print("Atom Tags:", tag)
    print("Atom Positions:", x)
    print("Force Additions:", fext)
    fext.fill(1.0)
    print("Force additions after update:", fext)
    print("="*40)

L = lammps.lammps()
L.file("in.lammps")

# you can pass an arbitrary Python object to the callback every time it is called
# this can be useful if you need more state information such as the LAMMPS ptr to
# make additional library calls
custom_object = ["Some data", L]

L.set_fix_external_callback("2", callback, custom_object)
L.command("run 100")

+23 −0
Original line number Diff line number Diff line
# LAMMPS input for coupling LAMMPS with Python via fix external

units		metal
dimension	3
atom_style	atomic
atom_modify	sort 0 0.0

lattice		diamond 5.43
region		box block 0 1 0 1 0 1
create_box	1 box
create_atoms	1 box
mass		1 28.08

velocity	all create 300.0 87293 loop geom

fix		1 all nve
fix		2 all external pf/callback 1 1

#dump		2 all image 25 image.*.jpg type type &
#		axes yes 0.8 0.02 view 60 -30
#dump_modify	2 pad 3

thermo		1
+44 −2
Original line number Diff line number Diff line
@@ -219,6 +219,12 @@ class lammps(object):
    self.c_imageint = get_ctypes_int(self.extract_setting("imageint"))
    self._installed_packages = None

    # add way to insert Python callback for fix external
    self.callback = {}
    self.FIX_EXTERNAL_CALLBACK_FUNC = CFUNCTYPE(None, c_void_p, self.c_bigint, c_int, POINTER(self.c_tagint), POINTER(POINTER(c_double)), POINTER(POINTER(c_double)))
    self.lib.lammps_set_fix_external_callback.argtypes = [c_void_p, c_char_p, self.FIX_EXTERNAL_CALLBACK_FUNC, c_void_p]
    self.lib.lammps_set_fix_external_callback.restype = None

  # shut-down LAMMPS instance

  def __del__(self):
@@ -602,6 +608,42 @@ class lammps(object):
        self._installed_packages.append(sb.value.decode())
    return self._installed_packages

  def set_fix_external_callback(self, fix_name, callback, caller=None):
    import numpy as np
    def _ctype_to_numpy_int(ctype_int):
          if ctype_int == c_int32:
            return np.int32
          elif ctype_int == c_int64:
            return np.int64
          return np.intc

    def callback_wrapper(caller_ptr, ntimestep, nlocal, tag_ptr, x_ptr, fext_ptr):
      if cast(caller_ptr,POINTER(py_object)).contents:
        pyCallerObj = cast(caller_ptr,POINTER(py_object)).contents.value
      else:
        pyCallerObj = None

      tptr = cast(tag_ptr, POINTER(self.c_tagint * nlocal))
      tag = np.frombuffer(tptr.contents, dtype=_ctype_to_numpy_int(self.c_tagint))
      tag.shape = (nlocal)

      xptr = cast(x_ptr[0], POINTER(c_double * nlocal * 3))
      x = np.frombuffer(xptr.contents)
      x.shape = (nlocal, 3)

      fptr = cast(fext_ptr[0], POINTER(c_double * nlocal * 3))
      f = np.frombuffer(fptr.contents)
      f.shape = (nlocal, 3)

      callback(pyCallerObj, ntimestep, nlocal, tag, x, f)

    cFunc   = self.FIX_EXTERNAL_CALLBACK_FUNC(callback_wrapper)
    cCaller = cast(pointer(py_object(caller)), c_void_p)

    self.callback[fix_name] = { 'function': cFunc, 'caller': caller }

    self.lib.lammps_set_fix_external_callback(self.lmp, fix_name.encode(), cFunc, cCaller)

# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
+30 −0
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@
#include "error.h"
#include "force.h"
#include "info.h"
#include "fix_external.h"

#if defined(LAMMPS_EXCEPTIONS)
#include "exceptions.h"
@@ -1605,6 +1606,35 @@ void lammps_create_atoms(void *ptr, int n, tagint *id, int *type,
  END_CAPTURE
}

void lammps_set_fix_external_callback(void *ptr, char *id, FixExternalFnPtr callback_ptr, void * caller)
{
  LAMMPS *lmp = (LAMMPS *) ptr;
  FixExternal::FnPtr callback = (FixExternal::FnPtr) callback_ptr;

  BEGIN_CAPTURE
  {
    int ifix = lmp->modify->find_fix(id);
    if (ifix < 0) {
      char str[50];
      sprintf(str, "Can not find fix with ID '%s'!", id);
      lmp->error->all(FLERR,str);
    }

    Fix *fix = lmp->modify->fix[ifix];

    if (strcmp("external",fix->style) != 0){
      char str[50];
      sprintf(str, "Fix '%s' is not of style external!", id);
      lmp->error->all(FLERR,str);
    }

    FixExternal * fext = (FixExternal*) fix;
    fext->set_callback(callback, caller);
  }
  END_CAPTURE
}


// ----------------------------------------------------------------------
// library API functions for accessing LAMMPS configuration
// ----------------------------------------------------------------------
+8 −0
Original line number Diff line number Diff line
@@ -58,6 +58,14 @@ void lammps_gather_atoms_subset(void *, char *, int, int, int, int *, void *);
void lammps_scatter_atoms(void *, char *, int, int, void *);
void lammps_scatter_atoms_subset(void *, char *, int, int, int, int *, void *);

#ifdef LAMMPS_BIGBIG
typedef void (*FixExternalFnPtr)(void *, int64_t, int, int64_t *, double **, double **);
void lammps_set_fix_external_callback(void *, char *, FixExternalFnPtr, void*);
#else
typedef void (*FixExternalFnPtr)(void *, int, int, int *, double **, double **);
void lammps_set_fix_external_callback(void *, char *, FixExternalFnPtr, void*);
#endif

int lammps_config_has_package(char * package_name);
int lammps_config_package_count();
int lammps_config_package_name(int index, char * buffer, int max_size);