created FFT::fwd and FFT::inv with ReturnByValue

This commit is contained in:
Mark Borgerding 2010-03-07 21:31:06 -05:00
parent b528b747c1
commit 5b2c8b77df
2 changed files with 80 additions and 3 deletions

View File

@ -110,12 +110,42 @@
namespace Eigen {
template <typename _Scalar,
typename _Impl=default_fft_impl<_Scalar> >
//
template<typename T_SrcMat,typename T_FftIfc,bool T_ForwardTransform>
struct fft_result_proxy;
template<typename T_SrcMat,typename T_FftIfc,bool T_ForwardTransform>
struct ei_traits< fft_result_proxy<T_SrcMat,T_FftIfc,T_ForwardTransform> >
{
typedef typename T_SrcMat::PlainObject ReturnType;
};
template<typename T_SrcMat,typename T_FftIfc,bool T_ForwardTransform>
struct fft_result_proxy
: public ReturnByValue<fft_result_proxy<T_SrcMat,T_FftIfc,T_ForwardTransform> >
{
fft_result_proxy(const T_SrcMat& src,T_FftIfc & fft,int nfft) : m_src(src),m_ifc(fft), m_nfft(nfft) {}
template<typename T_DestMat> void evalTo(T_DestMat& dst) const;
int rows() const { return m_src.rows(); }
int cols() const { return m_src.cols(); }
protected:
const T_SrcMat & m_src;
T_FftIfc & m_ifc;
int m_nfft;
};
template <typename T_Scalar,
typename T_Impl=default_fft_impl<T_Scalar> >
class FFT
{
public:
typedef _Impl impl_type;
typedef T_Impl impl_type;
typedef typename impl_type::Scalar Scalar;
typedef typename impl_type::Complex Complex;
@ -204,6 +234,22 @@ class FFT
fwd( &dst[0],&src[0],nfft );
}
}
template<typename InputDerived>
inline
fft_result_proxy< MatrixBase<InputDerived>, FFT<T_Scalar,T_Impl> ,true>
fwd( const MatrixBase<InputDerived> & src,int nfft=-1)
{
return fft_result_proxy< MatrixBase<InputDerived> ,FFT<T_Scalar,T_Impl>,true>( src, *this,nfft );
}
template<typename InputDerived>
inline
fft_result_proxy< MatrixBase<InputDerived>, FFT<T_Scalar,T_Impl> ,false>
inv( const MatrixBase<InputDerived> & src,int nfft=-1)
{
return fft_result_proxy< MatrixBase<InputDerived> ,FFT<T_Scalar,T_Impl>,false>( src, *this,nfft );
}
inline
void inv( Complex * dst, const Complex * src, int nfft)
@ -335,6 +381,17 @@ class FFT
impl_type m_impl;
int m_flag;
};
template<typename T_SrcMat,typename T_FftIfc,bool T_ForwardTransform>
template<typename T_DestMat> inline
void fft_result_proxy<T_SrcMat,T_FftIfc,T_ForwardTransform>::evalTo(T_DestMat& dst) const
{
if (T_ForwardTransform)
m_ifc.fwd( dst, m_src, m_nfft);
else
m_ifc.inv( dst, m_src, m_nfft);
}
}
#endif
/* vim: set filetype=cpp et sw=2 ts=2 ai: */

View File

@ -225,8 +225,28 @@ void test_complex2d()
}
*/
template <typename T,int nrows,int ncols>
void test_return_by_value()
{
Matrix<complex<T>,nrows,ncols> in;
Matrix<complex<T>,nrows,ncols> in1;
in.Random();
Matrix<complex<T>,nrows,ncols> out1;
Matrix<complex<T>,nrows,ncols> out2;
FFT<T> fft;
fft.fwd(out1,in);
out2 = fft.fwd(in);
VERIFY( (out1-out2).norm() < test_precision<T>() );
in1 = fft.inv(out1);
VERIFY( (in1-in).norm() < test_precision<T>() );
}
void test_FFTW()
{
test_return_by_value<float,1,32>();
test_return_by_value<double,1,32>();
//test_return_by_value<long double,1,32>();
//CALL_SUBTEST( ( test_complex2d<float,4,8> () ) ); CALL_SUBTEST( ( test_complex2d<double,4,8> () ) );
//CALL_SUBTEST( ( test_complex2d<long double,4,8> () ) );
CALL_SUBTEST( test_complex<float>(32) ); CALL_SUBTEST( test_complex<double>(32) ); CALL_SUBTEST( test_complex<long double>(32) );