From e4d5630698b067df179dc2cdadf96ff1377d2125 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 16 Feb 2024 22:13:55 -0800 Subject: [PATCH] Basic CircleCI (#449) * basic style checks for circleci * format * fix config --- .circleci/config.yml | 35 +++++++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/trainer.py | 17 +++++++++++++---- 2 files changed, 48 insertions(+), 4 deletions(-) create mode 100644 .circleci/config.yml diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 000000000..7279e6525 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,35 @@ +version: 2.1 + +jobs: + linux_build_and_test: + docker: + - image: cimg/python:3.9 + + steps: + - checkout + - run: + name: Run style checks + command: | + pip install pre-commit + pre-commit run --all + if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi + +workflows: + build_and_test: + when: + matches: + pattern: "^(?!pull/)[-\\w]+$" + value: << pipeline.git.branch >> + jobs: + - linux_build_and_test + + prb: + when: + matches: + pattern: "^pull/\\d+(/head)?$" + value: << pipeline.git.branch >> + jobs: + - hold: + type: approval + - linux_build_and_test: + requires: [ hold ] diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index ae7e1fc7a..c0fb55362 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -99,7 +99,7 @@ def evaluate( num_batches, max_seq_length=2048, loss: callable = default_loss, - iterate_batches: callable = iterate_batches + iterate_batches: callable = iterate_batches, ): all_losses = [] ntokens = 0 @@ -121,7 +121,14 @@ def evaluate( class TrainingCallback: - def on_train_loss_report(self, steps: int, loss: float, it_sec: float, tokens_sec: float, trained_tokens: int): + def on_train_loss_report( + self, + steps: int, + loss: float, + it_sec: float, + tokens_sec: float, + trained_tokens: int, + ): """Called to report training loss at specified intervals.""" pass @@ -193,7 +200,9 @@ def train( ) if training_callback is not None: - training_callback.on_train_loss_report(it + 1, train_loss, it_sec, tokens_sec, trained_tokens) + training_callback.on_train_loss_report( + it + 1, train_loss, it_sec, tokens_sec, trained_tokens + ) losses = [] n_tokens = 0 @@ -210,7 +219,7 @@ def train( batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, - iterate_batches=iterate_batches + iterate_batches=iterate_batches, ) val_time = time.perf_counter() - stop print(