Commit d730cda2 authored by sjplimp's avatar sjplimp Committed by GitHub
Browse files

Merge pull request #37 from rbberger/library_interface_abort

Allow detection of MPI_Abort condition in library call
parents 6f4b7268 90ff54c4
Loading
Loading
Loading
Loading
+25 −3
Original line number Original line Diff line number Diff line
@@ -28,6 +28,15 @@ import os
import select
import select
import re
import re



class MPIAbortException(Exception):
  def __init__(self, message):
    self.message = message

  def __str__(self):
    return repr(self.message)


class lammps(object):
class lammps(object):
  # detect if Python is using version of mpi4py that can pass a communicator
  # detect if Python is using version of mpi4py that can pass a communicator


@@ -43,6 +52,7 @@ class lammps(object):
  # create instance of LAMMPS
  # create instance of LAMMPS


  def __init__(self,name="",cmdargs=None,ptr=None,comm=None):
  def __init__(self,name="",cmdargs=None,ptr=None,comm=None):
    self.comm = comm


    # determine module location
    # determine module location


@@ -150,10 +160,14 @@ class lammps(object):
    if cmd: cmd = cmd.encode()
    if cmd: cmd = cmd.encode()
    self.lib.lammps_command(self.lmp,cmd)
    self.lib.lammps_command(self.lmp,cmd)


    if self.lib.lammps_has_error(self.lmp):
    if self.uses_exceptions and self.lib.lammps_has_error(self.lmp):
      sb = create_string_buffer(100)
      sb = create_string_buffer(100)
      self.lib.lammps_get_last_error_message(self.lmp, sb, 100)
      error_type = self.lib.lammps_get_last_error_message(self.lmp, sb, 100)
      raise Exception(sb.value.decode().strip())
      error_msg = sb.value.decode().strip()

      if error_type == 2:
        raise MPIAbortException(error_msg)
      raise Exception(error_msg)


  def extract_global(self,name,type):
  def extract_global(self,name,type):
    if name: name = name.encode()
    if name: name = name.encode()
@@ -286,6 +300,14 @@ class lammps(object):
    if name: name = name.encode()
    if name: name = name.encode()
    self.lib.lammps_scatter_atoms(self.lmp,name,type,count,data)
    self.lib.lammps_scatter_atoms(self.lmp,name,type,count,data)


  @property
  def uses_exceptions(self):
    try:
      if self.lib.lammps_has_error:
        return True
    except(AttributeError):
      return False

# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
+21 −4
Original line number Original line Diff line number Diff line
@@ -22,7 +22,12 @@ using namespace LAMMPS_NS;


/* ---------------------------------------------------------------------- */
/* ---------------------------------------------------------------------- */


Error::Error(LAMMPS *lmp) : Pointers(lmp), last_error_message(NULL) {}
Error::Error(LAMMPS *lmp) : Pointers(lmp) {
#ifdef LAMMPS_EXCEPTIONS
  last_error_message = NULL;
  last_error_type = ERROR_NONE;
#endif
}


/* ----------------------------------------------------------------------
/* ----------------------------------------------------------------------
   called by all procs in universe
   called by all procs in universe
@@ -198,6 +203,7 @@ void Error::done(int status)
  exit(status);
  exit(status);
}
}


#ifdef LAMMPS_EXCEPTIONS
/* ----------------------------------------------------------------------
/* ----------------------------------------------------------------------
   return the last error message reported by LAMMPS (only used if
   return the last error message reported by LAMMPS (only used if
   compiled with -DLAMMPS_EXCEPTIONS)
   compiled with -DLAMMPS_EXCEPTIONS)
@@ -208,13 +214,22 @@ char * Error::get_last_error() const
  return last_error_message;
  return last_error_message;
}
}


/* ----------------------------------------------------------------------
   return the type of the last error reported by LAMMPS (only used if
   compiled with -DLAMMPS_EXCEPTIONS)
------------------------------------------------------------------------- */

