Skip to content

Commit

Permalink
BGV: support user specifying the scheme param
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Feb 26, 2025
1 parent fa56d41 commit d1ab3c9
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 14 deletions.
1 change: 1 addition & 0 deletions lib/Parameters/BGV/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ cc_library(
srcs = ["Params.cpp"],
hdrs = ["Params.h"],
deps = [
"//lib/Dialect/BGV/IR:Dialect",
"@heir//lib/Parameters:RLWEParams",
"@llvm-project//llvm:Support",
],
Expand Down
26 changes: 26 additions & 0 deletions lib/Parameters/BGV/Params.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "lib/Parameters/BGV/Params.h"

#include <cassert>
#include <cmath>
#include <cstdint>
#include <utility>
#include <vector>
Expand All @@ -25,6 +26,31 @@ SchemeParam SchemeParam::getConcreteSchemeParam(std::vector<double> logqi,
plaintextModulus);
}

SchemeParam SchemeParam::getSchemeParamFromAttr(SchemeParamAttr attr) {
auto logN = attr.getLogN();
auto ringDim = pow(2, logN);
auto plaintextModulus = attr.getPlaintextModulus();
auto Q = attr.getQ();
auto P = attr.getP();
std::vector<int64_t> qiImpl;
std::vector<int64_t> piImpl;
std::vector<double> logqi;
std::vector<double> logpi;
for (auto qi : Q.asArrayRef()) {
qiImpl.push_back(qi);
logqi.push_back(log2(qi));
}
for (auto pi : P.asArrayRef()) {
piImpl.push_back(pi);
logpi.push_back(log2(pi));
}
auto level = logqi.size() - 1;
auto dnum = ceil(static_cast<double>(qiImpl.size()) / piImpl.size());
return SchemeParam(
RLWESchemeParam(ringDim, level, logqi, qiImpl, dnum, logpi, piImpl),
plaintextModulus);
}

void SchemeParam::print(llvm::raw_ostream &os) const {
os << "plaintextModulus: " << plaintextModulus << "\n";
RLWESchemeParam::print(os);
Expand Down
3 changes: 3 additions & 0 deletions lib/Parameters/BGV/Params.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cstdint>
#include <vector>

#include "lib/Dialect/BGV/IR/BGVAttributes.h"
#include "lib/Parameters/RLWEParams.h"
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project

Expand All @@ -30,6 +31,8 @@ class SchemeParam : public RLWESchemeParam {

static SchemeParam getConcreteSchemeParam(std::vector<double> logqi,
int64_t plaintextModulus);

static SchemeParam getSchemeParamFromAttr(SchemeParamAttr attr);
};

// Parameter for each BGV ciphertext SSA value.
Expand Down
65 changes: 51 additions & 14 deletions lib/Transforms/ValidateNoise/ValidateNoise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ struct ValidateNoise : impl::ValidateNoiseBase<ValidateNoise> {
llvm::dbgs() << "Noise Bound: " << boundString
<< " Budget: " << budgetString << " Total: " << totalString
<< " for value: " << value << " " << "\n";
// annotate the bound when debugging
auto boundStringAttr = StringAttr::get(&getContext(), boundString);
if (auto blockArg = mlir::dyn_cast<BlockArgument>(value)) {
auto *parentOp = blockArg.getOwner()->getParentOp();
auto genericOp = dyn_cast<secret::GenericOp>(parentOp);
if (genericOp) {
genericOp.setArgAttr(blockArg.getArgNumber(), "noise.bound",
boundStringAttr);
}
} else {
auto *parentOp = value.getDefiningOp();
parentOp->setAttr("noise.bound", boundStringAttr);
}
});

