Skip to content

Commit

Permalink
Enable disablement of cost catalog, only keep google client open as n…
Browse files Browse the repository at this point in the history
…eeded
  • Loading branch information
jgainerdewar committed Oct 2, 2024
1 parent e93b078 commit de345c2
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 31 deletions.
2 changes: 2 additions & 0 deletions core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,11 @@ services {
}
}

// When enabled, Cromwell will store vmCostPerHour metadata for GCP tasks
GcpCostCatalogService {
class = "cromwell.services.cost.GcpCostCatalogService"
config {
enabled = false
catalogExpirySeconds = 86400
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package cromwell.services.cost
import com.typesafe.config.Config
import net.ceedubs.ficus.Ficus._

final case class CostCatalogConfig(catalogExpirySeconds: Int)
final case class CostCatalogConfig(enabled: Boolean, catalogExpirySeconds: Int)

object CostCatalogConfig {
def apply(config: Config): CostCatalogConfig = CostCatalogConfig(config.as[Int]("catalogExpirySeconds"))
def apply(config: Config): CostCatalogConfig =
CostCatalogConfig(config.as[Boolean]("enabled"), config.as[Int]("catalogExpirySeconds"))
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import cromwell.util.GracefulShutdownHelper.ShutdownCommand
import java.time.{Duration, Instant}
import scala.jdk.CollectionConverters.IterableHasAsScala
import java.time.temporal.ChronoUnit.SECONDS
import scala.util.Using

case class CostCatalogKey(machineType: MachineType,
usageType: UsageType,
Expand Down Expand Up @@ -67,6 +68,9 @@ case class GcpCostLookupRequest(vmInfo: InstantiatedVmInfo, replyTo: ActorRef) e
case class GcpCostLookupResponse(vmInfo: InstantiatedVmInfo, calculatedCost: ErrorOr[BigDecimal])
case class CostCatalogValue(catalogObject: Sku)
case class ExpiringGcpCostCatalog(catalog: Map[CostCatalogKey, CostCatalogValue], fetchTime: Instant)
object ExpiringGcpCostCatalog {
def empty: ExpiringGcpCostCatalog = ExpiringGcpCostCatalog(Map.empty, Instant.MIN)
}

object GcpCostCatalogService {
// Can be gleaned by using googleClient.listServices
Expand Down Expand Up @@ -122,37 +126,38 @@ class GcpCostCatalogService(serviceConfig: Config, globalConfig: Config, service
extends Actor
with LazyLogging {

private val maxCatalogLifetime: Duration =
Duration.of(CostCatalogConfig(serviceConfig).catalogExpirySeconds.longValue, SECONDS)
private val costCatalogConfig = CostCatalogConfig(serviceConfig)

private var googleClient: Option[CloudCatalogClient] = Option.empty
private val maxCatalogLifetime: Duration =
Duration.of(costCatalogConfig.catalogExpirySeconds.longValue, SECONDS)

// Cached catalog. Refreshed lazily when older than maxCatalogLifetime.
private var costCatalog: Option[ExpiringGcpCostCatalog] = Option.empty
private var costCatalog: ExpiringGcpCostCatalog = ExpiringGcpCostCatalog.empty

/**
* Returns the SKU for a given key, if it exists
*/
def getSku(key: CostCatalogKey): Option[CostCatalogValue] = getOrFetchCachedCatalog().get(key)

protected def fetchNewCatalog: Iterable[Sku] = {
if (googleClient.isEmpty) {
// We use option rather than lazy here so that the client isn't created when it is told to shutdown (see receive override)
googleClient = Some(CloudCatalogClient.create)
protected def fetchSkuIterable(googleClient: CloudCatalogClient): Iterable[Sku] =
makeInitialWebRequest(googleClient).iterateAll().asScala

Check warning on line 143 in services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala

View check run for this annotation

Codecov / codecov/patch

services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala#L143

Added line #L143 was not covered by tests

private def fetchNewCatalog: ExpiringGcpCostCatalog =
Using.resource(CloudCatalogClient.create) { googleClient =>
val skus = fetchSkuIterable(googleClient)
ExpiringGcpCostCatalog(processCostCatalog(skus), Instant.now())

Check warning on line 148 in services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala

View check run for this annotation

Codecov / codecov/patch

services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala#L146-L148

Added lines #L146 - L148 were not covered by tests
}
makeInitialWebRequest(googleClient.get).iterateAll().asScala
}

def getCatalogAge: Duration =
Duration.between(costCatalog.map(c => c.fetchTime).getOrElse(Instant.ofEpochMilli(0)), Instant.now())
private def isCurrentCatalogExpired: Boolean = getCatalogAge.toNanos > maxCatalogLifetime.toNanos
def getCatalogAge: Duration = Duration.between(costCatalog.fetchTime, Instant.now())

Check warning on line 151 in services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala

View check run for this annotation

Codecov / codecov/patch

services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala#L151

Added line #L151 was not covered by tests

private def isCurrentCatalogExpired: Boolean = getCatalogAge.toSeconds > maxCatalogLifetime.toSeconds

Check warning on line 153 in services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala

View check run for this annotation

Codecov / codecov/patch

services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala#L153

Added line #L153 was not covered by tests

private def getOrFetchCachedCatalog(): Map[CostCatalogKey, CostCatalogValue] = {
if (costCatalog.isEmpty || isCurrentCatalogExpired) {
if (isCurrentCatalogExpired) {

Check warning on line 156 in services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala

View check run for this annotation

Codecov / codecov/patch

services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala#L156

Added line #L156 was not covered by tests
logger.info("Fetching a new GCP public cost catalog.")
costCatalog = Some(ExpiringGcpCostCatalog(processCostCatalog(fetchNewCatalog), Instant.now()))
costCatalog = fetchNewCatalog

Check warning on line 158 in services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala

View check run for this annotation

Codecov / codecov/patch

services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala#L158

Added line #L158 was not covered by tests
}
costCatalog.map(expiringCatalog => expiringCatalog.catalog).getOrElse(Map.empty)
costCatalog.catalog

Check warning on line 160 in services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala

View check run for this annotation

Codecov / codecov/patch

services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala#L160

Added line #L160 was not covered by tests
}

/**
Expand Down Expand Up @@ -225,12 +230,12 @@ class GcpCostCatalogService(serviceConfig: Config, globalConfig: Config, service

def serviceRegistryActor: ActorRef = serviceRegistry
override def receive: Receive = {
case GcpCostLookupRequest(vmInfo, replyTo) =>
case GcpCostLookupRequest(vmInfo, replyTo) if costCatalogConfig.enabled =>
val calculatedCost = calculateVmCostPerHour(vmInfo)
val response = GcpCostLookupResponse(vmInfo, calculatedCost)
replyTo ! response

Check warning on line 236 in services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala

View check run for this annotation

Codecov / codecov/patch

services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala#L234-L236

Added lines #L234 - L236 were not covered by tests
case GcpCostLookupRequest(_, _) => // do nothing if we're disabled
case ShutdownCommand =>
googleClient.foreach(client => client.shutdownNow())
context stop self
case other =>
logger.error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package cromwell.services.cost

import akka.actor.ActorRef
import akka.testkit.{ImplicitSender, TestActorRef, TestProbe}
import com.google.cloud.billing.v1.Sku
import com.google.cloud.billing.v1.{CloudCatalogClient, Sku}
import com.typesafe.config.{Config, ConfigFactory}
import cromwell.core.TestKitSuite
import cromwell.util.GracefulShutdownHelper.ShutdownCommand
Expand All @@ -14,10 +14,13 @@ import org.scalatest.prop.TableDrivenPropertyChecks
import java.time.Duration
import java.io.{File, FileInputStream, FileOutputStream}
import scala.collection.mutable.ListBuffer
import scala.util.Using

object GcpCostCatalogServiceSpec {
val catalogExpirySeconds: Long = 1 // Short duration so we can do a cache expiry test
val config: Config = ConfigFactory.parseString(s"catalogExpirySeconds = $catalogExpirySeconds")
val config: Config = ConfigFactory.parseString(
s"catalogExpirySeconds = $catalogExpirySeconds, enabled = true"
)
val mockTestDataFilePath: String = "services/src/test/scala/cromwell/services/cost/serializedSkuList.testData"
}
class GcpCostCatalogServiceTestActor(serviceConfig: Config, globalConfig: Config, serviceRegistry: ActorRef)
Expand All @@ -35,18 +38,21 @@ class GcpCostCatalogServiceTestActor(serviceConfig: Config, globalConfig: Config
fis.close()
skuList.toSeq
}
def saveMockData(): Unit = {
val fetchedData = super.fetchNewCatalog
val fos = new FileOutputStream(new File(GcpCostCatalogServiceSpec.mockTestDataFilePath))
fetchedData.foreach { sku =>
sku.writeDelimitedTo(fos)
def saveMockData(): Unit =
Using.resources(CloudCatalogClient.create,
new FileOutputStream(new File(GcpCostCatalogServiceSpec.mockTestDataFilePath))
) { (googleClient, fos) =>
val skus = super.fetchSkuIterable(googleClient)
skus.foreach { sku =>
sku.writeDelimitedTo(fos)
}
}
fos.close()
}

override def receive: Receive = { case ShutdownCommand =>
context stop self
}
override def fetchNewCatalog: Iterable[Sku] = loadMockData

override def fetchSkuIterable(client: CloudCatalogClient): Iterable[Sku] = loadMockData
}

class GcpCostCatalogServiceSpec
Expand Down Expand Up @@ -90,7 +96,7 @@ class GcpCostCatalogServiceSpec
freshActor.getCatalogAge.toNanos should (be < shortDuration.toNanos)

// Simulate the cached catalog living longer than its lifetime
Thread.sleep(shortDuration.toMillis)
Thread.sleep(shortDuration.plus(shortDuration).toMillis)

// Confirm that the catalog is old
freshActor.getCatalogAge.toNanos should (be > shortDuration.toNanos)
Expand Down

0 comments on commit de345c2

Please sign in to comment.