Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,27 @@ message(STATUS "Build type (CMAKE_BUILD_TYPE): ${CMAKE_BUILD_TYPE}")

project(SEAL VERSION 4.1.2 LANGUAGES CXX C)

# Define RISCV as an option
option(RISCVRVV "Enable RISC-V RVV extension support" OFF)

# Check if RISCV system
if(CMAKE_SYSTEM_PROCESSOR MATCHES "riscv")
message(STATUS "Detected RISC-V architecture. Adding custom flags...")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
# If RISCV is enabled (RISCVRVV=ON)
if(RISCVRVV)
message(STATUS "RISC-V RVV flag detected. Adding rv64gcv flags.")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=rv64gcv -mabi=lp64d")
else()
# If not enabled
message(STATUS "RISC-V RVV flag not detected. Adding rv64gc flags.")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=rv64gc -mabi=lp64d")
endif()
endif()

# Print all flags
message(STATUS "Final CMake CXX_FLAGS: ${CMAKE_CXX_FLAGS}")

########################
# Global configuration #
########################
Expand Down
273 changes: 241 additions & 32 deletions native/src/seal/util/dwthandler.h

Large diffs are not rendered by default.

133 changes: 126 additions & 7 deletions native/src/seal/util/ntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "seal/util/uintarith.h"
#include "seal/util/uintarithsmallmod.h"
#include <algorithm>
#include <vector>
#ifdef SEAL_USE_INTEL_HEXL
#include "seal/memorymanager.h"
#include "seal/util/iterator.h"
Expand All @@ -13,6 +14,9 @@
#include <unordered_map>
#include "hexl/hexl.hpp"
#endif
#ifdef __riscv_vector
#include <riscv_vector.h>
#endif

using namespace std;

