In this blog, we are going to take a look at TorchServe, a feature-rich framework for serving machine learning models. We will go through a variety of experiments to test the performance on different operating systems, thread settings, and a number of workers to discover the best ways to optimize our models depending on the environment configuration.
The inspiration for this exercise was from a Fortune 500 customer who is a member of the PyTorch Enterprise Support Program. They recently ran into an issue with a performance discrepancy between MacOS, Linux, and Windows. Setting the number of threads in the PyTorch code to 1 improved performance significantly on Linux and Windows but made no difference to MacOS performance. Let’s jump in and get a closer look at why this is happening.
Why the difference in performance based on the OS?
To dig deeper and do performance testing we need to look at some different parameters: threads and workers for autoscaling. The 3 groups of parameters to adjust and fine-tune TorchServe performance are: pool size in Netty, number of workers in TorchServe, and number of threads in PyTorch. TorchServe uses multiple threads on both the frontend and backend to speed up the overall inference process. PyTorch also has its own threading setting during model inference. More details on PyTorch threading can be found here.
The threading could impact the overall inference performance significantly. Our theory was the default setting on Linux (and Windows) creates too many threads. When the PyTorch model is small and the inferencing cost for a single input is small, the context switch time from the threads causes the inference time to increase and degrades the performance. TorchServe Docs shows us that we can set properties to tell TorchServe how many threads to use for both frontend and backend. Additionally, we can set the number of threads to use for PyTorch. We will adjust these parameters in our testing to see how we can improve performance and find the root cause of the issue reported by the customer.
- TorchServe Settings:
- number_of_netty_threads: Number frontend Netty thread.
- netty_client_threads: Number of backend natty thread.
- default_workers_per_model: Number of workers to create for each model that loaded at startup time.
- PyTorch Settings:
- PyTorch number of threads: There are several ways to set number of threads used in torch.
Testing TorchServe Configurations to Optimize Performance
Experiment 1: Performance test on serving PyTorch models directly without TorchServe on Windows.
To simulate how TorchServe loads and executes inferencing on the model we wrote some wrapper code to serve the model with TorchServe. This will allow us to isolate the location of the performance loss. When we execute inferencing directly on the model with the wrapper code there is no difference between setting the number of threads equal to 1 and when it is left as the default value. This would suggest that there is sufficient computing capacity to handle the workload and that these values work optimally for this model size. However, this could still cause a problem if the number of threads is set to a high value as we will discuss in more detail later. This result can be verified using the below script and running it against your model. (This script can be used for testing with both Linux and Windows)
import sys from ts.metrics.metrics_store import MetricsStore from ts.torch_handler.base_handler import BaseHandler from uuid import uuid4 from pprint import pprint class ModelContext: def __init__(self): self.manifest = { 'model': { 'modelName': 'ptclassifier', 'serializedFile': '<ADD MODEL NAME HERE>', 'modelFile': 'model_ph.py' } } self.system_properties = { 'model_dir': '<ADD COMPLETE MODEL PATH HERE>' } self.explain = False self.metrics = MetricsStore(uuid4(), self.manifest['model']['modelName']) def get_request_header(self, idx, exp): if exp == 'explain': return self.explain return False def main(): if sys.argv[1] == 'fast': from ptclassifier.TransformerSeqClassificationHandler import TransformersSeqClassifierHandler as Classifier else: from ptclassifiernotr.TransformerSeqClassificationHandler import TransformersSeqClassifierHandler as Classifier ctxt = ModelContext() handler = Classifier() handler.initialize(ctxt) data = [{'data': 'To be or not to be, that is the question.'}] for i in range(1000): processed = handler.handle(data, ctxt) #print(processed) for m in ctxt.metrics.store: print(f'{m.name}: {m.value} {m.unit}') if __name__ == '__main__': main()
Experiment 2: Performance test on a larger model on Linux.
We used a different model from the official TorchServe HuggingFaces Sample with the same testing script as above to get insights on Linux. Since the HuggingFace model is much heavier than the model from our customer, it will allow us to test the inferencing performance on a longer-running instance.
For a larger model size, the cost of the context switch is smaller as compared to the cost of inference. Thus, the performance difference is smaller. In the below test result we can see the performance difference (3X vs 10X):
Experiment 3: Performance test on different combinations of thread settings on Linux.
This experiment shows how the thread setting will impact the overall inference performance. We tested on WSL (Windows Subsystem Linux) with 4 physical cores and 8 logical cores. The test result shows sufficient threading will improve performance but over-threading will hurt the inference performance and overall throughput substantially. The best result from the experiment shows 50X inference speed up and 18X throughput improvement over the least effective settings. The metrics for the best and worst are illustrated below in the tables:
Here is the script used to create these test results:
import subprocess import itertools import time import os from time import sleep def do_test(number_of_netty_threads=1, netty_client_threads=1, default_workers_per_model=1, job_queue_size=100, MKL_NUM_THREADS=1, test_parallelism=8): # generate config.properties files based on combination config_file_name = "config_file/" + f"config_{number_of_netty_threads}_{netty_client_threads}_{default_workers_per_model}_{job_queue_size}.properties" f = open(config_file_name, "w") f.write("load_models=all\n") f.write("inference_address=http://0.0.0.0:8080\n") f.write("management_address=http://0.0.0.0:8081\n") f.write("metrics_address=http://0.0.0.0:8082\n") f.write("model_store=<ADD COMPLETE MODEL PATH HERE>\n") f.write(f"number_of_netty_threads={number_of_netty_threads}\n") f.write(f"netty_client_threads={netty_client_threads}\n") f.write(f"default_workers_per_model={default_workers_per_model}\n") f.write(f"job_queue_size={job_queue_size}\n") f.close() # start the torch serve with proper config properties and other parameter settings subprocess.call(f"MKL_NUM_THREADS={str(MKL_NUM_THREADS)} torchserve --start --model-store model-store --models model=<ADD MODEL NAME HERE> --ncs --ts-config {config_file_name}", shell=True, stdout=subprocess.DEVNULL) sleep(3) # test in parallel to inference API print("start to send test request...") start_time = time.time() print(time.ctime()) subprocess.run(f"seq 1 1000 | xargs -n 1 -P {str(test_parallelism)} bash -c 'url=\"http://127.0.0.1:8080/predictions/model\"; curl -X POST $url -T input.txt'", shell=True, capture_output=True, text=True) total_time = int((time.time() - start_time)*1e6) print("total time in ms:", total_time) # get metrics of ts inference latency and ts query latency output = subprocess.run("curl http://127.0.0.1:8082/metrics", shell=True, capture_output=True, text=True) inference_time=0 query_time=0 # capture inference latency and query latency from metrics for line in output.stdout.split('\n'): if line.startswith('ts_inference_latency_microseconds'): inference_time = line.split(' ')[1] if line.startswith('ts_queue_latency_microseconds'): query_time = line.split(' ')[1] # calculate the throughput throughput = 1000 / total_time * 1000000 # write metrics to csv file for display f = open("test_result_short.csv", "a") f.write(f"{number_of_netty_threads},{netty_client_threads},{default_workers_per_model}, {MKL_NUM_THREADS},{job_queue_size},{test_parallelism},{total_time}, {inference_time},{query_time},{throughput}\n") f.close() # stop torchserve for this stop_result = os.system("torchserve --stop") print(stop_result) stop_result = os.system("torchserve --stop") print(stop_result) stop_result = os.system("torchserve --stop") print(stop_result) def main(): # set the possible value, value range of each parameter number_of_netty_threads = [1, 2, 4, 8] netty_client_threads = [1, 2, 4, 8] default_workers_per_model = [1, 2, 4, 8] MKL_NUM_THREADS = [1, 2, 4, 8] job_queue_size = [1000] #[100, 200, 500, 1000] test_parallelism = [32] #[8, 16, 32, 64] # for each combination of parameters [do_test(a, b, c, d, e, f) for a, b, c, d, e, f in itertools.product(number_of_netty_threads, netty_client_threads, default_workers_per_model, job_queue_size, MKL_NUM_THREADS, test_parallelism)] if __name__ == "__main__": main()
Experiment 4: Testing on Windows using Windows Performance Recorder.
We were able to replicate the performance issue on Windows as well. We then used Windows Performance Recorder and Windows Performance Analyzer to analyze the overall system performance while running the tests on the model.
Both the figures below show the total number of context switches by process and thread in the system. All the python processes (worker processes created by TorchServe) are colored in green.
The figure above shows the number of context switches plotted against time for the slow case (when the number of threads is set to the default).
The figure above shows the same data when the number of threads is set to 1.
We can clearly see the difference in the performance coming from the significantly higher number of context switches when the number of threads is left to the default value.
Experiment 5: The MacOS Finding
On MacOS, PyTorch has an issue: Only one thread is used on macOS. The default thread number is 1 instead of the number of logical cores. This will result in better TorchServe performance as it will eliminate the context switch cost. Therefore, the performance issue was not present on MacOS.
Summary
This performance issue has turned out to be more of an issue on thread settings. The main takeaway is that setting the number of threads to 1 in PyTorch has the effect of reducing the total number of threads running on the system and thus reduces the overall number of context switches on Linux and Windows. For MacOS, the call to set the number of threads has no impact and we see no performance difference as a result.
Many factors can affect the optimal combination for inference, such as: number of available cores, the number of models served, and how large the models are. The best way to find the optimal combination is through experimentation. There is not an existing tool/framework to automatically set the best value for each parameter to reach balanced performance and throughput. An investigation is often needed to fine-tune thread setting depending on target hardware when using TorchServe to serve models as explained in the TorchServe Documentation. We have shared test scripts as part of this blog to run these performance tests on your model inferencing. Use the scripts provided to experiment and find the best settings balance for inference latency and overall throughput on your model.
References
- For full code samples check out this GitHub repo
- TorchServe Documentation
- TorchServe Management API
- TorchServe HuggingFaces Sample
Article by Frank Dong & Cassie Breviu of Microsoft.