From 6246f6a1be04931272aae27eae8715d160216f4a Mon Sep 17 00:00:00 2001 From: Quinten Parker <77176931+quintenp01@users.noreply.github.com> Date: Wed, 5 Jun 2024 10:03:04 -0700 Subject: [PATCH] Restart watch stream on error in WatcherService (#5486) * Restart watch stream on error in WatcherService * Restart watch stream on error in WatcherService * apply scalafmt --- .../core/service/WatcherService.scala | 103 ++++++++++-------- .../core/service/WatcherServiceTests.scala | 24 ++++ 2 files changed, 84 insertions(+), 43 deletions(-) diff --git a/common/scala/src/main/scala/org/apache/openwhisk/core/service/WatcherService.scala b/common/scala/src/main/scala/org/apache/openwhisk/core/service/WatcherService.scala index e5f3397da99..c6277a77f05 100644 --- a/common/scala/src/main/scala/org/apache/openwhisk/core/service/WatcherService.scala +++ b/common/scala/src/main/scala/org/apache/openwhisk/core/service/WatcherService.scala @@ -19,7 +19,7 @@ package org.apache.openwhisk.core.service import akka.actor.{Actor, ActorRef, ActorSystem, Props} import com.ibm.etcd.api.Event.EventType -import com.ibm.etcd.client.kv.WatchUpdate +import com.ibm.etcd.client.kv.KvClient import org.apache.openwhisk.common.{GracefulShutdown, Logging} import org.apache.openwhisk.core.etcd.EtcdClient import org.apache.openwhisk.core.etcd.EtcdType._ @@ -35,6 +35,8 @@ case class WatchEndpoint(key: String, listenEvents: Set[EtcdEvent] = Set.empty) case class UnwatchEndpoint(watchKey: String, isPrefix: Boolean, watchName: String, needFeedback: Boolean = false) +case object RestartWatcher + // the watchKey is the string user want to watch, it can be a prefix, the key is a record's key in Etcd // so if `isPrefix = true`, the `watchKey != key`, else the `watchKey == key` sealed abstract class WatchEndpointOperation(val watchKey: String, @@ -70,49 +72,58 @@ class WatcherService(etcdClient: EtcdClient)(implicit logging: Logging, actorSys private[service] val prefixPutWatchers = TrieMap[WatcherKey, ActorRef]() private[service] val prefixDeleteWatchers = TrieMap[WatcherKey, ActorRef]() - private val watcher = etcdClient.watchAllKeys { res: WatchUpdate => - res.getEvents.asScala.foreach { event => - event.getType match { - case EventType.DELETE => - val key = ByteStringToString(event.getPrevKv.getKey) - val value = ByteStringToString(event.getPrevKv.getValue) - val watchEvent = WatchEndpointRemoved(key, key, value, false) - deleteWatchers - .foreach { watcher => - if (watcher._1.watchKey == key) { - watcher._2 ! watchEvent - } - } - prefixDeleteWatchers - .foreach { watcher => - if (key.startsWith(watcher._1.watchKey)) { - watcher._2 ! WatchEndpointRemoved(watcher._1.watchKey, key, value, true) - } - } - case EventType.PUT => - val key = ByteStringToString(event.getKv.getKey) - val value = ByteStringToString(event.getKv.getValue) - val watchEvent = WatchEndpointInserted(key, key, value, false) - putWatchers - .foreach { watcher => - if (watcher._1.watchKey == key) { - watcher._2 ! watchEvent - } - } - prefixPutWatchers - .foreach { watcher => - if (key.startsWith(watcher._1.watchKey)) { - watcher._2 ! WatchEndpointInserted(watcher._1.watchKey, key, value, true) - } - } - case msg => - logging.debug(this, s"watch event received: $msg.") - } - } - + private def startWatch(): KvClient.Watch = { + etcdClient.watchAllKeys( + res => + res.getEvents.asScala.foreach { event => + event.getType match { + case EventType.DELETE => + val key = ByteStringToString(event.getPrevKv.getKey) + val value = ByteStringToString(event.getPrevKv.getValue) + val watchEvent = WatchEndpointRemoved(key, key, value, false) + deleteWatchers + .foreach { watcher => + if (watcher._1.watchKey == key) { + watcher._2 ! watchEvent + } + } + prefixDeleteWatchers + .foreach { watcher => + if (key.startsWith(watcher._1.watchKey)) { + watcher._2 ! WatchEndpointRemoved(watcher._1.watchKey, key, value, true) + } + } + case EventType.PUT => + val key = ByteStringToString(event.getKv.getKey) + val value = ByteStringToString(event.getKv.getValue) + val watchEvent = WatchEndpointInserted(key, key, value, false) + putWatchers + .foreach { watcher => + if (watcher._1.watchKey == key) { + watcher._2 ! watchEvent + } + } + prefixPutWatchers + .foreach { watcher => + if (key.startsWith(watcher._1.watchKey)) { + watcher._2 ! WatchEndpointInserted(watcher._1.watchKey, key, value, true) + } + } + case msg => + logging.debug(this, s"watch event received: $msg.") + } + }, + error => { + logging.error(this, s"encountered error, restarting watcher service: $error") + self ! RestartWatcher + }, + () => { + logging.warn(this, s"watch stream completed, restarting watcher service") + self ! RestartWatcher + }) } - override def receive: Receive = { + private def watchBehavior(watcher: KvClient.Watch): Receive = { case request: WatchEndpoint => logging.info(this, s"watch endpoint: $request") val watcherKey = WatcherKey(request.key, request.name) @@ -143,6 +154,10 @@ class WatcherService(etcdClient: EtcdClient)(implicit logging: Logging, actorSys if (request.needFeedback) sender ! WatcherClosed(request.watchKey, request.isPrefix) + case RestartWatcher => + watcher.close() + context.become(watchBehavior(startWatch())) + case GracefulShutdown => watcher.close() putWatchers.clear() @@ -150,8 +165,10 @@ class WatcherService(etcdClient: EtcdClient)(implicit logging: Logging, actorSys prefixPutWatchers.clear() prefixDeleteWatchers.clear() } -} + override def receive: Receive = watchBehavior(startWatch()) + +} object WatcherService { def props(etcdClient: EtcdClient)(implicit logging: Logging, actorSystem: ActorSystem): Props = { Props(new WatcherService(etcdClient)) diff --git a/tests/src/test/scala/org/apache/openwhisk/core/service/WatcherServiceTests.scala b/tests/src/test/scala/org/apache/openwhisk/core/service/WatcherServiceTests.scala index e015fec41f8..6c59cfab2ec 100644 --- a/tests/src/test/scala/org/apache/openwhisk/core/service/WatcherServiceTests.scala +++ b/tests/src/test/scala/org/apache/openwhisk/core/service/WatcherServiceTests.scala @@ -243,6 +243,26 @@ class WatcherServiceTests service.underlyingActor.deleteWatchers.size shouldBe 3 } + it should "restart underlying etcd watch if error occurs" in { + val etcdClient = new MockWatchClient(client)(ece) + val key = "testKey" + val value = "testValue" + + val probe = TestProbe() + val service = TestActorRef(new WatcherService(etcdClient)) + + etcdClient.onNext should not be null + etcdClient.onError should not be null + etcdClient.watchAllKeysCallCount shouldBe 1 + + val t = new Throwable("error") + etcdClient.onError(t) + + etcdClient.onNext should not be null + etcdClient.onError should not be null + etcdClient.watchAllKeysCallCount shouldBe 2 + } + } class mockWatchUpdate extends WatchUpdate { @@ -259,9 +279,13 @@ class mockWatchUpdate extends WatchUpdate { class MockWatchClient(client: Client)(ece: ExecutionContextExecutor) extends EtcdClient(client)(ece) { var onNext: WatchUpdate => Unit = null + var onError: Throwable => Unit = null + var watchAllKeysCallCount = 0 override def watchAllKeys(next: WatchUpdate => Unit, error: Throwable => Unit, completed: () => Unit): Watch = { onNext = next + onError = error + watchAllKeysCallCount += 1 new Watch { override def close(): Unit = {}