Skip to content

Commit

Permalink
restrict stash context only for stop words system index (#2283) (#2285)
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <[email protected]>
(cherry picked from commit 5f9026d)

Co-authored-by: Jing Zhang <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and jngz-es authored Mar 27, 2024
1 parent 2365afc commit 4dbeb28
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions common/src/main/java/org/opensearch/ml/common/model/MLGuard.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

package org.opensearch.ml.common.model;

import com.google.common.collect.ImmutableSet;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
Expand All @@ -20,7 +19,6 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.security.AccessController;
Expand All @@ -30,15 +28,14 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.utils.StringUtils.gson;

@Log4j2
Expand All @@ -52,6 +49,7 @@ public class MLGuard {
private List<Pattern> outputRegexPattern;
private NamedXContentRegistry xContentRegistry;
private Client client;
private Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");

public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) {
this.xContentRegistry = xContentRegistry;
Expand Down Expand Up @@ -128,27 +126,44 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
Map<String, Object> queryBodyMap = Map
.of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap)));
CountDownLatch latch = new CountDownLatch(1);
ThreadContext.StoredContext context = null;

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
try {
queryBody = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBodyMap));
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody);
searchSourceBuilder.parseXContent(queryParser);
searchSourceBuilder.size(1); //Only need 1 doc returned, if hit.
searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName);
context.restore();
client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
if (isStopWordsSystemIndex(indexName)) {
context = client.threadPool().getThreadContext().stashContext();
ThreadContext.StoredContext finalContext = context;
client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
hitStopWords.set(true);
}
}, e -> {
log.error("Failed to search stop words index {}", indexName, e);
hitStopWords.set(true);
}), latch), () -> finalContext.restore()));
} else {
client.search(searchRequest, new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
hitStopWords.set(true);
}
}, e -> {
log.error("Failed to search stop words index {}", indexName, e);
hitStopWords.set(true);
}
}, e -> {
log.error("Failed to search stop words index {}", indexName, e);
hitStopWords.set(true);
}), latch), () -> context.restore()));
}), latch));
}
} catch (Exception e) {
log.error("[validateStopWords] Searching stop words index failed.", e);
latch.countDown();
hitStopWords.set(true);
} finally {
if (context != null) {
context.close();
}
}

try {
Expand All @@ -160,6 +175,10 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
return hitStopWords.get();
}

private boolean isStopWordsSystemIndex(String index) {
return stopWordsIndices.contains(index);
}

public enum Type {
INPUT,
OUTPUT
Expand Down

0 comments on commit 4dbeb28

Please sign in to comment.