Skip to content

Commit

Permalink
Merge pull request #53 from hatappi/feature/resnet-18
Browse files Browse the repository at this point in the history
Add ResNet-18
  • Loading branch information
hatappi authored May 29, 2018
2 parents 2cdf192 + b331a37 commit 6b39f64
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 22 deletions.
90 changes: 90 additions & 0 deletions examples/cifar/models/resnet18.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
module ResNet18
class Plain < Chainer::Chain
include Chainer::Functions::Activation
include Chainer::Initializers
include Chainer::Links::Connection
include Chainer::Links::Normalization

def initialize(ch, stride, use_conv: false)
super()

@use_conv = use_conv
w = HeNormal.new

init_scope do
@conv1 = Convolution2D.new(nil, ch, 3, stride: stride, pad: 1, nobias: true, initial_w: w)
@bn1 = BatchNormalization.new(ch)
@conv2 = Convolution2D.new(nil, ch, 3, stride: 1, pad: 1, nobias: true, initial_w: w)
@bn2 = BatchNormalization.new(ch)
if @use_conv
@conv3 = Convolution2D.new(nil, ch, 3, stride: stride, pad: 1, nobias: true, initial_w: w)
@bn3 = BatchNormalization.new(ch)
end
end
end

def call(x)
h = Relu.relu(@bn1.(@conv1.(x)))
h = @bn2.(@conv2.(h))
if @use_conv
h2 = @bn3.(@conv3.(x))
Relu.relu(h + h2)
else
Relu.relu(h + x)
end
end
end

class Block < Chainer::ChainList
def initialize(layer, ch, stride=2)
super()
add_link(Plain.new(ch, stride, use_conv: true))
(layer-1).times do
add_link(Plain.new(ch, 1))
end
end

def call(x)
@children.each do |f|
x = f.(x)
end
x
end
end

class Model < Chainer::Chain
include Chainer::Functions::Activation
include Chainer::Functions::Evaluation
include Chainer::Functions::Loss
include Chainer::Functions::Pooling
include Chainer::Initializers
include Chainer::Links::Connection
include Chainer::Links::Normalization

def initialize(n_classes: 10)
super()
initial_w = HeNormal.new

init_scope do
@conv = Convolution2D.new(3, 64, 7, stride: 2, pad: 3, initial_w: initial_w)
@bn = BatchNormalization.new(64)

@res2 = Block.new(2, 64, 1)
@res3 = Block.new(2, 128)
@res4 = Block.new(2, 256)
@res5 = Block.new(2, 512)
@fc = Linear.new(nil, out_size: n_classes)
end
end

def call(x)
h = Relu.relu(@bn.(@conv.(x)))
h = @res2.(h)
h = @res3.(h)
h = @res4.(h)
h = @res5.(h)
h = AveragePooling2D.average_pooling_2d(h, h.shape[2..-1])
@fc.(h)
end
end
end
4 changes: 2 additions & 2 deletions examples/cifar/models/vgg.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def call(x)
end

class VGG < Chainer::Chain
def initialize(class_labels: 10)
def initialize(n_classes: 10)
super()
init_scope do
@block1_1 = Block.new(64, 3)
Expand All @@ -33,7 +33,7 @@ def initialize(class_labels: 10)
@block5_3 = Block.new(512, 3)
@fc1 = Chainer::Links::Connection::Linear.new(nil, out_size: 512, nobias: true)
@bn_fc1 = Chainer::Links::Normalization::BatchNormalization.new(512)
@fc2 = Chainer::Links::Connection::Linear.new(nil, out_size: class_labels, nobias: true)
@fc2 = Chainer::Links::Connection::Linear.new(nil, out_size: n_classes, nobias: true)
end
end

Expand Down
17 changes: 13 additions & 4 deletions examples/cifar/train_cifar.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
require 'chainer'
require __dir__ + '/models/vgg'
require __dir__ + '/models/resnet18'
require 'optparse'

args = {
Expand All @@ -9,7 +10,8 @@
learnrate: 0.05,
epoch: 300,
out: 'result',
resume: nil
resume: nil,
model: 'vgg',
}


Expand All @@ -21,6 +23,7 @@
opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
opt.on('-m', '--model VALUE', "Use model") { |v| args[:model] = v }
opt.parse!(ARGV)

# Set up a neural network to train.
Expand All @@ -38,9 +41,15 @@
raise 'Invalid dataset choice.'
end

puts "setup..."
if args[:model] == 'vgg'
puts 'Using VGG model'
model_class = VGG
elsif args[:model] == 'resnet18'
puts 'Using ResNet-18 model'
model_class = ResNet18::Model
end

model = Chainer::Links::Model::Classifier.new(VGG.new(class_labels: class_labels))
model = Chainer::Links::Model::Classifier.new(model_class.new(n_classes: class_labels))

optimizer = Chainer::Optimizers::MomentumSGD.new(lr: args[:learnrate])
optimizer.setup(model)
Expand All @@ -58,7 +67,7 @@
frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch'])

