source: mmcs/armadillo_bits/blas_wrapper.hpp @ 8ad4484

matrices
Last change on this file since 8ad4484 was 9dd61b1, checked in by rboet <rboet@…>, 9 years ago

Avance del proyecto 60%

  • Property mode set to 100644
File size: 6.7 KB
Line 
1// Copyright (C) 2008-2013 Conrad Sanderson
2// Copyright (C) 2008-2013 NICTA (www.nicta.com.au)
3//
4// This Source Code Form is subject to the terms of the Mozilla Public
5// License, v. 2.0. If a copy of the MPL was not distributed with this
6// file, You can obtain one at http://mozilla.org/MPL/2.0/.
7
8
9
10#ifdef ARMA_USE_BLAS
11
12
13//! \namespace blas namespace for BLAS functions
14namespace blas
15  {
16 
17 
18  template<typename eT>
19  inline
20  void
21  gemv(const char* transA, const blas_int* m, const blas_int* n, const eT* alpha, const eT* A, const blas_int* ldA, const eT* x, const blas_int* incx, const eT* beta, eT* y, const blas_int* incy)
22    {
23    arma_type_check((is_supported_blas_type<eT>::value == false));
24   
25    if(is_float<eT>::value == true)
26      {
27      typedef float T;
28      arma_fortran(arma_sgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
29      }
30    else
31    if(is_double<eT>::value == true)
32      {
33      typedef double T;
34      arma_fortran(arma_dgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
35      }
36    else
37    if(is_supported_complex_float<eT>::value == true)
38      {
39      typedef std::complex<float> T;
40      arma_fortran(arma_cgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
41      }
42    else
43    if(is_supported_complex_double<eT>::value == true)
44      {
45      typedef std::complex<double> T;
46      arma_fortran(arma_zgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
47      }
48   
49    }
50 
51 
52 
53  template<typename eT>
54  inline
55  void
56  gemm(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const eT* alpha, const eT* A, const blas_int* ldA, const eT* B, const blas_int* ldB, const eT* beta, eT* C, const blas_int* ldC)
57    {
58    arma_type_check((is_supported_blas_type<eT>::value == false));
59   
60    if(is_float<eT>::value == true)
61      {
62      typedef float T;
63      arma_fortran(arma_sgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
64      }
65    else
66    if(is_double<eT>::value == true)
67      {
68      typedef double T;
69      arma_fortran(arma_dgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
70      }
71    else
72    if(is_supported_complex_float<eT>::value == true)
73      {
74      typedef std::complex<float> T;
75      arma_fortran(arma_cgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
76      }
77    else
78    if(is_supported_complex_double<eT>::value == true)
79      {
80      typedef std::complex<double> T;
81      arma_fortran(arma_zgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
82      }
83   
84    }
85 
86 
87 
88  template<typename eT>
89  inline
90  void
91  syrk(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const eT* alpha, const eT* A, const blas_int* ldA, const eT* beta, eT* C, const blas_int* ldC)
92    {
93    arma_type_check((is_supported_blas_type<eT>::value == false));
94   
95    if(is_float<eT>::value == true)
96      {
97      typedef float T;
98      arma_fortran(arma_ssyrk)(uplo, transA, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)beta, (T*)C, ldC);
99      }
100    else
101    if(is_double<eT>::value == true)
102      {
103      typedef double T;
104      arma_fortran(arma_dsyrk)(uplo, transA, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)beta, (T*)C, ldC);
105      }
106    }
107 
108 
109 
110  template<typename T>
111  inline
112  void
113  herk(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const T* alpha, const std::complex<T>* A, const blas_int* ldA, const T* beta, std::complex<T>* C, const blas_int* ldC)
114    {
115    arma_type_check((is_supported_blas_type<T>::value == false));
116   
117    if(is_float<T>::value == true)
118      {
119      typedef float                  TT;
120      typedef std::complex<float> cx_TT;
121     
122      arma_fortran(arma_cherk)(uplo, transA, n, k, (const TT*)alpha, (const cx_TT*)A, ldA, (const TT*)beta, (cx_TT*)C, ldC);
123      }
124    else
125    if(is_double<T>::value == true)
126      {
127      typedef double                  TT;
128      typedef std::complex<double> cx_TT;
129     
130      arma_fortran(arma_zherk)(uplo, transA, n, k, (const TT*)alpha, (const cx_TT*)A, ldA, (const TT*)beta, (cx_TT*)C, ldC);
131      }
132    }
133 
134 
135 
136  template<typename eT>
137  inline
138  eT
139  dot(const uword n_elem, const eT* x, const eT* y)
140    {
141    arma_type_check((is_supported_blas_type<eT>::value == false));
142   
143    if(is_float<eT>::value == true)
144      {
145      #if defined(ARMA_BLAS_SDOT_BUG)
146        {
147        if(n_elem == 0)  { return eT(0); }
148       
149        const char trans   = 'T';
150       
151        const blas_int m   = blas_int(n_elem);
152        const blas_int n   = 1;
153        //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1);
154        const blas_int inc = 1;
155       
156        const eT alpha     = eT(1);
157        const eT beta      = eT(0);
158       
159        eT result[2];  // paranoia: using two elements instead of one
160       
161        //blas::gemv(&trans, &m, &n, &alpha, x, &lda, y, &inc, &beta, &result[0], &inc);
162        blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], &inc);
163       
164        return result[0];
165        }
166      #else
167        {
168        blas_int n   = blas_int(n_elem);
169        blas_int inc = 1;
170       
171        typedef float T;
172        return arma_fortran(arma_sdot)(&n, (const T*)x, &inc, (const T*)y, &inc);
173        }
174      #endif
175      }
176    else
177    if(is_double<eT>::value == true)
178      {
179      blas_int n   = blas_int(n_elem);
180      blas_int inc = 1;
181     
182      typedef double T;
183      return arma_fortran(arma_ddot)(&n, (const T*)x, &inc, (const T*)y, &inc);
184      }
185    else
186    if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) )
187      {
188      if(n_elem == 0)  { return eT(0); }
189     
190      // using gemv() workaround due to compatibility issues with cdotu() and zdotu()
191     
192      const char trans   = 'T';
193     
194      const blas_int m   = blas_int(n_elem);
195      const blas_int n   = 1;
196      //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1);
197      const blas_int inc = 1;
198     
199      const eT alpha     = eT(1);
200      const eT beta      = eT(0);
201     
202      eT result[2];  // paranoia: using two elements instead of one
203     
204      //blas::gemv(&trans, &m, &n, &alpha, x, &lda, y, &inc, &beta, &result[0], &inc);
205      blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], &inc);
206     
207      return result[0];
208      }
209    else
210      {
211      return eT(0);
212      }
213    }
214  }
215
216
217#endif
Note: See TracBrowser for help on using the repository browser.