fix synchronization issue

fix synchronization issue 2

fix synchronization issue 3

fix synchronization issue 4

fix synchronization issue 5

fix synchronization issue 6

fix synchronization issue 7

fix synchronization issue 8

fix synchronization issue 9

fix synchronization issue 10

fix synchronization issue 11

fix synchronization issue 12

fix synchronization issue 13

fix synchronization issue 14

fix synchronization issue 15

fix synchronization issue 16

fix synchronization issue 17

fix synchronization issue 18

fix synchronization issue 19

fix synchronization issue 20

fix synchronization issue 21

fix synchronization issue 22

fix synchronization issue 23
This commit is contained in:
Craig Raw 2025-10-29 15:53:22 +02:00
parent 56f5ec8872
commit 3b0708bc72
3 changed files with 143 additions and 201 deletions

2
gECC

@ -1 +1 @@
Subproject commit f3ab474f24d0e375bc2fa41dd480525506ceaa8a
Subproject commit 5622411e7f62c90a75d4c69b8e3ff8c8b24f2c39

View File

@ -104,9 +104,10 @@ struct CudaspScanBindData : public TableFunctionData {
};
struct CudaspScanLocalState : public LocalTableFunctionState {
CudaspScanLocalState() : finalized(false), output_position(0) {
CudaspScanLocalState() : finalized(false), is_output_thread(false) {
}
bool finalized;
bool is_output_thread; // True if this thread is responsible for returning output
// Per-thread accumulated input data
vector<std::string> accumulated_txids; // Transaction IDs (BLOB) - owned copies
@ -115,28 +116,27 @@ struct CudaspScanLocalState : public LocalTableFunctionState {
vector<int64_t> accumulated_outputs; // Flattened output values (BIGINT)
vector<idx_t> accumulated_output_offsets; // Offset into accumulated_outputs for each row
vector<idx_t> accumulated_output_lengths; // Length of each outputs list
// Per-thread processed output data (only rows with matches)
vector<std::string> output_txids; // Owned copies
vector<int32_t> output_heights;
vector<std::string> output_tweak_keys; // Owned copies
idx_t output_position;
};
struct CudaspScanState : public GlobalTableFunctionState {
CudaspScanState() : currently_adding(0) {
CudaspScanState() : currently_adding(0), output_position(0), output_thread_claimed(false) {
finalize_lock = make_uniq<std::mutex>();
}
// Limit to single thread to prevent duplicate results from parallel execution
// GPU parallelism with thousands of CUDA threads provides sufficient performance
idx_t MaxThreads() const override {
return 1;
output_lock = make_uniq<std::mutex>();
}
// Thread synchronization
std::atomic_uint64_t currently_adding;
unique_ptr<std::mutex> finalize_lock;
// Global output storage - all threads write here
unique_ptr<std::mutex> output_lock;
vector<string> output_txids;
vector<int32_t> output_heights;
vector<string> output_tweak_keys;
idx_t output_position;
// Only one thread returns output to avoid batch index conflicts
std::atomic<bool> output_thread_claimed;
};
static void AccumulateInput(CudaspScanLocalState &local_state, DataChunk &input) {
@ -211,13 +211,7 @@ static void AccumulateInput(CudaspScanLocalState &local_state, DataChunk &input)
}
}
static void ProcessBatch(CudaspScanLocalState &local_state, const CudaspScanBindData &bind_data) {
// Clear any previous output
local_state.output_txids.clear();
local_state.output_heights.clear();
local_state.output_tweak_keys.clear();
local_state.output_position = 0;
static void ProcessBatch(CudaspScanLocalState &local_state, const CudaspScanBindData &bind_data, CudaspScanState &global_state) {
idx_t batch_size = local_state.accumulated_txids.size();
if (batch_size == 0) {
return;
@ -362,21 +356,19 @@ static void ProcessBatch(CudaspScanLocalState &local_state, const CudaspScanBind
);
if (kernel_result == 0) {
// Ensure all GPU writes to managed_match_flags are visible to CPU
cudaDeviceSynchronize();
// RunBatchScanKernels already synchronizes the stream before returning
// managed_match_flags is now safe to read from CPU
// Build output for matching rows
idx_t match_count = 0;
std::vector<int32_t> matched_heights;
// Build output for matching rows - write to global state with locking
global_state.output_lock->lock();
for (idx_t i = 0; i < batch_size; i++) {
if (managed_match_flags[i]) {
local_state.output_txids.push_back(local_state.accumulated_txids[i]);
local_state.output_heights.push_back(local_state.accumulated_heights[i]);
local_state.output_tweak_keys.push_back(local_state.accumulated_tweak_keys[i]);
matched_heights.push_back(local_state.accumulated_heights[i]);
match_count++;
global_state.output_txids.push_back(local_state.accumulated_txids[i]);
global_state.output_heights.push_back(local_state.accumulated_heights[i]);
global_state.output_tweak_keys.push_back(local_state.accumulated_tweak_keys[i]);
}
}
global_state.output_lock->unlock();
}
// Cleanup managed memory
@ -397,8 +389,8 @@ static void ProcessBatch(CudaspScanLocalState &local_state, const CudaspScanBind
local_state.accumulated_output_lengths.clear();
}
static bool HasOutput(const CudaspScanLocalState &local_state) {
return local_state.output_position < local_state.output_txids.size();
static bool HasOutput(const CudaspScanState &global_state) {
return global_state.output_position < global_state.output_txids.size();
}
static bool ShouldProcessBatch(const CudaspScanLocalState &local_state, const CudaspScanBindData &bind_data) {
@ -510,90 +502,20 @@ static unique_ptr<LocalTableFunctionState> CudaspScanLocalInit(ExecutionContext
static OperatorResultType CudaspScanFunction(ExecutionContext &context, TableFunctionInput &data_p,
DataChunk &input, DataChunk &output) {
auto &bind_data = data_p.bind_data->Cast<CudaspScanBindData>();
auto &global_state = data_p.global_state->Cast<CudaspScanState>();
auto &local_state = data_p.local_state->Cast<CudaspScanLocalState>();
// If we have pending output from a previous batch, return it first
if (HasOutput(local_state)) {
auto &txid_result = output.data[0];
auto &height_result = output.data[1];
auto &tweak_key_result = output.data[2];
idx_t output_count = MinValue<idx_t>(STANDARD_VECTOR_SIZE,
local_state.output_txids.size() - local_state.output_position);
auto txid_data = FlatVector::GetData<string_t>(txid_result);
auto height_data = FlatVector::GetData<int32_t>(height_result);
auto tweak_key_data = FlatVector::GetData<string_t>(tweak_key_result);
for (idx_t i = 0; i < output_count; i++) {
auto &txid = local_state.output_txids[local_state.output_position + i];
auto &tweak_key = local_state.output_tweak_keys[local_state.output_position + i];
txid_data[i] = StringVector::AddStringOrBlob(txid_result, string_t(txid.data(), txid.size()));
height_data[i] = local_state.output_heights[local_state.output_position + i];
tweak_key_data[i] = StringVector::AddStringOrBlob(tweak_key_result, string_t(tweak_key.data(), tweak_key.size()));
}
output.SetCardinality(output_count);
local_state.output_position += output_count;
// If we still have more output, signal that
if (HasOutput(local_state)) {
return OperatorResultType::HAVE_MORE_OUTPUT;
}
// All output returned, clear buffers
local_state.output_txids.clear();
local_state.output_heights.clear();
local_state.output_tweak_keys.clear();
local_state.output_position = 0;
// Otherwise keep accepting input
return OperatorResultType::NEED_MORE_INPUT;
}
// Process new input
// Accumulate input
if (input.size() > 0) {
AccumulateInput(local_state, input);
// Signal that we've consumed the input
input.SetCardinality(0);
// Process batch if we've accumulated enough data
// Process batch immediately when full, but don't return output
// Output goes to global state for finalize to return
if (ShouldProcessBatch(local_state, bind_data)) {
ProcessBatch(local_state, bind_data);
local_state.output_position = 0;
// Write output immediately
auto &txid_result = output.data[0];
auto &height_result = output.data[1];
auto &tweak_key_result = output.data[2];
idx_t output_count = MinValue<idx_t>(STANDARD_VECTOR_SIZE, local_state.output_txids.size());
auto txid_data = FlatVector::GetData<string_t>(txid_result);
auto height_data = FlatVector::GetData<int32_t>(height_result);
auto tweak_key_data = FlatVector::GetData<string_t>(tweak_key_result);
for (idx_t i = 0; i < output_count; i++) {
auto &txid = local_state.output_txids[i];
auto &tweak_key = local_state.output_tweak_keys[i];
txid_data[i] = StringVector::AddStringOrBlob(txid_result, string_t(txid.data(), txid.size()));
height_data[i] = local_state.output_heights[i];
tweak_key_data[i] = StringVector::AddStringOrBlob(tweak_key_result, string_t(tweak_key.data(), tweak_key.size()));
}
output.SetCardinality(output_count);
local_state.output_position = output_count;
// We just wrote data to output, so we MUST return HAVE_MORE_OUTPUT
return OperatorResultType::HAVE_MORE_OUTPUT;
ProcessBatch(local_state, bind_data, global_state);
}
// Keep accumulating
return OperatorResultType::NEED_MORE_INPUT;
}
// No more input - should not reach here, finalize handles remaining data
return OperatorResultType::NEED_MORE_INPUT;
}
@ -603,39 +525,9 @@ static OperatorFinalizeResultType CudaspScanFinalFunction(ExecutionContext &cont
auto &state = data_p.global_state->Cast<CudaspScanState>();
auto &local_state = data_p.local_state->Cast<CudaspScanLocalState>();
// If we still have pending output from previous batch, return it
if (HasOutput(local_state)) {
auto &txid_result = output.data[0];
auto &height_result = output.data[1];
auto &tweak_key_result = output.data[2];
idx_t output_count = MinValue<idx_t>(STANDARD_VECTOR_SIZE,
local_state.output_txids.size() - local_state.output_position);
auto txid_data = FlatVector::GetData<string_t>(txid_result);
auto height_data = FlatVector::GetData<int32_t>(height_result);
auto tweak_key_data = FlatVector::GetData<string_t>(tweak_key_result);
for (idx_t i = 0; i < output_count; i++) {
auto &txid = local_state.output_txids[local_state.output_position + i];
auto &tweak_key = local_state.output_tweak_keys[local_state.output_position + i];
txid_data[i] = StringVector::AddStringOrBlob(txid_result, string_t(txid.data(), txid.size()));
height_data[i] = local_state.output_heights[local_state.output_position + i];
tweak_key_data[i] = StringVector::AddStringOrBlob(tweak_key_result, string_t(tweak_key.data(), tweak_key.size()));
}
output.SetCardinality(output_count);
local_state.output_position += output_count;
if (HasOutput(local_state)) {
return OperatorFinalizeResultType::HAVE_MORE_OUTPUT;
}
// Clear output buffers after returning all data
local_state.output_txids.clear();
local_state.output_heights.clear();
local_state.output_tweak_keys.clear();
local_state.output_position = 0;
// First, process any remaining accumulated data from this thread
if (!local_state.finalized && !local_state.accumulated_txids.empty()) {
ProcessBatch(local_state, bind_data, state);
}
// Decrement thread counter only once per thread
@ -646,31 +538,52 @@ static OperatorFinalizeResultType CudaspScanFinalFunction(ExecutionContext &cont
state.finalize_lock->unlock();
}
// Process any remaining accumulated data for this thread
if (!local_state.accumulated_txids.empty()) {
ProcessBatch(local_state, bind_data);
local_state.output_position = 0;
// If this thread is not the output thread, check if we can become it
if (!local_state.is_output_thread) {
// Wait for ALL threads to finish processing before claiming output
if (state.currently_adding != 0) {
return OperatorFinalizeResultType::FINISHED;
}
// Try to claim output responsibility - only one thread should return output
// to avoid batch index conflicts in DuckDB
bool expected = false;
if (!state.output_thread_claimed.compare_exchange_strong(expected, true)) {
// Another thread is handling output, we're done
return OperatorFinalizeResultType::FINISHED;
}
// We successfully claimed output responsibility
local_state.is_output_thread = true;
}
// This thread is responsible for returning all output
// All threads have finished processing, so global state is complete
// Return output from global state in chunks
if (HasOutput(state)) {
auto &txid_result = output.data[0];
auto &height_result = output.data[1];
auto &tweak_key_result = output.data[2];
idx_t output_count = MinValue<idx_t>(STANDARD_VECTOR_SIZE, local_state.output_txids.size());
idx_t output_count = MinValue<idx_t>(STANDARD_VECTOR_SIZE,
state.output_txids.size() - state.output_position);
auto txid_data = FlatVector::GetData<string_t>(txid_result);
auto height_data = FlatVector::GetData<int32_t>(height_result);
auto tweak_key_data = FlatVector::GetData<string_t>(tweak_key_result);
for (idx_t i = 0; i < output_count; i++) {
txid_data[i] = StringVector::AddStringOrBlob(txid_result, local_state.output_txids[i]);
height_data[i] = local_state.output_heights[i];
tweak_key_data[i] = StringVector::AddStringOrBlob(tweak_key_result, local_state.output_tweak_keys[i]);
auto &txid = state.output_txids[state.output_position + i];
auto &tweak_key = state.output_tweak_keys[state.output_position + i];
txid_data[i] = StringVector::AddStringOrBlob(txid_result, string_t(txid.data(), txid.size()));
height_data[i] = state.output_heights[state.output_position + i];
tweak_key_data[i] = StringVector::AddStringOrBlob(tweak_key_result, string_t(tweak_key.data(), tweak_key.size()));
}
output.SetCardinality(output_count);
local_state.output_position = output_count;
state.output_position += output_count;
if (HasOutput(local_state)) {
if (HasOutput(state)) {
return OperatorFinalizeResultType::HAVE_MORE_OUTPUT;
}
}

View File

@ -433,6 +433,7 @@ struct BatchScanState {
Solver *solver; // ECDSA solver instance
uint32_t count;
uint64_t batch_id; // Unique batch identifier for debugging
cudaStream_t stream; // CUDA stream for concurrent batch execution
};
// Host function to initialize solver and prepare for EC multiplication
@ -478,8 +479,16 @@ extern "C" void* LaunchBatchScan(
// Generate unique batch ID for debugging (use pointer address as unique ID)
state->batch_id = reinterpret_cast<uint64_t>(state);
// Allocate managed memory for point coordinates (caller will fill these)
// Create CUDA stream for this batch to enable concurrent execution
cudaError_t err;
err = cudaStreamCreate(&state->stream);
if (err != cudaSuccess) {
printf("cudaStreamCreate error: %s\n", cudaGetErrorString(err));
delete state;
return nullptr;
}
// Allocate managed memory for point coordinates (caller will fill these)
err = cudaMallocManaged(managed_points_x, Field::SIZE * count);
if (err != cudaSuccess) {
delete state;
@ -568,31 +577,24 @@ extern "C" void* LaunchBatchScan(
return nullptr;
}
// Copy outputs metadata using cudaMemcpy to ensure proper coherency
// CRITICAL: Use cudaMemcpy instead of memcpy for unified memory to ensure
// data is properly synchronized between host and device in concurrent execution
err = cudaMemcpy(state->d_outputs, h_outputs, outputs_size * sizeof(int64_t), cudaMemcpyHostToDevice);
// Copy outputs metadata using cudaMemcpyAsync with stream for concurrent execution
// This allows multiple batches to copy data concurrently without blocking
err = cudaMemcpyAsync(state->d_outputs, h_outputs, outputs_size * sizeof(int64_t), cudaMemcpyHostToDevice, state->stream);
if (err != cudaSuccess) {
printf("cudaMemcpy d_outputs error: %s\n", cudaGetErrorString(err));
printf("cudaMemcpyAsync d_outputs error: %s\n", cudaGetErrorString(err));
}
err = cudaMemcpy(state->d_output_offsets, h_output_offsets, count * sizeof(uint32_t), cudaMemcpyHostToDevice);
err = cudaMemcpyAsync(state->d_output_offsets, h_output_offsets, count * sizeof(uint32_t), cudaMemcpyHostToDevice, state->stream);
if (err != cudaSuccess) {
printf("cudaMemcpy d_output_offsets error: %s\n", cudaGetErrorString(err));
printf("cudaMemcpyAsync d_output_offsets error: %s\n", cudaGetErrorString(err));
}
err = cudaMemcpy(state->d_output_lengths, h_output_lengths, count * sizeof(uint32_t), cudaMemcpyHostToDevice);
err = cudaMemcpyAsync(state->d_output_lengths, h_output_lengths, count * sizeof(uint32_t), cudaMemcpyHostToDevice, state->stream);
if (err != cudaSuccess) {
printf("cudaMemcpy d_output_lengths error: %s\n", cudaGetErrorString(err));
}
// Synchronize to ensure all data is copied to device before returning
err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("cudaDeviceSynchronize after memcpy error: %s\n", cudaGetErrorString(err));
printf("cudaMemcpyAsync d_output_lengths error: %s\n", cudaGetErrorString(err));
}
// Copy spend public key
cudaMemcpy(state->d_spend_pubkey_x, h_spend_pubkey_x, field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice);
cudaMemcpy(state->d_spend_pubkey_y, h_spend_pubkey_y, field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice);
cudaMemcpyAsync(state->d_spend_pubkey_x, h_spend_pubkey_x, field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice, state->stream);
cudaMemcpyAsync(state->d_spend_pubkey_y, h_spend_pubkey_y, field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice, state->stream);
// Allocate and copy label keys (if any)
if (label_count > 0) {
@ -625,8 +627,8 @@ extern "C" void* LaunchBatchScan(
return nullptr;
}
cudaMemcpy(state->d_label_keys_x, h_label_keys_x, label_count * field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice);
cudaMemcpy(state->d_label_keys_y, h_label_keys_y, label_count * field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice);
cudaMemcpyAsync(state->d_label_keys_x, h_label_keys_x, label_count * field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice, state->stream);
cudaMemcpyAsync(state->d_label_keys_y, h_label_keys_y, label_count * field_limbs * sizeof(uint32_t), cudaMemcpyHostToDevice, state->stream);
}
// Create fresh solver for this batch
@ -658,8 +660,6 @@ extern "C" int RunBatchScanKernels(
Solver *solver = state->solver;
cudaDeviceSynchronize();
// Prepare data in the format expected by ec_pmul_init
// MAX_LIMBS is defined in gECC as 64 (maximum array size)
// For secp256k1 (256-bit), we use 4 u64 limbs, but arrays must be size MAX_LIMBS
@ -721,8 +721,8 @@ extern "C" int RunBatchScanKernels(
}
#endif
// Call ec_pmul_init with our specific data
solver->ec_pmul_init(h_scalars, h_keys_x, h_keys_y, count);
// Call ec_pmul_init with our specific data and stream
solver->ec_pmul_init(h_scalars, h_keys_x, h_keys_y, count, state->stream);
// Free host arrays
delete[] h_scalars;
@ -742,7 +742,7 @@ extern "C" int RunBatchScanKernels(
u32 max_thread_per_block = 256;
u32 block_num = MAX_SM_NUMS; // Use SM count like gECC test for optimal occupancy
solver->ecdsa_ec_pmul(block_num, max_thread_per_block, true); // true = unknown points
solver->ecdsa_ec_pmul(block_num, max_thread_per_block, true, state->stream); // true = unknown points
// Check for multiplication errors
err = cudaPeekAtLastError();
@ -751,7 +751,7 @@ extern "C" int RunBatchScanKernels(
return -1;
}
cudaDeviceSynchronize();
// ecdsa_ec_pmul() already synchronizes internally, no need to sync again
// === BIP-352 Silent Payment Pipeline ===
// Step 1: Serialize shared secrets to compressed SEC1 format + 4 zero bytes (37 bytes each)
@ -767,11 +767,11 @@ extern "C" int RunBatchScanKernels(
// Kernels will process multiple elements per thread when count > blocks * threads
int num_blocks = MAX_SM_NUMS; // Use SM count for good occupancy
SerializeToCompressedSEC1Kernel<<<num_blocks, threads_per_block>>>(
SerializeToCompressedSEC1Kernel<<<num_blocks, threads_per_block, 0, state->stream>>>(
solver->R0, d_serialized, count, state->batch_id
);
err = cudaDeviceSynchronize();
err = cudaStreamSynchronize(state->stream);
if (err != cudaSuccess) {
printf("SerializeToCompressedSEC1Kernel error: %s\n", cudaGetErrorString(err));
cudaFree(d_serialized);
@ -787,32 +787,35 @@ extern "C" int RunBatchScanKernels(
return -1;
}
ComputeTaggedHashesKernel<<<num_blocks, threads_per_block>>>(
ComputeTaggedHashesKernel<<<num_blocks, threads_per_block, 0, state->stream>>>(
d_serialized, d_hashes, count, state->batch_id
);
err = cudaDeviceSynchronize();
// Step 3: Copy hashes to host and convert to Order scalars for fixed-point multiplication
// Copy hashes to host memory first to avoid coherency issues
// Use async memcpy with stream, then sync the stream
uint8_t *h_hashes = new uint8_t[count * 32];
err = cudaMemcpyAsync(h_hashes, d_hashes, count * 32, cudaMemcpyDeviceToHost, state->stream);
if (err != cudaSuccess) {
printf("ComputeTaggedHashesKernel error: %s\n", cudaGetErrorString(err));
printf("cudaMemcpyAsync for h_hashes error: %s\n", cudaGetErrorString(err));
cudaFree(d_serialized);
cudaFree(d_hashes);
delete[] h_hashes;
return -1;
}
// Debug: Log hash result
cudaFree(d_serialized); // No longer needed
// Step 3: Convert hashes to Order scalars for fixed-point multiplication
// Copy hashes to host memory first to avoid coherency issues
uint8_t *h_hashes = new uint8_t[count * 32];
err = cudaMemcpy(h_hashes, d_hashes, count * 32, cudaMemcpyDeviceToHost);
err = cudaStreamSynchronize(state->stream);
if (err != cudaSuccess) {
printf("cudaMemcpy for h_hashes error: %s\n", cudaGetErrorString(err));
printf("cudaStreamSynchronize after memcpy error: %s\n", cudaGetErrorString(err));
cudaFree(d_serialized);
cudaFree(d_hashes);
delete[] h_hashes;
return -1;
}
cudaFree(d_hashes); // No longer needed
// Free device memory after memcpy completes
cudaFree(d_serialized);
cudaFree(d_hashes);
// Allocate host buffer for conversion
Order::Base *h_hash_scalars = new Order::Base[Order::SIZE * count];
@ -858,15 +861,15 @@ extern "C" int RunBatchScanKernels(
return -1;
}
err = cudaMemcpy(d_hash_scalars, h_hash_scalars, Order::SIZE * count, cudaMemcpyHostToDevice);
err = cudaMemcpyAsync(d_hash_scalars, h_hash_scalars, Order::SIZE * count, cudaMemcpyHostToDevice, state->stream);
if (err != cudaSuccess) {
printf("cudaMemcpy for d_hash_scalars error: %s\n", cudaGetErrorString(err));
printf("cudaMemcpyAsync for d_hash_scalars error: %s\n", cudaGetErrorString(err));
delete[] h_hash_scalars;
cudaFree(d_hash_scalars);
return -1;
}
delete[] h_hash_scalars; // No longer needed
delete[] h_hash_scalars; // Can be freed after async copy is queued
// Step 4: Allocate output buffer for fixed-point multiply results
ECPoint::Base *d_fpm_results;
@ -878,11 +881,11 @@ extern "C" int RunBatchScanKernels(
}
// Step 5: Fixed-point multiply: hash × G using precomputed table
FixedPointMultiplyKernel<<<num_blocks, threads_per_block>>>(
FixedPointMultiplyKernel<<<num_blocks, threads_per_block, 0, state->stream>>>(
count, d_hash_scalars, d_fpm_results, state->batch_id
);
err = cudaDeviceSynchronize();
err = cudaStreamSynchronize(state->stream);
if (err != cudaSuccess) {
printf("FixedPointMultiplyKernel error: %s\n", cudaGetErrorString(err));
cudaFree(d_hash_scalars);
@ -896,7 +899,7 @@ extern "C" int RunBatchScanKernels(
// This will: (1) try base case: output_point + spend_pubkey
// (2) for each label: try output_point + label_key (and negated)
CheckMatchesWithLabelsKernel<<<num_blocks, threads_per_block>>>(
CheckMatchesWithLabelsKernel<<<num_blocks, threads_per_block, 0, state->stream>>>(
d_fpm_results,
state->d_spend_pubkey_x,
state->d_spend_pubkey_y,
@ -911,7 +914,7 @@ extern "C" int RunBatchScanKernels(
state->batch_id
);
err = cudaDeviceSynchronize();
err = cudaStreamSynchronize(state->stream);
if (err != cudaSuccess) {
printf("CheckMatchesWithLabelsKernel error: %s\n", cudaGetErrorString(err));
cudaFree(d_fpm_results);
@ -929,6 +932,24 @@ extern "C" void FreeBatchScanState(void *state_handle) {
BatchScanState *state = static_cast<BatchScanState*>(state_handle);
// CRITICAL: Synchronize stream before cleanup to ensure all operations complete
// Without this, destroying the stream or freeing memory can cause deadlocks
if (state->stream) {
cudaError_t sync_err = cudaStreamSynchronize(state->stream);
if (sync_err != cudaSuccess) {
printf("WARNING: cudaStreamSynchronize in FreeBatchScanState failed: %s\n", cudaGetErrorString(sync_err));
}
// Double-check stream is idle
cudaError_t query_err = cudaStreamQuery(state->stream);
if (query_err == cudaErrorNotReady) {
printf("WARNING: Stream not ready after synchronize, forcing device sync\n");
cudaDeviceSynchronize();
} else if (query_err != cudaSuccess) {
printf("WARNING: cudaStreamQuery failed: %s\n", cudaGetErrorString(query_err));
}
}
// Free CUDA buffers
if (state->d_outputs) cudaFree(state->d_outputs);
if (state->d_output_offsets) cudaFree(state->d_output_offsets);
@ -939,8 +960,16 @@ extern "C" void FreeBatchScanState(void *state_handle) {
if (state->d_label_keys_y) cudaFree(state->d_label_keys_y);
if (state->d_fpm_results_backup) cudaFree(state->d_fpm_results_backup);
// Delete solver
if (state->solver) delete state->solver;
// Clean up solver resources (frees managed memory allocated in ec_pmul_init)
if (state->solver) {
state->solver->ec_pmul_close();
delete state->solver;
}
// Destroy CUDA stream (safe now after synchronization)
if (state->stream) {
cudaStreamDestroy(state->stream);
}
// Free state struct
delete state;