Commit bb3ee752 authored by Richard Berger's avatar Richard Berger
Browse files

Added optional numpy access to atom data

The library interface was extended to provide direct access to atom data using
numpy arrays. No data copies are made and numpy operations directly manipulate
memory of the native code.

To keep this numpy dependency optional, all functions are wrapped into the
lammps.numpy sub-object which is only loaded when accessed.
parent 6b2d321d
Loading
Loading
Loading
Loading
+38 −0
Original line number Diff line number Diff line
@@ -162,6 +162,9 @@ class lammps(object):
        pythonapi.PyCObject_AsVoidPtr.argtypes = [py_object]
        self.lmp = c_void_p(pythonapi.PyCObject_AsVoidPtr(ptr))

      # optional numpy support (lazy loading)
      self._numpy = None

  def __del__(self):
    if self.lmp and self.opened:
      self.lib.lammps_close(self.lmp)
@@ -236,6 +239,41 @@ class lammps(object):
    ptr = self.lib.lammps_extract_atom(self.lmp,name)
    return ptr

  @property
  def numpy(self):
    if not self._numpy:
      import numpy as np
      class LammpsNumpyWrapper:
        def __init__(self, lmp):
          self.lmp = lmp

        def extract_atom_iarray(self, name, nelem, dim=1):
          if dim == 1:
              tmp = self.lmp.extract_atom(name, 0)
              ptr = cast(tmp, POINTER(c_int * nelem))
          else:
              tmp = self.lmp.extract_atom(name, 1)
              ptr = cast(tmp[0], POINTER(c_int * nelem * dim))

          a = np.frombuffer(ptr.contents, dtype=np.intc)
          a.shape = (nelem, dim)
          return a

        def extract_atom_darray(self, name, nelem, dim=1):
          if dim == 1:
              tmp = self.lmp.extract_atom(name, 2)
              ptr = cast(tmp, POINTER(c_double * nelem))
          else:
              tmp = self.lmp.extract_atom(name, 3)
              ptr = cast(tmp[0], POINTER(c_double * nelem * dim))

          a = np.frombuffer(ptr.contents)
          a.shape = (nelem, dim)
          return a

      self._numpy = LammpsNumpyWrapper(self)
    return self._numpy

  # extract compute info
  
  def extract_compute(self,id,style,type):