In a project I am working on, I had to push the model trained on a GPU into production. The problem started when I tried to make an inference script for the model and execute it in a CPU-only device to test the script.
Habituated to the
.to(device) method, I loaded the model using the same structure.
model = torch.load(model_path).to(device)
However, it failed when I tried to load the model on a CPU-only device. That's when I learned we must pass the argument
map_location in the
.load() method instead of using the
model = torch.load(model_path, map_location=device)
Why can't I keep using
to(device)? What difference does
map_location make? When can I use it safely
to(device) ? When is
to(device) more useful than
map_location ? Let's discuss the answers to these questions now.
The "technical" difference between the two.
torch.load() (without map_location) method tries to load the model into the device it was saved on, then ports it to a specified GPU or CPU using
to(device). It's a multi-step process. Whereas the `
torch.load(model_path, map_location=device) is a single-step process where the model parameters are directly loaded into the specified device. The following illustration might help us understand what is happening on a high level.
Why does this difference matter?
Imagine if we removed the GPU when loading the image in the above image. Which of the two paths would fail? Then
.to("cpu") would fail as there is no GPU to load the model into. Even when you have an additional GPU that can store the model, you reduce the computational load of transferring the model from one device to another when using the
How can this difference be utilized?
There are multiple areas where we can make the most of this difference. For example, we can safely skip the CUDA version mismatch or the architecture mismatch issues with the help of
map_location, While we can use
.to(device) to dynamically choose model parts to be pushed into the device of choice.
In the following table, I have tried to summarize all the major differences where one approach can be more useful than the other. This side-by-side comparison is presented in the table below.
When to Use
After loading the model or tensor.
During the loading of the model or tensor from a checkpoint.
Slower if moving large models after loading.
More efficient for large models as it loads directly to the target device.
(Relevant to the performance) Less efficient if you're constantly swapping devices, as it involves additional memory transfers.
(Relevant to the performance) More efficient since it allows you to directly allocate memory on the target device during the loading process.
Better when dealing with data that varies in size in each batch and gives you the flexibility to manually choose the part of the model that needs to be on GPU (If you are using only some parts of the model, typically done in multi-modal approaches.).
Not suitable for dynamic batch processing.
Easier to manage when you're using DataParallel or DistributedDataParallel for multi-GPU support.
You'd generally set this once at the beginning, so it is generally not the best approach for dynamic multi-GPU scenarios.
Easier to debug because you can isolate the to(device) operation and examine what's happening. (Or confusing when you keep changing the devices of different parameters)
Since it's tied to loading, harder to debug in isolation. (Or sometimes easier, as you can be sure of the device the model is located in.)
Can be easily used in scenarios where you want to fine-tune only specific layers on a different device.
Not as flexible for fine-tuning specific layers, as it's an all-or-nothing approach at the time of loading.
Allows for conditional logic to determine device placement during runtime.
No room for conditional logic during loading (a set-and-forget operation)
To avoid all these issues, one of the best and recommended approaches is to push the models to the CPU before saving them using
torch.save() , as a system might not have a TPU or GPU, but it can't function without a CPU.
I hope you enjoyed reading this article. Let's meet again with more detailed articles. Till then, stay tuned to Neuronuts! ✌️ P.S. In the meantime, you can check out my other articles where I discussed RegEx or learn more about the different phases involved in a machine learning project lifecycle here.