Commit 9ca39c89 authored by Steve Plimpton's avatar Steve Plimpton Committed by GitHub
Browse files

Merge pull request #639 from rbberger/python_interface_improvements

Python interface improvements
parents 99791ce0 0b5a2e19
Loading
Loading
Loading
Loading
+57 −0
Original line number Diff line number Diff line
@@ -32,6 +32,13 @@ import select
import re
import sys

def get_ctypes_int(size):
  if size == 4:
    return c_int32
  elif size == 8:
    return c_int64
  return c_int

class MPIAbortException(Exception):
  def __init__(self, message):
    self.message = message
@@ -162,6 +169,14 @@ class lammps(object):
        pythonapi.PyCObject_AsVoidPtr.argtypes = [py_object]
        self.lmp = c_void_p(pythonapi.PyCObject_AsVoidPtr(ptr))

    # optional numpy support (lazy loading)
    self._numpy = None

    # set default types
    self.c_bigint = get_ctypes_int(self.extract_setting("bigint"))
    self.c_tagint = get_ctypes_int(self.extract_setting("tagint"))
    self.c_imageint = get_ctypes_int(self.extract_setting("imageint"))

  def __del__(self):
    if self.lmp and self.opened:
      self.lib.lammps_close(self.lmp)
@@ -236,6 +251,48 @@ class lammps(object):
    ptr = self.lib.lammps_extract_atom(self.lmp,name)
    return ptr

  # extract lammps type byte sizes

  def extract_setting(self, name):
    if name: name = name.encode()
    self.lib.lammps_extract_atom.restype = c_int
    return int(self.lib.lammps_extract_setting(self.lmp,name))

  @property
  def numpy(self):
    if not self._numpy:
      import numpy as np
      class LammpsNumpyWrapper:
        def __init__(self, lmp):
          self.lmp = lmp

        def extract_atom_iarray(self, name, nelem, dim=1):
          if dim == 1:
              tmp = self.lmp.extract_atom(name, 0)
              ptr = cast(tmp, POINTER(c_int * nelem))
          else:
              tmp = self.lmp.extract_atom(name, 1)
              ptr = cast(tmp[0], POINTER(c_int * nelem * dim))

          a = np.frombuffer(ptr.contents, dtype=np.intc)
          a.shape = (nelem, dim)
          return a

        def extract_atom_darray(self, name, nelem, dim=1):
          if dim == 1:
              tmp = self.lmp.extract_atom(name, 2)
              ptr = cast(tmp, POINTER(c_double * nelem))
          else:
              tmp = self.lmp.extract_atom(name, 3)
              ptr = cast(tmp[0], POINTER(c_double * nelem * dim))

          a = np.frombuffer(ptr.contents)
          a.shape = (nelem, dim)
          return a

      self._numpy = LammpsNumpyWrapper(self)
    return self._numpy

  # extract compute info
  
  def extract_compute(self,id,style,type):
+24 −0
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@
#include "comm.h"
#include "memory.h"
#include "error.h"
#include "force.h"

using namespace LAMMPS_NS;

@@ -370,6 +371,7 @@ void *lammps_extract_global(void *ptr, char *name)
  if (strcmp(name,"nlocal") == 0) return (void *) &lmp->atom->nlocal;
  if (strcmp(name,"nghost") == 0) return (void *) &lmp->atom->nghost;
  if (strcmp(name,"nmax") == 0) return (void *) &lmp->atom->nmax;
  if (strcmp(name,"ntypes") == 0) return (void *) &lmp->atom->ntypes;
  if (strcmp(name,"ntimestep") == 0) return (void *) &lmp->update->ntimestep;

  if (strcmp(name,"units") == 0) return (void *) lmp->update->unit_style;
@@ -384,6 +386,28 @@ void *lammps_extract_global(void *ptr, char *name)
  if (strcmp(name,"atime") == 0) return (void *) &lmp->update->atime;
  if (strcmp(name,"atimestep") == 0) return (void *) &lmp->update->atimestep;

  // global constants defined by units

  if (strcmp(name,"boltz") == 0) return (void *) &lmp->force->boltz;
  if (strcmp(name,"hplanck") == 0) return (void *) &lmp->force->hplanck;
  if (strcmp(name,"mvv2e") == 0) return (void *) &lmp->force->mvv2e;
  if (strcmp(name,"ftm2v") == 0) return (void *) &lmp->force->ftm2v;
  if (strcmp(name,"mv2d") == 0) return (void *) &lmp->force->mv2d;
  if (strcmp(name,"nktv2p") == 0) return (void *) &lmp->force->nktv2p;
  if (strcmp(name,"qqr2e") == 0) return (void *) &lmp->force->qqr2e;
  if (strcmp(name,"qe2f") == 0) return (void *) &lmp->force->qe2f;
  if (strcmp(name,"vxmu2f") == 0) return (void *) &lmp->force->vxmu2f;
  if (strcmp(name,"xxt2kmu") == 0) return (void *) &lmp->force->xxt2kmu;
  if (strcmp(name,"dielectric") == 0) return (void *) &lmp->force->dielectric;
  if (strcmp(name,"qqrd2e") == 0) return (void *) &lmp->force->qqrd2e;
  if (strcmp(name,"e_mass") == 0) return (void *) &lmp->force->e_mass;
  if (strcmp(name,"hhmrr2e") == 0) return (void *) &lmp->force->hhmrr2e;
  if (strcmp(name,"mvh2r") == 0) return (void *) &lmp->force->mvh2r;

  if (strcmp(name,"angstrom") == 0) return (void *) &lmp->force->angstrom;
  if (strcmp(name,"femtosecond") == 0) return (void *) &lmp->force->femtosecond;
  if (strcmp(name,"qelectron") == 0) return (void *) &lmp->force->qelectron;

  return NULL;
}