Skip to content

Commit 4cde537

Browse files
committed
Pubsub source v2 cancels streaming pull when blocked
1 parent 66292e1 commit 4cde537

File tree

6 files changed

+115
-79
lines changed

6 files changed

+115
-79
lines changed

.github/workflows/ci.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ on:
77

88
jobs:
99
test:
10-
runs-on: ubuntu-latest
10+
runs-on: ubuntu-22.04
1111
steps:
1212
- uses: actions/checkout@v2
1313
- uses: coursier/cache-action@v6
@@ -22,7 +22,7 @@ jobs:
2222

2323
publish_docker:
2424
needs: test
25-
runs-on: ubuntu-latest
25+
runs-on: ubuntu-22.04
2626
strategy:
2727
matrix:
2828
sbtProject:

modules/gcp/src/main/resources/application.conf

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
// V2 defaults
1515
"durationPerAckExtension": "10 minutes"
1616
"minRemainingDeadline": 0.1
17+
"progressTimeout": "10 seconds"
18+
"modackOnProgressTimeout": true
19+
"cancelOnProgressTimeout": false
1720
}
1821
"output": {
1922
"bad": ${snowplow.defaults.sinks.pubsub}

modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubCheckpointer.scala

+3-10
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,9 @@ class PubsubCheckpointer[F[_]: Async](
8888
ackDatas <- refAckIds.modify(m => (m.removedAll(c), c.flatMap(m.get)))
8989
grouped = ackDatas.groupBy(_.channelAffinity)
9090
_ <- grouped.toVector.parTraverse_ { case (channelAffinity, ackDatas) =>
91-
ackDatas.flatMap(_.ackIds).grouped(1000).toVector.traverse_ { ackIds =>
92-
// A nack is just a modack with zero duration
93-
Utils
94-
.modAck[F](subscription, stub, ackIds, Duration.Zero, channelAffinity)
95-
.retryingOnTransientGrpcFailures
96-
.recoveringOnGrpcInvalidArgument { s =>
97-
// This can happen if ack IDs have expired before we acked
98-
Logger[F].info(s"Ignoring error from GRPC when acking: ${s.getDescription}")
99-
}
100-
}
91+
val ackIds = ackDatas.flatMap(_.ackIds)
92+
// A nack is just a modack with zero duration
93+
Utils.modAck[F](subscription, stub, ackIds, Duration.Zero, channelAffinity)
10194
}
10295
} yield ()
10396
}

modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubSourceConfigV2.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ case class PubsubSourceConfigV2(
2323
durationPerAckExtension: FiniteDuration,
2424
minRemainingDeadline: Double,
2525
gcpUserAgent: GcpUserAgent,
26-
maxPullsPerTransportChannel: Int
26+
maxPullsPerTransportChannel: Int,
27+
progressTimeout: FiniteDuration,
28+
modackOnProgressTimeout: Boolean,
29+
cancelOnProgressTimeout: Boolean
2730
)
2831

