Skip to content

Commit a692a7a

Browse files
committed
clean it up like ryan wants it
1 parent 14ae470 commit a692a7a

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

src/models/debugging.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,23 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool
8888
if (dump_value) {
8989
DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count);
9090
}
91-
} else if (memory_info.GetDeviceType() == OrtMemoryInfoDeviceType_GPU || memory_info.GetDeviceType() == 4) {
92-
if (memory_info.GetDeviceType() == OrtMemoryInfoDeviceType_GPU)
93-
stream << "GPU\r\n";
94-
else
95-
stream << "DML\r\n";
91+
// Internally there are 5 device types defined in onnxruntime but only 3 are exposed in the public API
92+
// https://github.com/microsoft/onnxruntime/blob/9dbfee91ca9c2ba2074d19805bb6dedccedbcfe3/include/onnxruntime/core/framework/ortdevice.h#L15
93+
} else if (memory_info.GetDeviceType() < 5) {
94+
switch (model.p_device_->GetType()) {
95+
case DeviceType::CUDA:
96+
stream << "CUDA\r\n";
97+
break;
98+
case DeviceType::DML:
99+
stream << "DML\r\n";
100+
break;
101+
case DeviceType::QNN:
102+
stream << "QNN\r\n";
103+
break;
104+
default:
105+
stream << "Unknown\r\n";
106+
break;
107+
}
96108
auto type = type_info->GetElementType();
97109
auto tensor_span = std::span<uint8_t>{const_cast<OrtValue*>(value)->GetTensorMutableData<uint8_t>(), SizeOf(type) * element_count};
98110
auto device_span = model.p_device_->WrapMemory<uint8_t>(tensor_span);

0 commit comments

Comments
 (0)