omega_h
Reliable mesh adaptation
Omega_h_comm.hpp
1 #ifndef OMEGA_H_COMM_HPP
2 #define OMEGA_H_COMM_HPP
3 
4 #include <memory>
5 
6 #include <Omega_h_mpi.h>
7 #include <Omega_h_array.hpp>
8 #include <Omega_h_defines.hpp>
9 #include <Omega_h_fail.hpp>
10 #include <Omega_h_future.hpp>
11 #include <Omega_h_int128.hpp>
12 #ifndef OMEGA_H_FAIL_HPP
13 #error "included fail but guard not defined"
14 #endif
15 #ifndef OMEGA_H_CHECK
16 #error "included fail but check not defined"
17 #endif
18 
19 
20 namespace Omega_h {
21 
22 static_assert(sizeof(int) == 4, "Omega_h::Comm assumes 32-bit int");
23 
24 class Library;
25 class Comm;
26 
27 typedef std::shared_ptr<Comm> CommPtr;
28 
29 class Comm {
30 #ifdef OMEGA_H_USE_MPI
31  MPI_Comm impl_;
32 #endif
33  Library* library_;
34  Read<I32> srcs_;
35  Read<I32> dsts_;
36  HostRead<I32> host_srcs_;
37  HostRead<I32> host_dsts_;
38  LO self_src_;
39  LO self_dst_;
40 
41  public:
42  Comm();
43 #ifdef OMEGA_H_USE_MPI
44  Comm(Library* library, MPI_Comm impl);
45  Comm(Library* library, MPI_Comm impl, Read<I32> srcs, Read<I32> dests);
46  MPI_Comm get_impl() const { return impl_; }
47 #else
48  Comm(Library* library, bool is_graph, bool sends_to_self);
49 #endif
50  Comm(Comm const&) = delete;
51  Comm(Comm&&) = delete;
52  Comm& operator=(Comm const&) = delete;
53  Comm& operator=(Comm&&) = delete;
54  ~Comm();
55  Library* library() const;
56  I32 rank() const;
57  I32 size() const;
58  CommPtr dup() const;
59  CommPtr split(I32 color, I32 key) const;
60  CommPtr graph(Read<I32> dsts) const;
61  CommPtr graph_adjacent(Read<I32> srcs, Read<I32> dsts) const;
62  CommPtr graph_inverse() const;
63  Read<I32> sources() const;
64  Read<I32> destinations() const;
65  template <typename T>
66  T allreduce(T x, Omega_h_Op op) const;
67  bool reduce_or(bool x) const;
68  bool reduce_and(bool x) const;
69  Int128 add_int128(Int128 x) const;
70  template <typename T>
71  T exscan(T x, Omega_h_Op op) const;
72  template <typename T>
73  void bcast(T& x, int root_rank=0) const;
74  void bcast_string(std::string& s, int root_rank=0) const;
75  template <typename T>
76  Read<T> allgather(T x) const;
77  template <typename T>
78  Read<T> alltoall(Read<T> x) const;
79  template <typename T>
80  Read<T> alltoallv(
81  Read<T> sendbuf, Read<LO> sdispls, Read<LO> rdispls, Int width) const;
82  template <typename T>
83  Future<T> ialltoallv(
84  Read<T> sendbuf, Read<LO> sdispls, Read<LO> rdispls, Int width) const;
85  void barrier() const;
86  template<typename T>
87  void send(int rank, const T& x);
88  template<typename T>
89  void recv(int rank, T& x);
90 };
91 
92 #ifdef OMEGA_H_USE_MPI
93 
94 #ifdef OMPI_MPI_H
95 /* OpenMPI defines MPI_UNWEIGHTED using (void*)
96  * which causes compile errors with strict
97  * compile options
98  */
99 #define OMEGA_H_MPI_UNWEIGHTED reinterpret_cast<int*>(MPI_UNWEIGHTED)
100 #else
101 #define OMEGA_H_MPI_UNWEIGHTED MPI_UNWEIGHTED
102 #endif
103 
104 template <class T>
105 struct MpiTraits;
106 
107 template <>
108 struct MpiTraits<char> {
109  static MPI_Datatype datatype() { return MPI_CHAR; }
110 };
111 
112 template <>
113 struct MpiTraits<signed char> {
114  static MPI_Datatype datatype() { return MPI_SIGNED_CHAR; }
115 };
116 
117 template <>
118 struct MpiTraits<unsigned char> {
119  static MPI_Datatype datatype() { return MPI_UNSIGNED_CHAR; }
120 };
121 
122 template <>
123 struct MpiTraits<float> {
124  static MPI_Datatype datatype() { return MPI_FLOAT; }
125 };
126 
127 template <>
128 struct MpiTraits<double> {
129  static MPI_Datatype datatype() { return MPI_DOUBLE; }
130 };
131 
132 template <>
133 struct MpiTraits<unsigned short> {
134  static MPI_Datatype datatype() { return MPI_UNSIGNED_SHORT; }
135 };
136 
137 template <>
138 struct MpiTraits<unsigned> {
139  static MPI_Datatype datatype() { return MPI_UNSIGNED; }
140 };
141 
142 template <>
143 struct MpiTraits<unsigned long> {
144  static MPI_Datatype datatype() { return MPI_UNSIGNED_LONG; }
145 };
146 
147 template <>
148 struct MpiTraits<unsigned long long> {
149  static MPI_Datatype datatype() { return MPI_UNSIGNED_LONG_LONG; }
150 };
151 
152 template <>
153 struct MpiTraits<short int> {
154  static MPI_Datatype datatype() { return MPI_SHORT; }
155 };
156 
157 template <>
158 struct MpiTraits<int> {
159  static MPI_Datatype datatype() { return MPI_INT; }
160 };
161 
162 template <>
163 struct MpiTraits<long int> {
164  static MPI_Datatype datatype() { return MPI_LONG; }
165 };
166 
167 template <>
168 struct MpiTraits<long long int> {
169  static MPI_Datatype datatype() { return MPI_LONG_LONG_INT; }
170 };
171 
172 inline MPI_Op mpi_op(Omega_h_Op op) {
173  switch (op) {
174  case OMEGA_H_MIN:
175  return MPI_MIN;
176  case OMEGA_H_MAX:
177  return MPI_MAX;
178  case OMEGA_H_SUM:
179  return MPI_SUM;
180  }
181  OMEGA_H_NORETURN(MPI_MIN);
182 }
183 
184 
185 #endif
186 
187 
188 template<typename T>
189 void Comm::send(int rank, const T& x) {
190 #ifdef OMEGA_H_USE_MPI
191  size_t sz = x.size();
192  int omega_h_mpi_error = MPI_Send(&sz, 1, MpiTraits<size_t>::datatype(), rank, 0, impl_);
193  OMEGA_H_CHECK(MPI_SUCCESS == omega_h_mpi_error);
194  omega_h_mpi_error = MPI_Send(x.data(), x.size(), MpiTraits<typename T::value_type>::datatype(), rank, 0, impl_);;
195  OMEGA_H_CHECK(MPI_SUCCESS == omega_h_mpi_error);
196 #else
197  (void)rank;
198  (void)x;
199 #endif
200 }
201 
202 template<typename T>
203 void Comm::recv(int rank, T& x) {
204 #ifdef OMEGA_H_USE_MPI
205  size_t sz = 0;
206  int omega_h_mpi_error = MPI_Recv(&sz, 1, MpiTraits<size_t>::datatype(), rank, 0, impl_, MPI_STATUS_IGNORE);
207  OMEGA_H_CHECK(MPI_SUCCESS == omega_h_mpi_error);
208  x.resize(sz);
209  omega_h_mpi_error = MPI_Recv(x.data(), x.size(), MpiTraits<typename T::value_type>::datatype(), rank, 0, impl_, MPI_STATUS_IGNORE);
210  OMEGA_H_CHECK(MPI_SUCCESS == omega_h_mpi_error);
211 #else
212  (void)rank;
213  (void)x;
214 #endif
215 }
216 
217 #define OMEGA_H_EXPL_INST_DECL(T) \
218  extern template T Comm::allreduce(T x, Omega_h_Op op) const; \
219  extern template T Comm::exscan(T x, Omega_h_Op op) const; \
220  extern template void Comm::bcast(T& x, int root_rank) const; \
221  extern template Read<T> Comm::allgather(T x) const; \
222  extern template Read<T> Comm::alltoall(Read<T> x) const; \
223  extern template Read<T> Comm::alltoallv( \
224  Read<T> sendbuf, Read<LO> sdispls, Read<LO> rdispls, Int width) const; \
225  extern template Future<T> Comm::ialltoallv( \
226  Read<T> sendbuf, Read<LO> sdispls, Read<LO> rdispls, Int width) const;
227 OMEGA_H_EXPL_INST_DECL(I8)
228 OMEGA_H_EXPL_INST_DECL(I32)
229 OMEGA_H_EXPL_INST_DECL(I64)
230 OMEGA_H_EXPL_INST_DECL(Real)
231 #undef OMEGA_H_EXPL_INST_DECL
232 
233 } // end namespace Omega_h
234 
235 #endif
Definition: Omega_h_comm.hpp:29
Abstraction for asynchronous communication.
Definition: Omega_h_future.hpp:19
Definition: Omega_h_library.hpp:10
Definition: amr_mpi_test.cpp:6
Definition: Omega_h_int128.hpp:12