trainer.extend(Chainer::Training::Extensions::LogReport.new)
trainer.extend(Chainer::Training::Extensions::LogReport.new)
trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
trainer.extend(Chainer::Training::Extensions::ProgressBar.new)

Expand Down
2 changes: 1 addition & 1 deletion lib/chainer/datasets/tuple_dataset.rb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def [](index)
end
if index.kind_of?(Enumerable)
length = batches[0].shape[0]
length.times.map {|i| batches.map { |m| m[i] } }
length.times.map {|i| batches.map { |m| m.ndim > 1 ? m[i, false] : m[i] } }
else
batches
end
Expand Down
4 changes: 2 additions & 2 deletions lib/chainer/initializers/init.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def self.get_initializer(initializer)
return HeNormal.new(scale: 1 / Numo::NMath.sqrt(2)) if initializer.nil?
return Constant.new(initializer) if initializer.kind_of?(Numeric)
return Constant.new(initializer) if initializer.kind_of?(Numo::NArray)
unless initializer.method_defined?(:call)

unless initializer.respond_to?(:call)
raise TypeError, "invalid type of initializer: #{initializer.class}"
end

Expand Down
154 changes: 141 additions & 13 deletions lib/chainer/link.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module Chainer
class Link
attr_accessor :name

def initialize
@params = []
@persistent = []
Expand All @@ -17,22 +19,28 @@ def init_scope

begin
yield
set_attr
self.instance_variables.each do |name|
set_attr(name, self.instance_variable_get(name))
end
ensure
@within_init_scope = old_flag
end
end

def set_attr
self.instance_variables.each do |name|
value = self.instance_variable_get(name)
if value.instance_of?(Chainer::Parameter)
@params << name
@persistent.delete(name)
end
def set_attr(name, value)
if within_init_scope && value.kind_of?(Chainer::Parameter)
value.name = name
@params << name
@persistent.delete(name)
end
end

def del_attr(name)
@params.delete(name)
@persistent.delete(name)
self.remove_instance_variable(name)
end

def cleargrads
params do |param|
param.cleargrad
Expand Down Expand Up @@ -99,16 +107,22 @@ def initialize
@children = []
end

def set_attr
self.instance_variables.each do |name|
value = self.instance_variable_get(name)
if value.kind_of?(Chainer::Link)
@children << name
def set_attr(name, value)
if within_init_scope && value.kind_of?(Chainer::Link)
if self.respond_to?(name)
raise TypeError, "cannot register a new link #{name}: attribute exists"
end
value.name = name
@children << name
end
super
end

def del_attr(name)
@children.delete(name)
super
end

def params(include_uninit: true)
super(include_uninit: include_uninit) do |param|
yield param
Expand Down Expand Up @@ -155,4 +169,118 @@ def serialize(serializer)
end
end
end


# Composable link with list-like interface.
#
# This is another example of compositional link. Unlike :class:`Chainer::Chain`,
# this class can be used like a list of child links.
# Each child link is indexed by a non-negative integer,
# and it maintains the current number of registered child links.
# The :meth:`add_link` method inserts a new link at the end of the list.
# It is useful to write a chain with arbitrary number of child links,
# e.g. an arbitrarily deep multi-layer perceptron.
class ChainList < Link
attr_reader :children

def initialize(*links)
super()
@children = []

links.each do |link|
add_link(link)
end
end

def set_attr(name, value)
if within_init_scope && value.kind_of?(Chainer::Link)
raise TypeError, 'cannot register a new link within a "with chainlist.init_scope:" block.'
end
super
end

def [](index)
@children[index]
end

def each(&block)
@children.each(&block)
end

def size
@children.size
end

def <<(link)
add_link(link)
end

def add_link(link)
link.name = @children.size.to_s
@children << link
end

def params(include_uninit: true)
super(include_uninit: include_uninit) do |param|
yield param
end

@children.each do |link|
link.params(include_uninit: include_uninit) do |param|
yield param
end
end
end

def namedparams(include_uninit: true)
super(include_uninit: include_uninit) do |ret|
yield ret
end
@children.each_with_index do |link, idx|
prefix = "/#{idx}"
link.namedparams(include_uninit: include_uninit) do |path, param|
yield [prefix + path, param]
end
end
end

def links(skipself: false)
unless skipself
yield self
end

@children.each do |child|
child.links do |link|
yield link
end
end
end

def namedlinks(skipself: false)
unless skipself
yield '/', self
end

@children.each_with_index do |child, idx|
prefix = "/#{idx}"
yield prefix, child
child.namedlinks(skipself: true) do |path, link|
yield [prefix + path, link]
end
end
end

def children
@children.each do |child|
yield child
end
end

def serialize(serializer)
super
@children.each_with_index do |child, idx|
child.serialize(serializer[idx.to_s])
end
end
end
end

0 comments on commit 6b39f64

Please sign in to comment.