UltrafastSecp256k1/cpu/tests/test_multiscalar_batch.cpp

261 lines
9.0 KiB
C++

// ============================================================================
// Test: Multi-Scalar Multiplication + Batch Verification
// ============================================================================
#include "secp256k1/multiscalar.hpp"
#include "secp256k1/batch_verify.hpp"
#include "secp256k1/ecdsa.hpp"
#include "secp256k1/schnorr.hpp"
#include "secp256k1/sha256.hpp"
#include <cstdio>
#include <cstring>
#include <vector>
using namespace secp256k1;
using fast::Scalar;
using fast::Point;
static int tests_run = 0;
static int tests_passed = 0;
#define CHECK(cond, msg) do { \
++tests_run; \
if (cond) { ++tests_passed; printf(" [PASS] %s\n", msg); } \
else { printf(" [FAIL] %s\n", msg); } \
} while(0)
// -- Shamir's Trick -----------------------------------------------------------
static void test_shamir_trick() {
printf("\n--- Shamir's Trick ---\n");
auto G = Point::generator();
auto a = Scalar::from_uint64(7);
auto b = Scalar::from_uint64(13);
auto P = G;
auto Q = G.scalar_mul(Scalar::from_uint64(5));
// Expected: a*P + b*Q = 7*G + 13*5G = 7G + 65G = 72G
auto expected = G.scalar_mul(Scalar::from_uint64(72));
auto result = shamir_trick(a, P, b, Q);
CHECK(result.x().to_bytes() == expected.x().to_bytes(),
"shamir_trick(7, G, 13, 5G) == 72G");
// Edge: a=0
auto r2 = shamir_trick(Scalar::zero(), P, b, Q);
auto e2 = Q.scalar_mul(b);
CHECK(r2.x().to_bytes() == e2.x().to_bytes(),
"shamir_trick(0, P, b, Q) == b*Q");
// Edge: b=0
auto r3 = shamir_trick(a, P, Scalar::zero(), Q);
auto e3 = P.scalar_mul(a);
CHECK(r3.x().to_bytes() == e3.x().to_bytes(),
"shamir_trick(a, P, 0, Q) == a*P");
}
// -- Multi-Scalar Multiplication ----------------------------------------------
static void test_multi_scalar_mul() {
printf("\n--- Multi-Scalar Multiplication ---\n");
auto G = Point::generator();
// Test with 1 point
{
Scalar const s = Scalar::from_uint64(42);
Point const p = G;
auto result = multi_scalar_mul(&s, &p, 1);
auto expected = G.scalar_mul(Scalar::from_uint64(42));
CHECK(result.x().to_bytes() == expected.x().to_bytes(),
"multi_scalar_mul: 1 point");
}
// Test with 3 points: 2*G + 3*2G + 5*3G = 2G + 6G + 15G = 23G
{
std::vector<Scalar> const scalars = {
Scalar::from_uint64(2),
Scalar::from_uint64(3),
Scalar::from_uint64(5)
};
std::vector<Point> const points = {
G,
G.scalar_mul(Scalar::from_uint64(2)),
G.scalar_mul(Scalar::from_uint64(3))
};
auto result = multi_scalar_mul(scalars, points);
auto expected = G.scalar_mul(Scalar::from_uint64(23));
CHECK(result.x().to_bytes() == expected.x().to_bytes(),
"multi_scalar_mul: 3 points (2G+6G+15G=23G)");
}
// Test with 0 points
{
auto result = multi_scalar_mul(nullptr, nullptr, 0);
CHECK(result.is_infinity(), "multi_scalar_mul: 0 points = infinity");
}
// Test: sum cancels to infinity (P + (-P) = O)
{
Scalar const s1 = Scalar::from_uint64(1);
Scalar const s2 = Scalar::from_uint64(1);
Point const P1 = G;
Point const P2 = G.negate();
std::vector<Scalar> const scalars = {s1, s2};
std::vector<Point> const pts = {P1, P2};
auto result = multi_scalar_mul(scalars, pts);
CHECK(result.is_infinity(), "multi_scalar_mul: G + (-G) = infinity");
}
}
// -- Schnorr Batch Verification -----------------------------------------------
static void test_schnorr_batch_verify() {
printf("\n--- Schnorr Batch Verification ---\n");
// Create 5 valid Schnorr signatures
constexpr std::size_t N = 5;
std::vector<SchnorrBatchEntry> entries(N);
std::vector<Scalar> keys(N);
for (std::size_t i = 0; i < N; ++i) {
keys[i] = Scalar::from_uint64(100 + i);
entries[i].pubkey_x = schnorr_pubkey(keys[i]);
// Message: SHA256(i)
uint8_t ibuf[4] = {
static_cast<uint8_t>(i), static_cast<uint8_t>(i >> 8),
static_cast<uint8_t>(i >> 16), static_cast<uint8_t>(i >> 24)
};
auto msg_h = SHA256::hash(ibuf, 4);
entries[i].message = msg_h;
// Sign
std::array<uint8_t, 32> const aux{};
entries[i].signature = schnorr_sign(keys[i], msg_h, aux);
}
// All valid
bool const all_valid = schnorr_batch_verify(entries);
CHECK(all_valid, "Schnorr batch: 5 valid signatures pass");
// Individual verify should also pass
bool individual_ok = true;
for (std::size_t i = 0; i < N; ++i) {
if (!schnorr_verify(entries[i].pubkey_x, entries[i].message,
entries[i].signature)) {
individual_ok = false;
}
}
CHECK(individual_ok, "Schnorr batch: individual verification agrees");
// Corrupt one signature -- batch should fail
auto corrupted = entries;
corrupted[2].signature.s = corrupted[2].signature.s + Scalar::one();
bool const should_fail = schnorr_batch_verify(corrupted);
CHECK(!should_fail, "Schnorr batch: corrupted sig #2 detected");
// Identify invalid
auto invalid = schnorr_batch_identify_invalid(corrupted.data(), corrupted.size());
CHECK(invalid.size() == 1 && invalid[0] == 2,
"Schnorr batch identify: correctly finds sig #2");
// Empty batch
CHECK(schnorr_batch_verify(static_cast<const SchnorrBatchEntry*>(nullptr), 0),
"Schnorr batch: empty = true");
// Single entry
std::vector<SchnorrBatchEntry> const single = {entries[0]};
CHECK(schnorr_batch_verify(single), "Schnorr batch: single entry pass");
std::vector<SchnorrXonlyPubkey> cached_pubkeys(N);
std::vector<SchnorrBatchCachedEntry> cached_entries(N);
bool cache_parse_ok = true;
for (std::size_t i = 0; i < N; ++i) {
if (!schnorr_xonly_pubkey_parse(cached_pubkeys[i], entries[i].pubkey_x)) {
cache_parse_ok = false;
break;
}
cached_entries[i] = {&cached_pubkeys[i], entries[i].message,
entries[i].signature};
}
CHECK(cache_parse_ok, "Schnorr batch cached: parse x-only pubkeys");
CHECK(schnorr_batch_verify(cached_entries),
"Schnorr batch cached: valid signatures pass");
auto cached_corrupted = cached_entries;
cached_corrupted[2].signature.s = cached_corrupted[2].signature.s + Scalar::one();
CHECK(!schnorr_batch_verify(cached_corrupted),
"Schnorr batch cached: corrupted sig #2 detected");
auto cached_invalid = schnorr_batch_identify_invalid(
cached_corrupted.data(), cached_corrupted.size());
CHECK(cached_invalid.size() == 1 && cached_invalid[0] == 2,
"Schnorr batch cached identify: correctly finds sig #2");
auto cached_missing = cached_entries;
cached_missing[1].pubkey = nullptr;
CHECK(!schnorr_batch_verify(cached_missing),
"Schnorr batch cached: null pubkey rejected");
auto cached_missing_invalid = schnorr_batch_identify_invalid(
cached_missing.data(), cached_missing.size());
CHECK(cached_missing_invalid.size() == 1 && cached_missing_invalid[0] == 1,
"Schnorr batch cached identify: null pubkey reported invalid");
}
// -- ECDSA Batch Verification -------------------------------------------------
static void test_ecdsa_batch_verify() {
printf("\n--- ECDSA Batch Verification ---\n");
constexpr std::size_t N = 4;
std::vector<ECDSABatchEntry> entries(N);
std::vector<Scalar> keys(N);
auto G = Point::generator();
for (std::size_t i = 0; i < N; ++i) {
keys[i] = Scalar::from_uint64(200 + i);
entries[i].public_key = G.scalar_mul(keys[i]);
// Message hash
uint8_t ibuf[4] = {
static_cast<uint8_t>(i + 50), 0, 0, 0
};
entries[i].msg_hash = SHA256::hash(ibuf, 4);
// Sign
entries[i].signature = ecdsa_sign(entries[i].msg_hash, keys[i]);
}
// All valid
CHECK(ecdsa_batch_verify(entries), "ECDSA batch: 4 valid signatures pass");
// Corrupt one
auto corrupted = entries;
corrupted[1].signature.s = corrupted[1].signature.s + Scalar::one();
CHECK(!ecdsa_batch_verify(corrupted), "ECDSA batch: corrupted sig #1 detected");
// Identify
auto invalid = ecdsa_batch_identify_invalid(corrupted.data(), corrupted.size());
CHECK(invalid.size() == 1 && invalid[0] == 1,
"ECDSA batch identify: correctly finds sig #1");
}
// -- Main ---------------------------------------------------------------------
int test_multiscalar_batch_run() {
printf("=== Multi-Scalar Multiplication & Batch Verification Tests ===\n");
test_shamir_trick();
test_multi_scalar_mul();
test_schnorr_batch_verify();
test_ecdsa_batch_verify();
printf("\n=== Results: %d/%d passed ===\n", tests_passed, tests_run);
return (tests_passed == tests_run) ? 0 : 1;
}