Skip to content

Commit

Permalink
Basic CircleCI (ml-explore#449)
Browse files Browse the repository at this point in the history
* basic style checks for circleci

* format

* fix config
  • Loading branch information
awni authored Feb 17, 2024
1 parent 21e19b5 commit e4d5630
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
35 changes: 35 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
version: 2.1

jobs:
linux_build_and_test:
docker:
- image: cimg/python:3.9

steps:
- checkout
- run:
name: Run style checks
command: |
pip install pre-commit
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
workflows:
build_and_test:
when:
matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
jobs:
- linux_build_and_test

prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- linux_build_and_test:
requires: [ hold ]
17 changes: 13 additions & 4 deletions llms/mlx_lm/tuner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def evaluate(
num_batches,
max_seq_length=2048,
loss: callable = default_loss,
iterate_batches: callable = iterate_batches
iterate_batches: callable = iterate_batches,
):
all_losses = []
ntokens = 0
Expand All @@ -121,7 +121,14 @@ def evaluate(

class TrainingCallback:

def on_train_loss_report(self, steps: int, loss: float, it_sec: float, tokens_sec: float, trained_tokens: int):
def on_train_loss_report(
self,
steps: int,
loss: float,
it_sec: float,
tokens_sec: float,
trained_tokens: int,
):
"""Called to report training loss at specified intervals."""
pass

Expand Down Expand Up @@ -193,7 +200,9 @@ def train(
)

if training_callback is not None:
training_callback.on_train_loss_report(it + 1, train_loss, it_sec, tokens_sec, trained_tokens)
training_callback.on_train_loss_report(
it + 1, train_loss, it_sec, tokens_sec, trained_tokens
)

losses = []
n_tokens = 0
Expand All @@ -210,7 +219,7 @@ def train(
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
print(
Expand Down

0 comments on commit e4d5630

Please sign in to comment.