@@ -83,22 +83,22 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool
83
83
stream << SGR::Fg_Green << " Location: " << SGR::Reset;
84
84
85
85
const auto & memory_info = value->GetTensorMemoryInfo ();
86
- switch (memory_info.GetDeviceType ()) {
87
- case OrtMemoryInfoDeviceType_CPU:
88
- stream << " CPU \r\n " ;
86
+ if (memory_info.GetDeviceType () == OrtMemoryInfoDeviceType_CPU ) {
87
+ stream << " CPU \r\n " ;
88
+ if (dump_value) {
89
89
DumpValues (stream, type_info->GetElementType (), value->GetTensorRawData (), element_count);
90
- break ;
91
- case OrtMemoryInfoDeviceType_GPU: {
92
- stream << " GPU\r\n " ;
93
- auto type = type_info->GetElementType ();
94
- auto tensor_span = std::span<uint8_t >{const_cast <OrtValue*>(value)->GetTensorMutableData <uint8_t >(), SizeOf (type) * element_count};
95
- auto device_span = model.p_device_ ->WrapMemory <uint8_t >(tensor_span);
96
- DumpValues (stream, type, device_span.CopyDeviceToCpu ().data (), element_count);
97
- break ;
98
90
}
99
- default :
100
- stream << " Unhandled device type: " << static_cast <int >(memory_info.GetDeviceType ()) << " \r\n " ;
101
- break ;
91
+ } else if (memory_info.GetDeviceType () == OrtMemoryInfoDeviceType_GPU || memory_info.GetDeviceType () == 4 ) {
92
+ if (memory_info.GetDeviceType () == OrtMemoryInfoDeviceType_GPU)
93
+ stream << " GPU\r\n " ;
94
+ else
95
+ stream << " DML\r\n " ;
96
+ auto type = type_info->GetElementType ();
97
+ auto tensor_span = std::span<uint8_t >{const_cast <OrtValue*>(value)->GetTensorMutableData <uint8_t >(), SizeOf (type) * element_count};
98
+ auto device_span = model.p_device_ ->WrapMemory <uint8_t >(tensor_span);
99
+ DumpValues (stream, type, device_span.CopyDeviceToCpu ().data (), element_count);
100
+ } else {
101
+ stream << " Unhandled device type: " << static_cast <int >(memory_info.GetDeviceType ()) << " \r\n " ;
102
102
}
103
103
}
104
104
0 commit comments