@@ -68,8 +68,8 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
68
68
fn batch (& self , items : Vec <MnistItem >) -> MnistBatch <B > {
69
69
let images = items
70
70
. iter ()
71
- . map (| item | TensorData :: from (item . image))
72
- . map (| data | Tensor :: <B , 2 >:: from_data (data . convert () , & self . device))
71
+ . map (| item | TensorData :: from (item . image). convert :: < B :: FloatElem >() )
72
+ . map (| data | Tensor :: <B , 2 >:: from_data (data , & self . device))
73
73
. map (| tensor | tensor . reshape ([1 , 28 , 28 ]))
74
74
// Normalize: make between [0,1] and make the mean=0 and std=1
75
75
// values mean=0.1307,std=0.3081 are from the PyTorch MNIST example
@@ -119,8 +119,8 @@ images.
119
119
``` rust, ignore
120
120
let images = items // take items Vec<MnistItem>
121
121
.iter() // create an iterator over it
122
- .map(|item| TensorData::from(item.image)) // for each item, convert the image to float32 data struct
123
- .map(|data| Tensor::<B, 2>::from_data(data.convert() , &self.device)) // for each data struct, create a tensor on the device
122
+ .map(|item| TensorData::from(item.image).convert::<B::FloatElem>()) // for each item, convert the image to float data struct
123
+ .map(|data| Tensor::<B, 2>::from_data(data, &self.device)) // for each data struct, create a tensor on the device
124
124
.map(|tensor| tensor.reshape([1, 28, 28])) // for each tensor, reshape to the image dimensions [C, H, W]
125
125
.map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) // for each image tensor, apply normalization
126
126
.collect(); // consume the resulting iterator & collect the values into a new vector
@@ -138,5 +138,6 @@ a targets tensor that contains the indexes of the correct digit class. The first
138
138
the image array into a ` TensorData ` struct. Burn provides the ` TensorData ` struct to encapsulate
139
139
tensor storage information without being specific for a backend. When creating a tensor from data,
140
140
we often need to convert the data precision to the current backend in use. This can be done with the
141
- ` .convert() ` method. While importing the ` burn::tensor::ElementConversion ` trait, you can call
142
- ` .elem() ` on a specific number to convert it to the current backend element type in use.
141
+ ` .convert() ` method (in this example, the data is converted backend's float element type
142
+ ` B::FloatElem ` ). While importing the ` burn::tensor::ElementConversion ` trait, you can call ` .elem() `
143
+ on a specific number to convert it to the current backend element type in use.
0 commit comments