From d1ab3c9120f8675711db301d26e4303ea4280ba8 Mon Sep 17 00:00:00 2001 From: Zenithal Date: Tue, 18 Feb 2025 15:46:19 +0000 Subject: [PATCH] BGV: support user specifying the scheme param --- lib/Parameters/BGV/BUILD | 1 + lib/Parameters/BGV/Params.cpp | 26 ++++++++ lib/Parameters/BGV/Params.h | 3 + .../ValidateNoise/ValidateNoise.cpp | 65 +++++++++++++++---- .../validate_noise_preserve_user_param.mlir | 9 +++ ...lidate_noise_preserve_user_param_fail.mlir | 19 ++++++ 6 files changed, 109 insertions(+), 14 deletions(-) create mode 100644 tests/Transforms/validate_noise/validate_noise_preserve_user_param.mlir create mode 100644 tests/Transforms/validate_noise/validate_noise_preserve_user_param_fail.mlir diff --git a/lib/Parameters/BGV/BUILD b/lib/Parameters/BGV/BUILD index 6d80bd1a2..7283895c0 100644 --- a/lib/Parameters/BGV/BUILD +++ b/lib/Parameters/BGV/BUILD @@ -8,6 +8,7 @@ cc_library( srcs = ["Params.cpp"], hdrs = ["Params.h"], deps = [ + "//lib/Dialect/BGV/IR:Dialect", "@heir//lib/Parameters:RLWEParams", "@llvm-project//llvm:Support", ], diff --git a/lib/Parameters/BGV/Params.cpp b/lib/Parameters/BGV/Params.cpp index b1ad23fa9..e49ed9b87 100644 --- a/lib/Parameters/BGV/Params.cpp +++ b/lib/Parameters/BGV/Params.cpp @@ -1,6 +1,7 @@ #include "lib/Parameters/BGV/Params.h" #include +#include #include #include #include @@ -25,6 +26,31 @@ SchemeParam SchemeParam::getConcreteSchemeParam(std::vector logqi, plaintextModulus); } +SchemeParam SchemeParam::getSchemeParamFromAttr(SchemeParamAttr attr) { + auto logN = attr.getLogN(); + auto ringDim = pow(2, logN); + auto plaintextModulus = attr.getPlaintextModulus(); + auto Q = attr.getQ(); + auto P = attr.getP(); + std::vector qiImpl; + std::vector piImpl; + std::vector logqi; + std::vector logpi; + for (auto qi : Q.asArrayRef()) { + qiImpl.push_back(qi); + logqi.push_back(log2(qi)); + } + for (auto pi : P.asArrayRef()) { + piImpl.push_back(pi); + logpi.push_back(log2(pi)); + } + auto level = logqi.size() - 1; + auto dnum = ceil(static_cast(qiImpl.size()) / piImpl.size()); + return SchemeParam( + RLWESchemeParam(ringDim, level, logqi, qiImpl, dnum, logpi, piImpl), + plaintextModulus); +} + void SchemeParam::print(llvm::raw_ostream &os) const { os << "plaintextModulus: " << plaintextModulus << "\n"; RLWESchemeParam::print(os); diff --git a/lib/Parameters/BGV/Params.h b/lib/Parameters/BGV/Params.h index 67a62d0a2..8783e33bc 100644 --- a/lib/Parameters/BGV/Params.h +++ b/lib/Parameters/BGV/Params.h @@ -4,6 +4,7 @@ #include #include +#include "lib/Dialect/BGV/IR/BGVAttributes.h" #include "lib/Parameters/RLWEParams.h" #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project @@ -30,6 +31,8 @@ class SchemeParam : public RLWESchemeParam { static SchemeParam getConcreteSchemeParam(std::vector logqi, int64_t plaintextModulus); + + static SchemeParam getSchemeParamFromAttr(SchemeParamAttr attr); }; // Parameter for each BGV ciphertext SSA value. diff --git a/lib/Transforms/ValidateNoise/ValidateNoise.cpp b/lib/Transforms/ValidateNoise/ValidateNoise.cpp index 084030d96..8f7f631ef 100644 --- a/lib/Transforms/ValidateNoise/ValidateNoise.cpp +++ b/lib/Transforms/ValidateNoise/ValidateNoise.cpp @@ -83,6 +83,19 @@ struct ValidateNoise : impl::ValidateNoiseBase { llvm::dbgs() << "Noise Bound: " << boundString << " Budget: " << budgetString << " Total: " << totalString << " for value: " << value << " " << "\n"; + // annotate the bound when debugging + auto boundStringAttr = StringAttr::get(&getContext(), boundString); + if (auto blockArg = mlir::dyn_cast(value)) { + auto *parentOp = blockArg.getOwner()->getParentOp(); + auto genericOp = dyn_cast(parentOp); + if (genericOp) { + genericOp.setArgAttr(blockArg.getArgNumber(), "noise.bound", + boundStringAttr); + } + } else { + auto *parentOp = value.getDefiningOp(); + parentOp->setAttr("noise.bound", boundStringAttr); + } }); if (budget < 0) { @@ -96,6 +109,18 @@ struct ValidateNoise : impl::ValidateNoiseBase { LogicalResult validate( DataFlowSolver *solver, const typename NoiseAnalysis::SchemeParamType &schemeParam) { + solver->load(); + solver->load(); + // NoiseAnalysis depends on SecretnessAnalysis + solver->load(); + + solver->load(schemeParam); + + if (failed(solver->initializeAndRun(getOperation()))) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + } + auto res = getOperation()->walk([&](secret::GenericOp genericOp) { // check arguments for (Value arg : genericOp.getBody()->getArguments()) { @@ -184,7 +209,9 @@ struct ValidateNoise : impl::ValidateNoiseBase { qiSize[0] = firstModSize; for (auto &[level, gap] : levelToGap) { - qiSize[level] = int(ceil(gap)); + // the prime size should be larger than the gap to ensure after mod reduce + // the noise is still within the bound + qiSize[level] = 1 + int(ceil(gap)); } LLVM_DEBUG({ @@ -208,13 +235,28 @@ struct ValidateNoise : impl::ValidateNoiseBase { template void run() { DataFlowSolver solver; - solver.load(); - solver.load(); - // NoiseAnalysis depends on SecretnessAnalysis - solver.load(); int maxLevel = getMaxLevel(); + // if bgv.schemeParam is already set, use it + if (auto schemeParamAttr = + getOperation()->getAttrOfType( + bgv::BGVDialect::kSchemeParamAttrName)) { + auto schemeParam = NoiseAnalysis::SchemeParamType::getSchemeParamFromAttr( + schemeParamAttr); + if (schemeParam.getLevel() < maxLevel) { + getOperation()->emitOpError() + << "The level in the scheme param is smaller than the max level.\n"; + signalPassFailure(); + return; + } + if (failed(validate(&solver, schemeParam))) { + getOperation()->emitOpError() << "Noise validation failed.\n"; + signalPassFailure(); + } + return; + } + // plaintext modulus from command line option auto schemeParam = NoiseAnalysis::SchemeParamType::getConservativeSchemeParam( @@ -223,24 +265,19 @@ struct ValidateNoise : impl::ValidateNoiseBase { LLVM_DEBUG(llvm::dbgs() << "Conservative Scheme Param:\n" << schemeParam << "\n"); - solver.load(schemeParam); - - if (failed(solver.initializeAndRun(getOperation()))) { - getOperation()->emitOpError() << "Failed to run the analysis.\n"; - signalPassFailure(); - return; - } - if (failed(validate(&solver, schemeParam))) { getOperation()->emitOpError() << "Noise validation failed.\n"; signalPassFailure(); return; } + // use previous analysis result to generate concrete scheme param auto concreteSchemeParam = generateParamByGap(&solver, schemeParam); - if (failed(validate(&solver, concreteSchemeParam))) { + // new solver as the NoiseAnalysis need to load a new schemeParam + DataFlowSolver solver2; + if (failed(validate(&solver2, concreteSchemeParam))) { getOperation()->emitOpError() << "Noise validation failed for generated param.\n"; signalPassFailure(); diff --git a/tests/Transforms/validate_noise/validate_noise_preserve_user_param.mlir b/tests/Transforms/validate_noise/validate_noise_preserve_user_param.mlir new file mode 100644 index 000000000..bdcdb8615 --- /dev/null +++ b/tests/Transforms/validate_noise/validate_noise_preserve_user_param.mlir @@ -0,0 +1,9 @@ +// RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-bgv --validate-noise=model=bgv-noise-by-bound-coeff-average-case-pk %s | FileCheck %s + +// CHECK: module attributes {bgv.schemeParam = #bgv.scheme_param, scheme.bgv} { +module attributes {bgv.schemeParam = #bgv.scheme_param, scheme.bgv} { + // CHECK-LABEL: @return + func.func @return(%arg0: i16 {secret.secret}) -> i16 { + return %arg0 : i16 + } +} diff --git a/tests/Transforms/validate_noise/validate_noise_preserve_user_param_fail.mlir b/tests/Transforms/validate_noise/validate_noise_preserve_user_param_fail.mlir new file mode 100644 index 000000000..762057d2c --- /dev/null +++ b/tests/Transforms/validate_noise/validate_noise_preserve_user_param_fail.mlir @@ -0,0 +1,19 @@ +// RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-bgv --validate-noise=model=bgv-noise-by-bound-coeff-average-case-pk %s --verify-diagnostics --split-input-file + +// expected-error@below {{'builtin.module' op The level in the scheme param is smaller than the max level.}} +module attributes {bgv.schemeParam = #bgv.scheme_param} { + func.func @return(%arg0: i16 {secret.secret}) -> i16 { + %1 = arith.muli %arg0, %arg0 : i16 + return %1 : i16 + } +} + +// ----- + +// expected-error@below {{'builtin.module' op Noise validation failed.}} +module attributes {bgv.schemeParam = #bgv.scheme_param} { + func.func @return(%arg0: i16 {secret.secret}) -> i16 { + %1 = arith.muli %arg0, %arg0 : i16 + return %1 : i16 + } +}