Open
Description
I find two errors in the code for counting duration in file analyzer.py:
- In PyTorch, the execution of the program is asynchronous. If we use the following code to record the start and end time, the duration will be very short, because the end time is recorded without waiting for the GPU to complete the computation.
- If a module in CNN passes forward propagation multiple times, according to the following code, only the duration of the last forward propagation will be recorded, not the duration of each forward propagation.
Here is my solution:
# tensorwatch\tensorwatch\model_graph\torchstat\analyzer.py
class ModuleStats:
def __init__(self, name) -> None:
# self.duration = 0.0
self.duration = []
def _forward_pre_hook(module_stats:ModuleStats, module:nn.Module, input):
assert not module_stats.done
torch.cuda.synchronize()
module_stats.start_time = time.time()
def _forward_post_hook(module_stats:ModuleStats, module:nn.Module, input, output):
assert not module_stats.done
torch.cuda.synchronize()
module_stats.end_time = time.time()
# Using a list to store the duration of each forward propagation.
# module_stats.duration = module_stats.end_time-module_stats.start_time
module_stats.duration.append(module_stats.end_time - module_stats.start_time)
# other code
# tensorwatch\tensorwatch\model_graph\torchstat\stat_tree.py
class StatNode(object):
def __init__(self, name=str(), parent=None):
# self.duration = 0
self._duration = []
@property
def duration(self):
# total_duration = self._duration
total_duration = sum(self._duration)
for child in self.children:
total_duration += child.duration
return total_duration
# or
return self._duration
I also provide a simple comparison result. In the Bottleneck of the ResNet backbone, the same relu
function will be called three times, so there will be three corresponding durations. But in the TensorWatch statistics, we can only see one record of relu
in the Bottleneck.
But using my modified code, we can see that the duration of the three calls to the relu
function are all recorded.
Metadata
Metadata
Assignees
Labels
No labels