Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finished shell validation #40

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions scripts/dafny_shell_validator.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash

var1=$1;
var2=$2;
var3=$3;

echo "$var1" > $var3;
~/Nagini-Convertion/Binaries/Dafny validate $var3;
echo "module Gen {"
echo -e "$var2"
echo "}"
cat $var3
2 changes: 1 addition & 1 deletion verified_cogen/experiments/incremental_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def main():
)
rewriter = construct_rewriter(extension_from_file_list([file]), args.manual_rewriters)
runner = make_runner_cls(args.bench_type, extension_from_file_list([file]), config)(
llm, logger, verifier, rewriter
llm, logger, verifier, rewriter, None
)
display_name = rename_file(file)
marker_name = str(file.relative_to(directory))
Expand Down
15 changes: 10 additions & 5 deletions verified_cogen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
NaginiRewriterFixingAST,
)
from verified_cogen.runners.step_by_step import StepByStepRunner
from verified_cogen.runners.validating import ValidatingRunner
from verified_cogen.runners.validating import ValidatingRunner, Validator
from verified_cogen.tools import (
ext_glob,
extension_from_file_list,
Expand All @@ -39,7 +39,7 @@
def run_once(
files: list[Path],
args: ProgramArgs,
runner_cls: Callable[[LLM, Logger, Verifier, Optional[Rewriter]], Runner],
runner_cls: Callable[[LLM, Logger, Verifier, Optional[Rewriter], Optional[Validator]], Runner],
verifier: Verifier,
mode: Mode,
rewriter: Optional[Rewriter],
Expand All @@ -58,7 +58,7 @@ def run_once(
args.temperature,
)

runner = runner_cls(llm, logger, verifier, rewriter)
runner = runner_cls(llm, logger, verifier, rewriter, None)

retries = args.retries + 1
tries = None
Expand Down Expand Up @@ -125,12 +125,13 @@ def construct_rewriter(extension: str, runner_types: List[str]) -> Optional[Rewr

def make_runner_cls(
bench_type: str, extension: str, config: RunnerConfig
) -> Callable[[LLM, Logger, Verifier, Optional[Rewriter]], Runner]:
) -> Callable[[LLM, Logger, Verifier, Optional[Rewriter], Optional[Validator]], Runner]:
def runner_cls(
llm: LLM,
logger: Logger,
verifier: Verifier,
rewriter: Optional[Rewriter] = None,
validator: Optional[Validator] = None,
):
if bench_type == "invariants":
return InvariantRunner(llm, logger, verifier, config)
Expand All @@ -142,6 +143,7 @@ def runner_cls(
return ValidatingRunner(
InvariantRunner(llm, logger, verifier, config, rewriter),
LanguageDatabase().get(extension),
validator,
)
elif bench_type == "step-by-step":
return ValidatingRunner(
Expand Down Expand Up @@ -206,6 +208,7 @@ def main():
logger,
verifier,
rewriter,
None,
)
for file in files:
with open(file) as f:
Expand Down Expand Up @@ -248,7 +251,9 @@ def main():
args.prompts_directory,
args.temperature,
)
runner = make_runner_cls(args.bench_type, Path(args.input).suffix[1:], config)(llm, logger, verifier, None)
runner = make_runner_cls(args.bench_type, Path(args.input).suffix[1:], config)(
llm, logger, verifier, None, None
)
tries = runner.run_on_file(mode, args.tries, args.input)
if tries == 0:
print("Verified without modification")
Expand Down
43 changes: 35 additions & 8 deletions verified_cogen/runners/validating.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pathlib
from abc import ABC, abstractmethod
from subprocess import CalledProcessError, run
from typing import List, Optional

from verified_cogen import get_cache_dir
from verified_cogen.llm import prompts
from verified_cogen.llm.llm import LLM
from verified_cogen.runners import Runner
Expand All @@ -10,18 +12,18 @@


class Validator(ABC):
@abstractmethod
def add_validators(self, prg: str, inv_prg: str) -> str: ...


class LanguageValidator(Validator):
language: Language
remove_helpers: bool

def __init__(self, language: Language, remove_helpers: bool):
self.language = language
self.remove_helpers = remove_helpers

@abstractmethod
def add_validators(self, prg: str, inv_prg: str) -> str: ...


class LanguageValidator(Validator):
def add_validators(self, prg: str, inv_prg: str) -> str:
validators = self.language.generate_validators(prg, not self.remove_helpers)
comment = self.language.simple_comment
Expand All @@ -31,14 +33,39 @@ def add_validators(self, prg: str, inv_prg: str) -> str:

class ShellValidator(Validator):
cli_command: list[str]
LLM_GENERATED_CODE_DIR = pathlib.Path(get_cache_dir()) / "llm-generated-code"
LLM_GENERATED_VAL_DIR = pathlib.Path(get_cache_dir()) / "llm-generated-val"
tries = 0
cur_name: str

def __init__(self, cli_command: list[str]):
def __init__(self, cli_command: list[str], language: Language, remove_helpers: bool, cur_name: str):
super().__init__(language, remove_helpers)
self.cli_command = cli_command
self.LLM_GENERATED_CODE_DIR.mkdir(parents=True, exist_ok=True)
self.LLM_GENERATED_VAL_DIR.mkdir(parents=True, exist_ok=True)
self.cur_name = cur_name

def _code_file(self, name: str, try_n: int) -> pathlib.Path:
base, extension = name.rsplit(".", 1)
return self.LLM_GENERATED_CODE_DIR / f"{base}_{try_n}.{extension}"

def _validation_file(self, name: str, try_n: int) -> pathlib.Path:
base, extension = name.rsplit(".", 1)
return self.LLM_GENERATED_VAL_DIR / f"{base}_{try_n}.{extension}"

def add_validators(self, prg: str, inv_prg: str) -> str:
try:
command = self.cli_command + [prg, inv_prg]
result = run(command, capture_output=True, timeout=10, check=True).stdout.decode()
output: pathlib.Path = self._validation_file(self.cur_name, self.tries)
with open(output, "w") as f:
f.write(prg)

code: pathlib.Path = self._code_file(self.cur_name, self.tries)
with open(code, "w") as f:
f.write(inv_prg)

self.tries += 1
command = self.cli_command + [str(code), str(output)]
result = run(command, capture_output=True, timeout=30, check=True).stdout.decode()
return result
except (CalledProcessError, TimeoutError):
return inv_prg
Expand Down
8 changes: 8 additions & 0 deletions verified_cogen/several_modes/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ProgramArgsMultiple:
output_logging: bool
manual_rewriters: List[str]
max_jobs: int
shell_validator: List[str]

@no_type_check
def __init__(self, args):
Expand All @@ -52,6 +53,7 @@ def __init__(self, args):
self.modes = args.modes
self.skip_failed = args.skip_failed
self.max_jobs = args.max_jobs
self.shell_validator = args.shell_validator


def get_default_parser_multiple():
Expand Down Expand Up @@ -127,6 +129,12 @@ def get_default_parser_multiple():
default=[],
nargs="+",
)
parser.add_argument(
"--shell-validator",
help="Shell arguments to run validator",
default=[],
nargs="+",
)
parser.add_argument(
"--modes",
help="modes",
Expand Down
15 changes: 13 additions & 2 deletions verified_cogen/several_modes/several_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from verified_cogen.main import construct_rewriter, make_runner_cls
from verified_cogen.runners import RunnerConfig
from verified_cogen.runners.languages import register_basic_languages
from verified_cogen.runners.languages.language import AnnotationType
from verified_cogen.runners.languages.language import AnnotationType, LanguageDatabase
from verified_cogen.runners.rewriters import Rewriter
from verified_cogen.runners.validating import ShellValidator
from verified_cogen.several_modes.args import ProgramArgsMultiple, get_args
from verified_cogen.several_modes.constants import (
MODE_MAPPING,
Expand Down Expand Up @@ -65,8 +66,18 @@ def process_file(
config.args.temperature,
history=config.history_dir / f"{file.stem}.txt",
)
validator = (
ShellValidator(
config.args.shell_validator,
LanguageDatabase().get(file.suffix.lstrip(".")),
runner_config.remove_helpers,
file.stem + file.suffix,
)
if config.args.shell_validator
else None
)
runner = make_runner_cls(config.args.bench_types[idx], config.extension, runner_config)(
llm, logger, verifier, rewriter
llm, logger, verifier, rewriter, validator
)
try:
mode = Mode(config.args.insert_conditions_mode)
Expand Down