Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Master ci py312 new #5072

Open
wants to merge 61 commits into
base: master-ci-py312-pr-test
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
4aa6c7c
Test py312 ci support
mollyheamazon Feb 26, 2025
280424d
Test py312 ci support
mollyheamazon Feb 27, 2025
1374184
first commit
mollyheamazon Dec 18, 2024
49de844
test to point to personal stack
mollyheamazon Jan 2, 2025
696cb47
changed tox.ini
mollyheamazon Feb 27, 2025
5d35bd0
Revert "changed tox.ini"
mollyheamazon Feb 27, 2025
2ab95fb
Revert "Revert "changed tox.ini""
mollyheamazon Feb 27, 2025
dc2fcd5
Revert "Revert "Revert "changed tox.ini"""
mollyheamazon Feb 27, 2025
f143ca7
Revert "Revert "changed tox.ini""
mollyheamazon Feb 27, 2025
31d7478
Revert "changed tox.ini"
mollyheamazon Feb 27, 2025
f78febe
Revert "test to point to personal stack"
mollyheamazon Feb 27, 2025
7669550
Revert "first commit"
mollyheamazon Feb 27, 2025
598447f
add pyproject.toml
mollyheamazon Feb 27, 2025
1f0f608
Merge branch 'master-ci-py312-pr-test' into master-ci-py312
mollyheamazon Feb 27, 2025
332389c
bump numpy version
mollyheamazon Feb 27, 2025
041c739
numpy version change
mollyheamazon Feb 27, 2025
88fbe0c
add py312
mollyheamazon Feb 27, 2025
5c20eee
upgrade pip version
mollyheamazon Feb 27, 2025
3feadc5
add setuptools wheel to tox.ini
mollyheamazon Feb 28, 2025
6b57762
Fix key error in _send_metrics() (#5068)
pintaoz-aws Feb 28, 2025
f941b39
fix: Added check for the presence of model package group before creat…
keshav-chandak Feb 28, 2025
a27527d
add pyyaml version constraint, remove py312 from docstring because th…
mollyheamazon Feb 28, 2025
340ab3b
update pyyaml version constraint
mollyheamazon Feb 28, 2025
5bfbb55
update pyyaml version constraint
mollyheamazon Feb 28, 2025
5084e6f
deprecate py38
mollyheamazon Feb 28, 2025
f679f2c
bump scipy
mollyheamazon Feb 28, 2025
6655760
bump tensorflow and tensorboard
mollyheamazon Mar 1, 2025
4116094
bump dill
mollyheamazon Mar 1, 2025
1a23c2a
bump apache-airflow to ensure dill and greenlet version
mollyheamazon Mar 1, 2025
868894c
Use sagemaker session's s3_resource in download_folder (#5064)
pintaoz-aws Mar 3, 2025
9b9eb32
remove constraint for apache-airflow
mollyheamazon Mar 3, 2025
b5da915
remove constraint for apache-airflow
mollyheamazon Mar 3, 2025
e0caf76
bump torch version
mollyheamazon Mar 3, 2025
c69de67
bump torchvision version
mollyheamazon Mar 3, 2025
76efed5
new tests
mollyheamazon Mar 4, 2025
d7560cc
try changing some tests
mollyheamazon Mar 6, 2025
cbe7340
update model trainer test
mollyheamazon Mar 10, 2025
50ab9eb
fix test_pipeline
mollyheamazon Mar 12, 2025
10f4687
add constraint to apache in tox
mollyheamazon Mar 14, 2025
8af122b
Fix key error in _send_metrics() (#5068)
pintaoz-aws Feb 28, 2025
c6c6b64
fix: Added check for the presence of model package group before creat…
keshav-chandak Feb 28, 2025
09f0285
Use sagemaker session's s3_resource in download_folder (#5064)
pintaoz-aws Mar 3, 2025
604fae7
Fix error when there is no session to call _create_model_request() (#…
pintaoz-aws Mar 5, 2025
7c29b96
Ensure Model.is_repack() returns a boolean (#5060)
pintaoz-aws Mar 5, 2025
ed2c7e7
feat: Allow ModelTrainer to accept hyperparameters file (#5059)
benieric Mar 5, 2025
3f1d2de
feature: support training for JumpStart model references as part of C…
Narrohag Mar 5, 2025
cb2f1b2
feat: Make DistributedConfig Extensible (#5039)
benieric Mar 5, 2025
f186104
Skip tests with deprecated instance type (#5077)
pintaoz-aws Mar 6, 2025
1df4f38
prepare release v2.241.0
Mar 6, 2025
7b9013b
update development version to v2.241.1.dev0
Mar 6, 2025
23674fe
pipeline definition function doc update (#5074)
rohangujarathi Mar 6, 2025
e266baa
feat: add integ tests for training JumpStart models in private hub (#…
Narrohag Mar 10, 2025
3717e4d
fix: resolve infinite loop in _find_config on Windows systems (#4970)
Julfried Mar 10, 2025
11aac41
change: update image_uri_configs 03-11-2025 07:18:09 PST
sagemaker-bot Mar 11, 2025
b872f3e
Fixing Pytorch training python version in tests (#5084)
nargokul Mar 12, 2025
5e8e894
remove s3 output location requirement from hub class init (#5081)
bencrabtree Mar 12, 2025
a7458b9
fix: Prevent RunContext overlap between test_run tests (#5083)
rrrkharse Mar 12, 2025
5850515
Merge branch 'master-ci-py312-pr-test' into master-ci-py312-new
mollyheamazon Mar 17, 2025
3d0027e
fix integ test by bumping py38 to py39 for PyTorch
mollyheamazon Mar 17, 2025
bf24681
Merge branch 'aws:master-ci-py312-new' into master-ci-py312-new
mollyheamazon Mar 17, 2025
ae7d6e1
change framework_version that supports py39 in integ tests
mollyheamazon Mar 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: Allow ModelTrainer to accept hyperparameters file (#5059)
* Allow ModelTrainer to accept hyperparameter file and create Hyperparameter class

* pylint

* Detect hyperparameters from contents rather than file extension

* pylint

* change: add integs

* change: add integs

* change: remove custom hyperparameter tooling

* Add tests for hp contracts

* change: add unit tests and remove unreachable condition

* fix integs

* doc check fix

* fix tests

* fix tox.ini

* add unit test
benieric authored and mollyheamazon committed Mar 14, 2025
commit ed2c7e7590488ba2fa53a757710d7ffad90a5cb5
32 changes: 28 additions & 4 deletions src/sagemaker/modules/train/model_trainer.py
Original file line number Diff line number Diff line change
@@ -18,8 +18,8 @@
import json
import shutil
from tempfile import TemporaryDirectory

from typing import Optional, List, Union, Dict, Any, ClassVar
import yaml

from graphene.utils.str_converters import to_camel_case, to_snake_case

@@ -195,8 +195,9 @@ class ModelTrainer(BaseModel):
Defaults to "File".
environment (Optional[Dict[str, str]]):
The environment variables for the training job.
hyperparameters (Optional[Dict[str, Any]]):
The hyperparameters for the training job.
hyperparameters (Optional[Union[Dict[str, Any], str]):
The hyperparameters for the training job. Can be a dictionary of hyperparameters
or a path to hyperparameters json/yaml file.
tags (Optional[List[Tag]]):
An array of key-value pairs. You can use tags to categorize your AWS resources
in different ways, for example, by purpose, owner, or environment.
@@ -226,7 +227,7 @@ class ModelTrainer(BaseModel):
checkpoint_config: Optional[CheckpointConfig] = None
training_input_mode: Optional[str] = "File"
environment: Optional[Dict[str, str]] = {}
hyperparameters: Optional[Dict[str, Any]] = {}
hyperparameters: Optional[Union[Dict[str, Any], str]] = {}
tags: Optional[List[Tag]] = None
local_container_root: Optional[str] = os.getcwd()

@@ -470,6 +471,29 @@ def model_post_init(self, __context: Any):
f"StoppingCondition not provided. Using default:\n{self.stopping_condition}"
)

if self.hyperparameters and isinstance(self.hyperparameters, str):
if not os.path.exists(self.hyperparameters):
raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}")
logger.info(f"Loading hyperparameters from file: {self.hyperparameters}")
with open(self.hyperparameters, "r") as f:
contents = f.read()
try:
self.hyperparameters = json.loads(contents)
logger.debug("Hyperparameters loaded as JSON")
except json.JSONDecodeError:
try:
logger.info(f"contents: {contents}")
self.hyperparameters = yaml.safe_load(contents)
if not isinstance(self.hyperparameters, dict):
raise ValueError("YAML contents must be a valid mapping")
logger.info(f"hyperparameters: {self.hyperparameters}")
logger.debug("Hyperparameters loaded as YAML")
except (yaml.YAMLError, ValueError):
raise ValueError(
f"Invalid hyperparameters file: {self.hyperparameters}. "
"Must be a valid JSON or YAML file."
)

if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None:
session = self.sagemaker_session
base_job_name = self.base_job_name
15 changes: 15 additions & 0 deletions tests/data/modules/params_script/hyperparameters.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"integer": 1,
"boolean": true,
"float": 3.14,
"string": "Hello World",
"list": [1, 2, 3],
"dict": {
"string": "value",
"integer": 3,
"float": 3.14,
"list": [1, 2, 3],
"dict": {"key": "value"},
"boolean": true
}
}
19 changes: 19 additions & 0 deletions tests/data/modules/params_script/hyperparameters.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
integer: 1
boolean: true
float: 3.14
string: "Hello World"
list:
- 1
- 2
- 3
dict:
string: value
integer: 3
float: 3.14
list:
- 1
- 2
- 3
dict:
key: value
boolean: true
1 change: 1 addition & 0 deletions tests/data/modules/params_script/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
omegaconf
97 changes: 94 additions & 3 deletions tests/data/modules/params_script/train.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,9 @@
import argparse
import json
import os
from typing import List, Dict, Any
from dataclasses import dataclass
from omegaconf import OmegaConf

EXPECTED_HYPERPARAMETERS = {
"integer": 1,
@@ -26,6 +29,7 @@
"dict": {
"string": "value",
"integer": 3,
"float": 3.14,
"list": [1, 2, 3],
"dict": {"key": "value"},
"boolean": True,
@@ -117,7 +121,7 @@ def main():
assert isinstance(params["dict"], dict)

params = json.loads(os.environ["SM_TRAINING_ENV"])["hyperparameters"]
print(params)
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")
assert params["string"] == EXPECTED_HYPERPARAMETERS["string"]
assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"]
assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"]
@@ -132,9 +136,96 @@ def main():
assert isinstance(params["float"], float)
assert isinstance(params["list"], list)
assert isinstance(params["dict"], dict)
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")

print("Test passed.")
# Local JSON - DictConfig OmegaConf
params = OmegaConf.load("hyperparameters.json")

print(f"Local hyperparameters.json: {params}")
assert params.string == EXPECTED_HYPERPARAMETERS["string"]
assert params.integer == EXPECTED_HYPERPARAMETERS["integer"]
assert params.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
assert params.float == EXPECTED_HYPERPARAMETERS["float"]
assert params.list == EXPECTED_HYPERPARAMETERS["list"]
assert params.dict == EXPECTED_HYPERPARAMETERS["dict"]
assert params.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
assert params.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
assert params.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
assert params.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
assert params.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
assert params.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]

@dataclass
class DictConfig:
string: str
integer: int
boolean: bool
float: float
list: List[int]
dict: Dict[str, Any]

@dataclass
class HPConfig:
string: str
integer: int
boolean: bool
float: float
list: List[int]
dict: DictConfig

# Local JSON - Structured OmegaConf
hp_config: HPConfig = OmegaConf.merge(
OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.json")
)
print(f"Local hyperparameters.json - Structured: {hp_config}")
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]

# Local YAML - Structured OmegaConf
hp_config: HPConfig = OmegaConf.merge(
OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.yaml")
)
print(f"Local hyperparameters.yaml - Structured: {hp_config}")
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
print(f"hyperparameters.yaml -> hyperparameters: {hp_config}")

# HP Dict - Structured OmegaConf
hp_dict = json.loads(os.environ["SM_HPS"])
hp_config: HPConfig = OmegaConf.merge(OmegaConf.structured(HPConfig), OmegaConf.create(hp_dict))
print(f"SM_HPS - Structured: {hp_config}")
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
print(f"SM_HPS -> hyperparameters: {hp_config}")


if __name__ == "__main__":
52 changes: 36 additions & 16 deletions tests/integ/sagemaker/modules/train/test_model_trainer.py
Original file line number Diff line number Diff line change
@@ -28,26 +28,29 @@
"dict": {
"string": "value",
"integer": 3,
"float": 3.14,
"list": [1, 2, 3],
"dict": {"key": "value"},
"boolean": True,
},
}

PARAM_SCRIPT_SOURCE_DIR = f"{DATA_DIR}/modules/params_script"
PARAM_SCRIPT_SOURCE_CODE = SourceCode(
source_dir=PARAM_SCRIPT_SOURCE_DIR,
requirements="requirements.txt",
entry_script="train.py",
)

DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"


def test_hp_contract_basic_py_script(modules_sagemaker_session):
source_code = SourceCode(
source_dir=f"{DATA_DIR}/modules/params_script",
entry_script="train.py",
)

model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
hyperparameters=EXPECTED_HYPERPARAMETERS,
source_code=source_code,
source_code=PARAM_SCRIPT_SOURCE_CODE,
base_job_name="hp-contract-basic-py-script",
)

@@ -57,6 +60,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session):
def test_hp_contract_basic_sh_script(modules_sagemaker_session):
source_code = SourceCode(
source_dir=f"{DATA_DIR}/modules/params_script",
requirements="requirements.txt",
entry_script="train.sh",
)
model_trainer = ModelTrainer(
@@ -71,17 +75,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session):


def test_hp_contract_mpi_script(modules_sagemaker_session):
source_code = SourceCode(
source_dir=f"{DATA_DIR}/modules/params_script",
entry_script="train.py",
)
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
compute=compute,
hyperparameters=EXPECTED_HYPERPARAMETERS,
source_code=source_code,
source_code=PARAM_SCRIPT_SOURCE_CODE,
distributed=MPI(),
base_job_name="hp-contract-mpi-script",
)
@@ -90,19 +90,39 @@ def test_hp_contract_mpi_script(modules_sagemaker_session):


def test_hp_contract_torchrun_script(modules_sagemaker_session):
source_code = SourceCode(
source_dir=f"{DATA_DIR}/modules/params_script",
entry_script="train.py",
)
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
compute=compute,
hyperparameters=EXPECTED_HYPERPARAMETERS,
source_code=source_code,
source_code=PARAM_SCRIPT_SOURCE_CODE,
distributed=Torchrun(),
base_job_name="hp-contract-torchrun-script",
)

model_trainer.train()


def test_hp_contract_hyperparameter_json(modules_sagemaker_session):
model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.json",
source_code=PARAM_SCRIPT_SOURCE_CODE,
base_job_name="hp-contract-hyperparameter-json",
)
assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS
model_trainer.train()


def test_hp_contract_hyperparameter_yaml(modules_sagemaker_session):
model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.yaml",
source_code=PARAM_SCRIPT_SOURCE_CODE,
base_job_name="hp-contract-hyperparameter-yaml",
)
assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS
model_trainer.train()
93 changes: 92 additions & 1 deletion tests/unit/sagemaker/modules/train/test_model_trainer.py
Original file line number Diff line number Diff line change
@@ -17,9 +17,10 @@
import tempfile
import json
import os
import yaml
import pytest
from pydantic import ValidationError
from unittest.mock import patch, MagicMock, ANY
from unittest.mock import patch, MagicMock, ANY, mock_open

from sagemaker import image_uris
from sagemaker_core.main.resources import TrainingJob
@@ -1094,3 +1095,93 @@ def test_destructor_cleanup(mock_tmp_dir, modules_session):
mock_tmp_dir.assert_not_called()
del model_trainer
mock_tmp_dir.cleanup.assert_called_once()


@patch("os.path.exists")
def test_hyperparameters_valid_json(mock_exists, modules_session):
mock_exists.return_value = True
expected_hyperparameters = {"param1": "value1", "param2": 2}
mock_file_open = mock_open(read_data=json.dumps(expected_hyperparameters))

with patch("builtins.open", mock_file_open):
model_trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
sagemaker_session=modules_session,
compute=DEFAULT_COMPUTE_CONFIG,
hyperparameters="hyperparameters.json",
)
assert model_trainer.hyperparameters == expected_hyperparameters
mock_file_open.assert_called_once_with("hyperparameters.json", "r")
mock_exists.assert_called_once_with("hyperparameters.json")


@patch("os.path.exists")
def test_hyperparameters_valid_yaml(mock_exists, modules_session):
mock_exists.return_value = True
expected_hyperparameters = {"param1": "value1", "param2": 2}
mock_file_open = mock_open(read_data=yaml.dump(expected_hyperparameters))

with patch("builtins.open", mock_file_open):
model_trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
sagemaker_session=modules_session,
compute=DEFAULT_COMPUTE_CONFIG,
hyperparameters="hyperparameters.yaml",
)
assert model_trainer.hyperparameters == expected_hyperparameters
mock_file_open.assert_called_once_with("hyperparameters.yaml", "r")
mock_exists.assert_called_once_with("hyperparameters.yaml")


def test_hyperparameters_not_exist(modules_session):
with pytest.raises(ValueError):
ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
sagemaker_session=modules_session,
compute=DEFAULT_COMPUTE_CONFIG,
hyperparameters="nonexistent.json",
)


@patch("os.path.exists")
def test_hyperparameters_invalid(mock_exists, modules_session):
mock_exists.return_value = True

# YAML contents must be a valid mapping
mock_file_open = mock_open(read_data="- item1\n- item2")
with patch("builtins.open", mock_file_open):
with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."):
ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
sagemaker_session=modules_session,
compute=DEFAULT_COMPUTE_CONFIG,
hyperparameters="hyperparameters.yaml",
)

# YAML contents must be a valid mapping
mock_file_open = mock_open(read_data="invalid")
with patch("builtins.open", mock_file_open):
with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."):
ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
sagemaker_session=modules_session,
compute=DEFAULT_COMPUTE_CONFIG,
hyperparameters="hyperparameters.yaml",
)

# Must be valid YAML
mock_file_open = mock_open(read_data="* invalid")
with patch("builtins.open", mock_file_open):
with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."):
ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
sagemaker_session=modules_session,
compute=DEFAULT_COMPUTE_CONFIG,
hyperparameters="hyperparameters.yaml",
)