Skip to content

Commit

Permalink
Merge pull request #1407 from ZenithalHourlyRate:lattigo-bgv-inplace
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731389580
  • Loading branch information
copybara-github committed Feb 26, 2025
2 parents 17bc538 + 651256b commit 6b8eac7
Show file tree
Hide file tree
Showing 30 changed files with 600 additions and 68 deletions.
24 changes: 12 additions & 12 deletions lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,33 +415,33 @@ struct ConvertLWEReinterpretUnderlyingType
} // namespace

// BGV
using ConvertBGVAddOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RAddOp, lattigo::BGVAddOp>;
using ConvertBGVSubOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RSubOp, lattigo::BGVSubOp>;
using ConvertBGVMulOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RMulOp, lattigo::BGVMulOp>;
using ConvertBGVAddOp = ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RAddOp,
lattigo::BGVAddNewOp>;
using ConvertBGVSubOp = ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RSubOp,
lattigo::BGVSubNewOp>;
using ConvertBGVMulOp = ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RMulOp,
lattigo::BGVMulNewOp>;
using ConvertBGVAddPlainOp =
ConvertRlwePlainOp<lattigo::BGVEvaluatorType, bgv::AddPlainOp,
lattigo::BGVAddOp>;
lattigo::BGVAddNewOp>;
using ConvertBGVSubPlainOp =
ConvertRlwePlainOp<lattigo::BGVEvaluatorType, bgv::SubPlainOp,
lattigo::BGVSubOp>;
lattigo::BGVSubNewOp>;
using ConvertBGVMulPlainOp =
ConvertRlwePlainOp<lattigo::BGVEvaluatorType, bgv::MulPlainOp,
lattigo::BGVMulOp>;
lattigo::BGVMulNewOp>;

using ConvertBGVRelinOp =
ConvertRlweUnaryOp<lattigo::BGVEvaluatorType, bgv::RelinearizeOp,
lattigo::BGVRelinearizeOp>;
lattigo::BGVRelinearizeNewOp>;
using ConvertBGVModulusSwitchOp =
ConvertRlweUnaryOp<lattigo::BGVEvaluatorType, bgv::ModulusSwitchOp,
lattigo::BGVRescaleOp>;
lattigo::BGVRescaleNewOp>;

// TODO(#1186): figure out generic rotating using BGVRotateColumns/RowsOp
using ConvertBGVRotateOp =
ConvertRlweRotateOp<lattigo::BGVEvaluatorType, bgv::RotateOp,
lattigo::BGVRotateColumnsOp>;
lattigo::BGVRotateColumnsNewOp>;

