diff --git a/examples/aws-text-benchmarks/benchmark_deepsparse.py b/examples/aws-text-benchmarks/benchmark_deepsparse.py new file mode 100644 index 0000000000..42e8335f30 --- /dev/null +++ b/examples/aws-text-benchmarks/benchmark_deepsparse.py @@ -0,0 +1,90 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +from tqdm import tqdm +from transformers import AutoTokenizer + +from datasets import load_dataset +from deepsparse import Context, Pipeline + + +os.environ["NM_BIND_THREADS_TO_CORES"] = "1" +INPUT_COL = "text" +dataset = load_dataset("ag_news", split="train[:3000]") +batch_size = 64 +buckets = [64, 128, 256] +model_path = "./sparse-model/deployment/" + +# TOKENIZE DATASET - (used to comptue buckets) +tokenizer = AutoTokenizer.from_pretrained(model_path) + + +def pre_process_fn(examples): + return tokenizer( + examples[INPUT_COL], + add_special_tokens=True, + return_tensors="np", + padding=False, + truncation=False, + ) + + +dataset = dataset.map(pre_process_fn, batched=True) +dataset = dataset.add_column("num_tokens", list(map(len, dataset["input_ids"]))) +dataset = dataset.sort("num_tokens") +max_token_len = dataset[-1]["num_tokens"] + +# SPLIT DATA INTO BATCHES +num_pad_items = batch_size - (dataset.num_rows % batch_size) +inputs = ([""] * num_pad_items) + dataset[INPUT_COL] +batches = [] + +for b_index_start in range(0, len(inputs), batch_size): + batches.append(inputs[b_index_start : b_index_start + batch_size]) + +# RUN THROUPUT TESTING +print("\nCompiling models:") + +tc_pipeline = Pipeline.create( + task="zero_shot_text_classification", + model_path=model_path, + model_scheme="mnli", + sequence_length=buckets, + batch_size=batch_size, + context=Context(num_streams=1), +) +print("\nRunning test:") +# run inferences on the datset +start = time.perf_counter() + +predictions = [] +for batch in tqdm(batches): + predictions.append( + tc_pipeline(sequences=batch, labels=["Sports", "Business", "Sci/Tech"]) + ) + +# flatten and remove padded predictions +predictions = [pred for sublist in predictions for pred in sublist.labels] +predictions = predictions[num_pad_items:] +end = time.perf_counter() + +# compute throughput +total_time_executing = end - start +print(f"Total time: {total_time_executing}") +items_per_sec = len(predictions) / total_time_executing + +print(f"Items Per Second: {items_per_sec}") diff --git a/examples/aws-text-benchmarks/benchmark_huggingface.py b/examples/aws-text-benchmarks/benchmark_huggingface.py new file mode 100644 index 0000000000..3ccc485856 --- /dev/null +++ b/examples/aws-text-benchmarks/benchmark_huggingface.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from transformers import AutoTokenizer, pipeline +from transformers.pipelines.pt_utils import KeyDataset + +import torch +from datasets import load_dataset + + +model_path = "./dense-model/training/" +batch_size = 64 + +# SETUP DATASETS - in this case, we download ag_news +print("Setting up the dataset:") + +INPUT_COL = "text" +dataset = load_dataset("ag_news", split="train[:3000]") + +# TOKENIZE DATASETS - to sort dataset +tokenizer = AutoTokenizer.from_pretrained(model_path) + + +def pre_process_fn(examples): + return tokenizer( + examples[INPUT_COL], + add_special_tokens=True, + return_tensors="np", + padding=False, + truncation=False, + ) + + +dataset = dataset.map(pre_process_fn, batched=True) +dataset = dataset.add_column("num_tokens", list(map(len, dataset["input_ids"]))) +dataset = dataset.sort("num_tokens") + +# SPLIT DATA INTO BATCHES +hf_dataset = KeyDataset(dataset, INPUT_COL) + +# RUN THROUGPUT TESTING +# load model +hf_pipeline = pipeline( + "zero-shot-classification", + model_path, + batch_size=batch_size, + device=("cuda:0" if torch.cuda.is_available() else "cpu"), +) + +# run inferences +start = time.perf_counter() + +predictions = [] +for prediction in hf_pipeline( + hf_dataset, candidate_labels=["Sports", "Business", "Sci/Tech"] +): + predictions.append(prediction) + +# torch.cuda.synchronize() + +end = time.perf_counter() + +# compute throughput +total_time_executing = end - start +items_per_sec = len(predictions) / total_time_executing + +print(f"Total time: {total_time_executing}") +print(f"Items Per Second: {items_per_sec}") diff --git a/examples/aws-text-benchmarks/image.png b/examples/aws-text-benchmarks/image.png new file mode 100644 index 0000000000..9992593d50 Binary files /dev/null and b/examples/aws-text-benchmarks/image.png differ diff --git a/examples/aws-text-benchmarks/readme.md b/examples/aws-text-benchmarks/readme.md new file mode 100644 index 0000000000..b394014e00 --- /dev/null +++ b/examples/aws-text-benchmarks/readme.md @@ -0,0 +1,65 @@ + + +This repo contains example benchmarking scripts for computing throughput of DeepSparse with a sparse model and throughput of HuggingFace + PyTorch on a GPU with a dense model. + +In this example, we run on the `ag_news` dataset with models downloaded from SparseZoo. + +## Sparse Model DeepSparse + +Install DeepSparse: + +```bash +pip install deepsparse[transformers] +``` + +Download Sparse Model: + +```bash +sparsezoo.download zoo:nlp/text_classification/bert-large/pytorch/huggingface/mnli/pruned90_quant-none --save-dir ./sparse-model +``` + +Run DeepSparse Benchmark (creates buckets for token len 64, 128, and 256): + +```bash +python benchmark_deepsparse.py +``` + +Note: DeepSparse uses static input shapes. Since the distribution of inputs for a dataset will be varied (multiple different sequence lengths), +we can use bucketing where we compile DeepSparse with multiple input shapes and dynamically route inputs. +In the case of `ag_news` (the example dataset in this case), the distribution of token lengths looks like the following: +![Histogram](image.png) + +As such, we used buckets of length 64, 128, and 256. DeepSparse runs best with sequence lengths that are multiples of 16. + +## Dense Model GPU + +Install `transformers` and `datasets`: +``` +pip install transformers[torch] +pip install datasets +pip install sparzeoo +``` + +Download Dense Model: +```bash +sparsezoo.download zoo:nlp/text_classification/bert-large/pytorch/huggingface/mnli/base-none --save-dir ./dense-model +``` + +Run HF Benchmark (on GPU): +``` +python benchmark_huggingface.py +```