Unverified Commit 0b8136a3 authored by Richard Berger's avatar Richard Berger
Browse files

Add extract_compute, extract_fix, and extract_variable to lammps.numpy

parent a216d3f5
Loading
Loading
Loading
Loading
+48 −0
Original line number Diff line number Diff line
@@ -435,6 +435,54 @@ class lammps(object):

          return self.darray(raw_ptr, nelem, dim)

        def extract_compute(self, cid, style, datatype):
          value = self.lmp.extract_compute(cid, style, datatype)

          if style in (LMP_STYLE_GLOBAL, LMP_STYLE_LOCAL):
            if datatype == LMP_TYPE_VECTOR:
              nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_VECTOR)
              print("NROWS", nrows)
              return self.darray(value, nrows)
            elif datatype == LMP_TYPE_ARRAY:
              nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_ROWS)
              ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
              return self.darray(value, nrows, ncols)
          elif style == LMP_STYLE_ATOM:
            if datatype == LMP_TYPE_VECTOR:
              nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
              return self.darray(value, nlocal)
            elif datatype == LMP_TYPE_ARRAY:
              nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
              ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
              return self.darray(value, nlocal, ncols)
          return value

        def extract_fix(self, fid, style, datatype, nrow=0, ncol=0):
          value = self.lmp.extract_fix(fid, style, datatype, nrow, ncol)
          if style == LMP_STYLE_ATOM:
            if datatype == LMP_TYPE_VECTOR:
              nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
              return self.darray(value, nlocal)
            elif datatype == LMP_TYPE_ARRAY:
              nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
              ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
              return self.darray(value, nlocal, ncols)
          elif style == LMP_STYLE_LOCAL:
            if datatype == LMP_TYPE_VECTOR:
              nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
              return self.darray(value, nrows)
            elif datatype == LMP_TYPE_ARRAY:
              nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
              ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
              return self.darray(value, nrows, ncols)
          return value

        def extract_variable(self, name, group=None, datatype=LMP_VAR_EQUAL):
          value = self.lmp.extract_variable(name, group, datatype)
          if datatype == LMP_VAR_ATOM:
            return np.ctypeslib.as_array(value)
          return value

        def iarray(self, c_int_type, raw_ptr, nelem, dim=1):
          np_int_type = self._ctype_to_numpy_int(c_int_type)