using ConvertBGVEncryptOp =
ConvertRlweUnaryOp<lattigo::RLWEEncryptorType, lwe::RLWEEncryptOp,
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Lattigo/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ cc_library(
":LattigoAttributes",
":LattigoOps",
":LattigoTypes",
"@heir//lib/Utils/Tablegen:InplaceOpInterface",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
Expand Down Expand Up @@ -83,6 +84,7 @@ cc_library(
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
"@heir//lib/Utils/Tablegen:InplaceOpInterface",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
Expand All @@ -107,6 +109,7 @@ td_library(
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@heir//lib/Utils/Tablegen:td_files",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
],
Expand Down
131 changes: 123 additions & 8 deletions lib/Dialect/Lattigo/IR/LattigoBGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def Lattigo_BGVNewEncoderOp : Lattigo_BGVOp<"new_encoder"> {
let results = (outs Lattigo_BGVEncoder:$encoder);
}

def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode"> {
def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode", [InplaceOpInterface]> {
let summary = "Encode a plaintext value in the Lattigo BGV dialect";
let description = [{
This operation encodes a plaintext value using the specified encoder in the Lattigo BGV dialect.
Expand All @@ -55,6 +55,8 @@ def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode"> {
Lattigo_RLWEPlaintext:$plaintext
);
let results = (outs Lattigo_RLWEPlaintext:$encoded);

let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }";
}

def Lattigo_BGVDecodeOp : Lattigo_BGVOp<"decode", [AllTypesMatch<["value", "decoded"]>]> {
Expand All @@ -69,6 +71,8 @@ def Lattigo_BGVDecodeOp : Lattigo_BGVOp<"decode", [AllTypesMatch<["value", "deco
Lattigo_RLWEPlaintext:$plaintext,
RankedTensorOf<[AnyInteger]>:$value
);
// although bgv.Decode is also an inplace operation as bgv.Encode, as there are post-processing
// steps in emitter, we mark it as a normal operation.
let results = (outs RankedTensorOf<[AnyInteger]>:$decoded);
}

Expand Down Expand Up @@ -106,27 +110,72 @@ class Lattigo_BGVBinaryOp<string mnemonic> :
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_BGVAddOp : Lattigo_BGVBinaryOp<"add"> {
def Lattigo_BGVAddNewOp : Lattigo_BGVBinaryOp<"add_new"> {
let summary = "Add two ciphertexts in the Lattigo BGV dialect";
let description = [{
This operation adds two ciphertext values in the Lattigo BGV dialect.
}];
}

def Lattigo_BGVSubOp : Lattigo_BGVBinaryOp<"sub"> {
def Lattigo_BGVSubNewOp : Lattigo_BGVBinaryOp<"sub_new"> {
let summary = "Subtract two ciphertexts in the Lattigo BGV dialect";
let description = [{
This operation subtracts one ciphertext value from another in the Lattigo BGV dialect.
}];
}

def Lattigo_BGVMulOp : Lattigo_BGVBinaryOp<"mul"> {
def Lattigo_BGVMulNewOp : Lattigo_BGVBinaryOp<"mul_new"> {
let summary = "Multiply two ciphertexts in the Lattigo BGV dialect";
let description = [{
This operation multiplies two ciphertext values in the Lattigo BGV dialect.
}];
}

class Lattigo_BGVBinaryInplaceOp<string mnemonic> :
Lattigo_BGVOp<mnemonic, [InplaceOpInterface]> {
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$lhs,
Lattigo_RLWECiphertextOrPlaintext:$rhs,
// Lattigo API is like bgv.Add(lhs, rhs, out) but for MLIR we need to
// satisfy the SSA form, so we still have a separate output.
Lattigo_RLWECiphertext:$inplace
);
let results = (outs Lattigo_RLWECiphertext:$output);

let extraClassDeclaration = "int getInplaceOperandIndex() { return 3; }";
}

def Lattigo_BGVAddOp : Lattigo_BGVBinaryInplaceOp<"add"> {
let summary = "Add two ciphertexts in the Lattigo BGV dialect";
let description = [{
This operation adds two ciphertext values in the Lattigo BGV dialect.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
}

def Lattigo_BGVSubOp : Lattigo_BGVBinaryInplaceOp<"sub"> {
let summary = "Subtract two ciphertexts in the Lattigo BGV dialect";
let description = [{
This operation subtracts one ciphertext value from another in the Lattigo BGV dialect.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
}

def Lattigo_BGVMulOp : Lattigo_BGVBinaryInplaceOp<"mul"> {
let summary = "Multiply two ciphertexts in the Lattigo BGV dialect";
let description = [{
This operation multiplies two ciphertext values in the Lattigo BGV dialect.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
}

class Lattigo_BGVUnaryOp<string mnemonic> :
Lattigo_BGVOp<mnemonic> {
let arguments = (ins
Expand All @@ -136,43 +185,109 @@ class Lattigo_BGVUnaryOp<string mnemonic> :
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_BGVRelinearizeOp : Lattigo_BGVUnaryOp<"relinearize"> {
def Lattigo_BGVRelinearizeNewOp : Lattigo_BGVUnaryOp<"relinearize_new"> {
let summary = "Relinearize a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation relinearizes a ciphertext value in the Lattigo BGV dialect.
}];
}

def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new"> {
let summary = "Rescale a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation rescales a ciphertext value in the Lattigo BGV dialect.
}];
}

def Lattigo_BGVRotateColumnsNewOp : Lattigo_BGVOp<"rotate_columns_new"> {
let summary = "Rotate columns of a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation rotates the columns of a ciphertext value in the Lattigo BGV dialect.

Lattigo exposes the SIMD slot of BGV as a N/2 x 2 matrix, where N/2 is the column.

Offset is valid for both positive and negative number.
}];
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input,
Builtin_IntegerAttr:$offset
);
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_BGVRotateRowsNewOp : Lattigo_BGVUnaryOp<"rotate_rows_new"> {
let summary = "Rotate rows of a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation swap the rows of a ciphertext value in the Lattigo BGV dialect.

Lattigo exposes the SIMD slot of BGV as a N/2 x 2 matrix, where 2 is the row.
}];
}

class Lattigo_BGVUnaryInplaceOp<string mnemonic> :
Lattigo_BGVOp<mnemonic, [InplaceOpInterface]> {
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input,
// see BinaryInplaceOp above
Lattigo_RLWECiphertext:$inplace
);
let results = (outs Lattigo_RLWECiphertext:$output);

let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }";
}

def Lattigo_BGVRelinearizeOp : Lattigo_BGVUnaryInplaceOp<"relinearize"> {
let summary = "Relinearize a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation relinearizes a ciphertext value in the Lattigo BGV dialect.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
}

def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryOp<"rescale"> {
def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInplaceOp<"rescale"> {
let summary = "Rescale a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation rescales a ciphertext value in the Lattigo BGV dialect.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
}

def Lattigo_BGVRotateColumnsOp : Lattigo_BGVOp<"rotate_columns"> {
def Lattigo_BGVRotateColumnsOp : Lattigo_BGVUnaryInplaceOp<"rotate_columns"> {
let summary = "Rotate columns of a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation rotates the columns of a ciphertext value in the Lattigo BGV dialect.

Lattigo exposes the SIMD slot of BGV as a N/2 x 2 matrix, where N/2 is the column.

Offset is valid for both positive and negative number.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input,
Lattigo_RLWECiphertext:$inplace,
Builtin_IntegerAttr:$offset
);
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_BGVRotateRowsOp : Lattigo_BGVUnaryOp<"rotate_rows"> {
def Lattigo_BGVRotateRowsOp : Lattigo_BGVUnaryInplaceOp<"rotate_rows"> {
let summary = "Rotate rows of a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation swap the rows of a ciphertext value in the Lattigo BGV dialect.

Lattigo exposes the SIMD slot of BGV as a N/2 x 2 matrix, where 2 is the row.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
}

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Lattigo/IR/LattigoOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "lib/Dialect/Lattigo/IR/LattigoDialect.h"
#include "lib/Dialect/Lattigo/IR/LattigoTypes.h"
#include "lib/Utils/Tablegen/InplaceOpInterface.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project

#define GET_OP_CLASSES
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Lattigo/IR/LattigoOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
include "LattigoDialect.td"
include "LattigoTypes.td"
include "mlir/IR/OpBase.td"
include "lib/Utils/Tablegen/InplaceOpInterface.td"

class Lattigo_Op<string mnemonic, list<Trait> traits = []> :
Op<Lattigo_Dialect, mnemonic, traits> {
Expand Down
Loading

0 comments on commit 6b8eac7

Please sign in to comment.