Skip to content

Commit 2cdf192

Browse files
authored
Merge pull request #52 from hatappi/feature/average_pooling_2d
Support average pooling 2d
2 parents e79f4d9 + f45d768 commit 2cdf192

File tree

3 files changed

+167
-0
lines changed

3 files changed

+167
-0
lines changed

lib/chainer.rb

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
require 'chainer/functions/noise/dropout'
4444
require 'chainer/functions/normalization/batch_normalization'
4545
require 'chainer/functions/pooling/pooling_2d'
46+
require 'chainer/functions/pooling/average_pooling_2d'
4647
require 'chainer/functions/pooling/max_pooling_2d'
4748
require 'chainer/testing/array'
4849
require 'chainer/training/extension'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
module Chainer
2+
module Functions
3+
module Pooling
4+
class AveragePooling2D < Pooling2D
5+
# Spatial average pooling function.
6+
#
7+
# This function acts similarly to :class:`Convolution2D`,
8+
# but it computes the average of input spatial patch for each channel
9+
# without any parameter instead of computing the inner products.
10+
# @param [Chainer::Variable] x Input variable.
11+
# @param [integer] ksize Size of pooling window. `ksize=k` and `ksize=[k, k]` are equivalent.
12+
# @param [integer] stride Stride of pooling applications. `stride=s` and `stride=[s, s]` are equivalent.
13+
# If `nil` is specified, then it uses same stride as the pooling window size.
14+
# @param [integer] pad Spatial padding width for the input array. `pad=p` and `pad=[p, p]` are equivalent.
15+
# @return [Chainer::Variable] Output variable
16+
def self.average_pooling_2d(x, ksize, stride: nil, pad: 0)
17+
self.new(ksize, stride: stride, pad: pad, cover_all: false).(x)
18+
end
19+
20+
# Average pooling over a set of 2d planes.
21+
def forward_cpu(x)
22+
retain_inputs([])
23+
@in_shape = x[0].shape
24+
@in_dtype = x[0].class
25+
26+
col = Chainer::Utils::Conv.im2col_cpu(x[0], @kh, @kw, @sy, @sx, @ph, @pw)
27+
y = col.mean(axis: [2, 3])
28+
29+
[y]
30+
end
31+
32+
def backward_cpu(x, gy)
33+
h, w = @in_shape[2..-1]
34+
shape = gy[0].shape
35+
shape.insert(2, 1, 1)
36+
gcol = gy[0].reshape(*shape).tile(1, 1, @kh, @kw, 1, 1)
37+
38+
gx = Chainer::Utils::Conv.col2im_cpu(gcol, @sy, @sx, @ph, @pw, h, w)
39+
gx /= @kh * @kw
40+
[gx]
41+
end
42+
end
43+
end
44+
end
45+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
class Chainer::Functions::Pooling::AveragePooling2DTest < Test::Unit::TestCase
2+
data(
3+
test1: {
4+
case: {
5+
x: Numo::SFloat.new(1, 3, 4, 6).seq,
6+
ksize: 2,
7+
options: {}
8+
},
9+
expected: Numo::SFloat[[[[ 3.5, 5.5, 7.5],
10+
[15.5, 17.5, 19.5]],
11+
[[27.5, 29.5, 31.5],
12+
[39.5, 41.5, 43.5]],
13+
[[51.5, 53.5, 55.5],
14+
[63.5, 65.5, 67.5]]]]
15+
},
16+
test2: {
17+
case: {
18+
x: Numo::SFloat.new(1, 3, 4, 4).seq,
19+
ksize: 2,
20+
options: { stride: 2 }
21+
},
22+
expected: Numo::SFloat[[[[ 2.5, 4.5],
23+
[10.5, 12.5]],
24+
[[18.5, 20.5],
25+
[26.5, 28.5]],
26+
[[34.5, 36.5],
27+
[42.5, 44.5]]]]
28+
},
29+
test3: {
30+
case: {
31+
x: Numo::SFloat.new(1, 3, 4, 4).seq,
32+
ksize: 4,
33+
options: { stride: 2, pad: 1 }
34+
},
35+
expected: Numo::SFloat[[[[ 2.8125, 3.375 ],
36+
[ 5.0625, 5.625 ]],
37+
[[11.8125, 12.375 ],
38+
[14.0625, 14.625 ]],
39+
[[20.8125, 21.375 ],
40+
[23.0625, 23.625 ]]]]
41+
},
42+
)
43+
def test_average_pooling_2d(data)
44+
test_case = data[:case]
45+
actual = Chainer::Functions::Pooling::AveragePooling2D.average_pooling_2d(test_case[:x], test_case[:ksize], **test_case[:options])
46+
assert_equal(data[:expected], actual.data)
47+
end
48+
49+
data({
50+
test1: {
51+
case: {
52+
x: Numo::SFloat.new(2, 3, 2, 2).seq,
53+
gy: [Numo::SFloat.new(2, 3, 1, 1).seq],
54+
ksize: 2,
55+
stride: 2,
56+
pad: 0,
57+
cover_all: false
58+
},
59+
expected: Numo::SFloat[[[[0.0 , 0.0 ],
60+
[0.0 , 0.0 ]],
61+
[[0.25, 0.25],
62+
[0.25, 0.25]],
63+
[[0.5 , 0.5 ],
64+
[0.5 , 0.5 ]]],
65+
[[[0.75, 0.75],
66+
[0.75, 0.75]],
67+
[[1.0 , 1.0 ],
68+
[1.0 , 1.0 ]],
69+
[[1.25, 1.25],
70+
[1.25, 1.25]]]]
71+
},
72+
test2: {
73+
case: {
74+
x: Numo::SFloat.new(2, 2, 4, 4).seq,
75+
gy: [Numo::SFloat.new(2, 3 ,1, 1).seq],
76+
ksize: 6,
77+
stride: 8,
78+
pad: 1,
79+
cover_all: false
80+
},
81+
expected: Numo::SFloat[[[[0.0 , 0.0 , 0.0 , 0.0 ],
82+
[0.0 , 0.0 , 0.0 , 0.0 ],
83+
[0.0 , 0.0 , 0.0 , 0.0 ],
84+
[0.0 , 0.0 , 0.0 , 0.0 ]],
85+
[[0.0277778, 0.0277778, 0.0277778, 0.0277778],
86+
[0.0277778, 0.0277778, 0.0277778, 0.0277778],
87+
[0.0277778, 0.0277778, 0.0277778, 0.0277778],
88+
[0.0277778, 0.0277778, 0.0277778, 0.0277778]],
89+
[[0.0555556, 0.0555556, 0.0555556, 0.0555556],
90+
[0.0555556, 0.0555556, 0.0555556, 0.0555556],
91+
[0.0555556, 0.0555556, 0.0555556, 0.0555556],
92+
[0.0555556, 0.0555556, 0.0555556, 0.0555556]]],
93+
[[[0.0833333, 0.0833333, 0.0833333, 0.0833333],
94+
[0.0833333, 0.0833333, 0.0833333, 0.0833333],
95+
[0.0833333, 0.0833333, 0.0833333, 0.0833333],
96+
[0.0833333, 0.0833333, 0.0833333, 0.0833333]],
97+
[[0.1111111, 0.1111111, 0.1111111, 0.1111111],
98+
[0.1111111, 0.1111111, 0.1111111, 0.1111111],
99+
[0.1111111, 0.1111111, 0.1111111, 0.1111111],
100+
[0.1111111, 0.1111111, 0.1111111, 0.1111111]],
101+
[[0.1388889, 0.1388889, 0.1388889, 0.1388889 ],
102+
[0.1388889, 0.1388889, 0.1388889, 0.1388889 ],
103+
[0.1388889, 0.1388889, 0.1388889, 0.1388889 ],
104+
[0.1388889, 0.1388889, 0.1388889, 0.1388889 ]]]]
105+
}
106+
})
107+
def test_backward(data)
108+
c = data[:case]
109+
pooling = Chainer::Functions::Pooling::AveragePooling2D.new(c[:ksize], stride: c[:stride], pad: c[:pad], cover_all: c[:cover_all])
110+
pooling.(c[:x])
111+
gy = pooling.backward_cpu(c[:x], c[:gy])
112+
d = 7
113+
assert_equal(round(data[:expected], 7), round(gy[0], 7))
114+
end
115+
116+
def round(x, decimals)
117+
return nil if x.nil?
118+
t = 10 ** decimals
119+
(x * t).round / t
120+
end
121+
end

0 commit comments

Comments
 (0)