-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Master #5112
Merged
+39
−11
Merged
Master #5112
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
027415f
fix integ test hub
bencrabtree e262db9
lint
bencrabtree cacb977
fix jumpstart curated hub bugs
bencrabtree 4263904
lint
bencrabtree c844b33
fix tests
bencrabtree d1ee7d2
linting
bencrabtree 2e728f8
lint
bencrabtree a8361cc
rm test file
bencrabtree b3e6fde
fix test
bencrabtree b44f722
fix
bencrabtree 1e69585
lint
bencrabtree ea18f02
remove test
bencrabtree 988b2b5
update for test
bencrabtree File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1632,17 +1632,29 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str: | |
return neo_bucket | ||
|
||
|
||
def remove_env_var_from_estimator_kwargs_if_accept_eula_present( | ||
init_kwargs: dict, accept_eula: Optional[bool] | ||
def remove_env_var_from_estimator_kwargs_if_model_access_config_present( | ||
init_kwargs: dict, model_access_config: Optional[dict] | ||
): | ||
"""Remove env vars if access configs are used | ||
"""Remove env vars if ModelAccessConfig is used | ||
|
||
Args: | ||
init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated. | ||
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit(). | ||
""" | ||
if accept_eula is not None and init_kwargs["environment"]: | ||
del init_kwargs["environment"][constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY] | ||
if ( | ||
model_access_config is not None | ||
and init_kwargs.get("environment") is not None | ||
and init_kwargs.get("model_uri") is not None | ||
): | ||
if ( | ||
constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY | ||
in init_kwargs["environment"] | ||
): | ||
del init_kwargs["environment"][ | ||
constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY | ||
] | ||
if "accept_eula" in init_kwargs["environment"]: | ||
del init_kwargs["environment"]["accept_eula"] | ||
|
||
|
||
def get_hub_access_config(hub_content_arn: Optional[str]): | ||
|
@@ -1659,16 +1671,24 @@ def get_hub_access_config(hub_content_arn: Optional[str]): | |
return hub_access_config | ||
|
||
|
||
def get_model_access_config(accept_eula: Optional[bool]): | ||
def get_model_access_config(accept_eula: Optional[bool], environment: Optional[dict]): | ||
"""Get access configs | ||
|
||
Args: | ||
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit(). | ||
""" | ||
env_var_eula = environment.get("accept_eula") if environment else None | ||
if env_var_eula is not None and accept_eula is not None: | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'd log a warn here instead |
||
"Cannot pass in both accept_eula and environment variables. " | ||
"Please remove the environment variable and pass in the accept_eula parameter." | ||
) | ||
|
||
model_access_config = None | ||
if env_var_eula is not None: | ||
model_access_config = {"AcceptEula": env_var_eula == "true"} | ||
if accept_eula is not None: | ||
model_access_config = {"AcceptEula": accept_eula} | ||
else: | ||
model_access_config = None | ||
|
||
return model_access_config | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can do
init_kwargs["environment"].pop(constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, None)