Skip to content

Commit c5c889b

Browse files
committed
chore: add gpu dockerfile
1 parent 098bc29 commit c5c889b

File tree

3 files changed

+31
-23
lines changed

3 files changed

+31
-23
lines changed

Dockerfile

+1-22
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,3 @@
1-
# FROM python:3.6
2-
3-
# RUN apt-get update
4-
# RUN apt-get install -y git libsm6 libxrender1 libfontconfig1
5-
6-
# WORKDIR /workspace
7-
8-
# COPY ./requirements_container.txt /workspace
9-
# # Install python package dependices
10-
# RUN pip install -r requirements_container.txt
11-
12-
# WORKDIR /workspace/
13-
14-
15-
# COPY ./dextr_pb2.py /workspace
16-
# COPY ./dextr_pb2_grpc.py /workspace
17-
# COPY ./server.py /workspace
18-
# COPY ./dextr.proto /workspace
19-
20-
# ENTRYPOINT [ "python", "server.py" ]
21-
221
FROM python:3.8.5
232

243
RUN apt-get update
@@ -36,6 +15,6 @@ RUN pip install gunicorn==20.0.4
3615
COPY server.py /workspace
3716

3817
EXPOSE 8000
39-
18+
ENV DEVICE=cpu
4019
WORKDIR /workspace
4120
CMD [ "gunicorn", "-w 6", "-b 0.0.0.0:8000", "server:app" ]

Dockerfile.gpu

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-runtime
2+
3+
4+
RUN apt-get update
5+
RUN apt-get install 'ffmpeg'\
6+
'libsm6'\
7+
'libxext6' -y
8+
9+
RUN apt install liblzma-dev
10+
11+
WORKDIR /workspace
12+
COPY requirements_container.txt /workspace
13+
RUN pip install -r requirements_container.txt
14+
RUN python -c "from dextr.model import DextrModel; DextrModel.pascalvoc_resunet101()"
15+
RUN pip install gunicorn==20.0.4
16+
COPY server.py /workspace
17+
18+
EXPOSE 8000
19+
ENV DEVICE=cuda:0
20+
WORKDIR /workspace
21+
CMD [ "gunicorn", "-w 6", "-b 0.0.0.0:8000", "server:app" ]

server.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44
from PIL import Image
55
import numpy as np
66
import time
7+
import torch
8+
import os
79

810
from imantics import Mask
911

12+
device = os.getenv("DEVICE", "cpu")
13+
torch_device = torch.device(device)
1014

1115
app = Flask(__name__)
1216
model = DextrModel.pascalvoc_resunet101()
17+
model.eval()
1318

1419

1520
@app.route("/", methods=["GET", "POST"])
@@ -25,8 +30,11 @@ def hello_world():
2530
image = Image.open(path)
2631
print(f"Image Size: {image.size}", flush=True)
2732

33+
# points come in [x,y] order; this must be flipped
34+
points = points[:, ::-1]
2835
mask = model.predict([image], [points])[0]
29-
polygons = Mask(mask).polygons().points
36+
mask_bin = mask >= 0.5
37+
polygons = Mask(mask_bin).polygons().points
3038
polygons = [polygon.tolist() for polygon in polygons if len(polygon) > 2]
3139
print(f"Result: {polygons}", flush=True)
3240

0 commit comments

Comments
 (0)