ErrorType Error::get_last_error_type() const
{
  return last_error_type;
}


/* ----------------------------------------------------------------------
/* ----------------------------------------------------------------------
   set the last error message (only used if compiled with
   set the last error message and error type
   -DLAMMPS_EXCEPTIONS)
   (only used if compiled with -DLAMMPS_EXCEPTIONS)
------------------------------------------------------------------------- */
------------------------------------------------------------------------- */


void Error::set_last_error(const char * msg)
void Error::set_last_error(const char * msg, ErrorType type)
{
{
  delete [] last_error_message;
  delete [] last_error_message;


@@ -224,4 +239,6 @@ void Error::set_last_error(const char * msg)
  } else {
  } else {
    last_error_message = NULL;
    last_error_message = NULL;
  }
  }
  last_error_type = type;
}
}
#endif
+13 −33
Original line number Original line Diff line number Diff line
@@ -15,41 +15,14 @@
#define LMP_ERROR_H
#define LMP_ERROR_H


#include "pointers.h"
#include "pointers.h"
#include <string>
#include <exception>


namespace LAMMPS_NS {
#ifdef LAMMPS_EXCEPTIONS

#include "exceptions.h"
class LAMMPSException : public std::exception
#endif
{
public:
  std::string message;

  LAMMPSException(std::string msg) : message(msg) {
  }

  ~LAMMPSException() throw() {
  }

  virtual const char * what() const throw() {
    return message.c_str();
  }
};

class LAMMPSAbortException : public LAMMPSException {
public:
  MPI_Comm universe;


  LAMMPSAbortException(std::string msg, MPI_Comm universe) :
namespace LAMMPS_NS {
    LAMMPSException(msg),
    universe(universe)
  {
  }
};


class Error : protected Pointers {
class Error : protected Pointers {
  char * last_error_message;

 public:
 public:
  Error(class LAMMPS *);
  Error(class LAMMPS *);


@@ -63,8 +36,15 @@ class Error : protected Pointers {
  void message(const char *, int, const char *, int = 1);
  void message(const char *, int, const char *, int = 1);
  void done(int = 0); // 1 would be fully backwards compatible
  void done(int = 0); // 1 would be fully backwards compatible


#ifdef LAMMPS_EXCEPTIONS
  char *    get_last_error() const;
  char *    get_last_error() const;
  void   set_last_error(const char * msg);
  ErrorType get_last_error_type() const;
  void   set_last_error(const char * msg, ErrorType type = ERROR_NORMAL);

 private:
  char * last_error_message;
  ErrorType last_error_type;
#endif
};
};


}
}

src/exceptions.h

0 → 100644
+58 −0
Original line number Original line Diff line number Diff line
/* -*- c++ -*- ----------------------------------------------------------
   LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
   http://lammps.sandia.gov, Sandia National Laboratories
   Steve Plimpton, sjplimp@sandia.gov

   Copyright (2003) Sandia Corporation.  Under the terms of Contract
   DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
   certain rights in this software.  This software is distributed under
   the GNU General Public License.

   See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */

#ifndef LMP_EXCEPTIONS_H
#define LMP_EXCEPTIONS_H

#include <mpi.h>
#include <string>
#include <exception>

namespace LAMMPS_NS {

class LAMMPSException : public std::exception
{
public:
  std::string message;

  LAMMPSException(std::string msg) : message(msg) {
  }

  ~LAMMPSException() throw() {
  }

  virtual const char * what() const throw() {
    return message.c_str();
  }
};

class LAMMPSAbortException : public LAMMPSException {
public:
  MPI_Comm universe;

