needed different proxy return types for fwd,inv to work around static asserts

This commit is contained in:
Mark Borgerding 2010-03-07 22:05:48 -05:00
parent 5b2c8b77df
commit e31929337e
2 changed files with 48 additions and 21 deletions

View File

@ -110,23 +110,27 @@
namespace Eigen {
//
template<typename T_SrcMat,typename T_FftIfc,bool T_ForwardTransform>
struct fft_result_proxy;
template<typename T_SrcMat,typename T_FftIfc> struct fft_fwd_proxy;
template<typename T_SrcMat,typename T_FftIfc> struct fft_inv_proxy;
template<typename T_SrcMat,typename T_FftIfc,bool T_ForwardTransform>
struct ei_traits< fft_result_proxy<T_SrcMat,T_FftIfc,T_ForwardTransform> >
template<typename T_SrcMat,typename T_FftIfc>
struct ei_traits< fft_fwd_proxy<T_SrcMat,T_FftIfc> >
{
typedef typename T_SrcMat::PlainObject ReturnType;
};
template<typename T_SrcMat,typename T_FftIfc>
struct ei_traits< fft_inv_proxy<T_SrcMat,T_FftIfc> >
{
typedef typename T_SrcMat::PlainObject ReturnType;
};
template<typename T_SrcMat,typename T_FftIfc,bool T_ForwardTransform>
struct fft_result_proxy
: public ReturnByValue<fft_result_proxy<T_SrcMat,T_FftIfc,T_ForwardTransform> >
template<typename T_SrcMat,typename T_FftIfc>
struct fft_fwd_proxy
: public ReturnByValue<fft_fwd_proxy<T_SrcMat,T_FftIfc> >
{
fft_result_proxy(const T_SrcMat& src,T_FftIfc & fft,int nfft) : m_src(src),m_ifc(fft), m_nfft(nfft) {}
fft_fwd_proxy(const T_SrcMat& src,T_FftIfc & fft,int nfft) : m_src(src),m_ifc(fft), m_nfft(nfft) {}
template<typename T_DestMat> void evalTo(T_DestMat& dst) const;
@ -138,6 +142,21 @@ struct fft_result_proxy
int m_nfft;
};
template<typename T_SrcMat,typename T_FftIfc>
struct fft_inv_proxy
: public ReturnByValue<fft_inv_proxy<T_SrcMat,T_FftIfc> >
{
fft_inv_proxy(const T_SrcMat& src,T_FftIfc & fft,int nfft) : m_src(src),m_ifc(fft), m_nfft(nfft) {}
template<typename T_DestMat> void evalTo(T_DestMat& dst) const;
int rows() const { return m_src.rows(); }
int cols() const { return m_src.cols(); }
protected:
const T_SrcMat & m_src;
T_FftIfc & m_ifc;
int m_nfft;
};
template <typename T_Scalar,
@ -205,10 +224,12 @@ class FFT
inline
void fwd( MatrixBase<ComplexDerived> & dst, const MatrixBase<InputDerived> & src,int nfft=-1)
{
typedef typename ComplexDerived::Scalar dst_type;
typedef typename InputDerived::Scalar src_type;
EIGEN_STATIC_ASSERT_VECTOR_ONLY(InputDerived)
EIGEN_STATIC_ASSERT_VECTOR_ONLY(ComplexDerived)
EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(ComplexDerived,InputDerived) // size at compile-time
EIGEN_STATIC_ASSERT((ei_is_same_type<typename ComplexDerived::Scalar, Complex>::ret),
EIGEN_STATIC_ASSERT((ei_is_same_type<dst_type, Complex>::ret),
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
EIGEN_STATIC_ASSERT(int(InputDerived::Flags)&int(ComplexDerived::Flags)&DirectAccessBit,
THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES)
@ -216,13 +237,13 @@ class FFT
if (nfft<1)
nfft = src.size();
if ( NumTraits< typename InputDerived::Scalar >::IsComplex == 0 && HasFlag(HalfSpectrum) )
if ( NumTraits< src_type >::IsComplex == 0 && HasFlag(HalfSpectrum) )
dst.derived().resize( (nfft>>1)+1);
else
dst.derived().resize(nfft);
if ( src.stride() != 1 || src.size() < nfft ) {
Matrix<typename InputDerived::Scalar,1,Dynamic> tmp;
Matrix<src_type,1,Dynamic> tmp;
if (src.size()<nfft) {
tmp.setZero(nfft);
tmp.block(0,0,src.size(),1 ) = src;
@ -237,18 +258,18 @@ class FFT
template<typename InputDerived>
inline
fft_result_proxy< MatrixBase<InputDerived>, FFT<T_Scalar,T_Impl> ,true>
fft_fwd_proxy< MatrixBase<InputDerived>, FFT<T_Scalar,T_Impl> >
fwd( const MatrixBase<InputDerived> & src,int nfft=-1)
{
return fft_result_proxy< MatrixBase<InputDerived> ,FFT<T_Scalar,T_Impl>,true>( src, *this,nfft );
return fft_fwd_proxy< MatrixBase<InputDerived> ,FFT<T_Scalar,T_Impl> >( src, *this,nfft );
}
template<typename InputDerived>
inline
fft_result_proxy< MatrixBase<InputDerived>, FFT<T_Scalar,T_Impl> ,false>
fft_inv_proxy< MatrixBase<InputDerived>, FFT<T_Scalar,T_Impl> >
inv( const MatrixBase<InputDerived> & src,int nfft=-1)
{
return fft_result_proxy< MatrixBase<InputDerived> ,FFT<T_Scalar,T_Impl>,false>( src, *this,nfft );
return fft_inv_proxy< MatrixBase<InputDerived> ,FFT<T_Scalar,T_Impl> >( src, *this,nfft );
}
inline
@ -382,13 +403,17 @@ class FFT
int m_flag;
};
template<typename T_SrcMat,typename T_FftIfc,bool T_ForwardTransform>
template<typename T_SrcMat,typename T_FftIfc>
template<typename T_DestMat> inline
void fft_result_proxy<T_SrcMat,T_FftIfc,T_ForwardTransform>::evalTo(T_DestMat& dst) const
void fft_fwd_proxy<T_SrcMat,T_FftIfc>::evalTo(T_DestMat& dst) const
{
if (T_ForwardTransform)
m_ifc.fwd( dst, m_src, m_nfft);
else
}
template<typename T_SrcMat,typename T_FftIfc>
template<typename T_DestMat> inline
void fft_inv_proxy<T_SrcMat,T_FftIfc>::evalTo(T_DestMat& dst) const
{
m_ifc.inv( dst, m_src, m_nfft);
}

View File

@ -235,6 +235,9 @@ void test_return_by_value()
Matrix<complex<T>,nrows,ncols> out1;
Matrix<complex<T>,nrows,ncols> out2;
FFT<T> fft;
fft.SetFlag(fft.HalfSpectrum );
fft.fwd(out1,in);
out2 = fft.fwd(in);
VERIFY( (out1-out2).norm() < test_precision<T>() );
@ -246,7 +249,6 @@ void test_FFTW()
{
test_return_by_value<float,1,32>();
test_return_by_value<double,1,32>();
//test_return_by_value<long double,1,32>();
//CALL_SUBTEST( ( test_complex2d<float,4,8> () ) ); CALL_SUBTEST( ( test_complex2d<double,4,8> () ) );
//CALL_SUBTEST( ( test_complex2d<long double,4,8> () ) );
CALL_SUBTEST( test_complex<float>(32) ); CALL_SUBTEST( test_complex<double>(32) ); CALL_SUBTEST( test_complex<long double>(32) );