diff --git a/.github/workflows/continuous_integration.yml b/.github/workflows/continuous_integration.yml index 36cbdae..9975083 100644 --- a/.github/workflows/continuous_integration.yml +++ b/.github/workflows/continuous_integration.yml @@ -47,7 +47,11 @@ jobs: if: ${{ github.event_name == 'pull_request_target' }} uses: actions/checkout@v3 with: - ref: "refs/pull/${{ github.event.number }}/merge" + ref: ${{ github.sha }} + - name: Fetch base branch + run: git fetch origin ${{ github.event.pull_request.base.ref }} + - name: Merge base branch into PR branch + run: git merge origin/${{ github.event.pull_request.base.ref }} - uses: actions/setup-python@v4 with: python-version: '3.10' @@ -77,7 +81,11 @@ jobs: if: ${{ github.event_name == 'pull_request_target' }} uses: actions/checkout@v3 with: - ref: "refs/pull/${{ github.event.number }}/merge" + ref: ${{ github.sha }} + - name: Fetch base branch + run: git fetch origin ${{ github.event.pull_request.base.ref }} + - name: Merge base branch into PR branch + run: git merge origin/${{ github.event.pull_request.base.ref }} - uses: actions/setup-python@v4 with: python-version: '3.10' @@ -109,7 +117,11 @@ jobs: if: ${{ github.event_name == 'pull_request_target' }} uses: actions/checkout@v3 with: - ref: "refs/pull/${{ github.event.number }}/merge" + ref: ${{ github.sha }} + - name: Fetch base branch + run: git fetch origin ${{ github.event.pull_request.base.ref }} + - name: Merge base branch into PR branch + run: git merge origin/${{ github.event.pull_request.base.ref }} - uses: actions/setup-python@v4 with: python-version: '3.10' @@ -141,7 +153,11 @@ jobs: if: ${{ github.event_name == 'pull_request_target' }} uses: actions/checkout@v3 with: - ref: "refs/pull/${{ github.event.number }}/merge" + ref: ${{ github.sha }} + - name: Fetch base branch + run: git fetch origin ${{ github.event.pull_request.base.ref }} + - name: Merge base branch into PR branch + run: git merge origin/${{ github.event.pull_request.base.ref }} - uses: actions/setup-python@v4 with: python-version: '3.10' @@ -173,7 +189,11 @@ jobs: if: ${{ github.event_name == 'pull_request_target' }} uses: actions/checkout@v3 with: - ref: "refs/pull/${{ github.event.number }}/merge" + ref: ${{ github.sha }} + - name: Fetch base branch + run: git fetch origin ${{ github.event.pull_request.base.ref }} + - name: Merge base branch into PR branch + run: git merge origin/${{ github.event.pull_request.base.ref }} - uses: actions/setup-python@v4 with: python-version: '3.10' @@ -205,7 +225,11 @@ jobs: if: ${{ github.event_name == 'pull_request_target' }} uses: actions/checkout@v3 with: - ref: "refs/pull/${{ github.event.number }}/merge" + ref: ${{ github.sha }} + - name: Fetch base branch + run: git fetch origin ${{ github.event.pull_request.base.ref }} + - name: Merge base branch into PR branch + run: git merge origin/${{ github.event.pull_request.base.ref }} - uses: actions/setup-python@v4 with: python-version: '3.10' @@ -237,7 +261,11 @@ jobs: if: ${{ github.event_name == 'pull_request_target' }} uses: actions/checkout@v3 with: - ref: "refs/pull/${{ github.event.number }}/merge" + ref: ${{ github.sha }} + - name: Fetch base branch + run: git fetch origin ${{ github.event.pull_request.base.ref }} + - name: Merge base branch into PR branch + run: git merge origin/${{ github.event.pull_request.base.ref }} - uses: actions/setup-python@v4 with: python-version: '3.8' @@ -265,7 +293,11 @@ jobs: if: ${{ github.event_name == 'pull_request_target' }} uses: actions/checkout@v3 with: - ref: "refs/pull/${{ github.event.number }}/merge" + ref: ${{ github.sha }} + - name: Fetch base branch + run: git fetch origin ${{ github.event.pull_request.base.ref }} + - name: Merge base branch into PR branch + run: git merge origin/${{ github.event.pull_request.base.ref }} - uses: actions/setup-python@v4 with: python-version: '3.8' @@ -299,7 +331,11 @@ jobs: if: ${{ github.event_name == 'pull_request_target' }} uses: actions/checkout@v3 with: - ref: "refs/pull/${{ github.event.number }}/merge" + ref: ${{ github.sha }} + - name: Fetch base branch + run: git fetch origin ${{ github.event.pull_request.base.ref }} + - name: Merge base branch into PR branch + run: git merge origin/${{ github.event.pull_request.base.ref }} - name: Setup Env Vars uses: ./.github/actions/setup-env-vars - uses: actions/setup-python@v4 diff --git a/setup.py b/setup.py index 5a6a518..76fb20f 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ def default_setup_args(*, version): AUTOGLUON: [ "LICENSE", ], - 'autogluon.cloud': ['default_cluster_configs/*.yaml'], + "autogluon.cloud": ["default_cluster_configs/*.yaml"], }, classifiers=[ "Development Status :: 4 - Beta", diff --git a/src/autogluon/cloud/utils/ec2.py b/src/autogluon/cloud/utils/ec2.py index a0b7889..ab8888a 100644 --- a/src/autogluon/cloud/utils/ec2.py +++ b/src/autogluon/cloud/utils/ec2.py @@ -1,8 +1,10 @@ import os +import re from functools import partial from typing import Any, Dict, List, Optional import boto3 +import dateutil.parser as parser from botocore.exceptions import ClientError @@ -61,43 +63,46 @@ def delete_key_pair(key_name: str, local_path: Optional[str]): os.remove(local_path) -def get_latest_ami(ami_name: str = "Deep Learning AMI GPU PyTorch*Ubuntu*") -> str: - """ - Get the latest ami id +def parse_pytorch_version(name: str): + """Extracts the PyTorch version from the AMI name.""" + match = re.search(r"PyTorch (\d+\.\d+(\.\d+)?)", name) + return tuple(map(int, match.group(1).split("."))) if match else (0, 0, 0) - Parameter - --------- - ami_name: str, default = Deep Learning AMI GPU PyTorch*Ubuntu* - Name of the ami. Could be regex. - Return - ------ - str, - The latest ami id of the ami name being specified - """ - from dateutil import parser +def latest_torch_image(images: List[Dict[str, Any]]): + """Finds the newest image based on PyTorch version and then creation date.""" - def newest_image(list_of_images: List[Dict[str, Any]]): - latest = None + def image_key(image): + version = parse_pytorch_version(image["Name"]) + creation_date = parser.parse(image["CreationDate"]) + return version, creation_date - for image in list_of_images: - if not latest: - latest = image - continue + return max(images, key=image_key) - if parser.parse(image["CreationDate"]) > parser.parse(latest["CreationDate"]): - latest = image - return latest +def get_latest_ami(ami_name: str = "Deep Learning AMI GPU PyTorch*Ubuntu*") -> str: + """ + Get the latest AMI ID based on PyTorch version and creation date. - ec2 = boto3.client("ec2") + Parameters + ---------- + ami_name : str, default="Deep Learning AMI GPU PyTorch*Ubuntu*" + Name of the AMI. Could be a regex pattern. + Returns + ------- + str + The latest AMI ID of the specified AMI name. + """ + ec2 = boto3.client("ec2") filters = [ {"Name": "name", "Values": [ami_name]}, {"Name": "owner-alias", "Values": ["amazon"]}, {"Name": "architecture", "Values": ["x86_64"]}, {"Name": "state", "Values": ["available"]}, ] + response = ec2.describe_images(Owners=["amazon"], Filters=filters) - source_image = newest_image(response["Images"]) - return source_image["ImageId"] + latest_image = latest_torch_image(response["Images"]) + + return latest_image["ImageId"]