Expand Down Expand Up @@ -226,6 +230,41 @@ namespace seal
{
namespace util
{

#if defined(__riscv_v_intrinsic)
vuint64m4_t parallel_128bit_div_4_rvv(vuint64m4_t num_hi, vuint64m4_t num_lo, vuint64m4_t den, size_t vl) {
vuint64m4_t v_quo = __riscv_vmv_v_x_u64m4(0, vl);
vuint64m4_t v_rem = __riscv_vmv_v_x_u64m4(0, vl);

// Process upper 64 bits from num_hi
for (int j = 0; j < 64; j++) {
v_rem = __riscv_vsll_vx_u64m4(v_rem, 1, vl);
vuint64m4_t next_bit = __riscv_vsrl_vx_u64m4(num_hi, 63, vl);
num_hi = __riscv_vsll_vx_u64m4(num_hi, 1, vl);
v_rem = __riscv_vor_vv_u64m4(v_rem, next_bit, vl);
v_quo = __riscv_vsll_vx_u64m4(v_quo, 1, vl);

vbool16_t mask = __riscv_vmsgeu_vv_u64m4_b16(v_rem, den, vl);
v_rem = __riscv_vsub_vv_u64m4_mu(mask, v_rem, v_rem, den, vl);
v_quo = __riscv_vor_vx_u64m4_mu(mask, v_quo, v_quo, 1, vl);
}
// Process lower 64 bits from num_lo

for (int j = 0; j < 64; j++) {
v_rem = __riscv_vsll_vx_u64m4(v_rem, 1, vl);
vuint64m4_t next_bit = __riscv_vsrl_vx_u64m4(num_lo, 63, vl);
num_lo = __riscv_vsll_vx_u64m4(num_lo, 1, vl);
v_rem = __riscv_vor_vv_u64m4(v_rem, next_bit, vl);
v_quo = __riscv_vsll_vx_u64m4(v_quo, 1, vl);

vbool16_t mask = __riscv_vmsgeu_vv_u64m4_b16(v_rem, den, vl);
v_rem = __riscv_vsub_vv_u64m4_mu(mask, v_rem, v_rem, den, vl);
v_quo = __riscv_vor_vx_u64m4_mu(mask, v_quo, v_quo, 1, vl);
}
return v_quo;
}
#endif

NTTTables::NTTTables(int coeff_count_power, const Modulus &modulus, MemoryPoolHandle pool)
: pool_(std::move(pool))
{
Expand Down Expand Up @@ -264,29 +303,103 @@ namespace seal
// Pre-compute HEXL NTT object
intel::seal_ext::get_ntt(coeff_count_, modulus.value(), root_);
#endif

// Populate tables with powers of root in specific orders.
root_powers_ = allocate<MultiplyUIntModOperand>(coeff_count_, pool_);
MultiplyUIntModOperand root;
root.set(root_, modulus_);
uint64_t power = root_;

#if defined(__riscv_v_intrinsic)

// Unified function with buffer reuse - single optimization
auto compute_powers_vectorized = [&](uint64_t initial_power, MultiplyUIntModOperand* target_array, bool is_inverse) -> void {

// Thread-local buffers - reused across calls to avoid repeated allocation
static thread_local std::vector<uint64_t> num_buffer;
static thread_local std::vector<uint64_t> quot_buffer;

// Resize buffers only if needed
if (num_buffer.size() < coeff_count_) {
num_buffer.resize(coeff_count_);
quot_buffer.resize(coeff_count_);
}

// Generate powers
num_buffer[0] = initial_power;
for (size_t i = 1; i < coeff_count_; i++) {
num_buffer[i] = multiply_uint_mod(num_buffer[i-1], root, modulus_);
}

// Vectorized division
uint64_t denom = modulus_.value();
size_t processed = 0;

size_t vl = __riscv_vsetvl_e64m4(coeff_count_-1 - processed);
vuint64m4_t den_vec = __riscv_vmv_v_x_u64m4(denom, vl);
vuint64m4_t num_lo = __riscv_vmv_v_x_u64m4(0, vl); // low 64 bits assumed zero

while (processed < coeff_count_-1) {
vl = __riscv_vsetvl_e64m4(coeff_count_-1 - processed);
vuint64m4_t num_hi = __riscv_vle64_v_u64m4(num_buffer.data() + processed, vl);
vuint64m4_t quo_vec = parallel_128bit_div_4_rvv(num_hi, num_lo, den_vec, vl);
__riscv_vse64_v_u64m4(quot_buffer.data() + processed, quo_vec, vl);
processed += vl;
}

// Store results
if (is_inverse) {
for(size_t i = 1; i < coeff_count_; i++){
size_t rev = reverse_bits(i-1, coeff_count_power_) + 1;
target_array[rev].operand = num_buffer[i - 1];
target_array[rev].quotient = quot_buffer[i - 1];
}
} else {
for(size_t i = 1; i < coeff_count_; i++){
size_t rev = reverse_bits(i, coeff_count_power_);
target_array[rev].operand = num_buffer[i - 1];
target_array[rev].quotient = quot_buffer[i - 1];
}
}
};

// Compute root powers using unified function
compute_powers_vectorized(power, root_powers_.get(), false);

#else

// Original scalar fallback
for (size_t i = 1; i < coeff_count_; i++)
{
root_powers_[reverse_bits(i, coeff_count_power_)].set(power, modulus_);
power = multiply_uint_mod(power, root, modulus_);
}

#endif

root_powers_[0].set(static_cast<uint64_t>(1), modulus_);


// Inverse root powers
inv_root_powers_ = allocate<MultiplyUIntModOperand>(coeff_count_, pool_);
root.set(inv_root_, modulus_);
power = inv_root_;

#if defined(__riscv_v_intrinsic)
// Reuse the same function and buffers for inverse powers
compute_powers_vectorized(power, inv_root_powers_.get(), true);
#else

// Original scalar fallback for inverse
for (size_t i = 1; i < coeff_count_; i++)
{
inv_root_powers_[reverse_bits(i - 1, coeff_count_power_) + 1].set(power, modulus_);
power = multiply_uint_mod(power, root, modulus_);
}

#endif

inv_root_powers_[0].set(static_cast<uint64_t>(1), modulus_);

// Compute n^(-1) modulo q.
uint64_t degree_uint = static_cast<uint64_t>(coeff_count_);
if (!try_invert_uint_mod(degree_uint, modulus_, inv_degree_modulo_.operand))
Expand Down Expand Up @@ -400,8 +513,11 @@ namespace seal

intel::seal_ext::compute_forward_ntt(operand, N, p, root, 4, 4);
#else
tables.ntt_handler().transform_to_rev(
operand.ptr(), tables.coeff_count_power(), tables.get_from_root_powers());
#if defined(__riscv_v_intrinsic)
tables.ntt_handler().transform_to_rev_rvv(operand.ptr(), tables.coeff_count_power(), tables.get_from_root_powers());
#else
tables.ntt_handler().transform_to_rev(operand.ptr(), tables.coeff_count_power(), tables.get_from_root_powers());
#endif
#endif
}

Expand Down Expand Up @@ -445,8 +561,11 @@ namespace seal
intel::seal_ext::compute_inverse_ntt(operand, N, p, root, 2, 2);
#else
MultiplyUIntModOperand inv_degree_modulo = tables.inv_degree_modulo();
tables.ntt_handler().transform_from_rev(
operand.ptr(), tables.coeff_count_power(), tables.get_from_inv_root_powers(), &inv_degree_modulo);
#if defined(__riscv_v_intrinsic)
tables.ntt_handler().transform_from_rev_rvv(operand.ptr(), tables.coeff_count_power(), tables.get_from_inv_root_powers(), &inv_degree_modulo);
#else
tables.ntt_handler().transform_from_rev(operand.ptr(), tables.coeff_count_power(), tables.get_from_inv_root_powers(), &inv_degree_modulo);
#endif
#endif
}

Expand Down
44 changes: 44 additions & 0 deletions native/src/seal/util/ntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
#include "seal/util/uintcore.h"
#include <stdexcept>

#ifdef __riscv_vector
#include <riscv_vector.h>
#endif

namespace seal
{
namespace util
Expand Down Expand Up @@ -60,6 +64,46 @@ namespace seal
return SEAL_COND_SELECT(a >= two_times_modulus_, a - two_times_modulus_, a);
}

#if defined(__riscv_v_intrinsic)

inline vuint64m4_t guard_vector_rvv(const vuint64m4_t in, size_t vl) const {
vuint64m4_t modulus_vec = __riscv_vmv_v_x_u64m4(two_times_modulus_, vl);
vbool16_t mask = __riscv_vmsgeu_vx_u64m4_b16(in, two_times_modulus_, vl);
return __riscv_vsub_vv_u64m4_mu(mask, in, in, modulus_vec, vl);
}

inline vuint64m4_t add_vector_rvv(const vuint64m4_t a, const vuint64m4_t b, size_t vl) const {
return __riscv_vadd_vv_u64m4(a, b, vl);
}

inline vuint64m4_t sub_vector_rvv(const vuint64m4_t a, const vuint64m4_t b, size_t vl) const {
// Broadcast vector with all elements = two_times_modulus_
vuint64m4_t modulus_vec = __riscv_vmv_v_x_u64m4(two_times_modulus_, vl);
// result = a + (2 * modulus) - b
vuint64m4_t tmp = __riscv_vadd_vv_u64m4(a, modulus_vec, vl);
vuint64m4_t result = __riscv_vsub_vv_u64m4(tmp, b, vl);
return result;
}

inline vuint64m4_t mul_vector_rvv(const vuint64m4_t a, const uint64_t yquot, const uint64_t yop, size_t vl) const {
const uint64_t p = modulus_.value(); // Assuming modulus_ is in scope

// Replicate scalars across vector registers
vuint64m4_t vb = __riscv_vmv_v_x_u64m4(yquot, vl);
vuint64m4_t vp = __riscv_vmv_v_x_u64m4(p, vl);
vuint64m4_t vop = __riscv_vmv_v_x_u64m4(yop, vl);
// Unsigned high part of a * yquot
vuint64m4_t vhi = __riscv_vmulhu_vv_u64m4(a, vb, vl);
// a * yop
vuint64m4_t vmul1 = __riscv_vmul_vv_u64m4(a, vop, vl);
// vhi * p
vuint64m4_t vmul2 = __riscv_vmul_vv_u64m4(vhi, vp, vl);
// (a * yop) - (vhi * p)
vuint64m4_t vres = __riscv_vsub_vv_u64m4(vmul1, vmul2, vl);
return vres;
}
#endif

private:
Modulus modulus_;

Expand Down
Loading