GetFEM  5.4.4
gmm_blas_interface.h
Go to the documentation of this file.
1 /* -*- c++ -*- (enables emacs c++ mode) */
2 /*===========================================================================
3 
4  Copyright (C) 2003-2020 Yves Renard
5 
6  This file is a part of GetFEM
7 
8  GetFEM is free software; you can redistribute it and/or modify it
9  under the terms of the GNU Lesser General Public License as published
10  by the Free Software Foundation; either version 3 of the License, or
11  (at your option) any later version along with the GCC Runtime Library
12  Exception either version 3.1 or (at your option) any later version.
13  This program is distributed in the hope that it will be useful, but
14  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15  or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
16  License and GCC Runtime Library Exception for more details.
17  You should have received a copy of the GNU Lesser General Public License
18  along with this program; if not, write to the Free Software Foundation,
19  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA.
20 
21  As a special exception, you may use this file as it is a part of a free
22  software library without restriction. Specifically, if other files
23  instantiate templates or use macros or inline functions from this file,
24  or you compile this file and link it with other files to produce an
25  executable, this file does not by itself cause the resulting executable
26  to be covered by the GNU Lesser General Public License. This exception
27  does not however invalidate any other reasons why the executable file
28  might be covered by the GNU Lesser General Public License.
29 
30 ===========================================================================*/
31 
32 /**@file gmm_blas_interface.h
33  @author Yves Renard <Yves.Renard@insa-lyon.fr>
34  @date October 7, 2003.
35  @brief gmm interface for fortran BLAS.
36 */
37 
38 #if defined(GMM_USES_BLAS) || defined(GMM_USES_LAPACK)
39 
40 #ifndef GMM_BLAS_INTERFACE_H
41 #define GMM_BLAS_INTERFACE_H
42 
43 #include "gmm_blas.h"
44 #include "gmm_interface.h"
45 #include "gmm_matrix.h"
46 
47 namespace gmm {
48 
49  // Use ./configure --enable-blas-interface to activate this interface.
50 
51 #define GMMLAPACK_TRACE(f)
52 // #define GMMLAPACK_TRACE(f) cout << "function " << f << " called" << endl;
53 
54 #if defined(WeirdNEC) || defined(GMM_USE_BLAS64_INTERFACE)
55  #define BLAS_INT long
56 #else // By default BLAS_INT will just be int in C
57  #define BLAS_INT int
58 #endif
59 
60  /* ********************************************************************* */
61  /* Operations interfaced for T = float, double, std::complex<float> */
62  /* or std::complex<double> : */
63  /* */
64  /* vect_norm2(std::vector<T>) */
65  /* */
66  /* vect_sp(std::vector<T>, std::vector<T>) */
67  /* vect_sp(scaled(std::vector<T>), std::vector<T>) */
68  /* vect_sp(std::vector<T>, scaled(std::vector<T>)) */
69  /* vect_sp(scaled(std::vector<T>), scaled(std::vector<T>)) */
70  /* */
71  /* vect_hp(std::vector<T>, std::vector<T>) */
72  /* vect_hp(scaled(std::vector<T>), std::vector<T>) */
73  /* vect_hp(std::vector<T>, scaled(std::vector<T>)) */
74  /* vect_hp(scaled(std::vector<T>), scaled(std::vector<T>)) */
75  /* */
76  /* add(std::vector<T>, std::vector<T>) */
77  /* add(scaled(std::vector<T>, a), std::vector<T>) */
78  /* */
79  /* mult(dense_matrix<T>, dense_matrix<T>, dense_matrix<T>) */
80  /* mult(transposed(dense_matrix<T>), dense_matrix<T>, dense_matrix<T>) */
81  /* mult(dense_matrix<T>, transposed(dense_matrix<T>), dense_matrix<T>) */
82  /* mult(transposed(dense_matrix<T>), transposed(dense_matrix<T>), */
83  /* dense_matrix<T>) */
84  /* mult(conjugated(dense_matrix<T>), dense_matrix<T>, dense_matrix<T>) */
85  /* mult(dense_matrix<T>, conjugated(dense_matrix<T>), dense_matrix<T>) */
86  /* mult(conjugated(dense_matrix<T>), conjugated(dense_matrix<T>), */
87  /* dense_matrix<T>) */
88  /* */
89  /* mult(dense_matrix<T>, std::vector<T>, std::vector<T>) */
90  /* mult(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>) */
91  /* mult(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>) */
92  /* mult(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>) */
93  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
94  /* std::vector<T>) */
95  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
96  /* std::vector<T>) */
97  /* */
98  /* mult_add(dense_matrix<T>, std::vector<T>, std::vector<T>) */
99  /* mult_add(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>) */
100  /* mult_add(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>) */
101  /* mult_add(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>) */
102  /* mult_add(transposed(dense_matrix<T>), scaled(std::vector<T>), */
103  /* std::vector<T>) */
104  /* mult_add(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
105  /* std::vector<T>) */
106  /* */
107  /* mult(dense_matrix<T>, std::vector<T>, std::vector<T>, std::vector<T>) */
108  /* mult(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>, */
109  /* std::vector<T>) */
110  /* mult(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>, */
111  /* std::vector<T>) */
112  /* mult(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>, */
113  /* std::vector<T>) */
114  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
115  /* std::vector<T>, std::vector<T>) */
116  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
117  /* std::vector<T>, std::vector<T>) */
118  /* mult(dense_matrix<T>, std::vector<T>, scaled(std::vector<T>), */
119  /* std::vector<T>) */
120  /* mult(transposed(dense_matrix<T>), std::vector<T>, */
121  /* scaled(std::vector<T>), std::vector<T>) */
122  /* mult(conjugated(dense_matrix<T>), std::vector<T>, */
123  /* scaled(std::vector<T>), std::vector<T>) */
124  /* mult(dense_matrix<T>, scaled(std::vector<T>), scaled(std::vector<T>), */
125  /* std::vector<T>) */
126  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
127  /* scaled(std::vector<T>), std::vector<T>) */
128  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
129  /* scaled(std::vector<T>), std::vector<T>) */
130  /* */
131  /* lower_tri_solve(dense_matrix<T>, std::vector<T>, k, b) */
132  /* upper_tri_solve(dense_matrix<T>, std::vector<T>, k, b) */
133  /* lower_tri_solve(transposed(dense_matrix<T>), std::vector<T>, k, b) */
134  /* upper_tri_solve(transposed(dense_matrix<T>), std::vector<T>, k, b) */
135  /* lower_tri_solve(conjugated(dense_matrix<T>), std::vector<T>, k, b) */
136  /* upper_tri_solve(conjugated(dense_matrix<T>), std::vector<T>, k, b) */
137  /* */
138  /* rank_one_update(dense_matrix<T>, std::vector<T>, std::vector<T>) */
139  /* rank_one_update(dense_matrix<T>, scaled(std::vector<T>), */
140  /* std::vector<T>) */
141  /* rank_one_update(dense_matrix<T>, std::vector<T>, */
142  /* scaled(std::vector<T>)) */
143  /* */
144  /* ********************************************************************* */
145 
146  /* ********************************************************************* */
147  /* Basic defines. */
148  /* ********************************************************************* */
149 
150 # define BLAS_S float
151 # define BLAS_D double
152 # define BLAS_C std::complex<float>
153 # define BLAS_Z std::complex<double>
154 
155 // Hack due to BLAS ABI mess
156 #if defined(GMM_BLAS_RETURN_COMPLEX_AS_ARGUMENT)
157 # define BLAS_CPLX_FUNC_CALL(blasname, res, ...) blasname(&res, __VA_ARGS__)
158 #else
159 # define BLAS_CPLX_FUNC_CALL(blasname, res, ...) res = blasname(__VA_ARGS__)
160 #endif
161 
162  /* ********************************************************************* */
163  /* BLAS functions used. */
164  /* ********************************************************************* */
165  extern "C" {
166  void daxpy_(const BLAS_INT *n, const double *alpha, const double *x,
167  const BLAS_INT *incx, double *y, const BLAS_INT *incy);
168  void saxpy_(...); /*void daxpy_(...);*/ void caxpy_(...); void zaxpy_(...);
169  void dgemm_(const char *tA, const char *tB, const BLAS_INT *m,
170  const BLAS_INT *n, const BLAS_INT *k, const BLAS_D *alpha,
171  const BLAS_D *A, const BLAS_INT *ldA, const BLAS_D *B,
172  const BLAS_INT *ldB, const BLAS_D *beta, BLAS_D *C,
173  const BLAS_INT *ldC);
174  void sgemm_(...); /*void dgemm_(...);*/ void cgemm_(...); void zgemm_(...);
175  void sgemv_(...); void dgemv_(...); void cgemv_(...); void zgemv_(...);
176  void strsv_(...); void dtrsv_(...); void ctrsv_(...); void ztrsv_(...);
177  BLAS_S sdot_ (...); BLAS_D ddot_ (...);
178  BLAS_C cdotu_(...); BLAS_Z zdotu_(...);
179  // Hermitian product in {c,z}dotc is defined in reverse order than usually
180  BLAS_C cdotc_(...); BLAS_Z zdotc_(...);
181  BLAS_S snrm2_(...); BLAS_D dnrm2_(...);
182  BLAS_S scnrm2_(...); BLAS_D dznrm2_(...);
183  void sger_(...); void dger_(...); void cgerc_(...); void zgerc_(...);
184  }
185 
186 
187  /* ********************************************************************* */
188  /* vect_norm2(x). */
189  /* ********************************************************************* */
190 
191 # define nrm2_interface(blas_name, base_type) \
192  inline number_traits<base_type>::magnitude_type \
193  vect_norm2(const std::vector<base_type> &x) { \
194  GMMLAPACK_TRACE("nrm2_interface"); \
195  const BLAS_INT n=BLAS_INT(vect_size(x)), inc(1); \
196  return blas_name(&n, &x[0], &inc); \
197  }
198 
199  nrm2_interface(snrm2_, BLAS_S)
200  nrm2_interface(dnrm2_, BLAS_D)
201  nrm2_interface(scnrm2_, BLAS_C)
202  nrm2_interface(dznrm2_, BLAS_Z)
203 
204  /* ********************************************************************* */
205  /* vect_sp(x,y) = vect_hp(x,y) for real vectors */
206  /* ********************************************************************* */
207 
208 # define dot_interface(funcname, msg, blas_name, base_type) \
209  inline base_type funcname(const std::vector<base_type> &x, \
210  const std::vector<base_type> &y) { \
211  GMMLAPACK_TRACE(msg); \
212  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
213  return blas_name(&n, &x[0], &inc, &y[0], &inc); \
214  } \
215  inline base_type funcname \
216  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
217  const std::vector<base_type> &y) { \
218  GMMLAPACK_TRACE(msg); \
219  const std::vector<base_type> &x = *(linalg_origin(x_)); \
220  base_type a(x_.r); \
221  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
222  return a * blas_name(&n, &x[0], &inc, &y[0], &inc); \
223  } \
224  inline base_type funcname \
225  (const std::vector<base_type> &x, \
226  const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
227  GMMLAPACK_TRACE(msg); \
228  const std::vector<base_type> &y = *(linalg_origin(y_)); \
229  base_type b(y_.r); \
230  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
231  return b * blas_name(&n, &x[0], &inc, &y[0], &inc); \
232  } \
233  inline base_type funcname \
234  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
235  const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
236  GMMLAPACK_TRACE(msg); \
237  const std::vector<base_type> &x = *(linalg_origin(x_)); \
238  const std::vector<base_type> &y = *(linalg_origin(y_)); \
239  base_type a(x_.r), b(y_.r); \
240  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
241  return a*b * blas_name(&n, &x[0], &inc, &y[0], &inc); \
242  }
243 
244  dot_interface(vect_sp, "dot_interface", sdot_, BLAS_S)
245  dot_interface(vect_sp, "dot_interface", ddot_, BLAS_D)
246  dot_interface(vect_hp, "dotc_interface", sdot_, BLAS_S)
247  dot_interface(vect_hp, "dotc_interface", ddot_, BLAS_D)
248 
249  /* ********************************************************************* */
250  /* vect_sp(x,y) and vect_hp(x,y) for complex vectors */
251  /* vect_hp(x, y) = x.conj(y) (different order than in BLAS) */
252  /* switching x,y before passed to BLAS is important only for vect_hp */
253  /* ********************************************************************* */
254 
255 # define dot_interface_cplx(funcname, msg, blas_name, base_type, bdef) \
256  inline base_type funcname(const std::vector<base_type> &x, \
257  const std::vector<base_type> &y) { \
258  GMMLAPACK_TRACE(msg); \
259  base_type res; \
260  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
261  BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
262  return res; \
263  } \
264  inline base_type funcname \
265  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
266  const std::vector<base_type> &y) { \
267  GMMLAPACK_TRACE(msg); \
268  const std::vector<base_type> &x = *(linalg_origin(x_)); \
269  base_type res, a(x_.r); \
270  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
271  BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
272  return a*res; \
273  } \
274  inline base_type funcname \
275  (const std::vector<base_type> &x, \
276  const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
277  GMMLAPACK_TRACE(msg); \
278  const std::vector<base_type> &y = *(linalg_origin(y_)); \
279  base_type res, b(bdef); \
280  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
281  BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
282  return b*res; \
283  } \
284  inline base_type funcname \
285  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
286  const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
287  GMMLAPACK_TRACE(msg); \
288  const std::vector<base_type> &x = *(linalg_origin(x_)); \
289  const std::vector<base_type> &y = *(linalg_origin(y_)); \
290  base_type res, a(x_.r), b(bdef); \
291  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
292  BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
293  return a*b*res; \
294  }
295 
296  dot_interface_cplx(vect_sp, "dot_interface", cdotu_, BLAS_C, y_.r)
297  dot_interface_cplx(vect_sp, "dot_interface", zdotu_, BLAS_Z, y_.r)
298  dot_interface_cplx(vect_hp, "dotc_interface", cdotc_, BLAS_C, gmm::conj(y_.r))
299  dot_interface_cplx(vect_hp, "dotc_interface", zdotc_, BLAS_Z, gmm::conj(y_.r))
300 
301 
302  /* ********************************************************************* */
303  /* add(x, y). */
304  /* ********************************************************************* */
305  template<size_type N, class V1, class V2>
306  inline void add_fixed(const V1 &x, V2 &y)
307  {
308  for(size_type i = 0; i != N; ++i) y[i] += x[i];
309  }
310 
311  template<class V1, class V2>
312  inline void add_for_short_vectors(const V1 &x, V2 &y, size_type n)
313  {
314  switch(n)
315  {
316  case 1: add_fixed<1>(x, y); break;
317  case 2: add_fixed<2>(x, y); break;
318  case 3: add_fixed<3>(x, y); break;
319  case 4: add_fixed<4>(x, y); break;
320  case 5: add_fixed<5>(x, y); break;
321  case 6: add_fixed<6>(x, y); break;
322  case 7: add_fixed<7>(x, y); break;
323  case 8: add_fixed<8>(x, y); break;
324  case 9: add_fixed<9>(x, y); break;
325  case 10: add_fixed<10>(x, y); break;
326  case 11: add_fixed<11>(x, y); break;
327  case 12: add_fixed<12>(x, y); break;
328  case 13: add_fixed<13>(x, y); break;
329  case 14: add_fixed<14>(x, y); break;
330  case 15: add_fixed<15>(x, y); break;
331  case 16: add_fixed<16>(x, y); break;
332  case 17: add_fixed<17>(x, y); break;
333  case 18: add_fixed<18>(x, y); break;
334  case 19: add_fixed<19>(x, y); break;
335  case 20: add_fixed<20>(x, y); break;
336  case 21: add_fixed<21>(x, y); break;
337  case 22: add_fixed<22>(x, y); break;
338  case 23: add_fixed<23>(x, y); break;
339  case 24: add_fixed<24>(x, y); break;
340  default:
341  GMM_ASSERT2(false, "add_for_short_vectors used with unsupported size");
342  break;
343  }
344  }
345 
346  template<size_type N, class V1, class V2, class T>
347  inline void add_fixed(const V1 &x, V2 &y, const T &a)
348  {
349  for(size_type i = 0; i != N; ++i) y[i] += a*x[i];
350  }
351 
352  template<class V1, class V2, class T>
353  inline void add_for_short_vectors(const V1 &x, V2 &y, const T &a, size_type n)
354  {
355  switch(n)
356  {
357  case 1: add_fixed<1>(x, y, a); break;
358  case 2: add_fixed<2>(x, y, a); break;
359  case 3: add_fixed<3>(x, y, a); break;
360  case 4: add_fixed<4>(x, y, a); break;
361  case 5: add_fixed<5>(x, y, a); break;
362  case 6: add_fixed<6>(x, y, a); break;
363  case 7: add_fixed<7>(x, y, a); break;
364  case 8: add_fixed<8>(x, y, a); break;
365  case 9: add_fixed<9>(x, y, a); break;
366  case 10: add_fixed<10>(x, y, a); break;
367  case 11: add_fixed<11>(x, y, a); break;
368  case 12: add_fixed<12>(x, y, a); break;
369  case 13: add_fixed<13>(x, y, a); break;
370  case 14: add_fixed<14>(x, y, a); break;
371  case 15: add_fixed<15>(x, y, a); break;
372  case 16: add_fixed<16>(x, y, a); break;
373  case 17: add_fixed<17>(x, y, a); break;
374  case 18: add_fixed<18>(x, y, a); break;
375  case 19: add_fixed<19>(x, y, a); break;
376  case 20: add_fixed<20>(x, y, a); break;
377  case 21: add_fixed<21>(x, y, a); break;
378  case 22: add_fixed<22>(x, y, a); break;
379  case 23: add_fixed<23>(x, y, a); break;
380  case 24: add_fixed<24>(x, y, a); break;
381  default:
382  GMM_ASSERT2(false, "add_for_short_vectors used with unsupported size");
383  break;
384  }
385  }
386 
387 
388 # define axpy_interface(blas_name, base_type) \
389  inline void add(const std::vector<base_type> &x, \
390  std::vector<base_type> &y) { \
391  GMMLAPACK_TRACE("axpy_interface"); \
392  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); base_type a(1); \
393  if (n == 0) return; \
394  else if (n < 25) add_for_short_vectors(x, y, n); \
395  else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
396  }
397 
398  axpy_interface(saxpy_, BLAS_S)
399  axpy_interface(daxpy_, BLAS_D)
400  axpy_interface(caxpy_, BLAS_C)
401  axpy_interface(zaxpy_, BLAS_Z)
402 
403 
404 # define axpy2_interface(blas_name, base_type) \
405  inline void add \
406  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
407  std::vector<base_type> &y) { \
408  GMMLAPACK_TRACE("axpy_interface"); \
409  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
410  const std::vector<base_type>& x = *(linalg_origin(x_)); \
411  const base_type a(x_.r); \
412  if (n == 0) return; \
413  else if (n < 25) add_for_short_vectors(x, y, a, n); \
414  else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
415  }
416 
417  axpy2_interface(saxpy_, BLAS_S)
418  axpy2_interface(daxpy_, BLAS_D)
419  axpy2_interface(caxpy_, BLAS_C)
420  axpy2_interface(zaxpy_, BLAS_Z)
421 
422 
423  /* ********************************************************************* */
424  /* mult_add(A, x, z). */
425  /* ********************************************************************* */
426 
427 # define gemv_interface(param1, trans1, param2, trans2, blas_name, \
428  base_type, orien) \
429  inline void mult_add_spec(param1(base_type), param2(base_type), \
430  std::vector<base_type> &z, orien) { \
431  GMMLAPACK_TRACE("gemv_interface"); \
432  trans1(base_type); trans2(base_type); base_type beta(1); \
433  const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
434  n=BLAS_INT(mat_ncols(A)), inc(1); \
435  if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, \
436  &beta, &z[0], &inc); \
437  else gmm::clear(z); \
438  }
439 
440  // First parameter
441 # define gem_p1_n(base_type) const dense_matrix<base_type> &A
442 # define gem_trans1_n(base_type) const char t = 'N'
443 # define gem_p1_t(base_type) \
444  const transposed_col_ref<dense_matrix<base_type> *> &A_
445 # define gem_trans1_t(base_type) const dense_matrix<base_type> &A = \
446  *(linalg_origin(A_)); \
447  const char t = 'T'
448 # define gem_p1_tc(base_type) \
449  const transposed_col_ref<const dense_matrix<base_type> *> &A_
450 # define gem_p1_c(base_type) \
451  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_
452 # define gem_trans1_c(base_type) const dense_matrix<base_type> &A = \
453  *(linalg_origin(A_)); \
454  const char t = 'C'
455 
456  // second parameter
457 # define gemv_p2_n(base_type) const std::vector<base_type> &x
458 # define gemv_trans2_n(base_type) base_type alpha(1)
459 # define gemv_p2_s(base_type) \
460  const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_
461 # define gemv_trans2_s(base_type) const std::vector<base_type> &x = \
462  (*(linalg_origin(x_))); \
463  base_type alpha(x_.r)
464 
465  // Z <- AX + Z.
466  gemv_interface(gem_p1_n, gem_trans1_n,
467  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, col_major)
468  gemv_interface(gem_p1_n, gem_trans1_n,
469  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, col_major)
470  gemv_interface(gem_p1_n, gem_trans1_n,
471  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, col_major)
472  gemv_interface(gem_p1_n, gem_trans1_n,
473  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, col_major)
474 
475  // Z <- transposed(A)X + Z.
476  gemv_interface(gem_p1_t, gem_trans1_t,
477  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
478  gemv_interface(gem_p1_t, gem_trans1_t,
479  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
480  gemv_interface(gem_p1_t, gem_trans1_t,
481  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
482  gemv_interface(gem_p1_t, gem_trans1_t,
483  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
484 
485  // Z <- transposed(const A)X + Z.
486  gemv_interface(gem_p1_tc, gem_trans1_t,
487  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
488  gemv_interface(gem_p1_tc, gem_trans1_t,
489  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
490  gemv_interface(gem_p1_tc, gem_trans1_t,
491  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
492  gemv_interface(gem_p1_tc, gem_trans1_t,
493  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
494 
495  // Z <- conjugated(A)X + Z.
496  gemv_interface(gem_p1_c, gem_trans1_c,
497  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
498  gemv_interface(gem_p1_c, gem_trans1_c,
499  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
500  gemv_interface(gem_p1_c, gem_trans1_c,
501  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
502  gemv_interface(gem_p1_c, gem_trans1_c,
503  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
504 
505  // Z <- A scaled(X) + Z.
506  gemv_interface(gem_p1_n, gem_trans1_n,
507  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, col_major)
508  gemv_interface(gem_p1_n, gem_trans1_n,
509  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, col_major)
510  gemv_interface(gem_p1_n, gem_trans1_n,
511  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, col_major)
512  gemv_interface(gem_p1_n, gem_trans1_n,
513  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, col_major)
514 
515  // Z <- transposed(A) scaled(X) + Z.
516  gemv_interface(gem_p1_t, gem_trans1_t,
517  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
518  gemv_interface(gem_p1_t, gem_trans1_t,
519  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
520  gemv_interface(gem_p1_t, gem_trans1_t,
521  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
522  gemv_interface(gem_p1_t, gem_trans1_t,
523  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
524 
525  // Z <- transposed(const A) scaled(X) + Z.
526  gemv_interface(gem_p1_tc, gem_trans1_t,
527  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
528  gemv_interface(gem_p1_tc, gem_trans1_t,
529  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
530  gemv_interface(gem_p1_tc, gem_trans1_t,
531  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
532  gemv_interface(gem_p1_tc, gem_trans1_t,
533  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
534 
535  // Z <- conjugated(A) scaled(X) + Z.
536  gemv_interface(gem_p1_c, gem_trans1_c,
537  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
538  gemv_interface(gem_p1_c, gem_trans1_c,
539  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
540  gemv_interface(gem_p1_c, gem_trans1_c,
541  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
542  gemv_interface(gem_p1_c, gem_trans1_c,
543  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
544 
545 
546  /* ********************************************************************* */
547  /* mult(A, x, y). */
548  /* ********************************************************************* */
549 
550 # define gemv_interface2(param1, trans1, param2, trans2, blas_name, \
551  base_type, orien) \
552  inline void mult_spec(param1(base_type), param2(base_type), \
553  std::vector<base_type> &z, orien) { \
554  GMMLAPACK_TRACE("gemv_interface2"); \
555  trans1(base_type); trans2(base_type); base_type beta(0); \
556  const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
557  n=BLAS_INT(mat_ncols(A)), inc(1); \
558  if (m && n) \
559  blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, &beta, \
560  &z[0], &inc); \
561  else gmm::clear(z); \
562  }
563 
564  // Y <- AX.
565  gemv_interface2(gem_p1_n, gem_trans1_n,
566  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, col_major)
567  gemv_interface2(gem_p1_n, gem_trans1_n,
568  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, col_major)
569  gemv_interface2(gem_p1_n, gem_trans1_n,
570  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, col_major)
571  gemv_interface2(gem_p1_n, gem_trans1_n,
572  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, col_major)
573 
574  // Y <- transposed(A)X.
575  gemv_interface2(gem_p1_t, gem_trans1_t,
576  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
577  gemv_interface2(gem_p1_t, gem_trans1_t,
578  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
579  gemv_interface2(gem_p1_t, gem_trans1_t,
580  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
581  gemv_interface2(gem_p1_t, gem_trans1_t,
582  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
583 
584  // Y <- transposed(const A)X.
585  gemv_interface2(gem_p1_tc, gem_trans1_t,
586  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
587  gemv_interface2(gem_p1_tc, gem_trans1_t,
588  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
589  gemv_interface2(gem_p1_tc, gem_trans1_t,
590  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
591  gemv_interface2(gem_p1_tc, gem_trans1_t,
592  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
593 
594  // Y <- conjugated(A)X.
595  gemv_interface2(gem_p1_c, gem_trans1_c,
596  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
597  gemv_interface2(gem_p1_c, gem_trans1_c,
598  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
599  gemv_interface2(gem_p1_c, gem_trans1_c,
600  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
601  gemv_interface2(gem_p1_c, gem_trans1_c,
602  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
603 
604  // Y <- A scaled(X).
605  gemv_interface2(gem_p1_n, gem_trans1_n,
606  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, col_major)
607  gemv_interface2(gem_p1_n, gem_trans1_n,
608  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, col_major)
609  gemv_interface2(gem_p1_n, gem_trans1_n,
610  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, col_major)
611  gemv_interface2(gem_p1_n, gem_trans1_n,
612  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, col_major)
613 
614  // Y <- transposed(A) scaled(X).
615  gemv_interface2(gem_p1_t, gem_trans1_t,
616  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
617  gemv_interface2(gem_p1_t, gem_trans1_t,
618  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
619  gemv_interface2(gem_p1_t, gem_trans1_t,
620  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
621  gemv_interface2(gem_p1_t, gem_trans1_t,
622  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
623 
624  // Y <- transposed(const A) scaled(X).
625  gemv_interface2(gem_p1_tc, gem_trans1_t,
626  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
627  gemv_interface2(gem_p1_tc, gem_trans1_t,
628  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
629  gemv_interface2(gem_p1_tc, gem_trans1_t,
630  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
631  gemv_interface2(gem_p1_tc, gem_trans1_t,
632  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
633 
634  // Y <- conjugated(A) scaled(X).
635  gemv_interface2(gem_p1_c, gem_trans1_c,
636  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
637  gemv_interface2(gem_p1_c, gem_trans1_c,
638  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
639  gemv_interface2(gem_p1_c, gem_trans1_c,
640  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
641  gemv_interface2(gem_p1_c, gem_trans1_c,
642  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
643 
644 
645  /* ********************************************************************* */
646  /* Rank one update. */
647  /* ********************************************************************* */
648 
649 # define ger_interface(blas_name, base_type) \
650  inline void rank_one_update(const dense_matrix<base_type> &A, \
651  const std::vector<base_type> &V, \
652  const std::vector<base_type> &W) { \
653  GMMLAPACK_TRACE("ger_interface"); \
654  const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
655  n=BLAS_INT(mat_ncols(A)), inc(1); \
656  base_type alpha(1); \
657  if (m && n) \
658  blas_name(&m, &n, &alpha, &V[0], &inc, &W[0], &inc, &A(0,0), &lda); \
659  }
660 
661  ger_interface(sger_, BLAS_S)
662  ger_interface(dger_, BLAS_D)
663  ger_interface(cgerc_, BLAS_C)
664  ger_interface(zgerc_, BLAS_Z)
665 
666 # define ger_interface_sn(blas_name, base_type) \
667  inline void rank_one_update(const dense_matrix<base_type> &A, \
668  gemv_p2_s(base_type), \
669  const std::vector<base_type> &W) { \
670  GMMLAPACK_TRACE("ger_interface"); \
671  gemv_trans2_s(base_type); \
672  const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
673  n=BLAS_INT(mat_ncols(A)), inc(1); \
674  if (m && n) \
675  blas_name(&m, &n, &alpha, &x[0], &inc, &W[0], &inc, &A(0,0), &lda); \
676  }
677 
678  ger_interface_sn(sger_, BLAS_S)
679  ger_interface_sn(dger_, BLAS_D)
680  ger_interface_sn(cgerc_, BLAS_C)
681  ger_interface_sn(zgerc_, BLAS_Z)
682 
683 # define ger_interface_ns(blas_name, base_type) \
684  inline void rank_one_update(const dense_matrix<base_type> &A, \
685  const std::vector<base_type> &V, \
686  gemv_p2_s(base_type)) { \
687  GMMLAPACK_TRACE("ger_interface"); \
688  gemv_trans2_s(base_type); \
689  const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
690  n=BLAS_INT(mat_ncols(A)), inc(1); \
691  base_type al2 = gmm::conj(alpha); \
692  if (m && n) \
693  blas_name(&m, &n, &al2, &V[0], &inc, &x[0], &inc, &A(0,0), &lda); \
694  }
695 
696  ger_interface_ns(sger_, BLAS_S)
697  ger_interface_ns(dger_, BLAS_D)
698  ger_interface_ns(cgerc_, BLAS_C)
699  ger_interface_ns(zgerc_, BLAS_Z)
700 
701  /* ********************************************************************* */
702  /* dense matrix x dense matrix multiplication. */
703  /* ********************************************************************* */
704 
705 # define gemm_interface_nn(blas_name, base_type) \
706  inline void mult_spec(const dense_matrix<base_type> &A, \
707  const dense_matrix<base_type> &B, \
708  dense_matrix<base_type> &C, c_mult) { \
709  GMMLAPACK_TRACE("gemm_interface_nn"); \
710  const char t = 'N'; \
711  const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
712  k=BLAS_INT(mat_ncols(A)), ldb(k), \
713  n=BLAS_INT(mat_ncols(B)), ldc(m); \
714  const base_type alpha(1), beta(0); \
715  if (m && k && n) \
716  blas_name(&t, &t, &m, &n, &k, &alpha, \
717  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
718  else gmm::clear(C); \
719  }
720 
721  gemm_interface_nn(sgemm_, BLAS_S)
722  gemm_interface_nn(dgemm_, BLAS_D)
723  gemm_interface_nn(cgemm_, BLAS_C)
724  gemm_interface_nn(zgemm_, BLAS_Z)
725 
726  /* ********************************************************************* */
727  /* transposed(dense matrix) x dense matrix multiplication. */
728  /* ********************************************************************* */
729 
730 # define gemm_interface_tn(blas_name, base_type, mat_type) \
731  inline void mult_spec( \
732  const transposed_col_ref<mat_type<base_type> *> &A_, \
733  const dense_matrix<base_type> &B, \
734  dense_matrix<base_type> &C, rcmult) { \
735  GMMLAPACK_TRACE("gemm_interface_tn"); \
736  const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
737  const char t = 'T', u = 'N'; \
738  const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
739  n=BLAS_INT(mat_ncols(B)), lda(k), ldb(k), ldc(m); \
740  const base_type alpha(1), beta(0); \
741  if (m && k && n) \
742  blas_name(&t, &u, &m, &n, &k, &alpha, \
743  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
744  else gmm::clear(C); \
745  }
746 
747  gemm_interface_tn(sgemm_, BLAS_S, dense_matrix)
748  gemm_interface_tn(dgemm_, BLAS_D, dense_matrix)
749  gemm_interface_tn(cgemm_, BLAS_C, dense_matrix)
750  gemm_interface_tn(zgemm_, BLAS_Z, dense_matrix)
751  gemm_interface_tn(sgemm_, BLAS_S, const dense_matrix)
752  gemm_interface_tn(dgemm_, BLAS_D, const dense_matrix)
753  gemm_interface_tn(cgemm_, BLAS_C, const dense_matrix)
754  gemm_interface_tn(zgemm_, BLAS_Z, const dense_matrix)
755 
756  /* ********************************************************************* */
757  /* dense matrix x transposed(dense matrix) multiplication. */
758  /* ********************************************************************* */
759 
760 # define gemm_interface_nt(blas_name, base_type, mat_type) \
761  inline void \
762  mult_spec(const dense_matrix<base_type> &A, \
763  const transposed_col_ref<mat_type<base_type> *> &B_, \
764  dense_matrix<base_type> &C, r_mult) { \
765  GMMLAPACK_TRACE("gemm_interface_nt"); \
766  const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
767  const char t = 'N', u = 'T'; \
768  const BLAS_INT m=BLAS_INT(mat_nrows(A)), k=BLAS_INT(mat_ncols(A)), \
769  n=BLAS_INT(mat_nrows(B)), lda(m), ldb(n), ldc(m); \
770  const base_type alpha(1), beta(0); \
771  if (m && k && n) \
772  blas_name(&t, &u, &m, &n, &k, &alpha, \
773  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
774  else gmm::clear(C); \
775  }
776 
777  gemm_interface_nt(sgemm_, BLAS_S, dense_matrix)
778  gemm_interface_nt(dgemm_, BLAS_D, dense_matrix)
779  gemm_interface_nt(cgemm_, BLAS_C, dense_matrix)
780  gemm_interface_nt(zgemm_, BLAS_Z, dense_matrix)
781  gemm_interface_nt(sgemm_, BLAS_S, const dense_matrix)
782  gemm_interface_nt(dgemm_, BLAS_D, const dense_matrix)
783  gemm_interface_nt(cgemm_, BLAS_C, const dense_matrix)
784  gemm_interface_nt(zgemm_, BLAS_Z, const dense_matrix)
785 
786  /* ********************************************************************* */
787  /* transposed(dense matrix) x transposed(dense matrix) multiplication. */
788  /* ********************************************************************* */
789 
790 # define gemm_interface_tt(blas_name, base_type, matA_type, matB_type) \
791  inline void \
792  mult_spec(const transposed_col_ref<matA_type<base_type> *> &A_, \
793  const transposed_col_ref<matB_type<base_type> *> &B_, \
794  dense_matrix<base_type> &C, r_mult) { \
795  GMMLAPACK_TRACE("gemm_interface_tt"); \
796  const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
797  const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
798  const char t = 'T', u = 'T'; \
799  const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
800  n=BLAS_INT(mat_nrows(B)), lda(k), ldb(n), ldc(m); \
801  base_type alpha(1), beta(0); \
802  if (m && k && n) \
803  blas_name(&t, &u, &m, &n, &k, &alpha, \
804  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
805  else gmm::clear(C); \
806  }
807 
808  gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, dense_matrix)
809  gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, dense_matrix)
810  gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, dense_matrix)
811  gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, dense_matrix)
812  gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, dense_matrix)
813  gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, dense_matrix)
814  gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, dense_matrix)
815  gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, dense_matrix)
816  gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, const dense_matrix)
817  gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, const dense_matrix)
818  gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, const dense_matrix)
819  gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, const dense_matrix)
820  gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, const dense_matrix)
821  gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, const dense_matrix)
822  gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, const dense_matrix)
823  gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, const dense_matrix)
824 
825 
826  /* ********************************************************************* */
827  /* conjugated(dense matrix) x dense matrix multiplication. */
828  /* ********************************************************************* */
829 
830 # define gemm_interface_cn(blas_name, base_type) \
831  inline void mult_spec( \
832  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_, \
833  const dense_matrix<base_type> &B, \
834  dense_matrix<base_type> &C, rcmult) { \
835  GMMLAPACK_TRACE("gemm_interface_cn"); \
836  const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
837  const char t = 'C', u = 'N'; \
838  const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
839  n=BLAS_INT(mat_ncols(B)), lda(k), ldb(k), ldc(m); \
840  const base_type alpha(1), beta(0); \
841  if (m && k && n) \
842  blas_name(&t, &u, &m, &n, &k, &alpha, \
843  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
844  else gmm::clear(C); \
845  }
846 
847  gemm_interface_cn(sgemm_, BLAS_S)
848  gemm_interface_cn(dgemm_, BLAS_D)
849  gemm_interface_cn(cgemm_, BLAS_C)
850  gemm_interface_cn(zgemm_, BLAS_Z)
851 
852  /* ********************************************************************* */
853  /* dense matrix x conjugated(dense matrix) multiplication. */
854  /* ********************************************************************* */
855 
856 # define gemm_interface_nc(blas_name, base_type) \
857  inline void mult_spec( \
858  const dense_matrix<base_type> &A, \
859  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &B_, \
860  dense_matrix<base_type> &C, c_mult, row_major) { \
861  GMMLAPACK_TRACE("gemm_interface_nc"); \
862  const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
863  const char t = 'N', u = 'C'; \
864  const BLAS_INT m=BLAS_INT(mat_nrows(A)), k=BLAS_INT(mat_ncols(A)), \
865  n=BLAS_INT(mat_nrows(B)), lda(m), ldb(n), ldc(m); \
866  const base_type alpha(1), beta(0); \
867  if (m && k && n) \
868  blas_name(&t, &u, &m, &n, &k, &alpha, \
869  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
870  else gmm::clear(C); \
871  }
872 
873  gemm_interface_nc(sgemm_, BLAS_S)
874  gemm_interface_nc(dgemm_, BLAS_D)
875  gemm_interface_nc(cgemm_, BLAS_C)
876  gemm_interface_nc(zgemm_, BLAS_Z)
877 
878  /* ********************************************************************* */
879  /* conjugated(dense matrix) x conjugated(dense matrix) multiplication. */
880  /* ********************************************************************* */
881 
882 # define gemm_interface_cc(blas_name, base_type) \
883  inline void mult_spec( \
884  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_, \
885  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &B_, \
886  dense_matrix<base_type> &C, r_mult) { \
887  GMMLAPACK_TRACE("gemm_interface_cc"); \
888  const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
889  const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
890  const char t = 'C', u = 'C'; \
891  const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
892  n=BLAS_INT(mat_nrows(B)), lda(k), ldb(n), ldc(m); \
893  const base_type alpha(1), beta(0); \
894  if (m && k && n) \
895  blas_name(&t, &u, &m, &n, &k, &alpha, \
896  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
897  else gmm::clear(C); \
898  }
899 
900  gemm_interface_cc(sgemm_, BLAS_S)
901  gemm_interface_cc(dgemm_, BLAS_D)
902  gemm_interface_cc(cgemm_, BLAS_C)
903  gemm_interface_cc(zgemm_, BLAS_Z)
904 
905  /* ********************************************************************* */
906  /* Tri solve. */
907  /* ********************************************************************* */
908 
909 # define trsv_interface(f_name, LorU, param1, trans1, blas_name, base_type)\
910  inline void f_name(param1(base_type), std::vector<base_type> &x, \
911  size_type k, bool is_unit) { \
912  GMMLAPACK_TRACE("trsv_interface"); \
913  const char l = LorU; trans1(base_type); char d = is_unit ? 'U' : 'N'; \
914  const BLAS_INT lda=BLAS_INT(mat_nrows(A)), inc(1), n=BLAS_INT(k); \
915  if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \
916  }
917 
918  // X <- LOWER(A)^{-1}X.
919  trsv_interface(lower_tri_solve, 'L', gem_p1_n, gem_trans1_n, strsv_, BLAS_S)
920  trsv_interface(lower_tri_solve, 'L', gem_p1_n, gem_trans1_n, dtrsv_, BLAS_D)
921  trsv_interface(lower_tri_solve, 'L', gem_p1_n, gem_trans1_n, ctrsv_, BLAS_C)
922  trsv_interface(lower_tri_solve, 'L', gem_p1_n, gem_trans1_n, ztrsv_, BLAS_Z)
923 
924  // X <- UPPER(A)^{-1}X.
925  trsv_interface(upper_tri_solve, 'U', gem_p1_n, gem_trans1_n, strsv_, BLAS_S)
926  trsv_interface(upper_tri_solve, 'U', gem_p1_n, gem_trans1_n, dtrsv_, BLAS_D)
927  trsv_interface(upper_tri_solve, 'U', gem_p1_n, gem_trans1_n, ctrsv_, BLAS_C)
928  trsv_interface(upper_tri_solve, 'U', gem_p1_n, gem_trans1_n, ztrsv_, BLAS_Z)
929 
930  // X <- LOWER(transposed(A))^{-1}X.
931  trsv_interface(lower_tri_solve, 'U', gem_p1_t, gem_trans1_t, strsv_, BLAS_S)
932  trsv_interface(lower_tri_solve, 'U', gem_p1_t, gem_trans1_t, dtrsv_, BLAS_D)
933  trsv_interface(lower_tri_solve, 'U', gem_p1_t, gem_trans1_t, ctrsv_, BLAS_C)
934  trsv_interface(lower_tri_solve, 'U', gem_p1_t, gem_trans1_t, ztrsv_, BLAS_Z)
935 
936  // X <- UPPER(transposed(A))^{-1}X.
937  trsv_interface(upper_tri_solve, 'L', gem_p1_t, gem_trans1_t, strsv_, BLAS_S)
938  trsv_interface(upper_tri_solve, 'L', gem_p1_t, gem_trans1_t, dtrsv_, BLAS_D)
939  trsv_interface(upper_tri_solve, 'L', gem_p1_t, gem_trans1_t, ctrsv_, BLAS_C)
940  trsv_interface(upper_tri_solve, 'L', gem_p1_t, gem_trans1_t, ztrsv_, BLAS_Z)
941 
942  // X <- LOWER(transposed(const A))^{-1}X.
943  trsv_interface(lower_tri_solve, 'U', gem_p1_tc, gem_trans1_t, strsv_, BLAS_S)
944  trsv_interface(lower_tri_solve, 'U', gem_p1_tc, gem_trans1_t, dtrsv_, BLAS_D)
945  trsv_interface(lower_tri_solve, 'U', gem_p1_tc, gem_trans1_t, ctrsv_, BLAS_C)
946  trsv_interface(lower_tri_solve, 'U', gem_p1_tc, gem_trans1_t, ztrsv_, BLAS_Z)
947 
948  // X <- UPPER(transposed(const A))^{-1}X.
949  trsv_interface(upper_tri_solve, 'L', gem_p1_tc, gem_trans1_t, strsv_, BLAS_S)
950  trsv_interface(upper_tri_solve, 'L', gem_p1_tc, gem_trans1_t, dtrsv_, BLAS_D)
951  trsv_interface(upper_tri_solve, 'L', gem_p1_tc, gem_trans1_t, ctrsv_, BLAS_C)
952  trsv_interface(upper_tri_solve, 'L', gem_p1_tc, gem_trans1_t, ztrsv_, BLAS_Z)
953 
954  // X <- LOWER(conjugated(A))^{-1}X.
955  trsv_interface(lower_tri_solve, 'U', gem_p1_c, gem_trans1_c, strsv_, BLAS_S)
956  trsv_interface(lower_tri_solve, 'U', gem_p1_c, gem_trans1_c, dtrsv_, BLAS_D)
957  trsv_interface(lower_tri_solve, 'U', gem_p1_c, gem_trans1_c, ctrsv_, BLAS_C)
958  trsv_interface(lower_tri_solve, 'U', gem_p1_c, gem_trans1_c, ztrsv_, BLAS_Z)
959 
960  // X <- UPPER(conjugated(A))^{-1}X.
961  trsv_interface(upper_tri_solve, 'L', gem_p1_c, gem_trans1_c, strsv_, BLAS_S)
962  trsv_interface(upper_tri_solve, 'L', gem_p1_c, gem_trans1_c, dtrsv_, BLAS_D)
963  trsv_interface(upper_tri_solve, 'L', gem_p1_c, gem_trans1_c, ctrsv_, BLAS_C)
964  trsv_interface(upper_tri_solve, 'L', gem_p1_c, gem_trans1_c, ztrsv_, BLAS_Z)
965 }
966 
967 #endif // GMM_BLAS_INTERFACE_H
968 
969 #endif // GMM_USES_BLAS
Basic linear algebra functions.
gmm interface for STL vectors.
Declaration of some matrix types (gmm::dense_matrix, gmm::row_matrix, gmm::col_matrix,...
size_t size_type
used as the common size type in the library
Definition: bgeot_poly.h:49