Implement SDSDOT with DSDOT and avoid allocating buffers in DSDOT.

This commit is contained in:
Chen-Pang He 2012-09-08 02:06:45 +08:00
parent b0b9b4d6b2
commit 1b61aadcbe
2 changed files with 10 additions and 32 deletions

View File

@ -19,25 +19,15 @@
#include "level2_real_impl.h"
#include "level3_impl.h"
// currently used by DSDOT only
double* cast_vector_to_double(float* x, int n, int incx)
double BLASFUNC(dsdot)(int* n, float* x, int* incx, float* y, int* incy)
{
double* ret = new double[n];
if(incx<0) vector(ret,n) = vector(x,n,-incx).reverse().cast<double>();
else vector(ret,n) = vector(x,n, incx).cast<double>();
return ret;
}
double BLASFUNC(dsdot)(int* n, float* px, int* incx, float* py, int* incy)
{
if(*n <= 0) return 0;
double* x = cast_vector_to_double(px, *n, *incx);
double* y = cast_vector_to_double(py, *n, *incy);
double res = vector(x,*n).cwiseProduct(vector(y,*n)).sum();
delete[] x;
delete[] y;
return res;
if(*n<=0) return 0;
if(*incx==1 && *incy==1) return (vector(x,*n).cast<double>().cwiseProduct(vector(y,*n).cast<double>())).sum();
else if(*incx>0 && *incy>0) return (vector(x,*n,*incx).cast<double>().cwiseProduct(vector(y,*n,*incy).cast<double>())).sum();
else if(*incx<0 && *incy>0) return (vector(x,*n,-*incx).reverse().cast<double>().cwiseProduct(vector(y,*n,*incy).cast<double>())).sum();
else if(*incx>0 && *incy<0) return (vector(x,*n,*incx).cast<double>().cwiseProduct(vector(y,*n,-*incy).reverse().cast<double>())).sum();
else if(*incx<0 && *incy<0) return (vector(x,*n,-*incx).reverse().cast<double>().cwiseProduct(vector(y,*n,-*incy).reverse().cast<double>())).sum();
else return 0;
}

View File

@ -2,7 +2,6 @@
// for linear algebra.
//
// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@ -20,15 +19,4 @@
#include "level3_impl.h"
float BLASFUNC(sdsdot)(int* n, float* alpha, float* x, int* incx, float* y, int* incy)
{
float res = *alpha;
if(*n>0) {
if(*incx==1 && *incy==1) res += (vector(x,*n).cwiseProduct(vector(y,*n))).sum();
else if(*incx>0 && *incy>0) res += (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum();
else if(*incx<0 && *incy>0) res += (vector(x,*n,-*incx).reverse().cwiseProduct(vector(y,*n,*incy))).sum();
else if(*incx>0 && *incy<0) res += (vector(x,*n,*incx).cwiseProduct(vector(y,*n,-*incy).reverse())).sum();
else if(*incx<0 && *incy<0) res += (vector(x,*n,-*incx).reverse().cwiseProduct(vector(y,*n,-*incy).reverse())).sum();
}
return res;
}
{ return *alpha + BLASFUNC(dsdot)(n, x, incx, y, incy); }