-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathSKLearnProcessor_local_processing.py
50 lines (41 loc) · 2.22 KB
/
SKLearnProcessor_local_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# This is a sample Python program that runs a simple scikit-learn processing using the SKLearnProcessor.
# This implementation will work on your *local computer*.
#
# Prerequisites:
# 1. Install required Python packages:
# pip install boto3 sagemaker pandas scikit-learn
# pip install 'sagemaker[local]'
# 2. Docker Desktop installed and running on your computer:
# `docker ps`
# 3. You should have AWS credentials configured on your local machine
# in order to be able to pull the docker image from ECR.
########################################################################################################################
from sagemaker.local import LocalSession
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.sklearn.processing import SKLearnProcessor
sagemaker_session = LocalSession()
sagemaker_session.config = {'local': {'local_code': True}}
# For local training a dummy role will be sufficient
role = 'arn:aws:iam::111111111111:role/service-role/AmazonSageMaker-ExecutionRole-20200101T000001'
processor = SKLearnProcessor(framework_version='0.20.0',
instance_count=1,
instance_type='local',
role=role)
print('Starting processing job.')
print('Note: if launching for the first time in local mode, container image download might take a few minutes to complete.')
processor.run(code='processing_script.py',
inputs=[ProcessingInput(
source='./input_data/',
destination='/opt/ml/processing/input_data/')],
outputs=[ProcessingOutput(
output_name='word_count_data',
source='/opt/ml/processing/processed_data/')],
arguments=['job-type', 'word-count']
)
preprocessing_job_description = processor.jobs[-1].describe()
output_config = preprocessing_job_description['ProcessingOutputConfig']
print(output_config)
for output in output_config['Outputs']:
if output['OutputName'] == 'word_count_data':
word_count_data_file = output['S3Output']['S3Uri']
print('Output file is located on: {}'.format(word_count_data_file))