2932
object PubsubSourceConfigV2 {

modules/gcp/src/main/scala/common-streams-extensions/v2/PubsubSourceV2.scala

+81-53
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
package com.snowplowanalytics.snowplow.sources.pubsub.v2
99

1010
import cats.effect.{Async, Deferred, Ref, Resource, Sync}
11-
import cats.effect.std.Hotswap
11+
import cats.effect.std.{Hotswap, Queue, QueueSink}
1212
import cats.effect.kernel.Unique
13+
import cats.effect.implicits._
1314
import cats.implicits._
14-
import fs2.{Chunk, Stream}
15+
import fs2.{Chunk, Pipe, Stream}
1516
import org.typelevel.log4cats.Logger
1617
import org.typelevel.log4cats.slf4j.Slf4jLogger
1718

@@ -32,9 +33,8 @@ import org.threeten.bp.{Duration => ThreetenDuration}
3233
import com.snowplowanalytics.snowplow.pubsub.GcpUserAgent
3334
import com.snowplowanalytics.snowplow.sources.SourceAndAck
3435
import com.snowplowanalytics.snowplow.sources.internal.{Checkpointer, LowLevelEvents, LowLevelSource}
35-
import com.snowplowanalytics.snowplow.sources.pubsub.v2.PubsubRetryOps.implicits._
3636

37-
import scala.concurrent.duration.{DurationDouble, FiniteDuration}
37+
import scala.concurrent.duration.{Duration, DurationDouble, FiniteDuration}
3838
import scala.jdk.CollectionConverters._
3939

4040
import java.util.concurrent.{ExecutorService, Executors, LinkedBlockingQueue}
@@ -75,19 +75,51 @@ object PubsubSourceV2 {
7575
_ <- Stream.eval(deferredResources.complete(PubsubCheckpointer.Resources(stub, refStates)))
7676
} yield Stream
7777
.range(0, parallelPullCount)
78-
.map { i =>
79-
val actionQueue = new LinkedBlockingQueue[SubscriberAction]()
80-
val clientId = UUID.randomUUID
81-
val resource = initializeStreamingPull(config, stub, actionQueue, i, clientId)
82-
Stream.resource(Hotswap(resource)).flatMap { case (hotswap, _) =>
83-
Stream
84-
.eval(pullFromQueue(config, actionQueue, refStates, hotswap, resource, i))
85-
.repeat
86-
.concurrently(extendDeadlines(config, stub, refStates, i))
87-
}
88-
}
78+
.map(i => miniPubsubStream(config, stub, refStates, i))
8979
.parJoinUnbounded
9080

81+
private def miniPubsubStream[F[_]: Async](
82+
config: PubsubSourceConfigV2,
83+
stub: SubscriberStub,
84+
refStates: Ref[F, Map[Unique.Token, PubsubBatchState]],
85+
channelAffinity: Int
86+
): Stream[F, LowLevelEvents[Vector[Unique.Token]]] = {
87+
val jQueue = new LinkedBlockingQueue[SubscriberAction]()
88+
val clientId = UUID.randomUUID
89+
val resource = initializeStreamingPull[F](config, stub, jQueue, channelAffinity, clientId)
90+
91+
for {
92+
(hotswap, _) <- Stream.resource(Hotswap(resource))
93+
fs2Queue <- Stream.eval(Queue.synchronous[F, SubscriberAction])
94+
_ <- extendDeadlines(config, stub, refStates, channelAffinity).spawn
95+
_ <- Stream.eval(queueToQueue(config, jQueue, fs2Queue, stub, channelAffinity)).repeat.spawn
96+
lle <- Stream
97+
.fromQueueUnterminated(fs2Queue)
98+
.through(toLowLevelEvents(config, refStates, hotswap, resource, channelAffinity))
99+
} yield lle
100+
}
101+
102+
private def queueToQueue[F[_]: Async](
103+
config: PubsubSourceConfigV2,
104+
jQueue: LinkedBlockingQueue[SubscriberAction],
105+
fs2Queue: QueueSink[F, SubscriberAction],
106+
stub: SubscriberStub,
107+
channelAffinity: Int
108+
): F[Unit] =
109+
resolveNextAction(jQueue).flatMap {
110+
case action @ SubscriberAction.ProcessRecords(records, controller, _) =>
111+
val fallback = for {
112+
_ <- Logger[F].debug(s"Cancelling Pubsub channel $channelAffinity for not making progress")
113+
ackIds = records.map(_.getAckId)
114+
_ <- if (config.cancelOnProgressTimeout) Sync[F].delay(controller.cancel()) else Sync[F].unit
115+
_ <- if (config.modackOnProgressTimeout) Utils.modAck(config.subscription, stub, ackIds, Duration.Zero, channelAffinity)
116+
else fs2Queue.offer(action)
117+
} yield ()
118+
fs2Queue.offer(action).timeoutTo(config.progressTimeout, fallback)
119+
case action: SubscriberAction.SubscriberError =>
120+
fs2Queue.offer(action)
121+
}
122+
91123
/**
92124
* Modify ack deadlines if we need more time to process the messages
93125
*
@@ -126,27 +158,19 @@ object PubsubSourceV2 {
126158
.evalMap { toExtend =>
127159
if (toExtend.isEmpty)
128160
Sync[F].sleep(0.5 * config.minRemainingDeadline * config.durationPerAckExtension)
129-
else
130-
toExtend.sortBy(_.currentDeadline).flatMap(_.ackIds).grouped(1000).toVector.traverse_ { ackIds =>
131-
Utils
132-
.modAck[F](config.subscription, stub, ackIds, config.durationPerAckExtension, channelAffinity)
133-
.retryingOnTransientGrpcFailures
134-
.recoveringOnGrpcInvalidArgument { s =>
135-
// This can happen if ack IDs were acked before we modAcked
136-
Logger[F].info(s"Ignoring error from GRPC when modifying ack IDs: ${s.getDescription}")
137-
}
138-
}
161+
else {
162+
val ackIds = toExtend.sortBy(_.currentDeadline).flatMap(_.ackIds)
163+
Utils.modAck[F](config.subscription, stub, ackIds, config.durationPerAckExtension, channelAffinity)
164+
}
139165
}
140166
.repeat
141167
.drain
142168

143169
/**
144-
* Pulls a SubscriberAction from a queue when one becomes available
170+
* Pipe from SubscriberAction to LowLevelEvents TODO: Say what else this does
145171
*
146172
* @param config
147173
* The source configuration
148-
* @param queue
149-
* The queue from which to pull a SubscriberAction
150174
* @param refStates
151175
* A map from tokens to the data held about a batch of messages received from pubsub. This
152176
* function must update the state to add new batches.
@@ -159,44 +183,48 @@ object PubsubSourceV2 {
159183
* Identifies the GRPC channel (TCP connection) creating these Actions. Each GRPC channel has
160184
* its own queue, observer, and puller.
161185
*/
162-
private def pullFromQueue[F[_]: Async](
186+
private def toLowLevelEvents[F[_]: Async](
163187
config: PubsubSourceConfigV2,
164-
queue: LinkedBlockingQueue[SubscriberAction],
165188
refStates: Ref[F, Map[Unique.Token, PubsubBatchState]],
166189
hotswap: Hotswap[F, Unit],
167190
toSwap: Resource[F, Unit],
168191
channelAffinity: Int
169-
): F[LowLevelEvents[Vector[Unique.Token]]] = {
170-
def go(delayOnSubscriberError: FiniteDuration): F[LowLevelEvents[Vector[Unique.Token]]] =
171-
resolveNextAction[F, SubscriberAction](queue).flatMap {
172-
case SubscriberAction.ProcessRecords(records, controller, timeReceived) =>
173-
val chunk = Chunk.from(records.map(_.getMessage.getData.asReadOnlyByteBuffer()))
174-
val (tstampSeconds, tstampNanos) =
175-
records.map(r => (r.getMessage.getPublishTime.getSeconds, r.getMessage.getPublishTime.getNanos)).min
176-
val ackIds = records.map(_.getAckId)
192+
): Pipe[F, SubscriberAction, LowLevelEvents[Vector[Unique.Token]]] =
193+
_.flatMap {
194+
case SubscriberAction.ProcessRecords(records, controller, timeReceived) =>
195+
val chunk = Chunk.from(records.map(_.getMessage.getData.asReadOnlyByteBuffer()))
196+
val (tstampSeconds, tstampNanos) =
197+
records.map(r => (r.getMessage.getPublishTime.getSeconds, r.getMessage.getPublishTime.getNanos)).min
198+
val ackIds = records.map(_.getAckId)
199+
Stream.eval {
177200
for {
178201
token <- Unique[F].unique
179202
currentDeadline = timeReceived.plusMillis(config.durationPerAckExtension.toMillis)
180203
_ <- refStates.update(_ + (token -> PubsubBatchState(currentDeadline, ackIds, channelAffinity)))
181204
_ <- Sync[F].delay(controller.request(1))
182205
} yield LowLevelEvents(chunk, Vector(token), Some(Instant.ofEpochSecond(tstampSeconds, tstampNanos.toLong)))
183-
case SubscriberAction.SubscriberError(t) =>
184-
if (PubsubRetryOps.isRetryableException(t)) {
185-
val nextDelay = (2 * delayOnSubscriberError).min((10 + scala.util.Random.nextDouble()).second)
186-
// Log at debug level because retryable errors are very frequent.
187-
// In particular, if the pubsub subscription is empty then a streaming pull returns UNAVAILABLE
206+
}
207+
case SubscriberAction.SubscriberError(t) =>
208+
if (PubsubRetryOps.isRetryableException(t)) {
209+
// val nextDelay = (2 * delayOnSubscriberError).min((10 + scala.util.Random.nextDouble()).second)
210+
// Log at debug level because retryable errors are very frequent.
211+
// In particular, if the pubsub subscription is empty then a streaming pull returns UNAVAILABLE
212+
Stream.eval {
188213
Logger[F].debug(s"Retryable error on PubSub channel $channelAffinity: ${t.getMessage}") *>
189214
hotswap.clear *>
190-
Async[F].sleep(delayOnSubscriberError) *>
191-
hotswap.swap(toSwap) *>
192-
go(nextDelay)
193-
} else {
194-
Logger[F].error(t)("Exception from PubSub source") *> Sync[F].raiseError(t)
195-
}
196-
}
197-
198-
go(delayOnSubscriberError = (1.0 + scala.util.Random.nextDouble()).second)
199-
}
215+
Async[F].sleep((1.0 + scala.util.Random.nextDouble()).second) *> // TODO expotential backoff
216+
hotswap.swap(toSwap)
217+
}.drain
218+
} else if (t.isInstanceOf[java.util.concurrent.CancellationException]) {
219+
Stream.eval {
220+
Logger[F].debug("Cancellation exception on PubSub channel") *>
221+
hotswap.clear *>
222+
hotswap.swap(toSwap)
223+
}.drain
224+
} else {
225+
Stream.eval(Logger[F].error(t)("Exception from PubSub source")) *> Stream.raiseError[F](t)
226+
}
227+
}
200228

201229
private def resolveNextAction[F[_]: Sync, A](queue: LinkedBlockingQueue[A]): F[A] =
202230
Sync[F].delay(Option[A](queue.poll)).flatMap {

modules/gcp/src/main/scala/common-streams-extensions/v2/Utils.scala

+22-13
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,43 @@ package com.snowplowanalytics.snowplow.sources.pubsub.v2
99

1010
import cats.effect.{Async, Sync}
1111
import cats.implicits._
12+
import org.typelevel.log4cats.Logger
1213

1314
import com.google.api.gax.grpc.GrpcCallContext
1415
import com.google.cloud.pubsub.v1.stub.SubscriberStub
1516
import com.google.pubsub.v1.ModifyAckDeadlineRequest
1617
import com.snowplowanalytics.snowplow.pubsub.FutureInterop
18+
import com.snowplowanalytics.snowplow.sources.pubsub.v2.PubsubRetryOps.implicits._
1719

1820
import scala.concurrent.duration.FiniteDuration
1921
import scala.jdk.CollectionConverters._
2022

2123
private object Utils {
2224

23-
def modAck[F[_]: Async](
25+
def modAck[F[_]: Async: Logger](
2426
subscription: PubsubSourceConfigV2.Subscription,
2527
stub: SubscriberStub,
2628
ackIds: Vector[String],
2729
duration: FiniteDuration,
2830
channelAffinity: Int
29-
): F[Unit] = {
30-
val request = ModifyAckDeadlineRequest.newBuilder
31-
.setSubscription(subscription.show)
32-
.addAllAckIds(ackIds.asJava)
33-
.setAckDeadlineSeconds(duration.toSeconds.toInt)
34-
.build
35-
val context = GrpcCallContext.createDefault.withChannelAffinity(channelAffinity)
36-
for {
37-
apiFuture <- Sync[F].delay(stub.modifyAckDeadlineCallable.futureCall(request, context))
38-
_ <- FutureInterop.fromFuture(apiFuture)
39-
} yield ()
40-
}
31+
): F[Unit] =
32+
ackIds.grouped(1000).toVector.traverse_ { group =>
33+
val request = ModifyAckDeadlineRequest.newBuilder
34+
.setSubscription(subscription.show)
35+
.addAllAckIds(group.asJava)
36+
.setAckDeadlineSeconds(duration.toSeconds.toInt)
37+
.build
38+
val context = GrpcCallContext.createDefault.withChannelAffinity(channelAffinity)
39+
val io = for {
40+
apiFuture <- Sync[F].delay(stub.modifyAckDeadlineCallable.futureCall(request, context))
41+
_ <- FutureInterop.fromFuture(apiFuture)
42+
} yield ()
43+
44+
io.retryingOnTransientGrpcFailures
45+
.recoveringOnGrpcInvalidArgument { s =>
46+
// This can happen if ack IDs were acked before we modAcked
47+
Logger[F].info(s"Ignoring error from GRPC when modifying ack IDs: ${s.getDescription}")
48+
}
49+
}
4150

4251
}

0 commit comments

Comments
 (0)