1 | // Copyright (C) 2008-2012 Conrad Sanderson |
---|
2 | // Copyright (C) 2008-2012 NICTA (www.nicta.com.au) |
---|
3 | // Copyright (C) 2012 Ryan Curtin |
---|
4 | // |
---|
5 | // This Source Code Form is subject to the terms of the Mozilla Public |
---|
6 | // License, v. 2.0. If a copy of the MPL was not distributed with this |
---|
7 | // file, You can obtain one at http://mozilla.org/MPL/2.0/. |
---|
8 | |
---|
9 | |
---|
10 | //! \addtogroup operator_schur |
---|
11 | //! @{ |
---|
12 | |
---|
13 | |
---|
14 | // operator %, which we define it to do a schur product (element-wise multiplication) |
---|
15 | |
---|
16 | |
---|
17 | //! element-wise multiplication of user-accessible Armadillo objects with same element type |
---|
18 | template<typename T1, typename T2> |
---|
19 | arma_inline |
---|
20 | typename |
---|
21 | enable_if2 |
---|
22 | < |
---|
23 | is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value, |
---|
24 | const eGlue<T1, T2, eglue_schur> |
---|
25 | >::result |
---|
26 | operator% |
---|
27 | ( |
---|
28 | const T1& X, |
---|
29 | const T2& Y |
---|
30 | ) |
---|
31 | { |
---|
32 | arma_extra_debug_sigprint(); |
---|
33 | |
---|
34 | return eGlue<T1, T2, eglue_schur>(X, Y); |
---|
35 | } |
---|
36 | |
---|
37 | |
---|
38 | |
---|
39 | //! element-wise multiplication of user-accessible Armadillo objects with different element types |
---|
40 | template<typename T1, typename T2> |
---|
41 | inline |
---|
42 | typename |
---|
43 | enable_if2 |
---|
44 | < |
---|
45 | (is_arma_type<T1>::value && is_arma_type<T2>::value && (is_same_type<typename T1::elem_type, typename T2::elem_type>::no)), |
---|
46 | const mtGlue<typename promote_type<typename T1::elem_type, typename T2::elem_type>::result, T1, T2, glue_mixed_schur> |
---|
47 | >::result |
---|
48 | operator% |
---|
49 | ( |
---|
50 | const T1& X, |
---|
51 | const T2& Y |
---|
52 | ) |
---|
53 | { |
---|
54 | arma_extra_debug_sigprint(); |
---|
55 | |
---|
56 | typedef typename T1::elem_type eT1; |
---|
57 | typedef typename T2::elem_type eT2; |
---|
58 | |
---|
59 | typedef typename promote_type<eT1,eT2>::result out_eT; |
---|
60 | |
---|
61 | promote_type<eT1,eT2>::check(); |
---|
62 | |
---|
63 | return mtGlue<out_eT, T1, T2, glue_mixed_schur>( X, Y ); |
---|
64 | } |
---|
65 | |
---|
66 | |
---|
67 | |
---|
68 | //! element-wise multiplication of two sparse matrices |
---|
69 | template<typename T1, typename T2> |
---|
70 | inline |
---|
71 | typename |
---|
72 | enable_if2 |
---|
73 | < |
---|
74 | (is_arma_sparse_type<T1>::value && is_arma_sparse_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value), |
---|
75 | SpMat<typename T1::elem_type> |
---|
76 | >::result |
---|
77 | operator% |
---|
78 | ( |
---|
79 | const SpBase<typename T1::elem_type, T1>& x, |
---|
80 | const SpBase<typename T2::elem_type, T2>& y |
---|
81 | ) |
---|
82 | { |
---|
83 | arma_extra_debug_sigprint(); |
---|
84 | |
---|
85 | typedef typename T1::elem_type eT; |
---|
86 | |
---|
87 | const SpProxy<T1> pa(x.get_ref()); |
---|
88 | const SpProxy<T2> pb(y.get_ref()); |
---|
89 | |
---|
90 | arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication"); |
---|
91 | |
---|
92 | SpMat<typename T1::elem_type> result(pa.get_n_rows(), pa.get_n_cols()); |
---|
93 | |
---|
94 | if( (pa.get_n_nonzero() != 0) && (pb.get_n_nonzero() != 0) ) |
---|
95 | { |
---|
96 | // Resize memory to correct size. |
---|
97 | result.mem_resize(n_unique(x, y, op_n_unique_mul())); |
---|
98 | |
---|
99 | // Now iterate across both matrices. |
---|
100 | typename SpProxy<T1>::const_iterator_type x_it = pa.begin(); |
---|
101 | typename SpProxy<T2>::const_iterator_type y_it = pb.begin(); |
---|
102 | |
---|
103 | typename SpProxy<T1>::const_iterator_type x_end = pa.end(); |
---|
104 | typename SpProxy<T2>::const_iterator_type y_end = pb.end(); |
---|
105 | |
---|
106 | uword cur_val = 0; |
---|
107 | while((x_it != x_end) || (y_it != y_end)) |
---|
108 | { |
---|
109 | if(x_it == y_it) |
---|
110 | { |
---|
111 | const eT val = (*x_it) * (*y_it); |
---|
112 | |
---|
113 | if (val != eT(0)) |
---|
114 | { |
---|
115 | access::rw(result.values[cur_val]) = val; |
---|
116 | access::rw(result.row_indices[cur_val]) = x_it.row(); |
---|
117 | ++access::rw(result.col_ptrs[x_it.col() + 1]); |
---|
118 | ++cur_val; |
---|
119 | } |
---|
120 | |
---|
121 | ++x_it; |
---|
122 | ++y_it; |
---|
123 | } |
---|
124 | else |
---|
125 | { |
---|
126 | const uword x_it_row = x_it.row(); |
---|
127 | const uword x_it_col = x_it.col(); |
---|
128 | |
---|
129 | const uword y_it_row = y_it.row(); |
---|
130 | const uword y_it_col = y_it.col(); |
---|
131 | |
---|
132 | if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end |
---|
133 | { |
---|
134 | ++x_it; |
---|
135 | } |
---|
136 | else |
---|
137 | { |
---|
138 | ++y_it; |
---|
139 | } |
---|
140 | } |
---|
141 | } |
---|
142 | |
---|
143 | // Fix column pointers to be cumulative. |
---|
144 | for(uword c = 1; c <= result.n_cols; ++c) |
---|
145 | { |
---|
146 | access::rw(result.col_ptrs[c]) += result.col_ptrs[c - 1]; |
---|
147 | } |
---|
148 | } |
---|
149 | |
---|
150 | return result; |
---|
151 | } |
---|
152 | |
---|
153 | |
---|
154 | |
---|
155 | //! element-wise multiplication of one dense and one sparse object |
---|
156 | template<typename T1, typename T2> |
---|
157 | inline |
---|
158 | typename |
---|
159 | enable_if2 |
---|
160 | < |
---|
161 | (is_arma_type<T1>::value && is_arma_sparse_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value), |
---|
162 | SpMat<typename T1::elem_type> |
---|
163 | >::result |
---|
164 | operator% |
---|
165 | ( |
---|
166 | const T1& x, |
---|
167 | const T2& y |
---|
168 | ) |
---|
169 | { |
---|
170 | arma_extra_debug_sigprint(); |
---|
171 | |
---|
172 | typedef typename T1::elem_type eT; |
---|
173 | |
---|
174 | const Proxy<T1> pa(x); |
---|
175 | const SpProxy<T2> pb(y); |
---|
176 | |
---|
177 | arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication"); |
---|
178 | |
---|
179 | SpMat<eT> result(pa.get_n_rows(), pa.get_n_cols()); |
---|
180 | |
---|
181 | // count new size |
---|
182 | uword new_n_nonzero = 0; |
---|
183 | |
---|
184 | typename SpProxy<T2>::const_iterator_type it = pb.begin(); |
---|
185 | typename SpProxy<T2>::const_iterator_type it_end = pb.end(); |
---|
186 | |
---|
187 | while(it != it_end) |
---|
188 | { |
---|
189 | if( ((*it) * pa.at(it.row(), it.col())) != eT(0) ) |
---|
190 | { |
---|
191 | ++new_n_nonzero; |
---|
192 | } |
---|
193 | |
---|
194 | ++it; |
---|
195 | } |
---|
196 | |
---|
197 | // Resize memory accordingly. |
---|
198 | result.mem_resize(new_n_nonzero); |
---|
199 | |
---|
200 | uword cur_val = 0; |
---|
201 | |
---|
202 | typename SpProxy<T2>::const_iterator_type it2 = pb.begin(); |
---|
203 | |
---|
204 | while(it2 != it_end) |
---|
205 | { |
---|
206 | const uword it2_row = it2.row(); |
---|
207 | const uword it2_col = it2.col(); |
---|
208 | |
---|
209 | const eT val = (*it2) * pa.at(it2_row, it2_col); |
---|
210 | |
---|
211 | if(val != eT(0)) |
---|
212 | { |
---|
213 | access::rw(result.values[cur_val]) = val; |
---|
214 | access::rw(result.row_indices[cur_val]) = it2_row; |
---|
215 | ++access::rw(result.col_ptrs[it2_col + 1]); |
---|
216 | ++cur_val; |
---|
217 | } |
---|
218 | |
---|
219 | ++it2; |
---|
220 | } |
---|
221 | |
---|
222 | // Fix column pointers. |
---|
223 | for(uword c = 1; c <= result.n_cols; ++c) |
---|
224 | { |
---|
225 | access::rw(result.col_ptrs[c]) += result.col_ptrs[c - 1]; |
---|
226 | } |
---|
227 | |
---|
228 | return result; |
---|
229 | } |
---|
230 | |
---|
231 | |
---|
232 | |
---|
233 | //! element-wise multiplication of one sparse and one dense object |
---|
234 | template<typename T1, typename T2> |
---|
235 | inline |
---|
236 | typename |
---|
237 | enable_if2 |
---|
238 | < |
---|
239 | (is_arma_sparse_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value), |
---|
240 | SpMat<typename T1::elem_type> |
---|
241 | >::result |
---|
242 | operator% |
---|
243 | ( |
---|
244 | const T1& x, |
---|
245 | const T2& y |
---|
246 | ) |
---|
247 | { |
---|
248 | arma_extra_debug_sigprint(); |
---|
249 | |
---|
250 | // This operation is commutative. |
---|
251 | return (y % x); |
---|
252 | } |
---|
253 | |
---|
254 | |
---|
255 | |
---|
256 | //! @} |
---|