add gECC submodule and rename function

This commit is contained in:
Craig Raw 2025-10-24 12:47:13 +02:00
parent 494e0c6736
commit 9b15950c4d
5 changed files with 41 additions and 36 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@ duckdb_unittest_tempdir/
testext
test/python/__pycache__/
.Rhistory
duckdb-faiss-ext-0.12.0

5
.gitmodules vendored
View File

@ -5,4 +5,7 @@
[submodule "extension-ci-tools"]
path = extension-ci-tools
url = https://github.com/duckdb/extension-ci-tools
branch = main
branch = main
[submodule "gECC"]
path = gECC
url = git@github.com:sparrowwallet/gECC.git

1
gECC Submodule

@ -0,0 +1 @@
Subproject commit 0194307d631aa6edde9f469d7b6d155a5b5be00b

View File

@ -12,14 +12,14 @@
namespace duckdb {
struct MultiplyCompareBindData : public TableFunctionData {
MultiplyCompareBindData() {
struct CudaspScanBindData : public TableFunctionData {
CudaspScanBindData() {
}
static constexpr idx_t BATCH_SIZE = 10000; // Accumulate 10K rows before processing
};
struct MultiplyCompareLocalState : public LocalTableFunctionState {
MultiplyCompareLocalState() : finalized(false), output_position(0) {
struct CudaspScanLocalState : public LocalTableFunctionState {
CudaspScanLocalState() : finalized(false), output_position(0) {
}
bool finalized;
@ -35,8 +35,8 @@ struct MultiplyCompareLocalState : public LocalTableFunctionState {
idx_t output_position;
};
struct MultiplyCompareState : public GlobalTableFunctionState {
MultiplyCompareState() : currently_adding(0) {
struct CudaspScanState : public GlobalTableFunctionState {
CudaspScanState() : currently_adding(0) {
finalize_lock = make_uniq<std::mutex>();
}
@ -45,7 +45,7 @@ struct MultiplyCompareState : public GlobalTableFunctionState {
unique_ptr<std::mutex> finalize_lock;
};
static void AccumulateInput(MultiplyCompareLocalState &local_state, DataChunk &input) {
static void AccumulateInput(CudaspScanLocalState &local_state, DataChunk &input) {
idx_t count = input.size();
auto &value_column = input.data[0];
auto &list_column = input.data[1];
@ -93,7 +93,7 @@ static void AccumulateInput(MultiplyCompareLocalState &local_state, DataChunk &i
}
}
static void ProcessBatch(MultiplyCompareLocalState &local_state) {
static void ProcessBatch(CudaspScanLocalState &local_state) {
// This is where GPU processing would happen
// For now, we process on CPU
@ -135,38 +135,38 @@ static void ProcessBatch(MultiplyCompareLocalState &local_state) {
local_state.accumulated_list_lengths.clear();
}
static bool HasOutput(const MultiplyCompareLocalState &local_state) {
static bool HasOutput(const CudaspScanLocalState &local_state) {
return local_state.output_position < local_state.output_multiplied.size();
}
static bool ShouldProcessBatch(const MultiplyCompareLocalState &local_state) {
return local_state.accumulated_values.size() >= MultiplyCompareBindData::BATCH_SIZE;
static bool ShouldProcessBatch(const CudaspScanLocalState &local_state) {
return local_state.accumulated_values.size() >= CudaspScanBindData::BATCH_SIZE;
}
static unique_ptr<FunctionData> MultiplyCompareBind(ClientContext &context, TableFunctionBindInput &input,
static unique_ptr<FunctionData> CudaspScanBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
return_types.push_back(LogicalType::DOUBLE);
return_types.push_back(LogicalType::BOOLEAN);
names.push_back("multiplied_value");
names.push_back("is_in_list");
return make_uniq<MultiplyCompareBindData>();
return make_uniq<CudaspScanBindData>();
}
static unique_ptr<GlobalTableFunctionState> MultiplyCompareInit(ClientContext &context, TableFunctionInitInput &input) {
return make_uniq<MultiplyCompareState>();
static unique_ptr<GlobalTableFunctionState> CudaspScanInit(ClientContext &context, TableFunctionInitInput &input) {
return make_uniq<CudaspScanState>();
}
static unique_ptr<LocalTableFunctionState> MultiplyCompareLocalInit(ExecutionContext &context, TableFunctionInitInput &input,
static unique_ptr<LocalTableFunctionState> CudaspScanLocalInit(ExecutionContext &context, TableFunctionInitInput &input,
GlobalTableFunctionState *global_state) {
auto &state = global_state->Cast<MultiplyCompareState>();
auto &state = global_state->Cast<CudaspScanState>();
state.currently_adding++;
return make_uniq<MultiplyCompareLocalState>();
return make_uniq<CudaspScanLocalState>();
}
static OperatorResultType MultiplyCompareFunction(ExecutionContext &context, TableFunctionInput &data_p,
static OperatorResultType CudaspScanFunction(ExecutionContext &context, TableFunctionInput &data_p,
DataChunk &input, DataChunk &output) {
auto &local_state = data_p.local_state->Cast<MultiplyCompareLocalState>();
auto &local_state = data_p.local_state->Cast<CudaspScanLocalState>();
// If we have pending output from a previous batch, return it first
if (HasOutput(local_state)) {
@ -239,10 +239,10 @@ static OperatorResultType MultiplyCompareFunction(ExecutionContext &context, Tab
return OperatorResultType::NEED_MORE_INPUT;
}
static OperatorFinalizeResultType MultiplyCompareFinalFunction(ExecutionContext &context, TableFunctionInput &data_p,
static OperatorFinalizeResultType CudaspScanFinalFunction(ExecutionContext &context, TableFunctionInput &data_p,
DataChunk &output) {
auto &state = data_p.global_state->Cast<MultiplyCompareState>();
auto &local_state = data_p.local_state->Cast<MultiplyCompareLocalState>();
auto &state = data_p.global_state->Cast<CudaspScanState>();
auto &local_state = data_p.local_state->Cast<CudaspScanLocalState>();
// If we still have pending output from previous batch, return it
if (HasOutput(local_state)) {
@ -311,14 +311,14 @@ static OperatorFinalizeResultType MultiplyCompareFinalFunction(ExecutionContext
}
static void LoadInternal(ExtensionLoader &loader) {
TableFunctionSet multiply_compare("multiply_compare");
TableFunctionSet cudasp_scan("cudasp_scan");
TableFunction func({LogicalType::TABLE}, nullptr, MultiplyCompareBind, MultiplyCompareInit, MultiplyCompareLocalInit);
func.in_out_function = MultiplyCompareFunction;
func.in_out_function_final = MultiplyCompareFinalFunction;
TableFunction func({LogicalType::TABLE}, nullptr, CudaspScanBind, CudaspScanInit, CudaspScanLocalInit);
func.in_out_function = CudaspScanFunction;
func.in_out_function_final = CudaspScanFinalFunction;
multiply_compare.AddFunction(func);
loader.RegisterFunction(multiply_compare);
cudasp_scan.AddFunction(func);
loader.RegisterFunction(cudasp_scan);
}
void CudaspExtension::Load(ExtensionLoader &loader) {

View File

@ -5,7 +5,7 @@
# Require statement will ensure this test is run with this extension loaded
require cudasp
# Test the multiply_compare table function with small dataset
# Test the cudasp_scan table function with small dataset
statement ok
CREATE TABLE test_data(value DOUBLE, search_list DOUBLE[]);
@ -13,7 +13,7 @@ statement ok
INSERT INTO test_data VALUES (5.0, [10.0, 20.0, 15.0]), (3.0, [5.0, 7.0, 9.0]);
query I
SELECT multiplied_value FROM multiply_compare((SELECT value, search_list FROM test_data));
SELECT multiplied_value FROM cudasp_scan((SELECT value, search_list FROM test_data));
----
10.0
@ -24,12 +24,12 @@ SELECT CAST(column0 AS DOUBLE) AS value, CAST(column1 AS DOUBLE[]) AS search_lis
FROM repeat_row(5.0, [10.0, 20.0, 15.0], num_rows := 20000);
query I
SELECT COUNT(*) FROM multiply_compare((SELECT value, search_list FROM large_test));
SELECT COUNT(*) FROM cudasp_scan((SELECT value, search_list FROM large_test));
----
20000
query I
SELECT DISTINCT multiplied_value FROM multiply_compare((SELECT value, search_list FROM large_test));
SELECT DISTINCT multiplied_value FROM cudasp_scan((SELECT value, search_list FROM large_test));
----
10.0
@ -40,11 +40,11 @@ SELECT CAST(column0 AS DOUBLE) AS value, CAST(column1 AS DOUBLE[]) AS search_lis
FROM repeat_row(5.0, [10.0, 20.0, 15.0], num_rows := 200000);
query I
SELECT COUNT(*) FROM multiply_compare((SELECT value, search_list FROM very_large_test));
SELECT COUNT(*) FROM cudasp_scan((SELECT value, search_list FROM very_large_test));
----
200000
query I
SELECT DISTINCT multiplied_value FROM multiply_compare((SELECT value, search_list FROM very_large_test));
SELECT DISTINCT multiplied_value FROM cudasp_scan((SELECT value, search_list FROM very_large_test));
----
10.0