moved scaling to Eigen::FFT

This commit is contained in:
Mark Borgerding 2009-10-30 19:50:11 -04:00
parent 0fa68b9e50
commit a26b729cc9
3 changed files with 61 additions and 33 deletions

View File

@ -28,6 +28,7 @@
#include <complex>
#include <vector>
#include <map>
#include <Eigen/Core>
#ifdef EIGEN_FFTW_DEFAULT
// FFTW: faster, GPL -- incompatible with Eigen in LGPL form, bigger code size
@ -65,10 +66,31 @@ class FFT
typedef typename impl_type::Scalar Scalar;
typedef typename impl_type::Complex Complex;
FFT(const impl_type & impl=impl_type() ) :m_impl(impl) { }
enum Flag {
Default=0, // goof proof
Unscaled=1,
HalfSpectrum=2,
// SomeOtherSpeedOptimization=4
Speedy=32767
};
template <typename _Input>
void fwd( Complex * dst, const _Input * src, int nfft)
FFT( const impl_type & impl=impl_type() , Flag flags=Default ) :m_impl(impl),m_flag(flags) { }
inline
bool HasFlag(Flag f) const { return (m_flag & (int)f) == f;}
inline
void SetFlag(Flag f) { m_flag |= (int)f;}
inline
void ClearFlag(Flag f) { m_flag &= (~(int)f);}
void fwd( Complex * dst, const Scalar * src, int nfft)
{
m_impl.fwd(dst,src,nfft);
}
void fwd( Complex * dst, const Complex * src, int nfft)
{
m_impl.fwd(dst,src,nfft);
}
@ -76,8 +98,11 @@ class FFT
template <typename _Input>
void fwd( std::vector<Complex> & dst, const std::vector<_Input> & src)
{
dst.resize( src.size() );
fwd( &dst[0],&src[0],src.size() );
if ( NumTraits<_Input>::IsComplex == 0 && HasFlag(HalfSpectrum) )
dst.resize( (src.size()>>1)+1);
else
dst.resize(src.size());
fwd(&dst[0],&src[0],src.size());
}
template<typename InputDerived, typename ComplexDerived>
@ -94,17 +119,18 @@ class FFT
fwd( &dst[0],&src[0],src.size() );
}
template <typename _Output>
void inv( _Output * dst, const Complex * src, int nfft)
void inv( Complex * dst, const Complex * src, int nfft)
{
m_impl.inv( dst,src,nfft );
if ( HasFlag( Unscaled ) == false)
scale(dst,1./nfft,nfft);
}
template <typename _Output>
void inv( std::vector<_Output> & dst, const std::vector<Complex> & src)
void inv( Scalar * dst, const Complex * src, int nfft)
{
dst.resize( src.size() );
inv( &dst[0],&src[0],src.size() );
m_impl.inv( dst,src,nfft );
if ( HasFlag( Unscaled ) == false)
scale(dst,1./nfft,nfft);
}
template<typename OutputDerived, typename ComplexDerived>
@ -117,10 +143,24 @@ class FFT
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
EIGEN_STATIC_ASSERT(int(OutputDerived::Flags)&int(ComplexDerived::Flags)&DirectAccessBit,
THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES)
dst.derived().resize( src.size() );
int nfft = src.size();
int nout = HasFlag(HalfSpectrum) ? ((nfft>>1)+1) : nfft;
dst.derived().resize( nout );
inv( &dst[0],&src[0],src.size() );
}
template <typename _Output>
void inv( std::vector<_Output> & dst, const std::vector<Complex> & src)
{
if ( NumTraits<_Output>::IsComplex == 0 && HasFlag(HalfSpectrum) )
dst.resize( 2*(src.size()-1) );
else
dst.resize( src.size() );
inv( &dst[0],&src[0],dst.size() );
}
// TODO: multi-dimensional FFTs
// TODO: handle Eigen MatrixBase
@ -128,7 +168,16 @@ class FFT
impl_type & impl() {return m_impl;}
private:
template <typename _It,typename _Val>
void scale(_It x,_Val s,int nx)
{
for (int k=0;k<nx;++k)
*x++ *= s;
}
impl_type m_impl;
int m_flag;
};
}
#endif

View File

@ -187,12 +187,6 @@
void inv(Complex * dst,const Complex *src,int nfft)
{
get_plan(nfft,true,dst,src).inv(ei_fftw_cast(dst), ei_fftw_cast(src),nfft );
//TODO move scaling to Eigen::FFT
// scaling
Scalar s = Scalar(1.)/nfft;
for (int k=0;k<nfft;++k)
dst[k] *= s;
}
// half-complex to scalar
@ -200,11 +194,6 @@
void inv( Scalar * dst,const Complex * src,int nfft)
{
get_plan(nfft,true,dst,src).inv(ei_fftw_cast(dst), ei_fftw_cast(src),nfft );
//TODO move scaling to Eigen::FFT
Scalar s = Scalar(1.)/nfft;
for (int k=0;k<nfft;++k)
dst[k] *= s;
}
protected:

View File

@ -334,7 +334,6 @@
void inv(Complex * dst,const Complex *src,int nfft)
{
get_plan(nfft,true).work(0, dst, src, 1,1);
scale(dst, nfft, Scalar(1)/nfft );
}
// half-complex to scalar
@ -362,7 +361,6 @@
m_tmpBuf[k] = fek + fok;
m_tmpBuf[ncfft-k] = conj(fek - fok);
}
scale(&m_tmpBuf[0], ncfft, Scalar(1)/nfft );
get_plan(ncfft,true).work(0, reinterpret_cast<Complex*>(dst), &m_tmpBuf[0], 1,1);
}
}
@ -403,12 +401,4 @@
}
return &twidref[0];
}
// TODO move scaling up into Eigen::FFT
inline
void scale(Complex *dst,int n,Scalar s)
{
for (int k=0;k<n;++k)
dst[k] *= s;
}
};