Commit 2be848e5 authored by Stan Moore's avatar Stan Moore
Browse files

Remove team from compute_yi

parent aa2b8857
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -85,7 +85,7 @@ public:
  void operator() (TagPairSNAPComputeBi,const typename Kokkos::TeamPolicy<DeviceType, TagPairSNAPComputeBi>::member_type& team) const;

  KOKKOS_INLINE_FUNCTION
  void operator() (TagPairSNAPComputeYi,const typename Kokkos::TeamPolicy<DeviceType, TagPairSNAPComputeYi>::member_type& team) const;
  void operator() (TagPairSNAPComputeYi,const int& ii) const;

  KOKKOS_INLINE_FUNCTION
  void operator() (TagPairSNAPComputeDuidrj,const typename Kokkos::TeamPolicy<DeviceType, TagPairSNAPComputeDuidrj>::member_type& team) const;
+4 −4
Original line number Diff line number Diff line
@@ -250,7 +250,8 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
    Kokkos::parallel_for("ComputeBeta",policy_beta,*this);

    //ComputeYi
    typename Kokkos::TeamPolicy<DeviceType, TagPairSNAPComputeYi> policy_yi(chunk_size,yi_team_size,vector_length);
    //typename Kokkos::TeamPolicy<DeviceType, TagPairSNAPComputeYi> policy_yi(chunk_size,yi_team_size,vector_length);
    typename Kokkos::RangePolicy<DeviceType, TagPairSNAPComputeYi> policy_yi(0,chunk_size);
    Kokkos::parallel_for("ComputeYi",policy_yi,*this);

    //ComputeDuidrj
@@ -529,10 +530,9 @@ void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPComputeUi,const typename

template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPComputeYi,const typename Kokkos::TeamPolicy<DeviceType, TagPairSNAPComputeYi>::member_type& team) const {
  int ii = team.league_rank();
void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPComputeYi,const int &ii) const {
  SNAKokkos<DeviceType> my_sna = snaKK;
  my_sna.compute_yi(team,ii,d_beta);
  my_sna.compute_yi(ii,d_beta);
}

template<class DeviceType>
+1 −1
Original line number Diff line number Diff line
@@ -90,7 +90,7 @@ inline
  KOKKOS_INLINE_FUNCTION
  void compute_zi(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, int);    // ForceSNAP
  KOKKOS_INLINE_FUNCTION
  void compute_yi(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, int,
  void compute_yi(int,
   const Kokkos::View<F_FLOAT**, DeviceType> &beta); // ForceSNAP
  KOKKOS_INLINE_FUNCTION
  void compute_bi(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, int);    // ForceSNAP
+14 −12
Original line number Diff line number Diff line
@@ -386,25 +386,27 @@ void SNAKokkos<DeviceType>::compute_zi(const typename Kokkos::TeamPolicy<DeviceT

template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void SNAKokkos<DeviceType>::compute_yi(const typename Kokkos::TeamPolicy<DeviceType>::member_type& team, int iatom,
void SNAKokkos<DeviceType>::compute_yi(int iatom,
 const Kokkos::View<F_FLOAT**, DeviceType> &beta)
{
  double betaj;
  const int ii = iatom;

  {
    Kokkos::parallel_for(Kokkos::TeamThreadRange(team,ylist.extent(1)),
        [&] (const int& i) {
  //{
    //Kokkos::parallel_for(Kokkos::TeamThreadRange(team,ylist.extent(1)),
    //    [&] (const int& i) {
  for (int i = 0; i < ylist.extent(1); i++) {
      ylist(iatom,i).re = 0.0;
      ylist(iatom,i).im = 0.0;
    });
    }
  //  });
  //}

  //int flopsum = 0;

  Kokkos::parallel_for(Kokkos::TeamThreadRange(team,idxz_max),
      [&] (const int& jjz) {
  //for(int jjz = 0; jjz < idxz_max; jjz++) {
  //Kokkos::parallel_for(Kokkos::TeamThreadRange(team,idxz_max),
  //    [&] (const int& jjz) {
  for (int jjz = 0; jjz < idxz_max; jjz++) {
    const int j1 = idxz[jjz].j1;
    const int j2 = idxz[jjz].j2;
    const int j = idxz[jjz].j;
@@ -474,12 +476,12 @@ void SNAKokkos<DeviceType>::compute_yi(const typename Kokkos::TeamPolicy<DeviceT
      betaj = beta(ii,jjb)*(j1+1)/(j+1.0);
    }

  Kokkos::single(Kokkos::PerThread(team), [&] () {
  //Kokkos::single(Kokkos::PerThread(team), [&] () {
    Kokkos::atomic_add(&(ylist(iatom,jju).re), betaj*ztmp_r);
    Kokkos::atomic_add(&(ylist(iatom,jju).im), betaj*ztmp_i);
  });
  //});

  }); // end loop over jjz
  }//); // end loop over jjz

  //printf("sum %i\n",flopsum);
}