dune-istl  2.3.1
multitypeblockmatrix.hh
Go to the documentation of this file.
1 // -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
2 // vi: set et ts=4 sw=2 sts=2:
3 #ifndef DUNE_MultiTypeMATRIX_HH
4 #define DUNE_MultiTypeMATRIX_HH
5 
6 #include <cmath>
7 #include <iostream>
8 
9 #include "istlexception.hh"
10 
11 #if HAVE_DUNE_BOOST
12 #ifdef HAVE_BOOST_FUSION
13 
14 #include <boost/fusion/sequence.hpp>
15 #include <boost/fusion/container.hpp>
16 #include <boost/fusion/iterator.hpp>
17 #include <boost/typeof/typeof.hpp>
18 #include <boost/fusion/algorithm.hpp>
19 
20 namespace mpl=boost::mpl;
21 namespace fusion=boost::fusion;
22 
23 // forward decl
24 namespace Dune
25 {
26  template<typename T1, typename T2=fusion::void_, typename T3=fusion::void_, typename T4=fusion::void_,
27  typename T5=fusion::void_, typename T6=fusion::void_, typename T7=fusion::void_,
28  typename T8=fusion::void_, typename T9=fusion::void_>
29  class MultiTypeBlockMatrix;
30 
31  template<int I, int crow, int remain_row>
32  class MultiTypeBlockMatrix_Solver;
33 }
34 
35 #include "gsetc.hh"
36 
37 namespace Dune {
38 
56  template<int crow, int remain_rows, int ccol, int remain_cols,
57  typename TMatrix>
58  class MultiTypeBlockMatrix_Print {
59  public:
60 
64  static void print(const TMatrix& m) {
65  std::cout << "\t(" << crow << ", " << ccol << "): \n" << fusion::at_c<ccol>( fusion::at_c<crow>(m));
66  MultiTypeBlockMatrix_Print<crow,remain_rows,ccol+1,remain_cols-1,TMatrix>::print(m); //next column
67  }
68  };
69  template<int crow, int remain_rows, int ccol, typename TMatrix> //specialization for remain_cols=0
70  class MultiTypeBlockMatrix_Print<crow,remain_rows,ccol,0,TMatrix> {
71  public: static void print(const TMatrix& m) {
72  static const int xlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value;
73  MultiTypeBlockMatrix_Print<crow+1,remain_rows-1,0,xlen,TMatrix>::print(m); //next row
74  }
75  };
76 
77  template<int crow, int ccol, int remain_cols, typename TMatrix> //recursion end: specialization for remain_rows=0
78  class MultiTypeBlockMatrix_Print<crow,0,ccol,remain_cols,TMatrix> {
79  public:
80  static void print(const TMatrix& m)
81  {std::cout << std::endl;}
82  };
83 
84 
85 
86  //make MultiTypeBlockVector_Ident known (for MultiTypeBlockMatrix_Ident)
87  template<int count, typename T1, typename T2>
88  class MultiTypeBlockVector_Ident;
89 
90 
103  template<int rowcount, typename T1, typename T2>
104  class MultiTypeBlockMatrix_Ident {
105  public:
106 
111  static void equalize(T1& a, const T2& b) {
112  MultiTypeBlockVector_Ident< mpl::size< typename mpl::at_c<T1,rowcount-1>::type >::value ,T1,T2>::equalize(a,b); //rows are cvectors
113  MultiTypeBlockMatrix_Ident<rowcount-1,T1,T2>::equalize(a,b); //iterate over rows
114  }
115  };
116 
117  //recursion end for rowcount=0
118  template<typename T1, typename T2>
119  class MultiTypeBlockMatrix_Ident<0,T1,T2> {
120  public:
121  static void equalize (T1& a, const T2& b)
122  {}
123  };
124 
130  template<int crow, int remain_rows, int ccol, int remain_cols,
131  typename TVecY, typename TMatrix, typename TVecX>
132  class MultiTypeBlockMatrix_VectMul {
133  public:
134 
138  static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {
139  fusion::at_c<ccol>( fusion::at_c<crow>(A) ).umv( fusion::at_c<ccol>(x), fusion::at_c<crow>(y) );
140  MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::umv(y, A, x);
141  }
142 
146  static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {
147  fusion::at_c<ccol>( fusion::at_c<crow>(A) ).mmv( fusion::at_c<ccol>(x), fusion::at_c<crow>(y) );
148  MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::mmv(y, A, x);
149  }
150 
151  template<typename AlphaType>
152  static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {
153  fusion::at_c<ccol>( fusion::at_c<crow>(A) ).usmv(alpha, fusion::at_c<ccol>(x), fusion::at_c<crow>(y) );
154  MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x);
155  }
156 
157 
158  };
159 
160  //specialization for remain_cols = 0
161  template<int crow, int remain_rows,int ccol, typename TVecY,
162  typename TMatrix, typename TVecX>
163  class MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol,0,TVecY,TMatrix,TVecX> { //start iteration over next row
164 
165  public:
169  static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {
170  static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value;
171  MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::umv(y, A, x);
172  }
173 
177  static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {
178  static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value;
179  MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::mmv(y, A, x);
180  }
181 
182  template <typename AlphaType>
183  static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {
184  static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value;
185  MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x);
186  }
187  };
188 
189  //specialization for remain_rows = 0
190  template<int crow, int ccol, int remain_cols, typename TVecY,
191  typename TMatrix, typename TVecX>
192  class MultiTypeBlockMatrix_VectMul<crow,0,ccol,remain_cols,TVecY,TMatrix,TVecX> {
193  //end recursion
194  public:
195  static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {}
196  static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {}
197 
198  template<typename AlphaType>
199  static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {}
200  };
201 
202 
203 
204 
205 
206 
215  template<typename T1, typename T2, typename T3, typename T4,
216  typename T5, typename T6, typename T7, typename T8, typename T9>
217  class MultiTypeBlockMatrix : public fusion::vector<T1, T2, T3, T4, T5, T6, T7, T8, T9> {
218 
219  public:
220 
224  typedef MultiTypeBlockMatrix<T1, T2, T3, T4, T5, T6, T7, T8, T9> type;
225 
226  typedef typename mpl::at_c<T1,0>::type field_type;
227 
231  template<typename T>
232  void operator= (const T& newval) {MultiTypeBlockMatrix_Ident<mpl::size<type>::value,type,T>::equalize(*this, newval); }
233 
237  template<typename X, typename Y>
238  void mv (const X& x, Y& y) const {
239  BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value); //make sure x's length matches row length
240  BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value); //make sure y's length matches row count
241 
242  y = 0; //reset y (for mv uses umv)
243  MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::umv(y, *this, x); //iterate over all matrix elements
244  }
245 
249  template<typename X, typename Y>
250  void umv (const X& x, Y& y) const {
251  BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value); //make sure x's length matches row length
252  BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value); //make sure y's length matches row count
253 
254  MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::umv(y, *this, x); //iterate over all matrix elements
255  }
256 
260  template<typename X, typename Y>
261  void mmv (const X& x, Y& y) const {
262  BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value); //make sure x's length matches row length
263  BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value); //make sure y's length matches row count
264 
265  MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::mmv(y, *this, x); //iterate over all matrix elements
266  }
267 
269  template<typename AlphaType, typename X, typename Y>
270  void usmv (const AlphaType& alpha, const X& x, Y& y) const {
271  BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value); //make sure x's length matches row length
272  BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value); //make sure y's length matches row count
273 
274  MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::usmv(alpha,y, *this, x); //iterate over all matrix elements
275 
276  }
277 
278 
279 
280  };
281 
282 
283 
289  template<typename T1, typename T2, typename T3, typename T4, typename T5,
290  typename T6, typename T7, typename T8, typename T9>
291  std::ostream& operator<< (std::ostream& s, const MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9>& m) {
292  static const int i = mpl::size<MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::value; //row count
293  static const int j = mpl::size< typename mpl::at_c<MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9>,0>::type >::value; //col count of first row
294  MultiTypeBlockMatrix_Print<0,i,0,j,MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::print(m);
295  return s;
296  }
297 
298 
299 
300 
301 
302  //make algmeta_itsteps known
303  template<int I>
304  struct algmeta_itsteps;
305 
306 
307 
308 
309 
310 
317  template<int I, int crow, int ccol, int remain_col> //MultiTypeBlockMatrix_Solver_Col: iterating over one row
318  class MultiTypeBlockMatrix_Solver_Col { //calculating b- A[i][j]*x[j]
319  public:
323  template <typename Trhs, typename TVector, typename TMatrix, typename K>
324  static void calc_rhs(const TMatrix& A, TVector& x, TVector& v, Trhs& b, const K& w) {
325  fusion::at_c<ccol>( fusion::at_c<crow>(A) ).mmv( fusion::at_c<ccol>(x), b );
326  MultiTypeBlockMatrix_Solver_Col<I, crow, ccol+1, remain_col-1>::calc_rhs(A,x,v,b,w); //next column element
327  }
328 
329  };
330  template<int I, int crow, int ccol> //MultiTypeBlockMatrix_Solver_Col recursion end
331  class MultiTypeBlockMatrix_Solver_Col<I,crow,ccol,0> {
332  public:
333  template <typename Trhs, typename TVector, typename TMatrix, typename K>
334  static void calc_rhs(const TMatrix& A, TVector& x, TVector& v, Trhs& b, const K& w) {}
335  };
336 
337 
338 
345  template<int I, int crow, int remain_row>
346  class MultiTypeBlockMatrix_Solver {
347  public:
348 
352  template <typename TVector, typename TMatrix, typename K>
353  static void dbgs(const TMatrix& A, TVector& x, const TVector& b, const K& w) {
354  TVector xold(x);
355  xold=x; //store old x values
357  x *= w;
358  x.axpy(1-w,xold); //improve x
359  }
360  template <typename TVector, typename TMatrix, typename K>
361  static void dbgs(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
362  typename mpl::at_c<TVector,crow>::type rhs;
363  rhs = fusion::at_c<crow> (b);
364 
365  MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation
366  //solve on blocklevel I-1
367  algmeta_itsteps<I-1>::dbgs(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(x),rhs,w);
369  }
370 
371 
372 
376  template <typename TVector, typename TMatrix, typename K>
377  static void bsorf(const TMatrix& A, TVector& x, const TVector& b, const K& w) {
378  TVector v;
379  v=x; //use latest x values in right side calculation
381 
382  }
383  template <typename TVector, typename TMatrix, typename K> //recursion over all matrix rows (A)
384  static void bsorf(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
385  typename mpl::at_c<TVector,crow>::type rhs;
386  rhs = fusion::at_c<crow> (b);
387 
388  MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation
389  //solve on blocklevel I-1
390  algmeta_itsteps<I-1>::bsorf(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w);
391  fusion::at_c<crow>(x).axpy(w,fusion::at_c<crow>(v));
393  }
394 
398  template <typename TVector, typename TMatrix, typename K>
399  static void bsorb(const TMatrix& A, TVector& x, const TVector& b, const K& w) {
400  TVector v;
401  v=x; //use latest x values in right side calculation
403 
404  }
405  template <typename TVector, typename TMatrix, typename K> //recursion over all matrix rows (A)
406  static void bsorb(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
407  typename mpl::at_c<TVector,crow>::type rhs;
408  rhs = fusion::at_c<crow> (b);
409 
410  MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation
411  //solve on blocklevel I-1
412  algmeta_itsteps<I-1>::bsorb(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w);
413  fusion::at_c<crow>(x).axpy(w,fusion::at_c<crow>(v));
415  }
416 
417 
421  template <typename TVector, typename TMatrix, typename K>
422  static void dbjac(const TMatrix& A, TVector& x, const TVector& b, const K& w) {
423  TVector v(x);
424  v=0; //calc new x in v
426  x.axpy(w,v); //improve x
427  }
428  template <typename TVector, typename TMatrix, typename K>
429  static void dbjac(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
430  typename mpl::at_c<TVector,crow>::type rhs;
431  rhs = fusion::at_c<crow> (b);
432 
433  MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation
434  //solve on blocklevel I-1
435  algmeta_itsteps<I-1>::dbjac(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w);
437  }
438 
439 
440 
441 
442  };
443  template<int I, int crow> //recursion end for remain_row = 0
444  class MultiTypeBlockMatrix_Solver<I,crow,0> {
445  public:
446  template <typename TVector, typename TMatrix, typename K>
447  static void dbgs(const TMatrix& A, TVector& x, TVector& v,
448  const TVector& b, const K& w) {}
449 
450  template <typename TVector, typename TMatrix, typename K>
451  static void bsorf(const TMatrix& A, TVector& x, TVector& v,
452  const TVector& b, const K& w) {}
453 
454  template <typename TVector, typename TMatrix, typename K>
455  static void bsorb(const TMatrix& A, TVector& x, TVector& v,
456  const TVector& b, const K& w) {}
457 
458  template <typename TVector, typename TMatrix, typename K>
459  static void dbjac(const TMatrix& A, TVector& x, TVector& v,
460  const TVector& b, const K& w) {}
461  };
462 
463 } // end namespace
464 
465 #endif // HAVE_BOOST_FUSION
466 #endif // HAVE_DUNE_BOOST
467 #endif