Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated for sagemaker endpoints compatibility #9

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions .github/workflows/ecr_sagemaker_publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
name: Sagemaker ECR Publish (RC)

on:
push:
branches:
- main
workflow_dispatch:
inputs:
is_release_candidate:
description: 'Is this a release candidate?'
required: true
default: 'true'

# Needed for OIDC / assume role
permissions:
id-token: write
contents: read

jobs:
publish_image:
name: Publish Sagemaker Image (Release Candidate)
runs-on: ubuntu-latest
env:
VALIDATOR_TAG_NAME: restrict2topic
AWS_REGION: us-east-1
WORKING_DIR: "./"
AWS_CI_ROLE__PROD: ${{ secrets.AWS_CI_ROLE__PROD }}
AWS_ECR_RELEASE_CANDIDATE: ${{ inputs.is_release_candidate || 'true' }}
steps:

- name: Check out head
uses: actions/checkout@v3
with:
persist-credentials: false

- name: Set ECR Tag
id: set-ecr-tag
run: |
if [ ${{ env.AWS_ECR_RELEASE_CANDIDATE }} == 'true' ]; then
echo "This is a release candidate."
echo "Setting tag to -rc"
ECR_TAG=$VALIDATOR_TAG_NAME-rc
else
echo "This is a production image."
ECR_TAG=$VALIDATOR_TAG_NAME
fi
echo "Setting ECR tag to $ECR_TAG"
echo "ECR_TAG=$ECR_TAG" >> "$GITHUB_OUTPUT"

- name: Set up QEMU
uses: docker/setup-qemu-action@master
with:
platforms: linux/amd64

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@master
with:
platforms: linux/amd64

- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: ${{ env.AWS_REGION }}
role-to-assume: ${{ env.AWS_CI_ROLE__PROD}}

- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@v2
with:
mask-password: 'true'

- name: Build & Push ECR Image
uses: docker/build-push-action@v2
with:
builder: ${{ steps.buildx.outputs.name }}
context: ${{ env.WORKING_DIR }}
platforms: linux/amd64
cache-from: type=gha
cache-to: type=gha,mode=max
push: true
tags: 064852979926.dkr.ecr.us-east-1.amazonaws.com/gr-sagemaker-validator-images-prod:${{ steps.set-ecr-tag.outputs.ECR_TAG }}
37 changes: 37 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Use an official PyTorch image with CUDA support
FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-runtime

# Set the working directory
WORKDIR /app

# Copy the pyproject.toml and any other necessary files (e.g., README, LICENSE)
COPY pyproject.toml .
COPY README.md .
COPY LICENSE .

# Install dependencies from the pyproject.toml file
RUN pip install --upgrade pip setuptools wheel
RUN pip install .

ENV HF_HUB_ENABLE_HF_TRANSFER=1

# Install the necessary packages for the FastAPI app
RUN pip install fastapi "uvicorn[standard]" gunicorn transformers accelerate huggingface_hub hf-transfer "jinja2>=3.1.0"

# Copy the entire project code into the container
COPY . /app

# Copy the serve script into the container
COPY serve /usr/local/bin/serve

# Make the serve script executable
RUN chmod +x /usr/local/bin/serve

# Set environment variable to determine the device (cuda or cpu)
ENV env=prod

# Expose the port that the FastAPI app will run on
EXPOSE 8080

# Set the entrypoint for SageMaker to the serve script
ENTRYPOINT ["serve"]
135 changes: 104 additions & 31 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,107 @@
import json
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Union
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline, ZeroShotClassificationPipeline
import torch
import nltk
from typing import Any, Dict, List
from transformers import pipeline

class InferlessPythonModel:

def initialize(self):
self._classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device="cuda",
hypothesis_template="This example has to do with topic {}.",
multi_label=True,
)
#self._classifier.to("cuda")

def infer(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
result = self._classifier(inputs["text"], inputs["candidate_topics"])
topics = result["labels"]
scores = result["scores"]
found_topics = []
for topic, score in zip(topics, scores):
if score > inputs["zero_shot_threshold"]:
found_topics.append(topic)
if not found_topics:
return {"results": ["No valid topic found."]}
return {"results": found_topics}

def finalize(self):
pass

app = FastAPI()

# Initialize the zero-shot classification pipeline
model_save_directory = "/opt/ml/model"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using torch device: {torch_device}")

if not os.path.exists(model_save_directory):
print(f"Using cached model in {model_save_directory}...")
model = AutoModelForSequenceClassification.from_pretrained(model_save_directory)
tokenizer = AutoTokenizer.from_pretrained(model_save_directory)
classifier = ZeroShotClassificationPipeline(
model=model,
tokenizer=tokenizer,
device=torch.device(torch_device),
hypothesis_template="This example has to do with topic {}.",
multi_label=True
)
else:
print("Downloading model from Hugging Face...")
classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device=torch.device(torch_device),
hypothesis_template="This example has to do with topic {}.",
multi_label=True,
)