if (budget < 0) {
Expand All @@ -96,6 +109,18 @@ struct ValidateNoise : impl::ValidateNoiseBase<ValidateNoise> {
LogicalResult validate(
DataFlowSolver *solver,
const typename NoiseAnalysis::SchemeParamType &schemeParam) {
solver->load<dataflow::DeadCodeAnalysis>();
solver->load<dataflow::SparseConstantPropagation>();
// NoiseAnalysis depends on SecretnessAnalysis
solver->load<SecretnessAnalysis>();

solver->load<NoiseAnalysis>(schemeParam);

if (failed(solver->initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
}

auto res = getOperation()->walk([&](secret::GenericOp genericOp) {
// check arguments
for (Value arg : genericOp.getBody()->getArguments()) {
Expand Down Expand Up @@ -184,7 +209,9 @@ struct ValidateNoise : impl::ValidateNoiseBase<ValidateNoise> {
qiSize[0] = firstModSize;

for (auto &[level, gap] : levelToGap) {
qiSize[level] = int(ceil(gap));
// the prime size should be larger than the gap to ensure after mod reduce
// the noise is still within the bound
qiSize[level] = 1 + int(ceil(gap));
}

LLVM_DEBUG({
Expand All @@ -208,13 +235,28 @@ struct ValidateNoise : impl::ValidateNoiseBase<ValidateNoise> {
template <typename NoiseAnalysis>
void run() {
DataFlowSolver solver;
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
// NoiseAnalysis depends on SecretnessAnalysis
solver.load<SecretnessAnalysis>();

int maxLevel = getMaxLevel();

// if bgv.schemeParam is already set, use it
if (auto schemeParamAttr =
getOperation()->getAttrOfType<bgv::SchemeParamAttr>(
bgv::BGVDialect::kSchemeParamAttrName)) {
auto schemeParam = NoiseAnalysis::SchemeParamType::getSchemeParamFromAttr(
schemeParamAttr);
if (schemeParam.getLevel() < maxLevel) {
getOperation()->emitOpError()
<< "The level in the scheme param is smaller than the max level.\n";
signalPassFailure();
return;
}
if (failed(validate<NoiseAnalysis>(&solver, schemeParam))) {
getOperation()->emitOpError() << "Noise validation failed.\n";
signalPassFailure();
}
return;
}

// plaintext modulus from command line option
auto schemeParam =
NoiseAnalysis::SchemeParamType::getConservativeSchemeParam(
Expand All @@ -223,24 +265,19 @@ struct ValidateNoise : impl::ValidateNoiseBase<ValidateNoise> {
LLVM_DEBUG(llvm::dbgs() << "Conservative Scheme Param:\n"
<< schemeParam << "\n");

solver.load<NoiseAnalysis>(schemeParam);

if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
return;
}

if (failed(validate<NoiseAnalysis>(&solver, schemeParam))) {
getOperation()->emitOpError() << "Noise validation failed.\n";
signalPassFailure();
return;
}

// use previous analysis result to generate concrete scheme param
auto concreteSchemeParam =
generateParamByGap<NoiseAnalysis>(&solver, schemeParam);

if (failed(validate<NoiseAnalysis>(&solver, concreteSchemeParam))) {
// new solver as the NoiseAnalysis need to load a new schemeParam
DataFlowSolver solver2;
if (failed(validate<NoiseAnalysis>(&solver2, concreteSchemeParam))) {
getOperation()->emitOpError()
<< "Noise validation failed for generated param.\n";
signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-bgv --validate-noise=model=bgv-noise-by-bound-coeff-average-case-pk %s | FileCheck %s

// CHECK: module attributes {bgv.schemeParam = #bgv.scheme_param<logN = 14, Q = [2148728833, 2148794369, 1152921504607338497], P = [1152921504608747521, 1152921504609239041], plaintextModulus = 65537>, scheme.bgv} {
module attributes {bgv.schemeParam = #bgv.scheme_param<logN = 14, Q = [2148728833, 2148794369, 1152921504607338497], P = [1152921504608747521, 1152921504609239041], plaintextModulus = 65537>, scheme.bgv} {
// CHECK-LABEL: @return
func.func @return(%arg0: i16 {secret.secret}) -> i16 {
return %arg0 : i16
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-bgv --validate-noise=model=bgv-noise-by-bound-coeff-average-case-pk %s --verify-diagnostics --split-input-file

// expected-error@below {{'builtin.module' op The level in the scheme param is smaller than the max level.}}
module attributes {bgv.schemeParam = #bgv.scheme_param<logN = 11, Q = [17], P = [1093633], plaintextModulus = 65537>} {
func.func @return(%arg0: i16 {secret.secret}) -> i16 {
%1 = arith.muli %arg0, %arg0 : i16
return %1 : i16
}
}

// -----

// expected-error@below {{'builtin.module' op Noise validation failed.}}
module attributes {bgv.schemeParam = #bgv.scheme_param<logN = 11, Q = [17, 23], P = [1093633], plaintextModulus = 65537>} {
func.func @return(%arg0: i16 {secret.secret}) -> i16 {
%1 = arith.muli %arg0, %arg0 : i16
return %1 : i16
}
}

0 comments on commit d1ab3c9

Please sign in to comment.