postgresql/contrib/tsearch2/ts_stat.c

422 lines
9.7 KiB
C
Raw Normal View History

2003-07-21 18:27:44 +08:00
/*
* stat functions
*/
#include "tsvector.h"
#include "ts_stat.h"
#include "funcapi.h"
#include "catalog/pg_type.h"
#include "executor/spi.h"
#include "common.h"
PG_FUNCTION_INFO_V1(tsstat_in);
Datum tsstat_in(PG_FUNCTION_ARGS);
Datum
tsstat_in(PG_FUNCTION_ARGS) {
tsstat *stat=palloc(STATHDRSIZE);
stat->len=STATHDRSIZE;
stat->size=0;
PG_RETURN_POINTER(stat);
}
PG_FUNCTION_INFO_V1(tsstat_out);
Datum tsstat_out(PG_FUNCTION_ARGS);
Datum
tsstat_out(PG_FUNCTION_ARGS) {
ereport(ERROR,
(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("tsstat_out not implemented")));
2003-07-21 18:27:44 +08:00
PG_RETURN_NULL();
}
static WordEntry**
SEI_realloc( WordEntry** in, uint32 *len ) {
if ( *len==0 || in==NULL ) {
*len=8;
in=palloc( sizeof(WordEntry*)* (*len) );
} else {
*len *= 2;
in=repalloc( in, sizeof(WordEntry*)* (*len) );
}
return in;
}
static int
compareStatWord(StatEntry *a, WordEntry *b, tsstat *stat, tsvector *txt) {
if ( a->len == b->len )
return strncmp(
STATSTRPTR(stat) + a->pos,
STRPTR(txt) + b->pos,
a->len
);
return ( a->len > b->len ) ? 1 : -1;
}
static tsstat*
formstat(tsstat *stat, tsvector *txt, WordEntry** entry, uint32 len) {
tsstat *newstat;
uint32 totallen, nentry;
uint32 slen=0;
WordEntry **ptr=entry;
char *curptr;
StatEntry *sptr,*nptr;
while(ptr-entry<len) {
slen += (*ptr)->len;
ptr++;
}
nentry=stat->size + len;
slen+=STATSTRSIZE(stat);
totallen=CALCSTATSIZE(nentry,slen);
newstat=palloc(totallen);
newstat->len=totallen;
newstat->size=nentry;
memcpy(STATSTRPTR(newstat), STATSTRPTR(stat), STATSTRSIZE(stat));
curptr=STATSTRPTR(newstat) + STATSTRSIZE(stat);
ptr=entry;
sptr=STATPTR(stat);
nptr=STATPTR(newstat);
if ( len == 1 ) {
StatEntry *StopLow = STATPTR(stat);
StatEntry *StopHigh = (StatEntry*)STATSTRPTR(stat);
while (StopLow < StopHigh) {
sptr=StopLow + (StopHigh - StopLow) / 2;
if ( compareStatWord(sptr,*ptr,stat,txt) < 0 )
StopLow = sptr + 1;
else
StopHigh = sptr;
}
nptr =STATPTR(newstat) + (StopLow-STATPTR(stat));
memcpy( STATPTR(newstat), STATPTR(stat), sizeof(StatEntry) * (StopLow-STATPTR(stat)) );
nptr->nentry=POSDATALEN(txt,*ptr);
if ( nptr->nentry==0 )
nptr->nentry=1;
nptr->ndoc=1;
nptr->len=(*ptr)->len;
memcpy(curptr, STRPTR(txt) + (*ptr)->pos, nptr->len);
nptr->pos = curptr - STATSTRPTR(newstat);
memcpy( nptr+1, StopLow, sizeof(StatEntry) * ( ((StatEntry*)STATSTRPTR(stat))-StopLow ) );
} else {
while( sptr-STATPTR(stat) < stat->size && ptr-entry<len) {
if ( compareStatWord(sptr,*ptr,stat,txt) < 0 ) {
memcpy(nptr, sptr, sizeof(StatEntry));
sptr++;
} else {
nptr->nentry=POSDATALEN(txt,*ptr);
if ( nptr->nentry==0 )
nptr->nentry=1;
nptr->ndoc=1;
nptr->len=(*ptr)->len;
memcpy(curptr, STRPTR(txt) + (*ptr)->pos, nptr->len);
nptr->pos = curptr - STATSTRPTR(newstat);
curptr += nptr->len;
ptr++;
}
nptr++;
}
memcpy( nptr, sptr, sizeof(StatEntry)*( stat->size - (sptr-STATPTR(stat)) ) );
while(ptr-entry<len) {
nptr->nentry=POSDATALEN(txt,*ptr);
if ( nptr->nentry==0 )
nptr->nentry=1;
nptr->ndoc=1;
nptr->len=(*ptr)->len;
memcpy(curptr, STRPTR(txt) + (*ptr)->pos, nptr->len);
nptr->pos = curptr - STATSTRPTR(newstat);
curptr += nptr->len;
ptr++; nptr++;
}
}
return newstat;
}
PG_FUNCTION_INFO_V1(ts_accum);
Datum ts_accum(PG_FUNCTION_ARGS);
Datum
ts_accum(PG_FUNCTION_ARGS) {
tsstat *newstat,*stat= (tsstat*)PG_GETARG_POINTER(0);
tsvector *txt = (tsvector *) PG_DETOAST_DATUM(PG_GETARG_DATUM(1));
WordEntry **newentry=NULL;
uint32 len=0, cur=0;
StatEntry *sptr;
WordEntry *wptr;
if ( stat==NULL || PG_ARGISNULL(0) ) { /* Init in first */
stat=palloc(STATHDRSIZE);
stat->len=STATHDRSIZE;
stat->size=0;
}
/* simple check of correctness */
if ( txt==NULL || PG_ARGISNULL(1) || txt->size==0 ) {
PG_FREE_IF_COPY(txt,1);
PG_RETURN_POINTER(stat);
}
sptr=STATPTR(stat);
wptr=ARRPTR(txt);
if ( stat->size < 100*txt->size ) { /* merge */
while( sptr-STATPTR(stat) < stat->size && wptr-ARRPTR(txt) < txt->size ) {
int cmp = compareStatWord(sptr,wptr,stat,txt);
if ( cmp<0 ) {
sptr++;
} else if ( cmp==0 ) {
int n=POSDATALEN(txt,wptr);
if (n==0) n=1;
sptr->ndoc++;
sptr->nentry +=n ;
sptr++; wptr++;
} else {
if ( cur==len )
newentry=SEI_realloc(newentry, &len);
newentry[cur]=wptr;
wptr++; cur++;
}
}
while( wptr-ARRPTR(txt) < txt->size ) {
if ( cur==len )
newentry=SEI_realloc(newentry, &len);
newentry[cur]=wptr;
wptr++; cur++;
}
} else { /* search */
while( wptr-ARRPTR(txt) < txt->size ) {
StatEntry *StopLow = STATPTR(stat);
StatEntry *StopHigh = (StatEntry*)STATSTRPTR(stat);
int cmp;
while (StopLow < StopHigh) {
sptr=StopLow + (StopHigh - StopLow) / 2;
cmp = compareStatWord(sptr,wptr,stat,txt);
if (cmp==0) {
int n=POSDATALEN(txt,wptr);
if (n==0) n=1;
sptr->ndoc++;
sptr->nentry +=n ;
break;
} else if ( cmp < 0 )
StopLow = sptr + 1;
else
StopHigh = sptr;
}
if ( StopLow >= StopHigh ) { /* not found */
if ( cur==len )
newentry=SEI_realloc(newentry, &len);
newentry[cur]=wptr;
cur++;
}
wptr++;
}
}
if ( cur==0 ) { /* no new words */
PG_FREE_IF_COPY(txt,1);
PG_RETURN_POINTER(stat);
}
newstat = formstat(stat, txt, newentry, cur);
pfree(newentry);
PG_FREE_IF_COPY(txt,1);
/* pfree(stat); */
PG_RETURN_POINTER(newstat);
}
typedef struct {
uint32 cur;
tsvector *stat;
} StatStorage;
static void
ts_setup_firstcall(FuncCallContext *funcctx, tsstat *stat) {
TupleDesc tupdesc;
MemoryContext oldcontext;
StatStorage *st;
oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
st=palloc( sizeof(StatStorage) );
st->cur=0;
st->stat=palloc( stat->len );
memcpy(st->stat, stat, stat->len);
funcctx->user_fctx = (void*)st;
tupdesc = RelationNameGetTupleDesc("statinfo");
funcctx->slot = TupleDescGetSlot(tupdesc);
funcctx->attinmeta = TupleDescGetAttInMetadata(tupdesc);
MemoryContextSwitchTo(oldcontext);
}
static Datum
ts_process_call(FuncCallContext *funcctx) {
StatStorage *st;
st=(StatStorage*)funcctx->user_fctx;
if ( st->cur < st->stat->size ) {
Datum result;
char* values[3];
char ndoc[16];
char nentry[16];
StatEntry *entry=STATPTR(st->stat) + st->cur;
HeapTuple tuple;
values[1]=ndoc;
sprintf(ndoc,"%d",entry->ndoc);
values[2]=nentry;
sprintf(nentry,"%d",entry->nentry);
values[0]=palloc( entry->len+1 );
memcpy( values[0], STATSTRPTR(st->stat)+entry->pos, entry->len);
(values[0])[entry->len]='\0';
tuple = BuildTupleFromCStrings(funcctx->attinmeta, values);
result = TupleGetDatum(funcctx->slot, tuple);
pfree(values[0]);
st->cur++;
return result;
} else {
pfree(st->stat);
pfree(st);
}
return (Datum)0;
}
PG_FUNCTION_INFO_V1(ts_accum_finish);
Datum ts_accum_finish(PG_FUNCTION_ARGS);
Datum
ts_accum_finish(PG_FUNCTION_ARGS) {
FuncCallContext *funcctx;
Datum result;
if (SRF_IS_FIRSTCALL()) {
funcctx = SRF_FIRSTCALL_INIT();
ts_setup_firstcall(funcctx, (tsstat*)PG_GETARG_POINTER(0) );
}
funcctx = SRF_PERCALL_SETUP();
if ( (result=ts_process_call(funcctx)) != (Datum)0 )
SRF_RETURN_NEXT(funcctx, result);
SRF_RETURN_DONE(funcctx);
}
static Oid tiOid=InvalidOid;
static void
get_ti_Oid(void) {
int ret;
bool isnull;
if ( (ret = SPI_exec("select oid from pg_type where typname='tsvector'",1)) < 0 )
/* internal error */
2003-07-21 18:27:44 +08:00
elog(ERROR, "SPI_exec to get tsvector oid returns %d", ret);
if ( SPI_processed<0 )
/* internal error */
2003-07-21 18:27:44 +08:00
elog(ERROR, "There is no tsvector type");
tiOid = DatumGetObjectId( SPI_getbinval(SPI_tuptable->vals[0], SPI_tuptable->tupdesc, 1, &isnull) );
if ( tiOid==InvalidOid )
/* internal error */
2003-07-21 18:27:44 +08:00
elog(ERROR, "tsvector type has InvalidOid");
}
static tsstat*
ts_stat_sql(text *txt) {
char *query=text2char(txt);
int i;
tsstat *newstat,*stat;
bool isnull;
Portal portal;
void *plan;
if ( tiOid==InvalidOid )
get_ti_Oid();
if ( (plan = SPI_prepare(query,0,NULL))==NULL )
/* internal error */
2003-07-21 18:27:44 +08:00
elog(ERROR, "SPI_prepare('%s') returns NULL",query);
if ( (portal = SPI_cursor_open(NULL, plan, NULL, NULL)) == NULL )
/* internal error */
2003-07-21 18:27:44 +08:00
elog(ERROR, "SPI_cursor_open('%s') returns NULL",query);
SPI_cursor_fetch(portal, true, 100);
if ( SPI_tuptable->tupdesc->natts != 1 )
/* internal error */
elog(ERROR, "number of fields doesn't equal to 1");
2003-07-21 18:27:44 +08:00
if ( SPI_gettypeid(SPI_tuptable->tupdesc, 1) != tiOid )
/* internal error */
elog(ERROR, "column isn't of tsvector type");
2003-07-21 18:27:44 +08:00
stat=palloc(STATHDRSIZE);
stat->len=STATHDRSIZE;
stat->size=0;
while(SPI_processed>0) {
for(i=0;i<SPI_processed;i++) {
Datum data=SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 1, &isnull);
if ( !isnull ) {
newstat = (tsstat*)DatumGetPointer(DirectFunctionCall2(
ts_accum,
PointerGetDatum(stat),
data
));
if ( stat!=newstat && stat )
pfree(stat);
stat=newstat;
}
}
SPI_freetuptable(SPI_tuptable);
SPI_cursor_fetch(portal, true, 100);
}
SPI_freetuptable(SPI_tuptable);
SPI_cursor_close(portal);
SPI_freeplan(plan);
pfree(query);
return stat;
}
PG_FUNCTION_INFO_V1(ts_stat);
Datum ts_stat(PG_FUNCTION_ARGS);
Datum
ts_stat(PG_FUNCTION_ARGS) {
FuncCallContext *funcctx;
Datum result;
if (SRF_IS_FIRSTCALL()) {
tsstat *stat;
text *txt=PG_GETARG_TEXT_P(0);
funcctx = SRF_FIRSTCALL_INIT();
SPI_connect();
stat = ts_stat_sql(txt);
PG_FREE_IF_COPY(txt,0);
ts_setup_firstcall(funcctx, stat );
SPI_finish();
}
funcctx = SRF_PERCALL_SETUP();
if ( (result=ts_process_call(funcctx)) != (Datum)0 )
SRF_RETURN_NEXT(funcctx, result);
SRF_RETURN_DONE(funcctx);
}