Skip to content

Optimize IAST Vulnerability Detection #8885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ public void acquireReleaseRequestNoSampling() {

@Benchmark
public void consumeQuota() {
overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, null);
overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, null, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public IastContext resolve() {

@Override
public IastContext buildRequestContext() {
return new IastRequestContext(globalContext.getTaintedObjects());
return new IastRequestContext((TaintedObjects) globalContext.getTaintedObjects());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need a cast? getTaintedObjects() should return TaintedObjects or a subclass of it already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added this new constructor

  /**
   * Use this constructor only when you want to create a new context with a fresh overhead context
   * (e.g. for testing purposes).
   *
   * @param overheadContext the overhead context to use
   */
  public IastRequestContext(final OverheadContext overheadContext) {
    this.vulnerabilityBatch = new VulnerabilityBatch();
    this.overheadContext = overheadContext;
    this.taintedObjects = TaintedObjects.build(TaintedMap.build(MAP_SIZE));
  }

in IastContext we have

  /**
   * Get the tainted objects dictionary linked to the context, since we have no visibility over the
   * {@code TaintedObject} class from here, we use a dirty generics hack.
   */
  @Nonnull
  <TO> TO getTaintedObjects();

with public IastRequestContext(final TaintedObjects taintedObjects) and public IastRequestContext(final OverheadContext overheadContext) we need to cast to specify which constructor should be use

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rollback those changes thanks to a different approach

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public IastContext resolve() {

@Override
public IastContext buildRequestContext() {
return new IastRequestContext(optOutContext.getTaintedObjects());
return new IastRequestContext((TaintedObjects) optOutContext.getTaintedObjects());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ public IastRequestContext(final TaintedObjects taintedObjects) {
this.taintedObjects = taintedObjects;
}

/**
* Use this constructor only when you want to create a new context with a fresh overhead context
* (e.g. for testing purposes).
*
* @param overheadContext the overhead context to use
*/
public IastRequestContext(final OverheadContext overheadContext) {
this.vulnerabilityBatch = new VulnerabilityBatch();
this.overheadContext = overheadContext;
this.taintedObjects = TaintedObjects.build(TaintedMap.build(MAP_SIZE));
}

public VulnerabilityBatch getVulnerabilityBatch() {
return vulnerabilityBatch;
}
Expand Down Expand Up @@ -188,6 +200,7 @@ public void releaseRequestContext(@Nonnull final IastContext context) {
pool.offer(unwrapped);
iastCtx.setTaintedObjects(TaintedObjects.NoOp.INSTANCE);
}
iastCtx.overheadContext.resetMaps();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,53 @@

import static datadog.trace.api.iast.IastDetectionMode.UNLIMITED;

import com.datadog.iast.model.VulnerabilityType;
import com.datadog.iast.util.NonBlockingSemaphore;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;

public class OverheadContext {

/**
* Maximum number of distinct endpoints to remember in the global cache (LRU eviction beyond this
* size).
*/
private static final int GLOBAL_MAP_MAX_SIZE = 4096;

/**
* Global LRU cache mapping each “method + path” key to its historical vulnerabilityCounts map.
* Key: HTTP_METHOD + " " + HTTP_PATH Value: Map<vulnerabilityType, count>
*/
static final Map<String, Map<VulnerabilityType, Integer>> globalMap =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we shouldn't be using Map<VulnerabilityType, Integer> for the inner map, but fixed-size arrays indexed by VulnerabilityType index. Way more efficient. Also if we use AtomicIntegerArray for the inner array, we can perform atomic increments of the counters, which will be needed here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I’ve updated all the collections as we agreed offline

new LinkedHashMap<String, Map<VulnerabilityType, Integer>>(GLOBAL_MAP_MAX_SIZE, 0.75f, true) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConcurrentHashMap if this is going to be accessed concurrently.

@Override
protected boolean removeEldestEntry(
Map.Entry<String, Map<VulnerabilityType, Integer>> eldest) {
return size() > GLOBAL_MAP_MAX_SIZE;
}
};

@Nullable final Map<String, Map<VulnerabilityType, Integer>> copyMap;
@Nullable final Map<String, Map<VulnerabilityType, Integer>> requestMap;

private final NonBlockingSemaphore availableVulnerabilities;
private final boolean isGlobal;

public OverheadContext(final int vulnerabilitiesPerRequest) {
this(vulnerabilitiesPerRequest, false);
}

public OverheadContext(final int vulnerabilitiesPerRequest, final boolean isGlobal) {
availableVulnerabilities =
vulnerabilitiesPerRequest == UNLIMITED
? NonBlockingSemaphore.unlimited()
: NonBlockingSemaphore.withPermitCount(vulnerabilitiesPerRequest);
this.isGlobal = isGlobal;
this.requestMap = isGlobal ? null : new HashMap<>();
this.copyMap = isGlobal ? null : new HashMap<>();
}

public int getAvailableQuota() {
Expand All @@ -26,4 +62,49 @@ public boolean consumeQuota(final int delta) {
public void reset() {
availableVulnerabilities.reset();
}

public void resetMaps() {
if (isGlobal || requestMap == null || copyMap == null) {
return;
}
// If the budget is not consumed, we can reset the maps
Set<String> keys = requestMap.keySet();
if (getAvailableQuota() > 0) {
keys.forEach(globalMap::remove);
keys.clear();
requestMap.clear();
copyMap.clear();
return;
}
keys.forEach(
key -> {
Map<VulnerabilityType, Integer> countMap = requestMap.get(key);
// should not happen, but just in case
if (countMap == null || countMap.isEmpty()) {
globalMap.remove(key);
return;
}
countMap.forEach(
(key1, counter) -> {
Map<VulnerabilityType, Integer> globalCountMap = globalMap.get(key);
if (globalCountMap != null) {
Integer globalCounter = globalCountMap.getOrDefault(key1, 0);
if (counter > globalCounter) {
globalCountMap.put(key1, counter);
}
} else {
globalCountMap = new HashMap<>();
globalCountMap.put(key1, counter);
globalMap.put(key, globalCountMap);
}
});
});
keys.clear();
requestMap.clear();
copyMap.clear();
}

public boolean isGlobal() {
return isGlobal;
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.datadog.iast.overhead;

import static com.datadog.iast.overhead.OverheadContext.globalMap;
import static datadog.trace.api.iast.IastDetectionMode.UNLIMITED;

import com.datadog.iast.IastRequestContext;
import com.datadog.iast.IastSystem;
import com.datadog.iast.model.VulnerabilityType;
import com.datadog.iast.util.NonBlockingSemaphore;
import datadog.trace.api.Config;
import datadog.trace.api.gateway.RequestContext;
Expand All @@ -12,7 +14,11 @@
import datadog.trace.api.telemetry.LogCollector;
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
import datadog.trace.bootstrap.instrumentation.api.Tags;
import datadog.trace.util.AgentTaskScheduler;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.Nullable;
Expand All @@ -29,7 +35,10 @@ public interface OverheadController {

boolean hasQuota(final Operation operation, @Nullable final AgentSpan span);

boolean consumeQuota(final Operation operation, @Nullable final AgentSpan span);
boolean consumeQuota(
Operation operation,
@Nullable final AgentSpan span,
@Nullable final VulnerabilityType type);

static OverheadController build(final Config config, final AgentTaskScheduler scheduler) {
return build(
Expand Down Expand Up @@ -99,15 +108,19 @@ public boolean hasQuota(final Operation operation, @Nullable final AgentSpan spa
}

@Override
public boolean consumeQuota(final Operation operation, @Nullable final AgentSpan span) {
final boolean result = delegate.consumeQuota(operation, span);
public boolean consumeQuota(
final Operation operation,
@Nullable final AgentSpan span,
@Nullable final VulnerabilityType type) {
final boolean result = delegate.consumeQuota(operation, span, type);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
"consumeQuota: operation={}, result={}, availableQuota={}, span={}",
"consumeQuota: operation={}, result={}, availableQuota={}, span={}, type={}",
operation,
result,
getAvailableQuote(span),
span);
span,
type);
}
return result;
}
Expand Down Expand Up @@ -147,7 +160,7 @@ class OverheadControllerImpl implements OverheadController {
private volatile long lastAcquiredTimestamp = Long.MAX_VALUE;

final OverheadContext globalContext =
new OverheadContext(Config.get().getIastVulnerabilitiesPerRequest());
new OverheadContext(Config.get().getIastVulnerabilitiesPerRequest(), true);

public OverheadControllerImpl(
final float requestSampling,
Expand Down Expand Up @@ -191,8 +204,70 @@ public boolean hasQuota(final Operation operation, @Nullable final AgentSpan spa
}

@Override
public boolean consumeQuota(final Operation operation, @Nullable final AgentSpan span) {
return operation.consumeQuota(getContext(span));
public boolean consumeQuota(
final Operation operation,
@Nullable final AgentSpan span,
@Nullable final VulnerabilityType type) {

OverheadContext ctx = getContext(span);
if (ctx == null) {
return false;
}
if (ctx.isGlobal()) {
return operation.consumeQuota(ctx);
}
if (operation.hasQuota(ctx)) {
String method = null;
String path = null;
if (span != null) {
Object methodTag = span.getLocalRootSpan().getTag(Tags.HTTP_METHOD);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get local root span to a variable first, reuse that from there.

method = (methodTag == null) ? "" : methodTag.toString();
Object routeTag = span.getLocalRootSpan().getTag(Tags.HTTP_ROUTE);
path = (routeTag == null) ? "" : routeTag.toString();
}
if (!maybeSkipVulnerability(ctx, type, method, path)) {
return operation.consumeQuota(ctx);
}
}
return false;
}

/**
* Method to be called when a vulnerability of a certain type is detected. Implements the
* RFC-1029 algorithm.
*
* @param type the type of vulnerability detected
* @return true if the vulnerability should be skipped, false otherwise
*/
private boolean maybeSkipVulnerability(
@Nullable final OverheadContext ctx,
@Nullable final VulnerabilityType type,
@Nullable final String httpMethod,
@Nullable final String httpPath) {

if (ctx == null || type == null || ctx.requestMap == null || ctx.copyMap == null) {
return false;
}

String currentKey = httpMethod + " " + httpPath;
Set<String> keys = ctx.requestMap.keySet();

if (!keys.contains(currentKey)) {
ctx.copyMap.put(currentKey, globalMap.getOrDefault(currentKey, new HashMap<>()));
}

ctx.requestMap.computeIfAbsent(currentKey, k -> new HashMap<>());

Integer counter = ctx.requestMap.get(currentKey).getOrDefault(type, 0);
ctx.requestMap.get(currentKey).put(type, counter + 1);

Integer storedCounter = 0;
Map<VulnerabilityType, Integer> copyCountMap = ctx.copyMap.get(currentKey);
if (copyCountMap != null) {
storedCounter = copyCountMap.getOrDefault(type, 0);
}

return counter < storedCounter;
}

@Nullable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.datadog.iast.sink;

import static com.datadog.iast.model.VulnerabilityType.INSECURE_COOKIE;
import static com.datadog.iast.util.HttpHeader.SET_COOKIE;
import static com.datadog.iast.util.HttpHeader.SET_COOKIE2;
import static java.util.Collections.singletonList;
Expand Down Expand Up @@ -65,7 +66,11 @@ private void onCookies(final List<Cookie> cookies) {
return;
}
final AgentSpan span = AgentTracer.activeSpan();
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span)) {
// TODO decide if we remove this one quota for all vulnerabilities as new IAST sampling
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@smola What do you think about this?

We had previously implemented it so that only one quota would be consumed for all header/cookie-related vulnerabilities reported here.

The new algorithm would fix this behavior, but maybe it’s worth keeping the existing logic to ensure these vulns (and any future ones like them) are reported sooner.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we should probably preserve the behavior that all header/cookie vulns are a single one for quota purposes.

// algorithm is able to report all endpoint vulnerabilities
if (!overheadController.consumeQuota(
Operations.REPORT_VULNERABILITY, span, INSECURE_COOKIE // we need a type to check quota
)) {
return;
}
final Location location = Location.forSpanAndStack(span, getCurrentStackTrace());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ protected void report(final Vulnerability vulnerability) {
}

protected void report(@Nullable final AgentSpan span, final Vulnerability vulnerability) {
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span)) {
if (!overheadController.consumeQuota(
Operations.REPORT_VULNERABILITY, span, vulnerability.getType())) {
return;
}
reporter.report(span, vulnerability);
Expand All @@ -70,7 +71,7 @@ protected void report(final VulnerabilityType type, final Evidence evidence) {

protected void report(
@Nullable final AgentSpan span, final VulnerabilityType type, final Evidence evidence) {
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span)) {
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span, type)) {
return;
}
final Vulnerability vulnerability =
Expand Down Expand Up @@ -170,7 +171,7 @@ protected final Evidence checkInjection(
}

final AgentSpan span = AgentTracer.activeSpan();
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span)) {
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span, type)) {
return null;
}

Expand Down Expand Up @@ -251,7 +252,7 @@ protected final Evidence checkInjection(
if (!spanFetched && valueRanges != null && valueRanges.length > 0) {
span = AgentTracer.activeSpan();
spanFetched = true;
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span)) {
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span, type)) {
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.datadog.iast

import com.datadog.iast.model.Range
import com.datadog.iast.overhead.OverheadContext
import com.datadog.iast.taint.TaintedObjects
import datadog.trace.api.Config
import datadog.trace.api.gateway.RequestContext
Expand Down Expand Up @@ -120,4 +121,16 @@ class IastRequestContextTest extends DDSpecification {
then:
ctx.taintedObjects.count() == 0
}

void 'on release context overheadContext reset is called'() {
setup:
final overheadCtx = Mock(OverheadContext)
final ctx = new IastRequestContext(overheadCtx)

when:
provider.releaseRequestContext(ctx)

then:
1 * overheadCtx.resetMaps()
}
}
Loading
Loading