fix synchronization issue
fix synchronization issue 2 fix synchronization issue 3 fix synchronization issue 4 fix synchronization issue 5 fix synchronization issue 6 fix synchronization issue 7 fix synchronization issue 8 fix synchronization issue 9 fix synchronization issue 10 fix synchronization issue 11 fix synchronization issue 12 fix synchronization issue 13 fix synchronization issue 14 fix synchronization issue 15 fix synchronization issue 16 fix synchronization issue 17 fix synchronization issue 18 fix synchronization issue 19 fix synchronization issue 20 fix synchronization issue 21 fix synchronization issue 22 fix synchronization issue 23
This commit is contained in:
parent
56f5ec8872
commit
3b0708bc72
2
gECC
2
gECC
@ -1 +1 @@
|
||||
Subproject commit f3ab474f24d0e375bc2fa41dd480525506ceaa8a
|
||||
Subproject commit 5622411e7f62c90a75d4c69b8e3ff8c8b24f2c39
|
||||
@ -104,9 +104,10 @@ struct CudaspScanBindData : public TableFunctionData {
|
||||
};
|
||||
|
||||
struct CudaspScanLocalState : public LocalTableFunctionState {
|
||||
CudaspScanLocalState() : finalized(false), output_position(0) {
|
||||
CudaspScanLocalState() : finalized(false), is_output_thread(false) {
|
||||
}
|
||||
bool finalized;
|
||||
bool is_output_thread; // True if this thread is responsible for returning output
|
||||
|
||||
// Per-thread accumulated input data
|
||||
vector<std::string> accumulated_txids; // Transaction IDs (BLOB) - owned copies
|
||||
@ -115,28 +116,27 @@ struct CudaspScanLocalState : public LocalTableFunctionState {
|
||||
vector<int64_t> accumulated_outputs; // Flattened output values (BIGINT)
|
||||
vector<idx_t> accumulated_output_offsets; // Offset into accumulated_outputs for each row
|
||||
vector<idx_t> accumulated_output_lengths; // Length of each outputs list
|
||||
|
||||
// Per-thread processed output data (only rows with matches)
|
||||
vector<std::string> output_txids; // Owned copies
|
||||
vector<int32_t> output_heights;
|
||||
vector<std::string> output_tweak_keys; // Owned copies
|
||||
idx_t output_position;
|
||||
};
|
||||
|
||||
struct CudaspScanState : public GlobalTableFunctionState {
|
||||
CudaspScanState() : currently_adding(0) {
|
||||
CudaspScanState() : currently_adding(0), output_position(0), output_thread_claimed(false) {
|
||||
finalize_lock = make_uniq<std::mutex>();
|
||||
}
|
||||
|
||||
// Limit to single thread to prevent duplicate results from parallel execution
|
||||
// GPU parallelism with thousands of CUDA threads provides sufficient performance
|
||||
idx_t MaxThreads() const override {
|
||||
return 1;
|
||||
output_lock = make_uniq<std::mutex>();
|
||||
}
|
||||
|
||||
// Thread synchronization
|
||||
std::atomic_uint64_t currently_adding;
|
||||
unique_ptr<std::mutex> finalize_lock;
|
||||
|
||||
// Global output storage - all threads write here
|
||||
unique_ptr<std::mutex> output_lock;
|
||||
vector<string> output_txids;
|
||||
vector<int32_t> output_heights;
|
||||
vector<string> output_tweak_keys;
|
||||
idx_t output_position;
|
||||
|
||||
// Only one thread returns output to avoid batch index conflicts
|
||||
std::atomic<bool> output_thread_claimed;
|
||||
};
|
||||
|
||||
static void AccumulateInput(CudaspScanLocalState &local_state, DataChunk &input) {
|
||||
@ -211,13 +211,7 @@ static void AccumulateInput(CudaspScanLocalState &local_state, DataChunk &input)
|
||||
}
|
||||
}
|
||||
|
||||
static void ProcessBatch(CudaspScanLocalState &local_state, const CudaspScanBindData &bind_data) {
|
||||
// Clear any previous output
|
||||
local_state.output_txids.clear();
|
||||
local_state.output_heights.clear();
|
||||
local_state.output_tweak_keys.clear();
|
||||
local_state.output_position = 0;
|
||||
|
||||
static void ProcessBatch(CudaspScanLocalState &local_state, const CudaspScanBindData &bind_data, CudaspScanState &global_state) {
|
||||
idx_t batch_size = local_state.accumulated_txids.size();
|
||||
if (batch_size == 0) {
|
||||
return;
|
||||
@ -362,21 +356,19 @@ static void ProcessBatch(CudaspScanLocalState &local_state, const CudaspScanBind
|
||||
);
|
||||
|
||||
if (kernel_result == 0) {
|
||||
// Ensure all GPU writes to managed_match_flags are visible to CPU
|
||||
cudaDeviceSynchronize();
|
||||
// RunBatchScanKernels already synchronizes the stream before returning
|
||||
// managed_match_flags is now safe to read from CPU
|
||||
|
||||
// Build output for matching rows
|
||||
idx_t match_count = 0;
|
||||
std::vector<int32_t> matched_heights;
|
||||
// Build output for matching rows - write to global state with locking
|
||||
global_state.output_lock->lock();
|
||||
for (idx_t i = 0; i < batch_size; i++) {
|
||||
if (managed_match_flags[i]) {
|
||||
local_state.output_txids.push_back(local_state.accumulated_txids[i]);
|
||||
local_state.output_heights.push_back(local_state.accumulated_heights[i]);
|
||||
local_state.output_tweak_keys.push_back(local_state.accumulated_tweak_keys[i]);
|
||||
matched_heights.push_back(local_state.accumulated_heights[i]);
|
||||
match_count++;
|
||||
global_state.output_txids.push_back(local_state.accumulated_txids[i]);
|
||||
global_state.output_heights.push_back(local_state.accumulated_heights[i]);
|
||||
global_state.output_tweak_keys.push_back(local_state.accumulated_tweak_keys[i]);
|
||||
}
|
||||
}
|
||||
global_state.output_lock->unlock();
|
||||
}
|
||||
|
||||
// Cleanup managed memory
|
||||
@ -397,8 +389,8 @@ static void ProcessBatch(CudaspScanLocalState &local_state, const CudaspScanBind
|
||||
local_state.accumulated_output_lengths.clear();
|
||||
}
|
||||
|
||||
static bool HasOutput(const CudaspScanLocalState &local_state) {
|
||||
return local_state.output_position < local_state.output_txids.size();
|
||||
static bool HasOutput(const CudaspScanState &global_state) {
|
||||
return global_state.output_position < global_state.output_txids.size();
|
||||
}
|
||||
|
||||
static bool ShouldProcessBatch(const CudaspScanLocalState &local_state, const CudaspScanBindData &bind_data) {
|
||||
@ -510,90 +502,20 @@ static unique_ptr<LocalTableFunctionState> CudaspScanLocalInit(ExecutionContext
|
||||
static OperatorResultType CudaspScanFunction(ExecutionContext &context, TableFunctionInput &data_p,
|
||||
DataChunk &input, DataChunk &output) {
|
||||
auto &bind_data = data_p.bind_data->Cast<CudaspScanBindData>();
|
||||
auto &global_state = data_p.global_state->Cast<CudaspScanState>();
|
||||
auto &local_state = data_p.local_state->Cast<CudaspScanLocalState>();
|
||||
|
||||
// If we have pending output from a previous batch, return it first
|
||||
if (HasOutput(local_state)) {
|
||||
auto &txid_result = output.data[0];
|
||||
auto &height_result = output.data[1];
|
||||
auto &tweak_key_result = output.data[2];
|
||||
|
||||
idx_t output_count = MinValue<idx_t>(STANDARD_VECTOR_SIZE,
|
||||
local_state.output_txids.size() - local_state.output_position);
|
||||
|
||||
auto txid_data = FlatVector::GetData<string_t>(txid_result);
|
||||
auto height_data = FlatVector::GetData<int32_t>(height_result);
|
||||
auto tweak_key_data = FlatVector::GetData<string_t>(tweak_key_result);
|
||||
|
||||
for (idx_t i = 0; i < output_count; i++) {
|
||||
auto &txid = local_state.output_txids[local_state.output_position + i];
|
||||
auto &tweak_key = local_state.output_tweak_keys[local_state.output_position + i];
|
||||
txid_data[i] = StringVector::AddStringOrBlob(txid_result, string_t(txid.data(), txid.size()));
|
||||
height_data[i] = local_state.output_heights[local_state.output_position + i];
|
||||
tweak_key_data[i] = StringVector::AddStringOrBlob(tweak_key_result, string_t(tweak_key.data(), tweak_key.size()));
|
||||
}
|
||||
|
||||
output.SetCardinality(output_count);
|
||||
local_state.output_position += output_count;
|
||||
|
||||
// If we still have more output, signal that
|
||||
if (HasOutput(local_state)) {
|
||||
return OperatorResultType::HAVE_MORE_OUTPUT;
|
||||
}
|
||||
|
||||
// All output returned, clear buffers
|
||||
local_state.output_txids.clear();
|
||||
local_state.output_heights.clear();
|
||||
local_state.output_tweak_keys.clear();
|
||||
local_state.output_position = 0;
|
||||
|
||||
// Otherwise keep accepting input
|
||||
return OperatorResultType::NEED_MORE_INPUT;
|
||||
}
|
||||
|
||||
// Process new input
|
||||
// Accumulate input
|
||||
if (input.size() > 0) {
|
||||
AccumulateInput(local_state, input);
|
||||
|
||||
// Signal that we've consumed the input
|
||||
input.SetCardinality(0);
|
||||
|
||||
// Process batch if we've accumulated enough data
|
||||
// Process batch immediately when full, but don't return output
|
||||
// Output goes to global state for finalize to return
|
||||
if (ShouldProcessBatch(local_state, bind_data)) {
|
||||
ProcessBatch(local_state, bind_data);
|
||||
local_state.output_position = 0;
|
||||
|
||||
// Write output immediately
|
||||
auto &txid_result = output.data[0];
|
||||
auto &height_result = output.data[1];
|
||||
auto &tweak_key_result = output.data[2];
|
||||
|
||||
idx_t output_count = MinValue<idx_t>(STANDARD_VECTOR_SIZE, local_state.output_txids.size());
|
||||
|
||||
auto txid_data = FlatVector::GetData<string_t>(txid_result);
|
||||
auto height_data = FlatVector::GetData<int32_t>(height_result);
|
||||
auto tweak_key_data = FlatVector::GetData<string_t>(tweak_key_result);
|
||||
|
||||
for (idx_t i = 0; i < output_count; i++) {
|
||||
auto &txid = local_state.output_txids[i];
|
||||
auto &tweak_key = local_state.output_tweak_keys[i];
|
||||
txid_data[i] = StringVector::AddStringOrBlob(txid_result, string_t(txid.data(), txid.size()));
|
||||
height_data[i] = local_state.output_heights[i];
|
||||
tweak_key_data[i] = StringVector::AddStringOrBlob(tweak_key_result, string_t(tweak_key.data(), tweak_key.size()));
|
||||
}
|
||||
|
||||
output.SetCardinality(output_count);
|
||||
local_state.output_position = output_count;
|
||||
|
||||
// We just wrote data to output, so we MUST return HAVE_MORE_OUTPUT
|
||||
return OperatorResultType::HAVE_MORE_OUTPUT;
|
||||
ProcessBatch(local_state, bind_data, global_state);
|
||||
}
|
||||
|
||||
// Keep accumulating
|
||||
return OperatorResultType::NEED_MORE_INPUT;
|
||||
}
|
||||
|
||||
// No more input - should not reach here, finalize handles remaining data
|
||||
return OperatorResultType::NEED_MORE_INPUT;
|
||||
}
|
||||
|
||||
@ -603,39 +525,9 @@ static OperatorFinalizeResultType CudaspScanFinalFunction(ExecutionContext &cont
|
||||
auto &state = data_p.global_state->Cast<CudaspScanState>();
|
||||
auto &local_state = data_p.local_state->Cast<CudaspScanLocalState>();
|
||||
|
||||
// If we still have pending output from previous batch, return it
|
||||
if (HasOutput(local_state)) {
|
||||
auto &txid_result = output.data[0];
|
||||
auto &height_result = output.data[1];
|
||||
auto &tweak_key_result = output.data[2];
|
||||
|
||||
idx_t output_count = MinValue<idx_t>(STANDARD_VECTOR_SIZE,
|
||||
local_state.output_txids.size() - local_state.output_position);
|
||||
|
||||
auto txid_data = FlatVector::GetData<string_t>(txid_result);
|
||||
auto height_data = FlatVector::GetData<int32_t>(height_result);
|
||||
auto tweak_key_data = FlatVector::GetData<string_t>(tweak_key_result);
|
||||
|
||||
for (idx_t i = 0; i < output_count; i++) {
|
||||
auto &txid = local_state.output_txids[local_state.output_position + i];
|
||||
auto &tweak_key = local_state.output_tweak_keys[local_state.output_position + i];
|
||||
txid_data[i] = StringVector::AddStringOrBlob(txid_result, string_t(txid.data(), txid.size()));
|
||||
height_data[i] = local_state.output_heights[local_state.output_position + i];
|
||||
tweak_key_data[i] = StringVector::AddStringOrBlob(tweak_key_result, string_t(tweak_key.data(), tweak_key.size()));
|
||||
}
|
||||
|
||||
output.SetCardinality(output_count);
|
||||
local_state.output_position += output_count;
|
||||
|
||||
if (HasOutput(local_state)) {
|
||||
return OperatorFinalizeResultType::HAVE_MORE_OUTPUT;
|
||||
}
|
||||
|
||||
// Clear output buffers after returning all data
|
||||
local_state.output_txids.clear();
|
||||
local_state.output_heights.clear();
|
||||
local_state.output_tweak_keys.clear();
|
||||
local_state.output_position = 0;
|
||||
// First, process any remaining accumulated data from this thread
|
||||
if (!local_state.finalized && !local_state.accumulated_txids.empty()) {
|
||||
ProcessBatch(local_state, bind_data, state);
|
||||
}
|
||||
|
||||
// Decrement thread counter only once per thread
|
||||
@ -646,31 +538,52 @@ static OperatorFinalizeResultType CudaspScanFinalFunction(ExecutionContext &cont
|
||||
state.finalize_lock->unlock();
|
||||
}
|
||||
|
||||
// Process any remaining accumulated data for this thread
|
||||
if (!local_state.accumulated_txids.empty()) {
|
||||
ProcessBatch(local_state, bind_data);
|
||||
local_state.output_position = 0;
|
||||
// If this thread is not the output thread, check if we can become it
|
||||
if (!local_state.is_output_thread) {
|
||||
// Wait for ALL threads to finish processing before claiming output
|
||||
if (state.currently_adding != 0) {
|
||||
return OperatorFinalizeResultType::FINISHED;
|
||||
}
|
||||
|
||||
// Try to claim output responsibility - only one thread should return output
|
||||
// to avoid batch index conflicts in DuckDB
|
||||
bool expected = false;
|
||||
if (!state.output_thread_claimed.compare_exchange_strong(expected, true)) {
|
||||
// Another thread is handling output, we're done
|
||||
return OperatorFinalizeResultType::FINISHED;
|
||||
}
|
||||
|
||||
// We successfully claimed output responsibility
|
||||
local_state.is_output_thread = true;
|
||||
}
|
||||
|
||||
// This thread is responsible for returning all output
|
||||
// All threads have finished processing, so global state is complete
|
||||
// Return output from global state in chunks
|
||||
if (HasOutput(state)) {
|
||||
auto &txid_result = output.data[0];
|
||||
auto &height_result = output.data[1];
|
||||
auto &tweak_key_result = output.data[2];
|
||||
|
||||
idx_t output_count = MinValue<idx_t>(STANDARD_VECTOR_SIZE, local_state.output_txids.size());
|
||||
idx_t output_count = MinValue<idx_t>(STANDARD_VECTOR_SIZE,
|
||||
state.output_txids.size() - state.output_position);
|
||||
|
||||
auto txid_data = FlatVector::GetData<string_t>(txid_result);
|
||||
auto height_data = FlatVector::GetData<int32_t>(height_result);
|
||||
auto tweak_key_data = FlatVector::GetData<string_t>(tweak_key_result);
|
||||
|
||||
for (idx_t i = 0; i < output_count; i++) {
|
||||
txid_data[i] = StringVector::AddStringOrBlob(txid_result, local_state.output_txids[i]);
|
||||
height_data[i] = local_state.output_heights[i];
|
||||
tweak_key_data[i] = StringVector::AddStringOrBlob(tweak_key_result, local_state.output_tweak_keys[i]);
|
||||
auto &txid = state.output_txids[state.output_position + i];
|
||||
auto &tweak_key = state.output_tweak_keys[state.output_position + i];
|
||||
txid_data[i] = StringVector::AddStringOrBlob(txid_result, string_t(txid.data(), txid.size()));
|
||||
height_data[i] = state.output_heights[state.output_position + i];
|
||||
tweak_key_data[i] = StringVector::AddStringOrBlob(tweak_key_result, string_t(tweak_key.data(), tweak_key.size()));
|
||||
}
|
||||
|
||||
output.SetCardinality(output_count);
|
||||
local_state.output_position = output_count;
|
||||
state.output_position += output_count;
|
||||
|
||||
if (HasOutput(local_state)) {
|
||||
if (HasOutput(state)) {
|
||||
return OperatorFinalizeResultType::HAVE_MORE_OUTPUT;
|
||||
}
|
||||
}
|
||||
|
||||
@ -433,6 +433,7 @@ struct BatchScanState {
|
||||
Solver *solver; // ECDSA solver instance
|
||||
uint32_t count;
|
||||
uint64_t batch_id; // Unique batch identifier for debugging
|
||||
cudaStream_t stream; // CUDA stream for concurrent batch execution
|
||||
};
|
||||
|
||||
// Host function to initialize solver and prepare for EC multiplication
|
||||
@ -478,8 +479,16 @@ extern "C" void* LaunchBatchScan(
|
||||
// Generate unique batch ID for debugging (use pointer address as unique ID)
|
||||
state->batch_id = reinterpret_cast<uint64_t>(state);
|
||||
|
||||
// Allocate managed memory for point coordinates (caller will fill these)
|
||||
// Create CUDA stream for this batch to enable concurrent execution
|
||||
cudaError_t err;
|
||||
err = cudaStreamCreate(&state->stream);
|
||||
if (err != cudaSuccess) {
|
||||
printf("cudaStreamCreate error: %s\n", cudaGetErrorString(err));
|
||||
delete state;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Allocate managed memory for point coordinates (caller will fill these)
|
||||
err = cudaMallocManaged(managed_points_x, Field::SIZE * count);
|
||||
if (err != cudaSuccess) {
|
||||
delete state;
|
||||
@ -568,31 +577,24 @@ extern "C" void* LaunchBatchScan(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Copy outputs metadata using cudaMemcpy to ensure proper coherency
|
||||
// CRITICAL: Use cudaMemcpy instead of memcpy for unified memory to ensure
|
||||
// data is properly synchronized between host and device in concurrent execution
|
||||
err = cudaMemcpy(state->d_outputs, h_outputs, outputs_size * sizeof(int64_t), cudaMemcpyHostToDevice);
|
||||
// Copy outputs metadata using cudaMemcpyAsync with stream for concurrent execution
|
||||
// This allows multiple batches to copy data concurrently without blocking
|
||||
err = cudaMemcpyAsync(state->d_outputs, h_outputs, outputs_size * sizeof(int64_t), cudaMemcpyHostToDevice, state->stream);
|
||||
if (err != cudaSuccess) {
|
||||
printf("cudaMemcpy d_outputs error: %s\n", cudaGetErrorString(err));
|
||||
printf("cudaMemcpyAsync d_outputs error: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
err = cudaMemcpy(state->d_output_offsets, h_output_offsets, count * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
||||
err = cudaMemcpyAsync(state->d_output_offsets, h_output_offsets, count * sizeof(uint32_t), cudaMemcpyHostToDevice, state->stream);
|
||||
if (err != cudaSuccess) {
|
||||
printf("cudaMemcpy d_output_offsets error: %s\n", cudaGetErrorString(err));
|
||||
printf("cudaMemcpyAsync d_output_offsets error: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
err = cudaMemcpy(state->d_output_lengths, h_output_lengths, count * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
||||
err = cudaMemcpyAsync(state->d_output_lengths, h_output_lengths, count * sizeof(uint32_t), cudaMemcpyHostToDevice, state->stream);
|
||||
if (err != cudaSuccess) {
|
||||
printf("cudaMemcpy d_output_lengths error: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
// Synchronize to ensure all data is copied to device before returning
|
||||
err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("cudaDeviceSynchronize after memcpy error: %s\n", cudaGetErrorString(err));
|
||||
printf("cudaMemcpyAsync d_output_lengths error: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
// Copy spend public key
|
||||
cudaMemcpy(state->d_spend_pubkey_x, h_spend_pubkey_x, field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(state->d_spend_pubkey_y, h_spend_pubkey_y, field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpyAsync(state->d_spend_pubkey_x, h_spend_pubkey_x, field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice, state->stream);
|
||||
cudaMemcpyAsync(state->d_spend_pubkey_y, h_spend_pubkey_y, field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice, state->stream);
|
||||
|
||||
// Allocate and copy label keys (if any)
|
||||
if (label_count > 0) {
|
||||
@ -625,8 +627,8 @@ extern "C" void* LaunchBatchScan(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
cudaMemcpy(state->d_label_keys_x, h_label_keys_x, label_count * field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(state->d_label_keys_y, h_label_keys_y, label_count * field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpyAsync(state->d_label_keys_x, h_label_keys_x, label_count * field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice, state->stream);
|
||||
cudaMemcpyAsync(state->d_label_keys_y, h_label_keys_y, label_count * field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice, state->stream);
|
||||
}
|
||||
|
||||
// Create fresh solver for this batch
|
||||
@ -658,8 +660,6 @@ extern "C" int RunBatchScanKernels(
|
||||
|
||||
Solver *solver = state->solver;
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Prepare data in the format expected by ec_pmul_init
|
||||
// MAX_LIMBS is defined in gECC as 64 (maximum array size)
|
||||
// For secp256k1 (256-bit), we use 4 u64 limbs, but arrays must be size MAX_LIMBS
|
||||
@ -721,8 +721,8 @@ extern "C" int RunBatchScanKernels(
|
||||
}
|
||||
#endif
|
||||
|
||||
// Call ec_pmul_init with our specific data
|
||||
solver->ec_pmul_init(h_scalars, h_keys_x, h_keys_y, count);
|
||||
// Call ec_pmul_init with our specific data and stream
|
||||
solver->ec_pmul_init(h_scalars, h_keys_x, h_keys_y, count, state->stream);
|
||||
|
||||
// Free host arrays
|
||||
delete[] h_scalars;
|
||||
@ -742,7 +742,7 @@ extern "C" int RunBatchScanKernels(
|
||||
u32 max_thread_per_block = 256;
|
||||
u32 block_num = MAX_SM_NUMS; // Use SM count like gECC test for optimal occupancy
|
||||
|
||||
solver->ecdsa_ec_pmul(block_num, max_thread_per_block, true); // true = unknown points
|
||||
solver->ecdsa_ec_pmul(block_num, max_thread_per_block, true, state->stream); // true = unknown points
|
||||
|
||||
// Check for multiplication errors
|
||||
err = cudaPeekAtLastError();
|
||||
@ -751,7 +751,7 @@ extern "C" int RunBatchScanKernels(
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
// ecdsa_ec_pmul() already synchronizes internally, no need to sync again
|
||||
|
||||
// === BIP-352 Silent Payment Pipeline ===
|
||||
// Step 1: Serialize shared secrets to compressed SEC1 format + 4 zero bytes (37 bytes each)
|
||||
@ -767,11 +767,11 @@ extern "C" int RunBatchScanKernels(
|
||||
// Kernels will process multiple elements per thread when count > blocks * threads
|
||||
int num_blocks = MAX_SM_NUMS; // Use SM count for good occupancy
|
||||
|
||||
SerializeToCompressedSEC1Kernel<<<num_blocks, threads_per_block>>>(
|
||||
SerializeToCompressedSEC1Kernel<<<num_blocks, threads_per_block, 0, state->stream>>>(
|
||||
solver->R0, d_serialized, count, state->batch_id
|
||||
);
|
||||
|
||||
err = cudaDeviceSynchronize();
|
||||
err = cudaStreamSynchronize(state->stream);
|
||||
if (err != cudaSuccess) {
|
||||
printf("SerializeToCompressedSEC1Kernel error: %s\n", cudaGetErrorString(err));
|
||||
cudaFree(d_serialized);
|
||||
@ -787,32 +787,35 @@ extern "C" int RunBatchScanKernels(
|
||||
return -1;
|
||||
}
|
||||
|
||||
ComputeTaggedHashesKernel<<<num_blocks, threads_per_block>>>(
|
||||
ComputeTaggedHashesKernel<<<num_blocks, threads_per_block, 0, state->stream>>>(
|
||||
d_serialized, d_hashes, count, state->batch_id
|
||||
);
|
||||
|
||||
err = cudaDeviceSynchronize();
|
||||
// Step 3: Copy hashes to host and convert to Order scalars for fixed-point multiplication
|
||||
// Copy hashes to host memory first to avoid coherency issues
|
||||
// Use async memcpy with stream, then sync the stream
|
||||
uint8_t *h_hashes = new uint8_t[count * 32];
|
||||
err = cudaMemcpyAsync(h_hashes, d_hashes, count * 32, cudaMemcpyDeviceToHost, state->stream);
|
||||
if (err != cudaSuccess) {
|
||||
printf("ComputeTaggedHashesKernel error: %s\n", cudaGetErrorString(err));
|
||||
printf("cudaMemcpyAsync for h_hashes error: %s\n", cudaGetErrorString(err));
|
||||
cudaFree(d_serialized);
|
||||
cudaFree(d_hashes);
|
||||
delete[] h_hashes;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Debug: Log hash result
|
||||
cudaFree(d_serialized); // No longer needed
|
||||
|
||||
// Step 3: Convert hashes to Order scalars for fixed-point multiplication
|
||||
// Copy hashes to host memory first to avoid coherency issues
|
||||
uint8_t *h_hashes = new uint8_t[count * 32];
|
||||
err = cudaMemcpy(h_hashes, d_hashes, count * 32, cudaMemcpyDeviceToHost);
|
||||
err = cudaStreamSynchronize(state->stream);
|
||||
if (err != cudaSuccess) {
|
||||
printf("cudaMemcpy for h_hashes error: %s\n", cudaGetErrorString(err));
|
||||
printf("cudaStreamSynchronize after memcpy error: %s\n", cudaGetErrorString(err));
|
||||
cudaFree(d_serialized);
|
||||
cudaFree(d_hashes);
|
||||
delete[] h_hashes;
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaFree(d_hashes); // No longer needed
|
||||
// Free device memory after memcpy completes
|
||||
cudaFree(d_serialized);
|
||||
cudaFree(d_hashes);
|
||||
|
||||
// Allocate host buffer for conversion
|
||||
Order::Base *h_hash_scalars = new Order::Base[Order::SIZE * count];
|
||||
@ -858,15 +861,15 @@ extern "C" int RunBatchScanKernels(
|
||||
return -1;
|
||||
}
|
||||
|
||||
err = cudaMemcpy(d_hash_scalars, h_hash_scalars, Order::SIZE * count, cudaMemcpyHostToDevice);
|
||||
err = cudaMemcpyAsync(d_hash_scalars, h_hash_scalars, Order::SIZE * count, cudaMemcpyHostToDevice, state->stream);
|
||||
if (err != cudaSuccess) {
|
||||
printf("cudaMemcpy for d_hash_scalars error: %s\n", cudaGetErrorString(err));
|
||||
printf("cudaMemcpyAsync for d_hash_scalars error: %s\n", cudaGetErrorString(err));
|
||||
delete[] h_hash_scalars;
|
||||
cudaFree(d_hash_scalars);
|
||||
return -1;
|
||||
}
|
||||
|
||||
delete[] h_hash_scalars; // No longer needed
|
||||
delete[] h_hash_scalars; // Can be freed after async copy is queued
|
||||
|
||||
// Step 4: Allocate output buffer for fixed-point multiply results
|
||||
ECPoint::Base *d_fpm_results;
|
||||
@ -878,11 +881,11 @@ extern "C" int RunBatchScanKernels(
|
||||
}
|
||||
|
||||
// Step 5: Fixed-point multiply: hash × G using precomputed table
|
||||
FixedPointMultiplyKernel<<<num_blocks, threads_per_block>>>(
|
||||
FixedPointMultiplyKernel<<<num_blocks, threads_per_block, 0, state->stream>>>(
|
||||
count, d_hash_scalars, d_fpm_results, state->batch_id
|
||||
);
|
||||
|
||||
err = cudaDeviceSynchronize();
|
||||
err = cudaStreamSynchronize(state->stream);
|
||||
if (err != cudaSuccess) {
|
||||
printf("FixedPointMultiplyKernel error: %s\n", cudaGetErrorString(err));
|
||||
cudaFree(d_hash_scalars);
|
||||
@ -896,7 +899,7 @@ extern "C" int RunBatchScanKernels(
|
||||
// This will: (1) try base case: output_point + spend_pubkey
|
||||
// (2) for each label: try output_point + label_key (and negated)
|
||||
|
||||
CheckMatchesWithLabelsKernel<<<num_blocks, threads_per_block>>>(
|
||||
CheckMatchesWithLabelsKernel<<<num_blocks, threads_per_block, 0, state->stream>>>(
|
||||
d_fpm_results,
|
||||
state->d_spend_pubkey_x,
|
||||
state->d_spend_pubkey_y,
|
||||
@ -911,7 +914,7 @@ extern "C" int RunBatchScanKernels(
|
||||
state->batch_id
|
||||
);
|
||||
|
||||
err = cudaDeviceSynchronize();
|
||||
err = cudaStreamSynchronize(state->stream);
|
||||
if (err != cudaSuccess) {
|
||||
printf("CheckMatchesWithLabelsKernel error: %s\n", cudaGetErrorString(err));
|
||||
cudaFree(d_fpm_results);
|
||||
@ -929,6 +932,24 @@ extern "C" void FreeBatchScanState(void *state_handle) {
|
||||
|
||||
BatchScanState *state = static_cast<BatchScanState*>(state_handle);
|
||||
|
||||
// CRITICAL: Synchronize stream before cleanup to ensure all operations complete
|
||||
// Without this, destroying the stream or freeing memory can cause deadlocks
|
||||
if (state->stream) {
|
||||
cudaError_t sync_err = cudaStreamSynchronize(state->stream);
|
||||
if (sync_err != cudaSuccess) {
|
||||
printf("WARNING: cudaStreamSynchronize in FreeBatchScanState failed: %s\n", cudaGetErrorString(sync_err));
|
||||
}
|
||||
|
||||
// Double-check stream is idle
|
||||
cudaError_t query_err = cudaStreamQuery(state->stream);
|
||||
if (query_err == cudaErrorNotReady) {
|
||||
printf("WARNING: Stream not ready after synchronize, forcing device sync\n");
|
||||
cudaDeviceSynchronize();
|
||||
} else if (query_err != cudaSuccess) {
|
||||
printf("WARNING: cudaStreamQuery failed: %s\n", cudaGetErrorString(query_err));
|
||||
}
|
||||
}
|
||||
|
||||
// Free CUDA buffers
|
||||
if (state->d_outputs) cudaFree(state->d_outputs);
|
||||
if (state->d_output_offsets) cudaFree(state->d_output_offsets);
|
||||
@ -939,8 +960,16 @@ extern "C" void FreeBatchScanState(void *state_handle) {
|
||||
if (state->d_label_keys_y) cudaFree(state->d_label_keys_y);
|
||||
if (state->d_fpm_results_backup) cudaFree(state->d_fpm_results_backup);
|
||||
|
||||
// Delete solver
|
||||
if (state->solver) delete state->solver;
|
||||
// Clean up solver resources (frees managed memory allocated in ec_pmul_init)
|
||||
if (state->solver) {
|
||||
state->solver->ec_pmul_close();
|
||||
delete state->solver;
|
||||
}
|
||||
|
||||
// Destroy CUDA stream (safe now after synchronization)
|
||||
if (state->stream) {
|
||||
cudaStreamDestroy(state->stream);
|
||||
}
|
||||
|
||||
// Free state struct
|
||||
delete state;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user