add gECC submodule and rename function
This commit is contained in:
parent
494e0c6736
commit
9b15950c4d
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,3 +6,4 @@ duckdb_unittest_tempdir/
|
||||
testext
|
||||
test/python/__pycache__/
|
||||
.Rhistory
|
||||
duckdb-faiss-ext-0.12.0
|
||||
|
||||
5
.gitmodules
vendored
5
.gitmodules
vendored
@ -5,4 +5,7 @@
|
||||
[submodule "extension-ci-tools"]
|
||||
path = extension-ci-tools
|
||||
url = https://github.com/duckdb/extension-ci-tools
|
||||
branch = main
|
||||
branch = main
|
||||
[submodule "gECC"]
|
||||
path = gECC
|
||||
url = git@github.com:sparrowwallet/gECC.git
|
||||
|
||||
1
gECC
Submodule
1
gECC
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 0194307d631aa6edde9f469d7b6d155a5b5be00b
|
||||
@ -12,14 +12,14 @@
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
struct MultiplyCompareBindData : public TableFunctionData {
|
||||
MultiplyCompareBindData() {
|
||||
struct CudaspScanBindData : public TableFunctionData {
|
||||
CudaspScanBindData() {
|
||||
}
|
||||
static constexpr idx_t BATCH_SIZE = 10000; // Accumulate 10K rows before processing
|
||||
};
|
||||
|
||||
struct MultiplyCompareLocalState : public LocalTableFunctionState {
|
||||
MultiplyCompareLocalState() : finalized(false), output_position(0) {
|
||||
struct CudaspScanLocalState : public LocalTableFunctionState {
|
||||
CudaspScanLocalState() : finalized(false), output_position(0) {
|
||||
}
|
||||
bool finalized;
|
||||
|
||||
@ -35,8 +35,8 @@ struct MultiplyCompareLocalState : public LocalTableFunctionState {
|
||||
idx_t output_position;
|
||||
};
|
||||
|
||||
struct MultiplyCompareState : public GlobalTableFunctionState {
|
||||
MultiplyCompareState() : currently_adding(0) {
|
||||
struct CudaspScanState : public GlobalTableFunctionState {
|
||||
CudaspScanState() : currently_adding(0) {
|
||||
finalize_lock = make_uniq<std::mutex>();
|
||||
}
|
||||
|
||||
@ -45,7 +45,7 @@ struct MultiplyCompareState : public GlobalTableFunctionState {
|
||||
unique_ptr<std::mutex> finalize_lock;
|
||||
};
|
||||
|
||||
static void AccumulateInput(MultiplyCompareLocalState &local_state, DataChunk &input) {
|
||||
static void AccumulateInput(CudaspScanLocalState &local_state, DataChunk &input) {
|
||||
idx_t count = input.size();
|
||||
auto &value_column = input.data[0];
|
||||
auto &list_column = input.data[1];
|
||||
@ -93,7 +93,7 @@ static void AccumulateInput(MultiplyCompareLocalState &local_state, DataChunk &i
|
||||
}
|
||||
}
|
||||
|
||||
static void ProcessBatch(MultiplyCompareLocalState &local_state) {
|
||||
static void ProcessBatch(CudaspScanLocalState &local_state) {
|
||||
// This is where GPU processing would happen
|
||||
// For now, we process on CPU
|
||||
|
||||
@ -135,38 +135,38 @@ static void ProcessBatch(MultiplyCompareLocalState &local_state) {
|
||||
local_state.accumulated_list_lengths.clear();
|
||||
}
|
||||
|
||||
static bool HasOutput(const MultiplyCompareLocalState &local_state) {
|
||||
static bool HasOutput(const CudaspScanLocalState &local_state) {
|
||||
return local_state.output_position < local_state.output_multiplied.size();
|
||||
}
|
||||
|
||||
static bool ShouldProcessBatch(const MultiplyCompareLocalState &local_state) {
|
||||
return local_state.accumulated_values.size() >= MultiplyCompareBindData::BATCH_SIZE;
|
||||
static bool ShouldProcessBatch(const CudaspScanLocalState &local_state) {
|
||||
return local_state.accumulated_values.size() >= CudaspScanBindData::BATCH_SIZE;
|
||||
}
|
||||
|
||||
static unique_ptr<FunctionData> MultiplyCompareBind(ClientContext &context, TableFunctionBindInput &input,
|
||||
static unique_ptr<FunctionData> CudaspScanBind(ClientContext &context, TableFunctionBindInput &input,
|
||||
vector<LogicalType> &return_types, vector<string> &names) {
|
||||
return_types.push_back(LogicalType::DOUBLE);
|
||||
return_types.push_back(LogicalType::BOOLEAN);
|
||||
names.push_back("multiplied_value");
|
||||
names.push_back("is_in_list");
|
||||
|
||||
return make_uniq<MultiplyCompareBindData>();
|
||||
return make_uniq<CudaspScanBindData>();
|
||||
}
|
||||
|
||||
static unique_ptr<GlobalTableFunctionState> MultiplyCompareInit(ClientContext &context, TableFunctionInitInput &input) {
|
||||
return make_uniq<MultiplyCompareState>();
|
||||
static unique_ptr<GlobalTableFunctionState> CudaspScanInit(ClientContext &context, TableFunctionInitInput &input) {
|
||||
return make_uniq<CudaspScanState>();
|
||||
}
|
||||
|
||||
static unique_ptr<LocalTableFunctionState> MultiplyCompareLocalInit(ExecutionContext &context, TableFunctionInitInput &input,
|
||||
static unique_ptr<LocalTableFunctionState> CudaspScanLocalInit(ExecutionContext &context, TableFunctionInitInput &input,
|
||||
GlobalTableFunctionState *global_state) {
|
||||
auto &state = global_state->Cast<MultiplyCompareState>();
|
||||
auto &state = global_state->Cast<CudaspScanState>();
|
||||
state.currently_adding++;
|
||||
return make_uniq<MultiplyCompareLocalState>();
|
||||
return make_uniq<CudaspScanLocalState>();
|
||||
}
|
||||
|
||||
static OperatorResultType MultiplyCompareFunction(ExecutionContext &context, TableFunctionInput &data_p,
|
||||
static OperatorResultType CudaspScanFunction(ExecutionContext &context, TableFunctionInput &data_p,
|
||||
DataChunk &input, DataChunk &output) {
|
||||
auto &local_state = data_p.local_state->Cast<MultiplyCompareLocalState>();
|
||||
auto &local_state = data_p.local_state->Cast<CudaspScanLocalState>();
|
||||
|
||||
// If we have pending output from a previous batch, return it first
|
||||
if (HasOutput(local_state)) {
|
||||
@ -239,10 +239,10 @@ static OperatorResultType MultiplyCompareFunction(ExecutionContext &context, Tab
|
||||
return OperatorResultType::NEED_MORE_INPUT;
|
||||
}
|
||||
|
||||
static OperatorFinalizeResultType MultiplyCompareFinalFunction(ExecutionContext &context, TableFunctionInput &data_p,
|
||||
static OperatorFinalizeResultType CudaspScanFinalFunction(ExecutionContext &context, TableFunctionInput &data_p,
|
||||
DataChunk &output) {
|
||||
auto &state = data_p.global_state->Cast<MultiplyCompareState>();
|
||||
auto &local_state = data_p.local_state->Cast<MultiplyCompareLocalState>();
|
||||
auto &state = data_p.global_state->Cast<CudaspScanState>();
|
||||
auto &local_state = data_p.local_state->Cast<CudaspScanLocalState>();
|
||||
|
||||
// If we still have pending output from previous batch, return it
|
||||
if (HasOutput(local_state)) {
|
||||
@ -311,14 +311,14 @@ static OperatorFinalizeResultType MultiplyCompareFinalFunction(ExecutionContext
|
||||
}
|
||||
|
||||
static void LoadInternal(ExtensionLoader &loader) {
|
||||
TableFunctionSet multiply_compare("multiply_compare");
|
||||
TableFunctionSet cudasp_scan("cudasp_scan");
|
||||
|
||||
TableFunction func({LogicalType::TABLE}, nullptr, MultiplyCompareBind, MultiplyCompareInit, MultiplyCompareLocalInit);
|
||||
func.in_out_function = MultiplyCompareFunction;
|
||||
func.in_out_function_final = MultiplyCompareFinalFunction;
|
||||
TableFunction func({LogicalType::TABLE}, nullptr, CudaspScanBind, CudaspScanInit, CudaspScanLocalInit);
|
||||
func.in_out_function = CudaspScanFunction;
|
||||
func.in_out_function_final = CudaspScanFinalFunction;
|
||||
|
||||
multiply_compare.AddFunction(func);
|
||||
loader.RegisterFunction(multiply_compare);
|
||||
cudasp_scan.AddFunction(func);
|
||||
loader.RegisterFunction(cudasp_scan);
|
||||
}
|
||||
|
||||
void CudaspExtension::Load(ExtensionLoader &loader) {
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
# Require statement will ensure this test is run with this extension loaded
|
||||
require cudasp
|
||||
|
||||
# Test the multiply_compare table function with small dataset
|
||||
# Test the cudasp_scan table function with small dataset
|
||||
statement ok
|
||||
CREATE TABLE test_data(value DOUBLE, search_list DOUBLE[]);
|
||||
|
||||
@ -13,7 +13,7 @@ statement ok
|
||||
INSERT INTO test_data VALUES (5.0, [10.0, 20.0, 15.0]), (3.0, [5.0, 7.0, 9.0]);
|
||||
|
||||
query I
|
||||
SELECT multiplied_value FROM multiply_compare((SELECT value, search_list FROM test_data));
|
||||
SELECT multiplied_value FROM cudasp_scan((SELECT value, search_list FROM test_data));
|
||||
----
|
||||
10.0
|
||||
|
||||
@ -24,12 +24,12 @@ SELECT CAST(column0 AS DOUBLE) AS value, CAST(column1 AS DOUBLE[]) AS search_lis
|
||||
FROM repeat_row(5.0, [10.0, 20.0, 15.0], num_rows := 20000);
|
||||
|
||||
query I
|
||||
SELECT COUNT(*) FROM multiply_compare((SELECT value, search_list FROM large_test));
|
||||
SELECT COUNT(*) FROM cudasp_scan((SELECT value, search_list FROM large_test));
|
||||
----
|
||||
20000
|
||||
|
||||
query I
|
||||
SELECT DISTINCT multiplied_value FROM multiply_compare((SELECT value, search_list FROM large_test));
|
||||
SELECT DISTINCT multiplied_value FROM cudasp_scan((SELECT value, search_list FROM large_test));
|
||||
----
|
||||
10.0
|
||||
|
||||
@ -40,11 +40,11 @@ SELECT CAST(column0 AS DOUBLE) AS value, CAST(column1 AS DOUBLE[]) AS search_lis
|
||||
FROM repeat_row(5.0, [10.0, 20.0, 15.0], num_rows := 200000);
|
||||
|
||||
query I
|
||||
SELECT COUNT(*) FROM multiply_compare((SELECT value, search_list FROM very_large_test));
|
||||
SELECT COUNT(*) FROM cudasp_scan((SELECT value, search_list FROM very_large_test));
|
||||
----
|
||||
200000
|
||||
|
||||
query I
|
||||
SELECT DISTINCT multiplied_value FROM multiply_compare((SELECT value, search_list FROM very_large_test));
|
||||
SELECT DISTINCT multiplied_value FROM cudasp_scan((SELECT value, search_list FROM very_large_test));
|
||||
----
|
||||
10.0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user