diff --git a/.github/workflows/ecr_sagemaker_publish.yml b/.github/workflows/ecr_sagemaker_publish.yml new file mode 100644 index 0000000..365013c --- /dev/null +++ b/.github/workflows/ecr_sagemaker_publish.yml @@ -0,0 +1,81 @@ +name: Sagemaker ECR Publish (RC) + +on: + push: + branches: + - main + workflow_dispatch: + inputs: + is_release_candidate: + description: 'Is this a release candidate?' + required: true + default: 'true' + +# Needed for OIDC / assume role +permissions: + id-token: write + contents: read + +jobs: + publish_image: + name: Publish Sagemaker Image (Release Candidate) + runs-on: ubuntu-latest + env: + VALIDATOR_TAG_NAME: restrict2topic + AWS_REGION: us-east-1 + WORKING_DIR: "./" + AWS_CI_ROLE__PROD: ${{ secrets.AWS_CI_ROLE__PROD }} + AWS_ECR_RELEASE_CANDIDATE: ${{ inputs.is_release_candidate || 'true' }} + steps: + + - name: Check out head + uses: actions/checkout@v3 + with: + persist-credentials: false + + - name: Set ECR Tag + id: set-ecr-tag + run: | + if [ ${{ env.AWS_ECR_RELEASE_CANDIDATE }} == 'true' ]; then + echo "This is a release candidate." + echo "Setting tag to -rc" + ECR_TAG=$VALIDATOR_TAG_NAME-rc + else + echo "This is a production image." + ECR_TAG=$VALIDATOR_TAG_NAME + fi + echo "Setting ECR tag to $ECR_TAG" + echo "ECR_TAG=$ECR_TAG" >> "$GITHUB_OUTPUT" + + - name: Set up QEMU + uses: docker/setup-qemu-action@master + with: + platforms: linux/amd64 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@master + with: + platforms: linux/amd64 + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-region: ${{ env.AWS_REGION }} + role-to-assume: ${{ env.AWS_CI_ROLE__PROD}} + + - name: Login to Amazon ECR + id: login-ecr + uses: aws-actions/amazon-ecr-login@v2 + with: + mask-password: 'true' + + - name: Build & Push ECR Image + uses: docker/build-push-action@v2 + with: + builder: ${{ steps.buildx.outputs.name }} + context: ${{ env.WORKING_DIR }} + platforms: linux/amd64 + cache-from: type=gha + cache-to: type=gha,mode=max + push: true + tags: 064852979926.dkr.ecr.us-east-1.amazonaws.com/gr-sagemaker-validator-images-prod:${{ steps.set-ecr-tag.outputs.ECR_TAG }} diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..35c79b5 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,37 @@ +# Use an official PyTorch image with CUDA support +FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-runtime + +# Set the working directory +WORKDIR /app + +# Copy the pyproject.toml and any other necessary files (e.g., README, LICENSE) +COPY pyproject.toml . +COPY README.md . +COPY LICENSE . + +# Install dependencies from the pyproject.toml file +RUN pip install --upgrade pip setuptools wheel +RUN pip install . + +ENV HF_HUB_ENABLE_HF_TRANSFER=1 + +# Install the necessary packages for the FastAPI app +RUN pip install fastapi "uvicorn[standard]" gunicorn transformers accelerate huggingface_hub hf-transfer "jinja2>=3.1.0" + +# Copy the entire project code into the container +COPY . /app + +# Copy the serve script into the container +COPY serve /usr/local/bin/serve + +# Make the serve script executable +RUN chmod +x /usr/local/bin/serve + +# Set environment variable to determine the device (cuda or cpu) +ENV env=prod + +# Expose the port that the FastAPI app will run on +EXPOSE 8080 + +# Set the entrypoint for SageMaker to the serve script +ENTRYPOINT ["serve"] diff --git a/app.py b/app.py index 4f7914a..88cc819 100644 --- a/app.py +++ b/app.py @@ -1,34 +1,107 @@ -import json +import os +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from typing import List, Union +from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline, ZeroShotClassificationPipeline import torch -import nltk -from typing import Any, Dict, List -from transformers import pipeline - -class InferlessPythonModel: - - def initialize(self): - self._classifier = pipeline( - "zero-shot-classification", - model="facebook/bart-large-mnli", - device="cuda", - hypothesis_template="This example has to do with topic {}.", - multi_label=True, - ) - #self._classifier.to("cuda") - - def infer(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - result = self._classifier(inputs["text"], inputs["candidate_topics"]) - topics = result["labels"] - scores = result["scores"] - found_topics = [] - for topic, score in zip(topics, scores): - if score > inputs["zero_shot_threshold"]: - found_topics.append(topic) - if not found_topics: - return {"results": ["No valid topic found."]} - return {"results": found_topics} - - def finalize(self): - pass + +app = FastAPI() + +# Initialize the zero-shot classification pipeline +model_save_directory = "/opt/ml/model" +torch_device = "cuda" if torch.cuda.is_available() else "cpu" + +print(f"Using torch device: {torch_device}") + +if not os.path.exists(model_save_directory): + print(f"Using cached model in {model_save_directory}...") + model = AutoModelForSequenceClassification.from_pretrained(model_save_directory) + tokenizer = AutoTokenizer.from_pretrained(model_save_directory) + classifier = ZeroShotClassificationPipeline( + model=model, + tokenizer=tokenizer, + device=torch.device(torch_device), + hypothesis_template="This example has to do with topic {}.", + multi_label=True + ) +else: + print("Downloading model from Hugging Face...") + classifier = pipeline( + "zero-shot-classification", + model="facebook/bart-large-mnli", + device=torch.device(torch_device), + hypothesis_template="This example has to do with topic {}.", + multi_label=True, + ) + + +class InferenceData(BaseModel): + name: str + shape: List[int] + data: Union[List[str], List[float]] + datatype: str + +class InputRequest(BaseModel): + inputs: List[InferenceData] + +class OutputResponse(BaseModel): + modelname: str + modelversion: str + outputs: List[InferenceData] + +@app.post("/validate", response_model=OutputResponse) +async def restrict_to_topic(input_request: InputRequest): + print('make request') + text = None + candidate_topics = None + zero_shot_threshold = 0.5 + for inp in input_request.inputs: + if inp.name == "text": + text = inp.data[0] + elif inp.name == "candidate_topics": + candidate_topics = inp.data + elif inp.name == "zero_shot_threshold": + zero_shot_threshold = float(inp.data[0]) + if text is None or candidate_topics is None: + raise HTTPException(status_code=400, detail="Invalid input format") + + # Perform zero-shot classification + result = classifier(text, candidate_topics) + topics = result["labels"] + scores = result["scores"] + found_topics = [topic for topic, score in zip(topics, scores) if score > zero_shot_threshold] + + if not found_topics: + found_topics = ["No valid topic found."] + + output_data = OutputResponse( + modelname="RestrictToTopicModel", + modelversion="1", + outputs=[ + InferenceData( + name="results", + datatype="BYTES", + shape=[len(found_topics)], + data=found_topics + ) + ] + ) + + print(f"Output data: {output_data}") + return output_data + + +# Sagemaker specific endpoints +@app.get("/ping") +async def healtchcheck(): + return {"status": "ok"} + +@app.post("/invocations", response_model=OutputResponse) +async def retrict_to_topic_sagemaker(input_request: InputRequest): + return await restrict_to_topic(input_request) + + +# Run the app with uvicorn +# Save this script as app.py and run with: uvicorn app:app --reload diff --git a/serve b/serve new file mode 100644 index 0000000..f0e117a --- /dev/null +++ b/serve @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +import multiprocessing +import os +import signal +import subprocess +import sys +import math + +import torch +from huggingface_hub import snapshot_download + +cpu_count = multiprocessing.cpu_count() +default_worker_count = max(cpu_count // 8,1) + +model_server_timeout = os.environ.get('MODEL_SERVER_TIMEOUT', '60') +model_server_workers = int(os.environ.get('MODEL_SERVER_WORKERS', default_worker_count)) +model_save_directory = os.environ.get('MODEL_SAVE_DIRECTORY', '/opt/ml/model') + +MODEL_NAME = "facebook/bart-large-mnli" +DEFAULT_REVISION = "d7645e127eaf1aefc7862fd59a17a5aa8558b8ce" + +print(f'Model server workers: {model_server_workers}') +print(f'Model save directory: {model_save_directory}') +print(f'Model server timeout: {model_server_timeout}') + +print(f'CPU count: {cpu_count}') + +def sigterm_handler(gunicorn_pid): + try: + os.kill(gunicorn_pid, signal.SIGTERM) + except OSError: + pass + sys.exit(0) + +def load_and_save_model(): + try: + + print('Loading the model...') + # Ensure the save directory exists + if not os.path.exists(model_save_directory): + os.makedirs(model_save_directory) + + print("Downloading the model...") + + snapshot_download( + MODEL_NAME, + local_dir=model_save_directory, + ignore_patterns=[ + "*.pt", + "*.bin", + "*.pth", + "original/*", + ], # Ensure safetensors + revision=DEFAULT_REVISION, + force_download=False, + ) + else: + print("Model already downloaded.") + + print('Model loaded and saved successfully.') + except Exception as e: + print(f'Error loading and saving the model: {e}') + sys.exit(1) + +def start_server(): + print(f'Starting the inference server with {model_server_workers} workers.') + + load_and_save_model() + + try: + # Start Gunicorn to serve the FastAPI app + gunicorn = subprocess.Popen(['gunicorn', + '--timeout', str(model_server_timeout), + '-k', 'uvicorn.workers.UvicornWorker', + '-b', '0.0.0.0:8080', + '-w', str(model_server_workers), + 'app:app']) + + signal.signal(signal.SIGTERM, lambda a, b: sigterm_handler(gunicorn.pid)) + + # Wait for the Gunicorn process to exit + gunicorn.wait() + + except Exception as e: + print(f'Error starting the inference server: {e}') + sys.exit(1) + + print('Inference server exiting') + +if __name__ == '__main__': + start_server()