class InferenceData(BaseModel):
name: str
shape: List[int]
data: Union[List[str], List[float]]
datatype: str

class InputRequest(BaseModel):
inputs: List[InferenceData]

class OutputResponse(BaseModel):
modelname: str
modelversion: str
outputs: List[InferenceData]

@app.post("/validate", response_model=OutputResponse)
async def restrict_to_topic(input_request: InputRequest):
print('make request')
text = None
candidate_topics = None
zero_shot_threshold = 0.5

for inp in input_request.inputs:
if inp.name == "text":
text = inp.data[0]
elif inp.name == "candidate_topics":
candidate_topics = inp.data
elif inp.name == "zero_shot_threshold":
zero_shot_threshold = float(inp.data[0])

if text is None or candidate_topics is None:
raise HTTPException(status_code=400, detail="Invalid input format")

# Perform zero-shot classification
result = classifier(text, candidate_topics)
topics = result["labels"]
scores = result["scores"]
found_topics = [topic for topic, score in zip(topics, scores) if score > zero_shot_threshold]

if not found_topics:
found_topics = ["No valid topic found."]

output_data = OutputResponse(
modelname="RestrictToTopicModel",
modelversion="1",
outputs=[
InferenceData(
name="results",
datatype="BYTES",
shape=[len(found_topics)],
data=found_topics
)
]
)

print(f"Output data: {output_data}")
return output_data


# Sagemaker specific endpoints
@app.get("/ping")
async def healtchcheck():
return {"status": "ok"}

@app.post("/invocations", response_model=OutputResponse)
async def retrict_to_topic_sagemaker(input_request: InputRequest):
return await restrict_to_topic(input_request)


# Run the app with uvicorn
# Save this script as app.py and run with: uvicorn app:app --reload
92 changes: 92 additions & 0 deletions serve
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python

import multiprocessing
import os
import signal
import subprocess
import sys
import math

import torch
from huggingface_hub import snapshot_download

cpu_count = multiprocessing.cpu_count()
default_worker_count = max(cpu_count // 8,1)

model_server_timeout = os.environ.get('MODEL_SERVER_TIMEOUT', '60')
model_server_workers = int(os.environ.get('MODEL_SERVER_WORKERS', default_worker_count))
model_save_directory = os.environ.get('MODEL_SAVE_DIRECTORY', '/opt/ml/model')

MODEL_NAME = "facebook/bart-large-mnli"
DEFAULT_REVISION = "d7645e127eaf1aefc7862fd59a17a5aa8558b8ce"

print(f'Model server workers: {model_server_workers}')
print(f'Model save directory: {model_save_directory}')
print(f'Model server timeout: {model_server_timeout}')

print(f'CPU count: {cpu_count}')

def sigterm_handler(gunicorn_pid):
try:
os.kill(gunicorn_pid, signal.SIGTERM)
except OSError:
pass
sys.exit(0)

def load_and_save_model():
try:

print('Loading the model...')
# Ensure the save directory exists
if not os.path.exists(model_save_directory):
os.makedirs(model_save_directory)

print("Downloading the model...")

snapshot_download(
MODEL_NAME,
local_dir=model_save_directory,
ignore_patterns=[
"*.pt",
"*.bin",
"*.pth",
"original/*",
], # Ensure safetensors
revision=DEFAULT_REVISION,
force_download=False,
)
else:
print("Model already downloaded.")

print('Model loaded and saved successfully.')
except Exception as e:
print(f'Error loading and saving the model: {e}')
sys.exit(1)

def start_server():
print(f'Starting the inference server with {model_server_workers} workers.')

load_and_save_model()

try:
# Start Gunicorn to serve the FastAPI app
gunicorn = subprocess.Popen(['gunicorn',
'--timeout', str(model_server_timeout),
'-k', 'uvicorn.workers.UvicornWorker',
'-b', '0.0.0.0:8080',
'-w', str(model_server_workers),
'app:app'])

signal.signal(signal.SIGTERM, lambda a, b: sigterm_handler(gunicorn.pid))

# Wait for the Gunicorn process to exit
gunicorn.wait()

except Exception as e:
print(f'Error starting the inference server: {e}')
sys.exit(1)

print('Inference server exiting')

if __name__ == '__main__':
start_server()