|
13 | 13 | */ |
14 | 14 | package io.trino.gateway.ha.handler; |
15 | 15 |
|
| 16 | +import com.google.common.collect.ImmutableSet; |
16 | 17 | import com.google.common.io.CharStreams; |
17 | 18 | import io.airlift.log.Logger; |
18 | 19 | import io.trino.gateway.ha.router.TrinoQueryProperties; |
@@ -52,6 +53,8 @@ public final class ProxyUtils |
52 | 53 | * capitalization. |
53 | 54 | */ |
54 | 55 | private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*"); |
| 56 | + public static final ImmutableSet<String> QUERY_STATE_PATH = ImmutableSet.of("queued", "scheduled", "executing"); |
| 57 | + public static final String PARTIAL_CANCEL_PATH = "partialCancel"; |
55 | 58 |
|
56 | 59 | private ProxyUtils() {} |
57 | 60 |
|
@@ -100,10 +103,10 @@ public static Optional<String> extractQueryIdIfPresent(String path, String query |
100 | 103 | path = path.replace(matchingStatementPath.orElse(V1_QUERY_PATH), ""); |
101 | 104 | String[] tokens = path.split("/"); |
102 | 105 | if (tokens.length >= 2) { |
103 | | - if (tokens[1].equals("queued") |
104 | | - || tokens[1].equals("scheduled") |
105 | | - || tokens[1].equals("executing") |
106 | | - || tokens[1].equals("partialCancel")) { |
| 106 | + if (tokens.length >= 3 && QUERY_STATE_PATH.contains(tokens[1])) { |
| 107 | + if (tokens.length >= 4 && tokens[2].equals(PARTIAL_CANCEL_PATH)) { |
| 108 | + return Optional.of(tokens[3]); |
| 109 | + } |
107 | 110 | return Optional.of(tokens[2]); |
108 | 111 | } |
109 | 112 | return Optional.of(tokens[1]); |
|
0 commit comments