38 #if defined(GMM_USES_BLAS) || defined(GMM_USES_LAPACK)
40 #ifndef GMM_BLAS_INTERFACE_H
41 #define GMM_BLAS_INTERFACE_H
51 #define GMMLAPACK_TRACE(f)
54 #if defined(WeirdNEC) || defined(GMM_USE_BLAS64_INTERFACE)
150 # define BLAS_S float
151 # define BLAS_D double
152 # define BLAS_C std::complex<float>
153 # define BLAS_Z std::complex<double>
156 #if defined(GMM_BLAS_RETURN_COMPLEX_AS_ARGUMENT)
157 # define BLAS_CPLX_FUNC_CALL(blasname, res, ...) blasname(&res, __VA_ARGS__)
159 # define BLAS_CPLX_FUNC_CALL(blasname, res, ...) res = blasname(__VA_ARGS__)
166 void daxpy_(
const BLAS_INT *n,
const double *alpha,
const double *x,
167 const BLAS_INT *incx,
double *y,
const BLAS_INT *incy);
168 void saxpy_(...);
void caxpy_(...);
void zaxpy_(...);
169 void dgemm_(
const char *tA,
const char *tB,
const BLAS_INT *m,
170 const BLAS_INT *n,
const BLAS_INT *k,
const BLAS_D *alpha,
171 const BLAS_D *A,
const BLAS_INT *ldA,
const BLAS_D *B,
172 const BLAS_INT *ldB,
const BLAS_D *beta, BLAS_D *C,
173 const BLAS_INT *ldC);
174 void sgemm_(...);
void cgemm_(...);
void zgemm_(...);
175 void sgemv_(...);
void dgemv_(...);
void cgemv_(...);
void zgemv_(...);
176 void strsv_(...);
void dtrsv_(...);
void ctrsv_(...);
void ztrsv_(...);
177 BLAS_S sdot_ (...); BLAS_D ddot_ (...);
178 BLAS_C cdotu_(...); BLAS_Z zdotu_(...);
180 BLAS_C cdotc_(...); BLAS_Z zdotc_(...);
181 BLAS_S snrm2_(...); BLAS_D dnrm2_(...);
182 BLAS_S scnrm2_(...); BLAS_D dznrm2_(...);
183 void sger_(...);
void dger_(...);
void cgerc_(...);
void zgerc_(...);
191 # define nrm2_interface(blas_name, base_type) \
192 inline number_traits<base_type>::magnitude_type \
193 vect_norm2(const std::vector<base_type> &x) { \
194 GMMLAPACK_TRACE("nrm2_interface"); \
195 const BLAS_INT n=BLAS_INT(vect_size(x)), inc(1); \
196 return blas_name(&n, &x[0], &inc); \
199 nrm2_interface(snrm2_, BLAS_S)
200 nrm2_interface(dnrm2_, BLAS_D)
201 nrm2_interface(scnrm2_, BLAS_C)
202 nrm2_interface(dznrm2_, BLAS_Z)
208 # define dot_interface(funcname, msg, blas_name, base_type) \
209 inline base_type funcname(const std::vector<base_type> &x, \
210 const std::vector<base_type> &y) { \
211 GMMLAPACK_TRACE(msg); \
212 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
213 return blas_name(&n, &x[0], &inc, &y[0], &inc); \
215 inline base_type funcname \
216 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
217 const std::vector<base_type> &y) { \
218 GMMLAPACK_TRACE(msg); \
219 const std::vector<base_type> &x = *(linalg_origin(x_)); \
221 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
222 return a * blas_name(&n, &x[0], &inc, &y[0], &inc); \
224 inline base_type funcname \
225 (const std::vector<base_type> &x, \
226 const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
227 GMMLAPACK_TRACE(msg); \
228 const std::vector<base_type> &y = *(linalg_origin(y_)); \
230 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
231 return b * blas_name(&n, &x[0], &inc, &y[0], &inc); \
233 inline base_type funcname \
234 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
235 const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
236 GMMLAPACK_TRACE(msg); \
237 const std::vector<base_type> &x = *(linalg_origin(x_)); \
238 const std::vector<base_type> &y = *(linalg_origin(y_)); \
239 base_type a(x_.r), b(y_.r); \
240 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
241 return a*b * blas_name(&n, &x[0], &inc, &y[0], &inc); \
244 dot_interface(vect_sp,
"dot_interface", sdot_, BLAS_S)
245 dot_interface(vect_sp,
"dot_interface", ddot_, BLAS_D)
246 dot_interface(vect_hp,
"dotc_interface", sdot_, BLAS_S)
247 dot_interface(vect_hp,
"dotc_interface", ddot_, BLAS_D)
255 # define dot_interface_cplx(funcname, msg, blas_name, base_type, bdef) \
256 inline base_type funcname(const std::vector<base_type> &x, \
257 const std::vector<base_type> &y) { \
258 GMMLAPACK_TRACE(msg); \
260 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
261 BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
264 inline base_type funcname \
265 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
266 const std::vector<base_type> &y) { \
267 GMMLAPACK_TRACE(msg); \
268 const std::vector<base_type> &x = *(linalg_origin(x_)); \
269 base_type res, a(x_.r); \
270 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
271 BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
274 inline base_type funcname \
275 (const std::vector<base_type> &x, \
276 const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
277 GMMLAPACK_TRACE(msg); \
278 const std::vector<base_type> &y = *(linalg_origin(y_)); \
279 base_type res, b(bdef); \
280 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
281 BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
284 inline base_type funcname \
285 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
286 const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
287 GMMLAPACK_TRACE(msg); \
288 const std::vector<base_type> &x = *(linalg_origin(x_)); \
289 const std::vector<base_type> &y = *(linalg_origin(y_)); \
290 base_type res, a(x_.r), b(bdef); \
291 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
292 BLAS_CPLX_FUNC_CALL(blas_name, res, &n, &y[0], &inc, &x[0], &inc); \
296 dot_interface_cplx(vect_sp,
"dot_interface", cdotu_, BLAS_C, y_.r)
297 dot_interface_cplx(vect_sp, "dot_interface", zdotu_, BLAS_Z, y_.r)
298 dot_interface_cplx(vect_hp, "dotc_interface", cdotc_, BLAS_C, gmm::conj(y_.r))
299 dot_interface_cplx(vect_hp, "dotc_interface", zdotc_, BLAS_Z, gmm::conj(y_.r))
305 template<
size_type N, class V1, class V2>
306 inline
void add_fixed(const V1 &x, V2 &y)
308 for(
size_type i = 0; i != N; ++i) y[i] += x[i];
311 template<
class V1,
class V2>
312 inline void add_for_short_vectors(
const V1 &x, V2 &y,
size_type n)
316 case 1: add_fixed<1>(x, y);
break;
317 case 2: add_fixed<2>(x, y);
break;
318 case 3: add_fixed<3>(x, y);
break;
319 case 4: add_fixed<4>(x, y);
break;
320 case 5: add_fixed<5>(x, y);
break;
321 case 6: add_fixed<6>(x, y);
break;
322 case 7: add_fixed<7>(x, y);
break;
323 case 8: add_fixed<8>(x, y);
break;
324 case 9: add_fixed<9>(x, y);
break;
325 case 10: add_fixed<10>(x, y);
break;
326 case 11: add_fixed<11>(x, y);
break;
327 case 12: add_fixed<12>(x, y);
break;
328 case 13: add_fixed<13>(x, y);
break;
329 case 14: add_fixed<14>(x, y);
break;
330 case 15: add_fixed<15>(x, y);
break;
331 case 16: add_fixed<16>(x, y);
break;
332 case 17: add_fixed<17>(x, y);
break;
333 case 18: add_fixed<18>(x, y);
break;
334 case 19: add_fixed<19>(x, y);
break;
335 case 20: add_fixed<20>(x, y);
break;
336 case 21: add_fixed<21>(x, y);
break;
337 case 22: add_fixed<22>(x, y);
break;
338 case 23: add_fixed<23>(x, y);
break;
339 case 24: add_fixed<24>(x, y);
break;
341 GMM_ASSERT2(
false,
"add_for_short_vectors used with unsupported size");
346 template<
size_type N,
class V1,
class V2,
class T>
347 inline void add_fixed(
const V1 &x, V2 &y,
const T &a)
349 for(
size_type i = 0; i != N; ++i) y[i] += a*x[i];
352 template<
class V1,
class V2,
class T>
353 inline void add_for_short_vectors(
const V1 &x, V2 &y,
const T &a,
size_type n)
357 case 1: add_fixed<1>(x, y, a);
break;
358 case 2: add_fixed<2>(x, y, a);
break;
359 case 3: add_fixed<3>(x, y, a);
break;
360 case 4: add_fixed<4>(x, y, a);
break;
361 case 5: add_fixed<5>(x, y, a);
break;
362 case 6: add_fixed<6>(x, y, a);
break;
363 case 7: add_fixed<7>(x, y, a);
break;
364 case 8: add_fixed<8>(x, y, a);
break;
365 case 9: add_fixed<9>(x, y, a);
break;
366 case 10: add_fixed<10>(x, y, a);
break;
367 case 11: add_fixed<11>(x, y, a);
break;
368 case 12: add_fixed<12>(x, y, a);
break;
369 case 13: add_fixed<13>(x, y, a);
break;
370 case 14: add_fixed<14>(x, y, a);
break;
371 case 15: add_fixed<15>(x, y, a);
break;
372 case 16: add_fixed<16>(x, y, a);
break;
373 case 17: add_fixed<17>(x, y, a);
break;
374 case 18: add_fixed<18>(x, y, a);
break;
375 case 19: add_fixed<19>(x, y, a);
break;
376 case 20: add_fixed<20>(x, y, a);
break;
377 case 21: add_fixed<21>(x, y, a);
break;
378 case 22: add_fixed<22>(x, y, a);
break;
379 case 23: add_fixed<23>(x, y, a);
break;
380 case 24: add_fixed<24>(x, y, a);
break;
382 GMM_ASSERT2(
false,
"add_for_short_vectors used with unsupported size");
388 # define axpy_interface(blas_name, base_type) \
389 inline void add(const std::vector<base_type> &x, \
390 std::vector<base_type> &y) { \
391 GMMLAPACK_TRACE("axpy_interface"); \
392 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); base_type a(1); \
393 if (n == 0) return; \
394 else if (n < 25) add_for_short_vectors(x, y, n); \
395 else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
398 axpy_interface(saxpy_, BLAS_S)
399 axpy_interface(daxpy_, BLAS_D)
400 axpy_interface(caxpy_, BLAS_C)
401 axpy_interface(zaxpy_, BLAS_Z)
404 # define axpy2_interface(blas_name, base_type) \
406 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
407 std::vector<base_type> &y) { \
408 GMMLAPACK_TRACE("axpy_interface"); \
409 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
410 const std::vector<base_type>& x = *(linalg_origin(x_)); \
411 const base_type a(x_.r); \
412 if (n == 0) return; \
413 else if (n < 25) add_for_short_vectors(x, y, a, n); \
414 else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
417 axpy2_interface(saxpy_, BLAS_S)
418 axpy2_interface(daxpy_, BLAS_D)
419 axpy2_interface(caxpy_, BLAS_C)
420 axpy2_interface(zaxpy_, BLAS_Z)
427 # define gemv_interface(param1, trans1, param2, trans2, blas_name, \
429 inline void mult_add_spec(param1(base_type), param2(base_type), \
430 std::vector<base_type> &z, orien) { \
431 GMMLAPACK_TRACE("gemv_interface"); \
432 trans1(base_type); trans2(base_type); base_type beta(1); \
433 const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
434 n=BLAS_INT(mat_ncols(A)), inc(1); \
435 if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, \
436 &beta, &z[0], &inc); \
437 else gmm::clear(z); \
441 # define gem_p1_n(base_type) const dense_matrix<base_type> &A
442 # define gem_trans1_n(base_type) const char t = 'N'
443 # define gem_p1_t(base_type) \
444 const transposed_col_ref<dense_matrix<base_type> *> &A_
445 # define gem_trans1_t(base_type) const dense_matrix<base_type> &A = \
446 *(linalg_origin(A_)); \
448 # define gem_p1_tc(base_type) \
449 const transposed_col_ref<const dense_matrix<base_type> *> &A_
450 # define gem_p1_c(base_type) \
451 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_
452 # define gem_trans1_c(base_type) const dense_matrix<base_type> &A = \
453 *(linalg_origin(A_)); \
457 # define gemv_p2_n(base_type) const std::vector<base_type> &x
458 # define gemv_trans2_n(base_type) base_type alpha(1)
459 # define gemv_p2_s(base_type) \
460 const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_
461 # define gemv_trans2_s(base_type) const std::vector<base_type> &x = \
462 (*(linalg_origin(x_))); \
463 base_type alpha(x_.r)
466 gemv_interface(gem_p1_n, gem_trans1_n,
467 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, col_major)
468 gemv_interface(gem_p1_n, gem_trans1_n,
469 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, col_major)
470 gemv_interface(gem_p1_n, gem_trans1_n,
471 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, col_major)
472 gemv_interface(gem_p1_n, gem_trans1_n,
473 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, col_major)
476 gemv_interface(gem_p1_t, gem_trans1_t,
477 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
478 gemv_interface(gem_p1_t, gem_trans1_t,
479 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
480 gemv_interface(gem_p1_t, gem_trans1_t,
481 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
482 gemv_interface(gem_p1_t, gem_trans1_t,
483 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
486 gemv_interface(gem_p1_tc, gem_trans1_t,
487 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
488 gemv_interface(gem_p1_tc, gem_trans1_t,
489 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
490 gemv_interface(gem_p1_tc, gem_trans1_t,
491 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
492 gemv_interface(gem_p1_tc, gem_trans1_t,
493 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
496 gemv_interface(gem_p1_c, gem_trans1_c,
497 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
498 gemv_interface(gem_p1_c, gem_trans1_c,
499 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
500 gemv_interface(gem_p1_c, gem_trans1_c,
501 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
502 gemv_interface(gem_p1_c, gem_trans1_c,
503 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
506 gemv_interface(gem_p1_n, gem_trans1_n,
507 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, col_major)
508 gemv_interface(gem_p1_n, gem_trans1_n,
509 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, col_major)
510 gemv_interface(gem_p1_n, gem_trans1_n,
511 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, col_major)
512 gemv_interface(gem_p1_n, gem_trans1_n,
513 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, col_major)
516 gemv_interface(gem_p1_t, gem_trans1_t,
517 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
518 gemv_interface(gem_p1_t, gem_trans1_t,
519 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
520 gemv_interface(gem_p1_t, gem_trans1_t,
521 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
522 gemv_interface(gem_p1_t, gem_trans1_t,
523 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
526 gemv_interface(gem_p1_tc, gem_trans1_t,
527 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
528 gemv_interface(gem_p1_tc, gem_trans1_t,
529 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
530 gemv_interface(gem_p1_tc, gem_trans1_t,
531 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
532 gemv_interface(gem_p1_tc, gem_trans1_t,
533 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
536 gemv_interface(gem_p1_c, gem_trans1_c,
537 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
538 gemv_interface(gem_p1_c, gem_trans1_c,
539 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
540 gemv_interface(gem_p1_c, gem_trans1_c,
541 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
542 gemv_interface(gem_p1_c, gem_trans1_c,
543 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
550 # define gemv_interface2(param1, trans1, param2, trans2, blas_name, \
552 inline void mult_spec(param1(base_type), param2(base_type), \
553 std::vector<base_type> &z, orien) { \
554 GMMLAPACK_TRACE("gemv_interface2"); \
555 trans1(base_type); trans2(base_type); base_type beta(0); \
556 const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
557 n=BLAS_INT(mat_ncols(A)), inc(1); \
559 blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, &beta, \
561 else gmm::clear(z); \
565 gemv_interface2(gem_p1_n, gem_trans1_n,
566 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, col_major)
567 gemv_interface2(gem_p1_n, gem_trans1_n,
568 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, col_major)
569 gemv_interface2(gem_p1_n, gem_trans1_n,
570 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, col_major)
571 gemv_interface2(gem_p1_n, gem_trans1_n,
572 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, col_major)
575 gemv_interface2(gem_p1_t, gem_trans1_t,
576 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
577 gemv_interface2(gem_p1_t, gem_trans1_t,
578 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
579 gemv_interface2(gem_p1_t, gem_trans1_t,
580 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
581 gemv_interface2(gem_p1_t, gem_trans1_t,
582 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
585 gemv_interface2(gem_p1_tc, gem_trans1_t,
586 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
587 gemv_interface2(gem_p1_tc, gem_trans1_t,
588 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
589 gemv_interface2(gem_p1_tc, gem_trans1_t,
590 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
591 gemv_interface2(gem_p1_tc, gem_trans1_t,
592 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
595 gemv_interface2(gem_p1_c, gem_trans1_c,
596 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
597 gemv_interface2(gem_p1_c, gem_trans1_c,
598 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
599 gemv_interface2(gem_p1_c, gem_trans1_c,
600 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
601 gemv_interface2(gem_p1_c, gem_trans1_c,
602 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
605 gemv_interface2(gem_p1_n, gem_trans1_n,
606 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, col_major)
607 gemv_interface2(gem_p1_n, gem_trans1_n,
608 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, col_major)
609 gemv_interface2(gem_p1_n, gem_trans1_n,
610 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, col_major)
611 gemv_interface2(gem_p1_n, gem_trans1_n,
612 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, col_major)
615 gemv_interface2(gem_p1_t, gem_trans1_t,
616 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
617 gemv_interface2(gem_p1_t, gem_trans1_t,
618 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
619 gemv_interface2(gem_p1_t, gem_trans1_t,
620 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
621 gemv_interface2(gem_p1_t, gem_trans1_t,
622 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
625 gemv_interface2(gem_p1_tc, gem_trans1_t,
626 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
627 gemv_interface2(gem_p1_tc, gem_trans1_t,
628 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
629 gemv_interface2(gem_p1_tc, gem_trans1_t,
630 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
631 gemv_interface2(gem_p1_tc, gem_trans1_t,
632 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
635 gemv_interface2(gem_p1_c, gem_trans1_c,
636 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
637 gemv_interface2(gem_p1_c, gem_trans1_c,
638 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
639 gemv_interface2(gem_p1_c, gem_trans1_c,
640 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
641 gemv_interface2(gem_p1_c, gem_trans1_c,
642 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
649 # define ger_interface(blas_name, base_type) \
650 inline void rank_one_update(const dense_matrix<base_type> &A, \
651 const std::vector<base_type> &V, \
652 const std::vector<base_type> &W) { \
653 GMMLAPACK_TRACE("ger_interface"); \
654 const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
655 n=BLAS_INT(mat_ncols(A)), inc(1); \
656 base_type alpha(1); \
658 blas_name(&m, &n, &alpha, &V[0], &inc, &W[0], &inc, &A(0,0), &lda); \
661 ger_interface(sger_, BLAS_S)
662 ger_interface(dger_, BLAS_D)
663 ger_interface(cgerc_, BLAS_C)
664 ger_interface(zgerc_, BLAS_Z)
666 # define ger_interface_sn(blas_name, base_type) \
667 inline void rank_one_update(const dense_matrix<base_type> &A, \
668 gemv_p2_s(base_type), \
669 const std::vector<base_type> &W) { \
670 GMMLAPACK_TRACE("ger_interface"); \
671 gemv_trans2_s(base_type); \
672 const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
673 n=BLAS_INT(mat_ncols(A)), inc(1); \
675 blas_name(&m, &n, &alpha, &x[0], &inc, &W[0], &inc, &A(0,0), &lda); \
678 ger_interface_sn(sger_, BLAS_S)
679 ger_interface_sn(dger_, BLAS_D)
680 ger_interface_sn(cgerc_, BLAS_C)
681 ger_interface_sn(zgerc_, BLAS_Z)
683 # define ger_interface_ns(blas_name, base_type) \
684 inline void rank_one_update(const dense_matrix<base_type> &A, \
685 const std::vector<base_type> &V, \
686 gemv_p2_s(base_type)) { \
687 GMMLAPACK_TRACE("ger_interface"); \
688 gemv_trans2_s(base_type); \
689 const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
690 n=BLAS_INT(mat_ncols(A)), inc(1); \
691 base_type al2 = gmm::conj(alpha); \
693 blas_name(&m, &n, &al2, &V[0], &inc, &x[0], &inc, &A(0,0), &lda); \
696 ger_interface_ns(sger_, BLAS_S)
697 ger_interface_ns(dger_, BLAS_D)
698 ger_interface_ns(cgerc_, BLAS_C)
699 ger_interface_ns(zgerc_, BLAS_Z)
705 # define gemm_interface_nn(blas_name, base_type) \
706 inline void mult_spec(const dense_matrix<base_type> &A, \
707 const dense_matrix<base_type> &B, \
708 dense_matrix<base_type> &C, c_mult) { \
709 GMMLAPACK_TRACE("gemm_interface_nn"); \
710 const char t = 'N'; \
711 const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
712 k=BLAS_INT(mat_ncols(A)), ldb(k), \
713 n=BLAS_INT(mat_ncols(B)), ldc(m); \
714 const base_type alpha(1), beta(0); \
716 blas_name(&t, &t, &m, &n, &k, &alpha, \
717 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
718 else gmm::clear(C); \
721 gemm_interface_nn(sgemm_, BLAS_S)
722 gemm_interface_nn(dgemm_, BLAS_D)
723 gemm_interface_nn(cgemm_, BLAS_C)
724 gemm_interface_nn(zgemm_, BLAS_Z)
730 # define gemm_interface_tn(blas_name, base_type, mat_type) \
731 inline void mult_spec( \
732 const transposed_col_ref<mat_type<base_type> *> &A_, \
733 const dense_matrix<base_type> &B, \
734 dense_matrix<base_type> &C, rcmult) { \
735 GMMLAPACK_TRACE("gemm_interface_tn"); \
736 const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
737 const char t = 'T', u = 'N'; \
738 const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
739 n=BLAS_INT(mat_ncols(B)), lda(k), ldb(k), ldc(m); \
740 const base_type alpha(1), beta(0); \
742 blas_name(&t, &u, &m, &n, &k, &alpha, \
743 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
744 else gmm::clear(C); \
747 gemm_interface_tn(sgemm_, BLAS_S, dense_matrix)
748 gemm_interface_tn(dgemm_, BLAS_D, dense_matrix)
749 gemm_interface_tn(cgemm_, BLAS_C, dense_matrix)
750 gemm_interface_tn(zgemm_, BLAS_Z, dense_matrix)
751 gemm_interface_tn(sgemm_, BLAS_S,
const dense_matrix)
752 gemm_interface_tn(dgemm_, BLAS_D,
const dense_matrix)
753 gemm_interface_tn(cgemm_, BLAS_C,
const dense_matrix)
754 gemm_interface_tn(zgemm_, BLAS_Z,
const dense_matrix)
760 # define gemm_interface_nt(blas_name, base_type, mat_type) \
762 mult_spec(const dense_matrix<base_type> &A, \
763 const transposed_col_ref<mat_type<base_type> *> &B_, \
764 dense_matrix<base_type> &C, r_mult) { \
765 GMMLAPACK_TRACE("gemm_interface_nt"); \
766 const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
767 const char t = 'N', u = 'T'; \
768 const BLAS_INT m=BLAS_INT(mat_nrows(A)), k=BLAS_INT(mat_ncols(A)), \
769 n=BLAS_INT(mat_nrows(B)), lda(m), ldb(n), ldc(m); \
770 const base_type alpha(1), beta(0); \
772 blas_name(&t, &u, &m, &n, &k, &alpha, \
773 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
774 else gmm::clear(C); \
777 gemm_interface_nt(sgemm_, BLAS_S, dense_matrix)
778 gemm_interface_nt(dgemm_, BLAS_D, dense_matrix)
779 gemm_interface_nt(cgemm_, BLAS_C, dense_matrix)
780 gemm_interface_nt(zgemm_, BLAS_Z, dense_matrix)
781 gemm_interface_nt(sgemm_, BLAS_S,
const dense_matrix)
782 gemm_interface_nt(dgemm_, BLAS_D,
const dense_matrix)
783 gemm_interface_nt(cgemm_, BLAS_C,
const dense_matrix)
784 gemm_interface_nt(zgemm_, BLAS_Z,
const dense_matrix)
790 # define gemm_interface_tt(blas_name, base_type, matA_type, matB_type) \
792 mult_spec(const transposed_col_ref<matA_type<base_type> *> &A_, \
793 const transposed_col_ref<matB_type<base_type> *> &B_, \
794 dense_matrix<base_type> &C, r_mult) { \
795 GMMLAPACK_TRACE("gemm_interface_tt"); \
796 const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
797 const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
798 const char t = 'T', u = 'T'; \
799 const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
800 n=BLAS_INT(mat_nrows(B)), lda(k), ldb(n), ldc(m); \
801 base_type alpha(1), beta(0); \
803 blas_name(&t, &u, &m, &n, &k, &alpha, \
804 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
805 else gmm::clear(C); \
808 gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, dense_matrix)
809 gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, dense_matrix)
810 gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, dense_matrix)
811 gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, dense_matrix)
812 gemm_interface_tt(sgemm_, BLAS_S,
const dense_matrix, dense_matrix)
813 gemm_interface_tt(dgemm_, BLAS_D,
const dense_matrix, dense_matrix)
814 gemm_interface_tt(cgemm_, BLAS_C,
const dense_matrix, dense_matrix)
815 gemm_interface_tt(zgemm_, BLAS_Z,
const dense_matrix, dense_matrix)
816 gemm_interface_tt(sgemm_, BLAS_S, dense_matrix,
const dense_matrix)
817 gemm_interface_tt(dgemm_, BLAS_D, dense_matrix,
const dense_matrix)
818 gemm_interface_tt(cgemm_, BLAS_C, dense_matrix,
const dense_matrix)
819 gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix,
const dense_matrix)
820 gemm_interface_tt(sgemm_, BLAS_S,
const dense_matrix,
const dense_matrix)
821 gemm_interface_tt(dgemm_, BLAS_D,
const dense_matrix,
const dense_matrix)
822 gemm_interface_tt(cgemm_, BLAS_C,
const dense_matrix,
const dense_matrix)
823 gemm_interface_tt(zgemm_, BLAS_Z,
const dense_matrix,
const dense_matrix)
830 # define gemm_interface_cn(blas_name, base_type) \
831 inline void mult_spec( \
832 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_, \
833 const dense_matrix<base_type> &B, \
834 dense_matrix<base_type> &C, rcmult) { \
835 GMMLAPACK_TRACE("gemm_interface_cn"); \
836 const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
837 const char t = 'C', u = 'N'; \
838 const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
839 n=BLAS_INT(mat_ncols(B)), lda(k), ldb(k), ldc(m); \
840 const base_type alpha(1), beta(0); \
842 blas_name(&t, &u, &m, &n, &k, &alpha, \
843 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
844 else gmm::clear(C); \
847 gemm_interface_cn(sgemm_, BLAS_S)
848 gemm_interface_cn(dgemm_, BLAS_D)
849 gemm_interface_cn(cgemm_, BLAS_C)
850 gemm_interface_cn(zgemm_, BLAS_Z)
856 # define gemm_interface_nc(blas_name, base_type) \
857 inline void mult_spec( \
858 const dense_matrix<base_type> &A, \
859 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &B_, \
860 dense_matrix<base_type> &C, c_mult, row_major) { \
861 GMMLAPACK_TRACE("gemm_interface_nc"); \
862 const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
863 const char t = 'N', u = 'C'; \
864 const BLAS_INT m=BLAS_INT(mat_nrows(A)), k=BLAS_INT(mat_ncols(A)), \
865 n=BLAS_INT(mat_nrows(B)), lda(m), ldb(n), ldc(m); \
866 const base_type alpha(1), beta(0); \
868 blas_name(&t, &u, &m, &n, &k, &alpha, \
869 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
870 else gmm::clear(C); \
873 gemm_interface_nc(sgemm_, BLAS_S)
874 gemm_interface_nc(dgemm_, BLAS_D)
875 gemm_interface_nc(cgemm_, BLAS_C)
876 gemm_interface_nc(zgemm_, BLAS_Z)
882 # define gemm_interface_cc(blas_name, base_type) \
883 inline void mult_spec( \
884 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_, \
885 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &B_, \
886 dense_matrix<base_type> &C, r_mult) { \
887 GMMLAPACK_TRACE("gemm_interface_cc"); \
888 const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
889 const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
890 const char t = 'C', u = 'C'; \
891 const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
892 n=BLAS_INT(mat_nrows(B)), lda(k), ldb(n), ldc(m); \
893 const base_type alpha(1), beta(0); \
895 blas_name(&t, &u, &m, &n, &k, &alpha, \
896 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
897 else gmm::clear(C); \
900 gemm_interface_cc(sgemm_, BLAS_S)
901 gemm_interface_cc(dgemm_, BLAS_D)
902 gemm_interface_cc(cgemm_, BLAS_C)
903 gemm_interface_cc(zgemm_, BLAS_Z)
909 # define trsv_interface(f_name, LorU, param1, trans1, blas_name, base_type)\
910 inline void f_name(param1(base_type), std::vector<base_type> &x, \
911 size_type k, bool is_unit) { \
912 GMMLAPACK_TRACE("trsv_interface"); \
913 const char l = LorU; trans1(base_type); char d = is_unit ? 'U' : 'N'; \
914 const BLAS_INT lda=BLAS_INT(mat_nrows(A)), inc(1), n=BLAS_INT(k); \
915 if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \
919 trsv_interface(lower_tri_solve,
'L', gem_p1_n, gem_trans1_n, strsv_, BLAS_S)
920 trsv_interface(lower_tri_solve,
'L', gem_p1_n, gem_trans1_n, dtrsv_, BLAS_D)
921 trsv_interface(lower_tri_solve,
'L', gem_p1_n, gem_trans1_n, ctrsv_, BLAS_C)
922 trsv_interface(lower_tri_solve,
'L', gem_p1_n, gem_trans1_n, ztrsv_, BLAS_Z)
925 trsv_interface(upper_tri_solve,
'U', gem_p1_n, gem_trans1_n, strsv_, BLAS_S)
926 trsv_interface(upper_tri_solve,
'U', gem_p1_n, gem_trans1_n, dtrsv_, BLAS_D)
927 trsv_interface(upper_tri_solve,
'U', gem_p1_n, gem_trans1_n, ctrsv_, BLAS_C)
928 trsv_interface(upper_tri_solve,
'U', gem_p1_n, gem_trans1_n, ztrsv_, BLAS_Z)
931 trsv_interface(lower_tri_solve,
'U', gem_p1_t, gem_trans1_t, strsv_, BLAS_S)
932 trsv_interface(lower_tri_solve,
'U', gem_p1_t, gem_trans1_t, dtrsv_, BLAS_D)
933 trsv_interface(lower_tri_solve,
'U', gem_p1_t, gem_trans1_t, ctrsv_, BLAS_C)
934 trsv_interface(lower_tri_solve,
'U', gem_p1_t, gem_trans1_t, ztrsv_, BLAS_Z)
937 trsv_interface(upper_tri_solve,
'L', gem_p1_t, gem_trans1_t, strsv_, BLAS_S)
938 trsv_interface(upper_tri_solve,
'L', gem_p1_t, gem_trans1_t, dtrsv_, BLAS_D)
939 trsv_interface(upper_tri_solve,
'L', gem_p1_t, gem_trans1_t, ctrsv_, BLAS_C)
940 trsv_interface(upper_tri_solve,
'L', gem_p1_t, gem_trans1_t, ztrsv_, BLAS_Z)
943 trsv_interface(lower_tri_solve,
'U', gem_p1_tc, gem_trans1_t, strsv_, BLAS_S)
944 trsv_interface(lower_tri_solve,
'U', gem_p1_tc, gem_trans1_t, dtrsv_, BLAS_D)
945 trsv_interface(lower_tri_solve,
'U', gem_p1_tc, gem_trans1_t, ctrsv_, BLAS_C)
946 trsv_interface(lower_tri_solve,
'U', gem_p1_tc, gem_trans1_t, ztrsv_, BLAS_Z)
949 trsv_interface(upper_tri_solve,
'L', gem_p1_tc, gem_trans1_t, strsv_, BLAS_S)
950 trsv_interface(upper_tri_solve,
'L', gem_p1_tc, gem_trans1_t, dtrsv_, BLAS_D)
951 trsv_interface(upper_tri_solve,
'L', gem_p1_tc, gem_trans1_t, ctrsv_, BLAS_C)
952 trsv_interface(upper_tri_solve,
'L', gem_p1_tc, gem_trans1_t, ztrsv_, BLAS_Z)
955 trsv_interface(lower_tri_solve,
'U', gem_p1_c, gem_trans1_c, strsv_, BLAS_S)
956 trsv_interface(lower_tri_solve,
'U', gem_p1_c, gem_trans1_c, dtrsv_, BLAS_D)
957 trsv_interface(lower_tri_solve,
'U', gem_p1_c, gem_trans1_c, ctrsv_, BLAS_C)
958 trsv_interface(lower_tri_solve,
'U', gem_p1_c, gem_trans1_c, ztrsv_, BLAS_Z)
961 trsv_interface(upper_tri_solve,
'L', gem_p1_c, gem_trans1_c, strsv_, BLAS_S)
962 trsv_interface(upper_tri_solve,
'L', gem_p1_c, gem_trans1_c, dtrsv_, BLAS_D)
963 trsv_interface(upper_tri_solve,
'L', gem_p1_c, gem_trans1_c, ctrsv_, BLAS_C)
964 trsv_interface(upper_tri_solve,
'L', gem_p1_c, gem_trans1_c, ztrsv_, BLAS_Z)
Basic linear algebra functions.
gmm interface for STL vectors.
Declaration of some matrix types (gmm::dense_matrix, gmm::row_matrix, gmm::col_matrix,...
size_t size_type
used as the common size type in the library