Skip to content

Commit 14ae470

Browse files
committed
re-enable debugging for dml
1 parent 5f4fa86 commit 14ae470

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

src/models/debugging.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,22 +83,22 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool
8383
stream << SGR::Fg_Green << " Location: " << SGR::Reset;
8484

8585
const auto& memory_info = value->GetTensorMemoryInfo();
86-
switch (memory_info.GetDeviceType()) {
87-
case OrtMemoryInfoDeviceType_CPU:
88-
stream << "CPU\r\n";
86+
if (memory_info.GetDeviceType() == OrtMemoryInfoDeviceType_CPU) {
87+
stream << "CPU\r\n";
88+
if (dump_value) {
8989
DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count);
90-
break;
91-
case OrtMemoryInfoDeviceType_GPU: {
92-
stream << "GPU\r\n";
93-
auto type = type_info->GetElementType();
94-
auto tensor_span = std::span<uint8_t>{const_cast<OrtValue*>(value)->GetTensorMutableData<uint8_t>(), SizeOf(type) * element_count};
95-
auto device_span = model.p_device_->WrapMemory<uint8_t>(tensor_span);
96-
DumpValues(stream, type, device_span.CopyDeviceToCpu().data(), element_count);
97-
break;
9890
}
99-
default:
100-
stream << "Unhandled device type: " << static_cast<int>(memory_info.GetDeviceType()) << "\r\n";
101-
break;
91+
} else if (memory_info.GetDeviceType() == OrtMemoryInfoDeviceType_GPU || memory_info.GetDeviceType() == 4) {
92+
if (memory_info.GetDeviceType() == OrtMemoryInfoDeviceType_GPU)
93+
stream << "GPU\r\n";
94+
else
95+
stream << "DML\r\n";
96+
auto type = type_info->GetElementType();
97+
auto tensor_span = std::span<uint8_t>{const_cast<OrtValue*>(value)->GetTensorMutableData<uint8_t>(), SizeOf(type) * element_count};
98+
auto device_span = model.p_device_->WrapMemory<uint8_t>(tensor_span);
99+
DumpValues(stream, type, device_span.CopyDeviceToCpu().data(), element_count);
100+
} else {
101+
stream << "Unhandled device type: " << static_cast<int>(memory_info.GetDeviceType()) << "\r\n";
102102
}
103103
}
104104

0 commit comments

Comments
 (0)