diff --git a/core/src/main/scala/io/chrisdavenport/rediculous/RedisConnection.scala b/core/src/main/scala/io/chrisdavenport/rediculous/RedisConnection.scala index 3631cb4..ed7a028 100644 --- a/core/src/main/scala/io/chrisdavenport/rediculous/RedisConnection.scala +++ b/core/src/main/scala/io/chrisdavenport/rediculous/RedisConnection.scala @@ -20,6 +20,7 @@ import java.time.Instant import _root_.io.chrisdavenport.rediculous.cluster.ClusterCommands.ClusterSlots import fs2.io.net.SocketGroupCompanionPlatform import scodec.bits.ByteVector +import cats.effect.std.Supervisor trait RedisConnection[F[_]]{ def runRequest( @@ -42,11 +43,12 @@ object RedisConnection{ } } private[rediculous] case class PooledConnection[F[_]: Concurrent]( - pool: KeyPool[F, Unit, (Socket[F], F[Unit])] + pool: KeyPool[F, Unit, (Socket[F], F[Unit])], + supervisor: Supervisor[F] ) extends RedisConnection[F]{ def runRequest(inputs: Chunk[NonEmptyList[ByteVector]], key: Option[ByteVector]): F[Chunk[Resp]] = { val chunk = Chunk.seq(inputs.toList.map(Resp.renderRequest)) - def withSocket(socket: Socket[F]): F[Chunk[Resp]] = explicitPipelineRequest[F](socket, chunk) + def withSocket(socket: Socket[F]): F[Chunk[Resp]] = makeSoftCancelable(explicitPipelineRequest[F](socket, chunk), supervisor) Functor[KeyPool[F, Unit, *]].map(pool)(_._1).take(()).use{ m => withSocket(m.value).attempt.flatTap{ case Left(_) => m.canBeReused.set(Reusable.DontReuse) @@ -56,11 +58,11 @@ object RedisConnection{ } } - private[rediculous] case class DirectConnection[F[_]: Concurrent](socket: Socket[F]) extends RedisConnection[F]{ + private[rediculous] case class DirectConnection[F[_]: Concurrent](socket: Socket[F], supervisor: Supervisor[F]) extends RedisConnection[F]{ def runRequest(inputs: Chunk[NonEmptyList[ByteVector]], key: Option[ByteVector]): F[Chunk[Resp]] = { val chunk = Chunk.seq(inputs.toList.map(Resp.renderRequest)) def withSocket(socket: Socket[F]): F[Chunk[Resp]] = explicitPipelineRequest[F](socket, chunk) - withSocket(socket) + makeSoftCancelable(withSocket(socket), supervisor) } } @@ -75,6 +77,11 @@ object RedisConnection{ } } + private def makeSoftCancelable[F[_]: Concurrent, A](fa: F[A], supervisor: Supervisor[F]): F[A] = { + supervisor.supervise(fa) + .flatMap(_.joinWith(Concurrent[F].raiseError(new java.util.concurrent.CancellationException("Outcome was Canceled")))) + } + // Guarantees With Socket That Each Call Receives a Response // Chunk must be non-empty but to do so incurs a penalty private[rediculous] def explicitPipelineRequest[F[_]: Concurrent](socket: Socket[F], calls: Chunk[Resp], maxBytes: Int = 16 * 1024 * 1024, timeout: Option[FiniteDuration] = 5.seconds.some): F[Chunk[Resp]] = { @@ -169,9 +176,10 @@ object RedisConnection{ def build: Resource[F,RedisConnection[F]] = for { + supervisor <- Supervisor[F] socket <- sg.client(SocketAddress(host,port), Nil) out <- elevateSocket(socket, tlsContext, tlsParameters) - } yield RedisConnection.DirectConnection(out) + } yield RedisConnection.DirectConnection(out, supervisor) } def pool[F[_]: Async]: PooledConnectionBuilder[F] = @@ -213,13 +221,16 @@ object RedisConnection{ def withSocketGroup(sg: SocketGroup[F]) = copy(sg = sg) def build: Resource[F,RedisConnection[F]] = + ( KeyPoolBuilder[F, Unit, (Socket[F], F[Unit])]( {_ => sg.client(SocketAddress(host,port), Nil) .flatMap(elevateSocket(_, tlsContext, tlsParameters)) .allocated }, { case (_, shutdown) => shutdown} - ).build.map(PooledConnection[F](_)) + ).build, + Supervisor[F] + ).mapN(PooledConnection[F](_, _)) } def queued[F[_]: Async]: QueuedConnectionBuilder[F] = @@ -392,6 +403,7 @@ object RedisConnection{ def build: Resource[F,RedisConnection[F]] = { for { + supervisor <- Supervisor[F] keypool <- KeyPoolBuilder[F, (Host, Port), (Socket[F], F[Unit])]( {(t: (Host, Port)) => sg.client(SocketAddress(host,port), Nil) .flatMap(elevateSocket(_, tlsContext, tlsParameters)) @@ -401,7 +413,7 @@ object RedisConnection{ ).build // Cluster Topology Acquisition and Management - sockets <- Resource.eval(keypool.take((host, port)).map(_.value._1).map(DirectConnection(_)).use(ClusterCommands.clusterslots[Redis[F, *]].run(_))) + sockets <- Resource.eval(keypool.take((host, port)).map(_.value._1).map(DirectConnection(_, supervisor)).use(ClusterCommands.clusterslots[Redis[F, *]].run(_))) now <- Resource.eval(Temporal[F].realTime.map(_.toMillis)) refreshLock <- Resource.eval(Semaphore[F](1L)) refTopology <- Resource.eval(Ref[F].of((sockets, now))) @@ -420,7 +432,7 @@ object RedisConnection{ case ((_, setAt), now) if setAt >= (now - cacheTopologySeconds.toMillis) => Applicative[F].unit case ((l, _), _) => val nelActions: NonEmptyList[F[ClusterSlots]] = l.map{ case (host, port) => - keypool.take((host, port)).map(_.value._1).map(DirectConnection(_)).use(ClusterCommands.clusterslots[Redis[F, *]].run(_)) + keypool.take((host, port)).map(_.value._1).map(DirectConnection(_, supervisor)).use(ClusterCommands.clusterslots[Redis[F, *]].run(_)) } raceNThrowFirst(nelActions) .flatMap(s => Clock[F].realTime.map(_.toMillis).flatMap(now => refTopology.set((s,now)))) diff --git a/core/src/main/scala/io/chrisdavenport/rediculous/RedisPubSub.scala b/core/src/main/scala/io/chrisdavenport/rediculous/RedisPubSub.scala index 3785e60..318a294 100644 --- a/core/src/main/scala/io/chrisdavenport/rediculous/RedisPubSub.scala +++ b/core/src/main/scala/io/chrisdavenport/rediculous/RedisPubSub.scala @@ -235,7 +235,7 @@ object RedisPubSub { } } } - case RedisConnection.PooledConnection(pool) => + case RedisConnection.PooledConnection(pool, _) => pool.take(()).map(_.map(_._1)).flatMap{managed => val messagesR = Concurrent[F].ref(Map[String, RedisPubSub.PubSubMessage => F[Unit]]()) val onNonMessageR = Concurrent[F].ref((_: PubSubReply) => Applicative[F].unit) @@ -247,7 +247,7 @@ object RedisPubSub { } } } - case RedisConnection.DirectConnection(s) => + case RedisConnection.DirectConnection(s, _) => val messagesR = Concurrent[F].ref(Map[String, RedisPubSub.PubSubMessage => F[Unit]]()) val onNonMessageR = Concurrent[F].ref((_: PubSubReply) => Applicative[F].unit) val onUnhandledMessageR = Concurrent[F].ref((_: PubSubMessage) => Applicative[F].unit) diff --git a/core/src/test/scala/io/chrisdavenport/rediculous/RedisCommandsSpec.scala b/core/src/test/scala/io/chrisdavenport/rediculous/RedisCommandsSpec.scala index 0223cc3..f83313d 100644 --- a/core/src/test/scala/io/chrisdavenport/rediculous/RedisCommandsSpec.scala +++ b/core/src/test/scala/io/chrisdavenport/rediculous/RedisCommandsSpec.scala @@ -182,4 +182,39 @@ class RedisCommandsSpec extends CatsEffectSuite { action.run(connection) } } + + test("early termination"){ + redisConnection().flatMap{ connection => + val msg1 = "msg1" -> "msg1" + val msg2 = "msg2" -> "msg2" + val msg3 = "msg3" -> "msg3" + + val xopts = + RedisCommands.XReadOpts.default + // .copy(blockMillisecond = 0L.some, 1L.some) + .copy(count = 1L.some, blockMillisecond = 1000L.some) + + val offset: Set[RedisCommands.StreamOffset] = Set(RedisCommands.StreamOffset.From("foo", "$")) + + val extract = (resp: Option[List[RedisCommands.XReadResponse]]) => + resp.flatMap(_.headOption).flatMap(_.records.headOption).flatMap(_.keyValues.headOption) + + val action = + for { + _ <- ( + RedisCommands.xadd[RedisIO]("foo", List(msg1)), + RedisCommands.xadd[RedisIO]("foo", List(msg2)), + RedisCommands.xadd[RedisIO]("foo", List(msg3)) + ).tupled.run(connection) + msg1 <- RedisCommands.xread[RedisIO](offset, xopts).run(connection).timeout(50.milli).attempt.map{ + case Right(resp) => extract(resp) + case Left(_) => None + } + empty <- RedisCommands.xread[RedisIO](offset, xopts).run(connection).timeout(50.milli).replicateA(100).attempt + _ <- RedisCommands.del[RedisIO]("foo").run(connection) + } yield msg1 + + action.assertEquals(None) + } + } }