diff --git a/README.md b/README.md index 545003cc..b24cef75 100644 --- a/README.md +++ b/README.md @@ -244,7 +244,7 @@ Throttle state is stored in a [configurable cache](#cache-store-configuration) ( #### `throttle(name, options, &block)` -Name your custom throttle, provide `limit` and `period` as options, and make your ruby-block argument return the __discriminator__. This discriminator is how you tell rack-attack whether you're limiting per IP address, per user email or any other. +Name your custom throttle, provide `limit`, `period` and `weight` as options, and make your ruby-block argument return the __discriminator__. This discriminator is how you tell rack-attack whether you're limiting per IP address, per user email or any other. The request object is a [Rack::Request](http://www.rubydoc.info/gems/rack/Rack/Request). @@ -275,6 +275,14 @@ period_proc = proc { |req| req.env["REMOTE_USER"] == "admin" ? 1 : 60 } Rack::Attack.throttle('request per ip', limit: limit_proc, period: period_proc) do |request| request.ip end + +# Weight can be used to make some requests cost more than others while +# sharing the same limit. +weight_proc = proc { |req| req.path.start_with?('/search') ? 10 : 1 } + +Rack::Attack.throttle('request per ip', limit: 10, period: 1, weight: weight_proc) do |request| + request.ip +end ``` ### Tracks diff --git a/lib/rack/attack/cache.rb b/lib/rack/attack/cache.rb index ecbd3368..b9f4e1d9 100644 --- a/lib/rack/attack/cache.rb +++ b/lib/rack/attack/cache.rb @@ -28,9 +28,9 @@ def store=(store) end end - def count(unprefixed_key, period) + def count(unprefixed_key, period, weight = 1) key, expires_in = key_and_expiry(unprefixed_key, period) - do_count(key, expires_in) + do_count(key, expires_in, weight) end def read(unprefixed_key) @@ -73,19 +73,19 @@ def key_and_expiry(unprefixed_key, period) ["#{prefix}:#{(@last_epoch_time / period).to_i}:#{unprefixed_key}", expires_in] end - def do_count(key, expires_in) + def do_count(key, expires_in, weight) enforce_store_presence! enforce_store_method_presence!(:increment) - result = store.increment(key, 1, expires_in: expires_in) + result = store.increment(key, weight, expires_in: expires_in) # NB: Some stores return nil when incrementing uninitialized values if result.nil? enforce_store_method_presence!(:write) - store.write(key, 1, expires_in: expires_in) + store.write(key, weight, expires_in: expires_in) end - result || 1 + result || weight end def enforce_store_presence! diff --git a/lib/rack/attack/throttle.rb b/lib/rack/attack/throttle.rb index 0ec5f7aa..878dba2d 100644 --- a/lib/rack/attack/throttle.rb +++ b/lib/rack/attack/throttle.rb @@ -5,7 +5,7 @@ class Attack class Throttle MANDATORY_OPTIONS = [:limit, :period].freeze - attr_reader :name, :limit, :period, :block, :type + attr_reader :name, :limit, :period, :weight, :block, :type def initialize(name, options, &block) @name = name @@ -15,6 +15,7 @@ def initialize(name, options, &block) end @limit = options[:limit] @period = options[:period].respond_to?(:call) ? options[:period] : options[:period].to_i + @weight = options[:weight].respond_to?(:call) ? options[:weight] : (options[:weight] || 1).to_i @type = options.fetch(:type, :throttle) end @@ -28,7 +29,8 @@ def matched_by?(request) current_period = period_for(request) current_limit = limit_for(request) - count = cache.count("#{name}:#{discriminator}", current_period) + current_weight = weight_for(request) + count = cache.count("#{name}:#{discriminator}", current_period, current_weight) data = { discriminator: discriminator, @@ -66,6 +68,10 @@ def limit_for(request) limit.respond_to?(:call) ? limit.call(request) : limit end + def weight_for(request) + weight.respond_to?(:call) ? weight.call(request) : weight + end + def annotate_request_with_throttle_data(request, data) (request.env['rack.attack.throttle_data'] ||= {})[name] = data end diff --git a/spec/acceptance/throttling_spec.rb b/spec/acceptance/throttling_spec.rb index 0db89dd6..09dd39d1 100644 --- a/spec/acceptance/throttling_spec.rb +++ b/spec/acceptance/throttling_spec.rb @@ -34,6 +34,79 @@ end end + it "supports a non-1 constant weight" do + Rack::Attack.throttle("by ip", limit: 4, period: 60, weight: 2) do |request| + request.ip + end + + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 200, last_response.status + + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 200, last_response.status + + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 429, last_response.status + assert_nil last_response.headers["Retry-After"] + assert_equal "Retry later\n", last_response.body + + get "/", {}, "REMOTE_ADDR" => "5.6.7.8" + + assert_equal 200, last_response.status + + Timecop.travel(60) do + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 200, last_response.status + end + end + + it "supports a dynamic weight" do + weight_proc = lambda do |request| + if request.env["X-APIKey"] == "private-secret" + 3 + else + 2 + end + end + Rack::Attack.throttle("by ip", limit: 4, period: 60, weight: weight_proc) do |request| + request.ip + end + + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 200, last_response.status + + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 200, last_response.status + + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 429, last_response.status + assert_nil last_response.headers["Retry-After"] + assert_equal "Retry later\n", last_response.body + + get "/", {}, "REMOTE_ADDR" => "5.6.7.8", "X-APIKey" => "private-secret" + + assert_equal 200, last_response.status + + get "/", {}, "REMOTE_ADDR" => "5.6.7.8", "X-APIKey" => "private-secret" + + assert_equal 429, last_response.status + assert_nil last_response.headers["Retry-After"] + assert_equal "Retry later\n", last_response.body + + Timecop.travel(60) do + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 200, last_response.status + end + end + it "returns correct Retry-After header if enabled" do Rack::Attack.throttled_response_retry_after_header = true