source: mmcs/armadillo_bits/op_trimat_meat.hpp @ 8daa049

matrices
Last change on this file since 8daa049 was 9dd61b1, checked in by rboet <rboet@…>, 9 years ago

Avance del proyecto 60%

  • Property mode set to 100644
File size: 5.4 KB
Line 
1// Copyright (C) 2010-2012 Conrad Sanderson
2// Copyright (C) 2010-2012 NICTA (www.nicta.com.au)
3// Copyright (C) 2011      Ryan Curtin
4//
5// This Source Code Form is subject to the terms of the Mozilla Public
6// License, v. 2.0. If a copy of the MPL was not distributed with this
7// file, You can obtain one at http://mozilla.org/MPL/2.0/.
8
9
10//! \addtogroup op_trimat
11//! @{
12
13
14
15template<typename eT>
16inline
17void
18op_trimat::fill_zeros(Mat<eT>& out, const bool upper)
19  {
20  arma_extra_debug_sigprint();
21 
22  const uword N = out.n_rows;
23 
24  if(upper)
25    {
26    // upper triangular: set all elements below the diagonal to zero
27   
28    for(uword i=0; i<N; ++i)
29      {
30      eT* data = out.colptr(i);
31     
32      arrayops::inplace_set( &data[i+1], eT(0), (N-(i+1)) );
33      }
34    }
35  else
36    {
37    // lower triangular: set all elements above the diagonal to zero
38   
39    for(uword i=1; i<N; ++i)
40      {
41      eT* data = out.colptr(i);
42     
43      arrayops::inplace_set( data, eT(0), i );
44      }
45    }
46  }
47
48
49
50template<typename T1>
51inline
52void
53op_trimat::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_trimat>& in)
54  {
55  arma_extra_debug_sigprint();
56 
57  typedef typename T1::elem_type eT;
58 
59  const unwrap<T1>   tmp(in.m);
60  const Mat<eT>& A = tmp.M;
61 
62  arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square" );
63 
64  const uword N     = A.n_rows;
65  const bool  upper = (in.aux_uword_a == 0);
66 
67  if(&out != &A)
68    {
69    out.copy_size(A);
70   
71    if(upper)
72      {
73      // upper triangular: copy the diagonal and the elements above the diagonal
74      for(uword i=0; i<N; ++i)
75        {
76        const eT* A_data   = A.colptr(i);
77              eT* out_data = out.colptr(i);
78       
79        arrayops::copy( out_data, A_data, i+1 );
80        }
81      }
82    else
83      {
84      // lower triangular: copy the diagonal and the elements below the diagonal
85      for(uword i=0; i<N; ++i)
86        {
87        const eT* A_data   = A.colptr(i);
88              eT* out_data = out.colptr(i);
89       
90        arrayops::copy( &out_data[i], &A_data[i], N-i );
91        }
92      }
93    }
94 
95  op_trimat::fill_zeros(out, upper);
96  }
97
98
99
100template<typename T1>
101inline
102void
103op_trimat::apply(Mat<typename T1::elem_type>& out, const Op<Op<T1, op_htrans>, op_trimat>& in)
104  {
105  arma_extra_debug_sigprint();
106 
107  typedef typename T1::elem_type eT;
108 
109  const unwrap<T1>   tmp(in.m.m);
110  const Mat<eT>& A = tmp.M;
111 
112  const bool upper = (in.aux_uword_a == 0);
113 
114  op_trimat::apply_htrans(out, A, upper);
115  }
116
117
118
119template<typename eT>
120inline
121void
122op_trimat::apply_htrans
123  (
124        Mat<eT>& out,
125  const Mat<eT>& A,
126  const bool     upper,
127  const typename arma_not_cx<eT>::result* junk
128  )
129  {
130  arma_extra_debug_sigprint();
131  arma_ignore(junk);
132 
133  // This specialisation is for trimatl(trans(X)) = trans(trimatu(X)) and also
134  // trimatu(trans(X)) = trans(trimatl(X)).  We want to avoid the creation of an
135  // extra temporary.
136 
137  // It doesn't matter if the input and output matrices are the same; we will
138  // pull data from the upper or lower triangular to the lower or upper
139  // triangular (respectively) and then set the rest to 0, so overwriting issues
140  // aren't present.
141 
142  arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square" );
143 
144  const uword N = A.n_rows;
145 
146  if(&out != &A)
147    {
148    out.copy_size(A);
149    }
150 
151  // We can't really get away with any array copy operations here,
152  // unfortunately...
153 
154  if(upper)
155    {
156    // Upper triangular: but since we're transposing, we're taking the lower
157    // triangular and putting it in the upper half.
158    for(uword row = 0; row < N; ++row)
159      {
160      eT* out_colptr = out.colptr(row);
161     
162      for(uword col = 0; col <= row; ++col)
163        {
164        //out.at(col, row) = A.at(row, col);
165        out_colptr[col] = A.at(row, col);
166        }
167      }
168    }
169  else
170    {
171    // Lower triangular: but since we're transposing, we're taking the upper
172    // triangular and putting it in the lower half.
173    for(uword row = 0; row < N; ++row)
174      {
175      for(uword col = row; col < N; ++col)
176        {
177        out.at(col, row) = A.at(row, col);
178        }
179      }
180    }
181 
182  op_trimat::fill_zeros(out, upper);
183  }
184
185
186
187template<typename eT>
188inline
189void
190op_trimat::apply_htrans
191  (
192        Mat<eT>& out,
193  const Mat<eT>& A,
194  const bool     upper,
195  const typename arma_cx_only<eT>::result* junk
196  )
197  {
198  arma_extra_debug_sigprint();
199  arma_ignore(junk);
200 
201  arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square" );
202 
203  const uword N = A.n_rows;
204 
205  if(&out != &A)
206    {
207    out.copy_size(A);
208    }
209 
210  if(upper)
211    {
212    // Upper triangular: but since we're transposing, we're taking the lower
213    // triangular and putting it in the upper half.
214    for(uword row = 0; row < N; ++row)
215      {
216      eT* out_colptr = out.colptr(row);
217     
218      for(uword col = 0; col <= row; ++col)
219        {
220        //out.at(col, row) = std::conj( A.at(row, col) );
221        out_colptr[col] = std::conj( A.at(row, col) );
222        }
223      }
224    }
225  else
226    {
227    // Lower triangular: but since we're transposing, we're taking the upper
228    // triangular and putting it in the lower half.
229    for(uword row = 0; row < N; ++row)
230      {
231      for(uword col = row; col < N; ++col)
232        {
233        out.at(col, row) = std::conj( A.at(row, col) );
234        }
235      }
236    }
237 
238  op_trimat::fill_zeros(out, upper);
239  }
240
241
242
243//! @}
Note: See TracBrowser for help on using the repository browser.