  LAMMPSAbortException(std::string msg, MPI_Comm universe) :
    LAMMPSException(msg),
    universe(universe)
  {
  }
};

enum ErrorType {
   ERROR_NONE   = 0,
   ERROR_NORMAL = 1,
   ERROR_ABORT  = 2
};

}

#endif
+325 −218
Original line number Original line Diff line number Diff line
@@ -38,6 +38,45 @@


using namespace LAMMPS_NS;
using namespace LAMMPS_NS;


/* ----------------------------------------------------------------------
   Utility macros for optional code path which captures all exceptions
   and stores the last error message. These assume there is a variable lmp
   which is a pointer to the current LAMMPS instance.

   Usage:

   BEGIN_CAPTURE
   {
     // code paths which might throw exception
     ...
   }
   END_CAPTURE
------------------------------------------------------------------------- */

#ifdef LAMMPS_EXCEPTIONS
#define BEGIN_CAPTURE \
  Error * error = lmp->error; \
  try

#define END_CAPTURE \
  catch(LAMMPSAbortException & ae) { \
    int nprocs = 0; \
    MPI_Comm_size(ae.universe, &nprocs ); \
    \
    if (nprocs > 1) { \
      error->set_last_error(ae.message.c_str(), ERROR_ABORT); \
    } else { \
      error->set_last_error(ae.message.c_str(), ERROR_NORMAL); \
    } \
  } catch(LAMMPSException & e) { \
    error->set_last_error(e.message.c_str(), ERROR_NORMAL); \
  }
#else
#define BEGIN_CAPTURE
#define END_CAPTURE
#endif


/* ----------------------------------------------------------------------
/* ----------------------------------------------------------------------
   create an instance of LAMMPS and return pointer to it
   create an instance of LAMMPS and return pointer to it
   pass in command-line args and MPI communicator to run on
   pass in command-line args and MPI communicator to run on
@@ -45,8 +84,20 @@ using namespace LAMMPS_NS;


void lammps_open(int argc, char **argv, MPI_Comm communicator, void **ptr)
void lammps_open(int argc, char **argv, MPI_Comm communicator, void **ptr)
{
{
#ifdef LAMMPS_EXCEPTIONS
  try
  {
    LAMMPS *lmp = new LAMMPS(argc,argv,communicator);
    *ptr = (void *) lmp;
  }
  catch(LAMMPSException & e) {
    fprintf(stderr, "LAMMPS Exception: %s", e.message.c_str());
    *ptr = (void*) NULL;
  }
#else
  LAMMPS *lmp = new LAMMPS(argc,argv,communicator);
  LAMMPS *lmp = new LAMMPS(argc,argv,communicator);
  *ptr = (void *) lmp;
  *ptr = (void *) lmp;
#endif
}
}


/* ----------------------------------------------------------------------
/* ----------------------------------------------------------------------
@@ -68,9 +119,21 @@ void lammps_open_no_mpi(int argc, char **argv, void **ptr)


  MPI_Comm communicator = MPI_COMM_WORLD;
  MPI_Comm communicator = MPI_COMM_WORLD;


#ifdef LAMMPS_EXCEPTIONS
  try
  {
    LAMMPS *lmp = new LAMMPS(argc,argv,communicator);
    LAMMPS *lmp = new LAMMPS(argc,argv,communicator);
    *ptr = (void *) lmp;
    *ptr = (void *) lmp;
  }
  }
  catch(LAMMPSException & e) {
    fprintf(stderr, "LAMMPS Exception: %s", e.message.c_str());
    *ptr = (void*) NULL;
  }
#else
  LAMMPS *lmp = new LAMMPS(argc,argv,communicator);
  *ptr = (void *) lmp;
#endif
}


/* ----------------------------------------------------------------------
/* ----------------------------------------------------------------------
   destruct an instance of LAMMPS
   destruct an instance of LAMMPS
@@ -99,8 +162,13 @@ int lammps_version(void *ptr)
void lammps_file(void *ptr, char *str)
void lammps_file(void *ptr, char *str)
{
{
  LAMMPS *lmp = (LAMMPS *) ptr;
  LAMMPS *lmp = (LAMMPS *) ptr;

  BEGIN_CAPTURE
  {
    lmp->input->file(str);
    lmp->input->file(str);
  }
  }
  END_CAPTURE
}


/* ----------------------------------------------------------------------
/* ----------------------------------------------------------------------
   process a single input command in str
   process a single input command in str
@@ -109,14 +177,15 @@ void lammps_file(void *ptr, char *str)
char *lammps_command(void *ptr, char *str)
char *lammps_command(void *ptr, char *str)
{
{
  LAMMPS *lmp = (LAMMPS *) ptr;
  LAMMPS *lmp = (LAMMPS *) ptr;
  Error * error = lmp->error;
  char * result = NULL;


  try {
  BEGIN_CAPTURE
    return lmp->input->one(str);
  {
  } catch(LAMMPSException & e) {
    result = lmp->input->one(str);
    error->set_last_error(e.message.c_str());
    return NULL;
  }
  }
  END_CAPTURE

  return result;
}
}


/* ----------------------------------------------------------------------
/* ----------------------------------------------------------------------
@@ -215,6 +284,8 @@ void *lammps_extract_compute(void *ptr, char *id, int style, int type)
{
{
  LAMMPS *lmp = (LAMMPS *) ptr;
  LAMMPS *lmp = (LAMMPS *) ptr;


  BEGIN_CAPTURE
  {
    int icompute = lmp->modify->find_compute(id);
    int icompute = lmp->modify->find_compute(id);
    if (icompute < 0) return NULL;
    if (icompute < 0) return NULL;
    Compute *compute = lmp->modify->compute[icompute];
    Compute *compute = lmp->modify->compute[icompute];
@@ -267,6 +338,8 @@ void *lammps_extract_compute(void *ptr, char *id, int style, int type)
        return (void *) compute->array_local;
        return (void *) compute->array_local;
      }
      }
    }
    }
  }
  END_CAPTURE


  return NULL;
  return NULL;
}
}
@@ -300,6 +373,8 @@ void *lammps_extract_fix(void *ptr, char *id, int style, int type,
{
{
  LAMMPS *lmp = (LAMMPS *) ptr;
  LAMMPS *lmp = (LAMMPS *) ptr;


  BEGIN_CAPTURE
  {
    int ifix = lmp->modify->find_fix(id);
    int ifix = lmp->modify->find_fix(id);
    if (ifix < 0) return NULL;
    if (ifix < 0) return NULL;
    Fix *fix = lmp->modify->fix[ifix];
    Fix *fix = lmp->modify->fix[ifix];
@@ -334,6 +409,8 @@ void *lammps_extract_fix(void *ptr, char *id, int style, int type,
      if (type == 1) return (void *) fix->vector_local;
      if (type == 1) return (void *) fix->vector_local;
      if (type == 2) return (void *) fix->array_local;
      if (type == 2) return (void *) fix->array_local;
    }
    }
  }
  END_CAPTURE


  return NULL;
  return NULL;
}
}
@@ -369,6 +446,8 @@ void *lammps_extract_variable(void *ptr, char *name, char *group)
{
{
  LAMMPS *lmp = (LAMMPS *) ptr;
  LAMMPS *lmp = (LAMMPS *) ptr;


  BEGIN_CAPTURE
  {
    int ivar = lmp->input->variable->find(name);
    int ivar = lmp->input->variable->find(name);
    if (ivar < 0) return NULL;
    if (ivar < 0) return NULL;


@@ -386,6 +465,8 @@ void *lammps_extract_variable(void *ptr, char *name, char *group)
      lmp->input->variable->compute_atom(ivar,igroup,vector,1,0);
      lmp->input->variable->compute_atom(ivar,igroup,vector,1,0);
      return (void *) vector;
      return (void *) vector;
    }
    }
  }
  END_CAPTURE


  return NULL;
  return NULL;
}
}
@@ -399,7 +480,14 @@ void *lammps_extract_variable(void *ptr, char *name, char *group)
int lammps_set_variable(void *ptr, char *name, char *str)
int lammps_set_variable(void *ptr, char *name, char *str)
{
{
  LAMMPS *lmp = (LAMMPS *) ptr;
  LAMMPS *lmp = (LAMMPS *) ptr;
  int err = lmp->input->variable->set_string(name,str);
  int err = -1;

  BEGIN_CAPTURE
  {
    err = lmp->input->variable->set_string(name,str);
  }
  END_CAPTURE

  return err;
  return err;
}
}


@@ -414,9 +502,14 @@ int lammps_set_variable(void *ptr, char *name, char *str)
double lammps_get_thermo(void *ptr, char *name)
double lammps_get_thermo(void *ptr, char *name)
{
{
  LAMMPS *lmp = (LAMMPS *) ptr;
  LAMMPS *lmp = (LAMMPS *) ptr;
  double dval;
  double dval = 0.0;


  BEGIN_CAPTURE
  {
    lmp->output->thermo->evaluate_keyword(name,&dval);
    lmp->output->thermo->evaluate_keyword(name,&dval);
  }
  END_CAPTURE

  return dval;
  return dval;
}
}


@@ -449,6 +542,8 @@ void lammps_gather_atoms(void *ptr, char *name,
{
{
  LAMMPS *lmp = (LAMMPS *) ptr;
  LAMMPS *lmp = (LAMMPS *) ptr;


  BEGIN_CAPTURE
  {
    // error if tags are not defined or not consecutive
    // error if tags are not defined or not consecutive


    int flag = 0;
    int flag = 0;
@@ -523,6 +618,8 @@ void lammps_gather_atoms(void *ptr, char *name,
      lmp->memory->destroy(copy);
      lmp->memory->destroy(copy);
    }
    }
  }
  }
  END_CAPTURE
}


/* ----------------------------------------------------------------------
/* ----------------------------------------------------------------------
   scatter the named atom-based entity across all processors
   scatter the named atom-based entity across all processors
@@ -538,6 +635,8 @@ void lammps_scatter_atoms(void *ptr, char *name,
{
{
  LAMMPS *lmp = (LAMMPS *) ptr;
  LAMMPS *lmp = (LAMMPS *) ptr;


  BEGIN_CAPTURE
  {
    // error if tags are not defined or not consecutive or no atom map
    // error if tags are not defined or not consecutive or no atom map


    int flag = 0;
    int flag = 0;
@@ -600,7 +699,10 @@ void lammps_scatter_atoms(void *ptr, char *name,
      }
      }
    }
    }
  }
  }
  END_CAPTURE
}


#ifdef LAMMPS_EXCEPTIONS
/* ----------------------------------------------------------------------
/* ----------------------------------------------------------------------
   Check if a new error message
   Check if a new error message
------------------------------------------------------------------------- */
------------------------------------------------------------------------- */
@@ -613,6 +715,9 @@ int lammps_has_error(void *ptr) {


/* ----------------------------------------------------------------------
/* ----------------------------------------------------------------------
   Copy the last error message of LAMMPS into a character buffer
   Copy the last error message of LAMMPS into a character buffer
   The return value encodes which type of error it is.
   1 = normal error (recoverable)
   2 = abort error (non-recoverable)
------------------------------------------------------------------------- */
------------------------------------------------------------------------- */


int lammps_get_last_error_message(void *ptr, char * buffer, int buffer_size) {
int lammps_get_last_error_message(void *ptr, char * buffer, int buffer_size) {
@@ -620,9 +725,11 @@ int lammps_get_last_error_message(void *ptr, char * buffer, int buffer_size) {
  Error * error = lmp->error;
  Error * error = lmp->error;


  if(error->get_last_error()) {
  if(error->get_last_error()) {
    int error_type = error->get_last_error_type();
    strncpy(buffer, error->get_last_error(), buffer_size-1);
    strncpy(buffer, error->get_last_error(), buffer_size-1);
    error->set_last_error(NULL);
    error->set_last_error(NULL, ERROR_NONE);
    return 1;
    return error_type;
  }
  }
  return 0;
  return 0;
}
}
#endif
Loading