diff --git a/sdk/src/main/java/com/apphud/sdk/ApphudInternal+Products.kt b/sdk/src/main/java/com/apphud/sdk/ApphudInternal+Products.kt index 96301169..bd890dd0 100644 --- a/sdk/src/main/java/com/apphud/sdk/ApphudInternal+Products.kt +++ b/sdk/src/main/java/com/apphud/sdk/ApphudInternal+Products.kt @@ -6,31 +6,16 @@ import com.android.billingclient.api.ProductDetails import com.apphud.sdk.domain.ApphudGroup import com.apphud.sdk.domain.ApphudPaywall import com.apphud.sdk.domain.ApphudPlacement +import com.apphud.sdk.internal.data.ProductLoadingState import kotlinx.coroutines.async import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.delay import kotlinx.coroutines.launch -import java.util.concurrent.CopyOnWriteArrayList -internal var productsStatus = ApphudProductsStatus.none -internal var respondedWithProducts = false -private var loadingStoreProducts = false -internal var productsResponseCode = BillingClient.BillingResponseCode.OK -private var loadedDetails = CopyOnWriteArrayList() - -// to avoid Google servers spamming if there is no productDetails added at all -internal var currentPoductsLoadingCounts: Int = 0 -internal var totalPoductsLoadingCounts: Int = 0 const val MAX_TOTAL_PRODUCTS_RETRIES: Int = 100 -internal enum class ApphudProductsStatus { - none, - loading, - loaded, - failed -} - internal fun ApphudInternal.finishedLoadingProducts(): Boolean { - return productsStatus == ApphudProductsStatus.loaded || productsStatus == ApphudProductsStatus.failed + return productRepository.state.value.isFinished } internal fun ApphudInternal.shouldLoadProducts(): Boolean { @@ -39,35 +24,36 @@ internal fun ApphudInternal.shouldLoadProducts(): Boolean { return false } - return when (productsStatus) { - ApphudProductsStatus.none -> true - ApphudProductsStatus.loading -> false - else -> { - productDetails.isEmpty() && totalPoductsLoadingCounts < MAX_TOTAL_PRODUCTS_RETRIES + return when (val state = productRepository.state.value) { + is ProductLoadingState.Idle -> true + is ProductLoadingState.Loading -> false + is ProductLoadingState.Success -> false + is ProductLoadingState.Failed -> { + state.cachedProducts.isEmpty() && state.totalRetryCount < MAX_TOTAL_PRODUCTS_RETRIES } } } internal fun ApphudInternal.loadProducts() { if (!shouldLoadProducts()) { - if (totalPoductsLoadingCounts >= MAX_TOTAL_PRODUCTS_RETRIES) { + val state = productRepository.state.value + if (state is ProductLoadingState.Failed && state.totalRetryCount >= MAX_TOTAL_PRODUCTS_RETRIES) { respondWithProducts() } return } - productsStatus = ApphudProductsStatus.loading + productRepository.transitionToLoading() ApphudLog.logI("Loading ProductDetails from the Store") coroutineScope.launch(errorHandler) { - fetchProducts() + val responseCode = fetchProducts() - if (productsResponseCode != APPHUD_NO_REQUEST) { - totalPoductsLoadingCounts += 1 - currentPoductsLoadingCounts += 1 + if (responseCode == APPHUD_NO_REQUEST) { + productRepository.rollbackRetryCounters() } - if (isRetriableProductsRequest() && shouldRetryRequest("billing") && currentPoductsLoadingCounts < APPHUD_DEFAULT_RETRIES) { + if (isRetriableProductsRequest() && shouldRetryRequest("billing")) { retryProductsLoad() } else { ApphudLog.log("Finished Loading Product Details") @@ -77,42 +63,33 @@ internal fun ApphudInternal.loadProducts() { } internal fun respondWithProducts() { - respondedWithProducts = true + ApphudInternal.productRepository.markAsResponded() ApphudInternal.mainScope.launch { - ApphudInternal.notifyLoadingCompleted(null, loadedDetails.toList(), false, false) + ApphudInternal.notifyLoadingCompleted(null, ApphudInternal.productRepository.state.value.products, false, false) } } internal fun isRetriableProductsRequest(): Boolean { - return ApphudInternal.productDetails.isEmpty() && productsStatus == ApphudProductsStatus.failed && isRetriableErrorCode( - productsResponseCode - ) && ApphudInternal.isActive && !ApphudUtils.isEmulator() + val state = ApphudInternal.productRepository.state.value + return state is ProductLoadingState.Failed && + state.isRetriable && + ApphudInternal.isActive && + !ApphudUtils.isEmulator() } -internal fun retryProductsLoad() { - val delay: Long = 300 +internal suspend fun retryProductsLoad() { + val delayMs: Long = 300 + val state = ApphudInternal.productRepository.state.value + val responseCode = if (state is ProductLoadingState.Failed) state.responseCode else BillingResponseCode.OK ApphudLog.logI( "Load products from store status code: (${ - ApphudBillingResponseCodes.getName( - productsResponseCode - ) - }), will retry in $delay ms" + ApphudBillingResponseCodes.getName(responseCode) + }), will retry in $delayMs ms" ) - Thread.sleep(delay) + delay(delayMs) ApphudInternal.loadProducts() } -private fun isRetriableErrorCode(code: Int): Boolean { - return listOf( - BillingClient.BillingResponseCode.NETWORK_ERROR, - BillingClient.BillingResponseCode.SERVICE_TIMEOUT, - BillingClient.BillingResponseCode.SERVICE_DISCONNECTED, - BillingClient.BillingResponseCode.SERVICE_UNAVAILABLE, - BillingClient.BillingResponseCode.BILLING_UNAVAILABLE, - BillingClient.BillingResponseCode.ERROR - ).contains(code) -} - internal suspend fun ApphudInternal.fetchProducts(): Int { val user = userRepository.getCurrentUser() val userPaywalls = user?.paywalls.orEmpty() @@ -146,12 +123,12 @@ private fun allAvailableProductIds( placements.map { pl -> pl.paywall?.products?.map { it.productId } ?: listOf() }.flatten().toMutableList() idsGroups.forEach { - if (!ids.contains(it) && it != null) { + if (!ids.contains(it)) { ids.add(it) } } idsFromPlacements.forEach { - if (!ids.contains(it) && it != null) { + if (!ids.contains(it)) { ids.add(it) } } @@ -163,34 +140,26 @@ internal suspend fun ApphudInternal.fetchDetails( ids: List, loadingAll: Boolean = false, ): Pair?> { - if (loadingAll) { - loadedDetails.clear() - } - // Assuming ProductDetails has a property 'id' that corresponds to the product ID - val existingIds = productDetails.map { it.productId } + val tempLoadedDetails = mutableListOf() + + val existingIds = productRepository.state.value.products.map { it.productId } val idsToFetch = ids.filterNot { existingIds.contains(it) } if (existingIds.isNotEmpty() && idsToFetch.isEmpty()) { - // All Ids already loaded, return OK - if (loadingAll) { - productsStatus = ApphudProductsStatus.loaded - } + // All requested IDs already loaded in state + // Don't call transitionToSuccess - it would replace all products with just the requested subset! return Pair(BillingResponseCode.OK, null) } else if (idsToFetch.isEmpty()) { // If none ids to load, return immediately + // This happens when user/paywall has no products configured ApphudLog.log("NO REQUEST TO FETCH PRODUCT DETAILS") - if (loadingAll) { - productsStatus = ApphudProductsStatus.loaded - } + // Don't transition to Success (requires at least one product) + // Leave state as-is return Pair(APPHUD_NO_REQUEST, null) } ApphudLog.log("Fetching Product Details: ${idsToFetch.toString()}") - loadingStoreProducts = true - if (productsStatus != ApphudProductsStatus.loading && loadingAll) { - productsStatus = ApphudProductsStatus.loading - } val startTime = System.currentTimeMillis() @@ -201,11 +170,9 @@ internal suspend fun ApphudInternal.fetchDetails( val inAppResult = async { billing.detailsEx(BillingClient.ProductType.INAPP, idsToFetch) }.await() subsResult.first?.let { subsDetails -> - // Add new subscription details if they're not already present - // CopyOnWriteArrayList is thread-safe, no synchronization needed subsDetails.forEach { detail -> - if (!loadedDetails.any { it.productId == detail.productId }) { - loadedDetails.add(detail) + if (!tempLoadedDetails.any { it.productId == detail.productId }) { + tempLoadedDetails.add(detail) } } } ?: run { @@ -215,11 +182,9 @@ internal suspend fun ApphudInternal.fetchDetails( } inAppResult.first?.let { inAppDetails -> - // Add new in-app product details if they're not already present - // CopyOnWriteArrayList is thread-safe, no synchronization needed inAppDetails.forEach { detail -> - if (!loadedDetails.any { it.productId == detail.productId }) { - loadedDetails.add(detail) + if (!tempLoadedDetails.any { it.productId == detail.productId }) { + tempLoadedDetails.add(detail) } } } ?: run { @@ -230,14 +195,24 @@ internal suspend fun ApphudInternal.fetchDetails( } val benchmark = System.currentTimeMillis() - startTime - loadingStoreProducts = false ApphudInternal.productsLoadedTime = benchmark if (loadingAll) { - productsResponseCode = responseCode - productsStatus = if (responseCode == BillingClient.BillingResponseCode.OK) ApphudProductsStatus.loaded else - ApphudProductsStatus.failed + if (responseCode == BillingClient.BillingResponseCode.OK) { + if (tempLoadedDetails.isNotEmpty()) { + productRepository.transitionToSuccess(tempLoadedDetails, loadTimeMs = benchmark) + } else { + val currentProducts = productRepository.state.value.products + if (currentProducts.isEmpty()) { + productRepository.transitionToFailed(BillingClient.BillingResponseCode.ITEM_UNAVAILABLE) + } else { + productRepository.transitionToSuccess(currentProducts, loadTimeMs = benchmark) + } + } + } else { + productRepository.transitionToFailed(responseCode) + } } - return Pair(responseCode, loadedDetails) + return Pair(responseCode, tempLoadedDetails) } diff --git a/sdk/src/main/java/com/apphud/sdk/ApphudInternal.kt b/sdk/src/main/java/com/apphud/sdk/ApphudInternal.kt index bcb48b67..60d6451e 100644 --- a/sdk/src/main/java/com/apphud/sdk/ApphudInternal.kt +++ b/sdk/src/main/java/com/apphud/sdk/ApphudInternal.kt @@ -22,6 +22,8 @@ import com.apphud.sdk.domain.PaywallEvent import com.apphud.sdk.domain.PurchaseRecordDetails import com.apphud.sdk.internal.BillingWrapper import com.apphud.sdk.internal.ServiceLocator +import com.apphud.sdk.internal.data.ProductLoadingState +import com.apphud.sdk.internal.data.ProductRepository import com.apphud.sdk.internal.domain.model.ApiKey as ApiKeyModel import com.apphud.sdk.internal.presentation.figma.FigmaWebViewActivity import com.apphud.sdk.internal.util.runCatchingCancellable @@ -69,10 +71,13 @@ internal object ApphudInternal { get() = ServiceLocator.instance.billingWrapper internal val userRepository get() = ServiceLocator.instance.userRepository + internal val productRepository + get() = ServiceLocator.instance.productRepository internal val storage: SharedPreferencesStorage get() = ServiceLocator.instance.storage internal val prevPurchases = CopyOnWriteArraySet() - internal val productDetails = CopyOnWriteArrayList() + internal val productDetails: List + get() = productRepository.state.value.products internal var isRegisteringUser = false @@ -347,7 +352,7 @@ internal object ApphudInternal { forceRegistration() }.onSuccess { if (wasDeferred) { - productsStatus = ApphudProductsStatus.none + productRepository.reset() } loadProducts() }.getOrNull() @@ -361,8 +366,6 @@ internal object ApphudInternal { fromFallback: Boolean = false, customerError: ApphudError? = null, ) { - var paywallsPrepared = true - customerError?.let { ApphudLog.logE("Customer Registration Error: ${it}") latestCustomerLoadError = it @@ -393,14 +396,6 @@ internal object ApphudInternal { customerLoaded?.let { updateUserState(it, fromFallback) - if (!fromCache && !fromFallback && it.paywalls.isEmpty()) { - /* Attention: - * If customer loaded without paywalls, do not reload paywalls from cache! - * If cache time is over, paywall from cache will be NULL - */ - paywallsPrepared = false - } - // TODO: should be called only if something changed coroutineScope.launch { delay(500) @@ -457,10 +452,10 @@ internal object ApphudInternal { private fun hasDataLoadFailed(customerError: ApphudError?) = (customerError != null && (userRepository.getCurrentUser()?.paywalls?.isEmpty() != false)) || isProductsLoadFailed() - private fun isProductsLoadFailed() = - productsStatus != ApphudProductsStatus.loading && - productsResponseCode != BillingClient.BillingResponseCode.OK && - productDetails.isEmpty() + private fun isProductsLoadFailed(): Boolean { + val state = productRepository.state.value + return state is ProductLoadingState.Failed && state.cachedProducts.isEmpty() + } private fun handleSuccessfulLoad() { val user = userRepository.getCurrentUser() @@ -487,10 +482,12 @@ internal object ApphudInternal { } private fun handleError(customerError: ApphudError?) { - val error = latestCustomerLoadError ?: customerError ?: if (productsResponseCode == APPHUD_NO_REQUEST) { - ApphudError("Paywalls load error", errorCode = productsResponseCode) + val state = productRepository.state.value + val responseCode = if (state is ProductLoadingState.Failed) state.responseCode else BillingClient.BillingResponseCode.OK + val error = latestCustomerLoadError ?: customerError ?: if (responseCode == APPHUD_NO_REQUEST) { + ApphudError("Paywalls load error", errorCode = responseCode) } else { - ApphudError("Google Billing error", errorCode = productsResponseCode) + ApphudError("Google Billing error", errorCode = responseCode) } if (offeringsPreparedCallbacks.isNotEmpty()) { @@ -509,7 +506,9 @@ internal object ApphudInternal { private fun logNotReadyState() { val user = userRepository.getCurrentUser() - ApphudLog.log("Not yet ready for callbacks invoke: isRegisteringUser: $isRegisteringUser, currentUserExist: ${user != null}, latestCustomerError: $latestCustomerLoadError, paywallsEmpty: ${user?.paywalls?.isEmpty() != false}, productsResponseCode = $productsResponseCode, productsStatus: $productsStatus, productDetailsEmpty: ${productDetails.isEmpty()}, deferred: $deferPlacements, hasRespondedToPaywallsRequest=$hasRespondedToPaywallsRequest") + val productsState = productRepository.state.value + val productsResponseCode = if (productsState is ProductLoadingState.Failed) productsState.responseCode else BillingClient.BillingResponseCode.OK + ApphudLog.log("Not yet ready for callbacks invoke: isRegisteringUser: $isRegisteringUser, currentUserExist: ${user != null}, latestCustomerError: $latestCustomerLoadError, paywallsEmpty: ${user?.paywalls?.isEmpty() != false}, productsResponseCode = $productsResponseCode, productsStatus: $productsState, productDetailsEmpty: ${productDetails.isEmpty()}, deferred: $deferPlacements, hasRespondedToPaywallsRequest=$hasRespondedToPaywallsRequest") } private fun trackAnalytics(success: Boolean) { @@ -521,7 +520,9 @@ internal object ApphudInternal { val totalLoad = (System.currentTimeMillis() - sdkLaunchedAt) val userLoad = if (firstCustomerLoadedTime != null) (firstCustomerLoadedTime!! - sdkLaunchedAt) else 0 val productsLoaded = productsLoadedTime ?: 0 - ApphudLog.logI("SDK Benchmarks: User ${userLoad}ms, Products: ${productsLoaded}ms, Total: ${totalLoad}ms, Apphud Error: ${latestCustomerLoadError?.message}, Billing Response Code: ${productsResponseCode}, ErrorCode: ${latestCustomerLoadError?.errorCode}") + val state = productRepository.state.value + val responseCode = if (state is ProductLoadingState.Failed) state.responseCode else BillingClient.BillingResponseCode.OK + ApphudLog.logI("SDK Benchmarks: User ${userLoad}ms, Products: ${productsLoaded}ms, Total: ${totalLoad}ms, Apphud Error: ${latestCustomerLoadError?.message}, Billing Response Code: ${responseCode}, ErrorCode: ${latestCustomerLoadError?.errorCode}") coroutineScope.launch { RequestManager.sendPaywallLogs( sdkLaunchedAt, @@ -530,7 +531,7 @@ internal object ApphudInternal { productsLoaded.toDouble(), totalLoad.toDouble(), latestCustomerLoadError, - productsResponseCode, + responseCode, success ) } @@ -707,7 +708,7 @@ internal object ApphudInternal { preferredTimeout?.let { this.preferredTimeout = max(it, APPHUD_DEFAULT_MAX_TIMEOUT) this.offeringsCalledAt = System.currentTimeMillis() - currentPoductsLoadingCounts = 0 + // Retry counts are now managed by state transitions } mainScope.launch { @@ -1293,10 +1294,15 @@ internal object ApphudInternal { ApphudLog.log("ServiceLocator not initialized, skip userRepository.clearUser(): ${e.message}") } + // Reset products state to Idle + runCatching { + productRepository.reset() + }.onFailure { e -> + ApphudLog.log("ServiceLocator not initialized, skip productRepository.reset(): ${e.message}") + } + ServiceLocator.clearInstance() RequestManager.cleanRegistration() - productsStatus = ApphudProductsStatus.none - productsResponseCode = BillingClient.BillingResponseCode.OK customProductsFetchedBlock = null offeringsPreparedCallbacks.clear() purchaseCallbacks.clear() @@ -1307,7 +1313,6 @@ internal object ApphudInternal { ApphudLog.log("SDK not initialized, skip storage.clean()") } prevPurchases.clear() - productDetails.clear() productGroups.set(emptyList()) pendingUserProperties.clear() allowIdentifyUser = true @@ -1393,13 +1398,6 @@ internal object ApphudInternal { } private fun updateProductState(productsLoaded: List) { - synchronized(productDetails) { - productsLoaded.forEach { detail -> - if (!productDetails.map { it.productId }.contains(detail.productId)) { - productDetails.add(detail) - } - } - } val cachedProductGroups = readGroupsFromCache() productGroups.set(cachedProductGroups.toList()) updateGroupsWithProductDetails(productGroups.get()) diff --git a/sdk/src/main/java/com/apphud/sdk/internal/ServiceLocator.kt b/sdk/src/main/java/com/apphud/sdk/internal/ServiceLocator.kt index 41d275ec..33072528 100644 --- a/sdk/src/main/java/com/apphud/sdk/internal/ServiceLocator.kt +++ b/sdk/src/main/java/com/apphud/sdk/internal/ServiceLocator.kt @@ -6,6 +6,7 @@ import com.apphud.sdk.ApphudRuleCallback import com.apphud.sdk.internal.data.local.LifecycleRepository import com.apphud.sdk.internal.data.local.LocalRulesScreenRepository import com.apphud.sdk.internal.data.local.PaywallRepository +import com.apphud.sdk.internal.data.ProductRepository import com.apphud.sdk.internal.data.mapper.CustomerMapper import com.apphud.sdk.internal.data.mapper.PaywallsMapper import com.apphud.sdk.internal.data.mapper.PlacementsMapper @@ -16,6 +17,8 @@ import com.apphud.sdk.internal.data.mapper.SubscriptionMapper import com.apphud.sdk.internal.data.network.HeadersInterceptor import com.apphud.sdk.internal.data.network.HostSwitcherInterceptor import com.apphud.sdk.internal.data.network.HttpRetryInterceptor +import com.apphud.sdk.internal.data.network.PrettyHttpLoggingInterceptor +import com.apphud.sdk.internal.data.network.PrettyJsonFormatter import com.apphud.sdk.internal.data.network.TimeoutInterceptor import com.apphud.sdk.internal.data.network.UrlProvider import com.apphud.sdk.internal.data.remote.PurchaseBodyFactory @@ -41,6 +44,7 @@ import com.apphud.sdk.managers.RequestManager import com.apphud.sdk.mappers.AttributionMapper import com.apphud.sdk.storage.SharedPreferencesStorage import com.google.gson.Gson +import com.google.gson.GsonBuilder import okhttp3.OkHttpClient import okhttp3.logging.HttpLoggingInterceptor @@ -68,6 +72,14 @@ internal class ServiceLocator( private val hostSwitcherInterceptor = HostSwitcherInterceptor(OkHttpClient(), urlProvider) private val hostSwitcherInterceptorWithoutHeaders = HostSwitcherInterceptor(OkHttpClient(), urlProvider) + private val prettyGson: Gson by lazy { + GsonBuilder().setPrettyPrinting().create() + } + private val prettyJsonFormatter: PrettyJsonFormatter by lazy { PrettyJsonFormatter(prettyGson) } + private val prettyLoggingInterceptor: PrettyHttpLoggingInterceptor by lazy { + PrettyHttpLoggingInterceptor(prettyJsonFormatter) + } + private val okHttpClient: OkHttpClient = OkHttpClient.Builder() .addInterceptor( @@ -84,6 +96,7 @@ internal class ServiceLocator( .addInterceptor(TimeoutInterceptor()) .addInterceptor(hostSwitcherInterceptor) .addInterceptor(HttpRetryInterceptor()) + .addInterceptor(prettyLoggingInterceptor) .build() private val okHttpClientWithoutHeaders: OkHttpClient = @@ -101,6 +114,7 @@ internal class ServiceLocator( .addInterceptor(TimeoutInterceptor()) .addInterceptor(hostSwitcherInterceptorWithoutHeaders) .addInterceptor(HttpRetryInterceptor()) + .addInterceptor(prettyLoggingInterceptor) .build() val remoteRepository: RemoteRepository = @@ -188,6 +202,8 @@ internal class ServiceLocator( val userRepository: UserRepository = UserRepository(userDataSource) + val productRepository: ProductRepository = ProductRepository() + val registrationUseCase: RegistrationUseCase = RegistrationUseCase( userRepository = userRepository, diff --git a/sdk/src/main/java/com/apphud/sdk/internal/data/ProductLoadingState.kt b/sdk/src/main/java/com/apphud/sdk/internal/data/ProductLoadingState.kt new file mode 100644 index 00000000..2a2cfdde --- /dev/null +++ b/sdk/src/main/java/com/apphud/sdk/internal/data/ProductLoadingState.kt @@ -0,0 +1,135 @@ +package com.apphud.sdk.internal.data + +import com.android.billingclient.api.BillingClient +import com.android.billingclient.api.ProductDetails +import com.apphud.sdk.APPHUD_DEFAULT_RETRIES +import com.apphud.sdk.MAX_TOTAL_PRODUCTS_RETRIES + +/** + * Represents all possible states of product loading from the billing library. + * This sealed class ensures type-safe state management and prevents invalid state combinations. + * + * State transitions: + * - Idle → Loading + * - Loading → Success | Failed + * - Failed → Loading (retry) + * - Success (terminal state) + */ +sealed class ProductLoadingState { + + /** + * Initial state - no loading attempt yet. + * The repository is initialized but hasn't started loading products. + */ + object Idle : ProductLoadingState() + + /** + * Currently loading products from the Google Play Billing Library. + * + * @param currentRetryCount Current attempt number for this session (0 for first try) + * @param totalRetryCount Total attempts across all app sessions (persisted) + * @param previousProducts Previously loaded products (for retry scenarios) + */ + data class Loading( + val currentRetryCount: Int = 0, + val totalRetryCount: Int = 0, + val previousProducts: List = emptyList() + ) : ProductLoadingState() + + /** + * Successfully loaded products from the billing library. + * + * @param loadedProducts List of loaded ProductDetails (must be non-empty) + * @param loadTimeMs Time taken to load in milliseconds (for analytics) + * @param respondedWithCallback Whether callbacks/listeners have been notified + */ + data class Success( + val loadedProducts: List, + val loadTimeMs: Long? = null, + val respondedWithCallback: Boolean = false + ) : ProductLoadingState() { + init { + require(loadedProducts.isNotEmpty()) { "Success state must have at least one product" } + } + } + + /** + * Failed to load products from the billing library. + * + * @param responseCode Billing response code indicating the error type + * @param cachedProducts Previously loaded products from cache (may be empty) + * @param currentRetryCount Current retry attempt for this session + * @param totalRetryCount Total attempts across all sessions + * @param respondedWithCallback Whether callbacks/listeners have been notified + */ + data class Failed( + val responseCode: Int, + val cachedProducts: List = emptyList(), + val currentRetryCount: Int = 0, + val totalRetryCount: Int = 0, + val respondedWithCallback: Boolean = false + ) : ProductLoadingState() { + + /** + * Determines if this failure can be retried based on: + * - Error code is retriable (network/service errors) + * - No cached products available (if we have products, no need to retry) + * - Haven't exceeded retry limits (per-session and total) + */ + val isRetriable: Boolean + get() = isRetriableErrorCode(responseCode) && + cachedProducts.isEmpty() && + currentRetryCount < APPHUD_DEFAULT_RETRIES && + totalRetryCount < MAX_TOTAL_PRODUCTS_RETRIES + + private fun isRetriableErrorCode(code: Int): Boolean { + return listOf( + BillingClient.BillingResponseCode.NETWORK_ERROR, + BillingClient.BillingResponseCode.SERVICE_TIMEOUT, + BillingClient.BillingResponseCode.SERVICE_DISCONNECTED, + BillingClient.BillingResponseCode.SERVICE_UNAVAILABLE, + BillingClient.BillingResponseCode.BILLING_UNAVAILABLE, + BillingClient.BillingResponseCode.ERROR + ).contains(code) + } + } + + /** + * Whether the loading process has finished (either successfully or with failure). + * Loading and Idle states return false. + */ + val isFinished: Boolean + get() = this is Success || this is Failed + + /** + * Get the product list from any state: + * - Success: returns loaded products + * - Failed: returns cached products (may be empty) + * - Loading: returns previous products (for retry scenarios) + * - Idle: returns empty list + */ + val products: List + get() = when (this) { + is Success -> loadedProducts + is Failed -> cachedProducts + is Loading -> previousProducts + else -> emptyList() + } + + /** + * Whether callbacks/listeners have been notified for this state. + * Only Success and Failed states track this. + */ + val hasRespondedWithCallback: Boolean + get() = when (this) { + is Success -> respondedWithCallback + is Failed -> respondedWithCallback + else -> false + } + + /** + * Whether products are currently being loaded from the billing library. + */ + val isLoading: Boolean + get() = this is Loading +} diff --git a/sdk/src/main/java/com/apphud/sdk/internal/data/ProductRepository.kt b/sdk/src/main/java/com/apphud/sdk/internal/data/ProductRepository.kt new file mode 100644 index 00000000..a6c3bbb2 --- /dev/null +++ b/sdk/src/main/java/com/apphud/sdk/internal/data/ProductRepository.kt @@ -0,0 +1,150 @@ +package com.apphud.sdk.internal.data + +import com.android.billingclient.api.ProductDetails +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.update + +/** + * Repository for managing product loading state + * Thread-safe with StateFlow for reactive state management + * + * Uses ProductLoadingState sealed class for type-safe state management. + */ +internal class ProductRepository { + private val _state = MutableStateFlow(ProductLoadingState.Idle) + + /** + * Observable state flow for reactive state management. + * + * Use this to: + * - Observe state changes: `state.collect { ... }` + * - Get current state: `state.value` + * - Get current products: `state.value.products` + */ + val state: StateFlow = _state.asStateFlow() + + /** + * Transition to Loading state. + * Derives retry counts from current state (MVI pattern). + * Preserves previous products for access during loading. + */ + fun transitionToLoading() { + _state.update { current -> + val previousProducts = current.products + + val (currentRetry, totalRetry) = when (current) { + is ProductLoadingState.Failed -> { + // Increment retry counters when retrying after failure + (current.currentRetryCount + 1) to (current.totalRetryCount + 1) + } + else -> { + // First attempt or loading from other states + 0 to 0 + } + } + + ProductLoadingState.Loading( + currentRetryCount = currentRetry, + totalRetryCount = totalRetry, + previousProducts = previousProducts + ) + } + } + + /** + * Transition to Success state with loaded products. + * Always merges new products with existing ones by productId (incremental loading). + * New products replace old ones with the same productId. + * To replace all products, call reset() first. + * + * @param products The loaded products to add/update + * @param loadTimeMs Optional load time for analytics + */ + fun transitionToSuccess(products: List, loadTimeMs: Long? = null) { + _state.update { current -> + // New products first, so distinctBy keeps new version on productId collision + val mergedProducts = (products + current.products).distinctBy { it.productId } + ProductLoadingState.Success( + loadedProducts = mergedProducts, + loadTimeMs = loadTimeMs, + respondedWithCallback = false + ) + } + } + + /** + * Transition to Failed state with error information. + * Derives cached products and retry counts from current state (MVI pattern). + * @param responseCode The billing error code + */ + fun transitionToFailed(responseCode: Int) { + _state.update { current -> + val cachedProducts = current.products + + val (currentRetry, totalRetry) = when (current) { + is ProductLoadingState.Loading -> { + // Keep retry counts from loading state + current.currentRetryCount to current.totalRetryCount + } + is ProductLoadingState.Failed -> { + // Keep retry counts from previous failed state + current.currentRetryCount to current.totalRetryCount + } + else -> { + // No retry context + 0 to 0 + } + } + + ProductLoadingState.Failed( + responseCode = responseCode, + cachedProducts = cachedProducts, + currentRetryCount = currentRetry, + totalRetryCount = totalRetry, + respondedWithCallback = false + ) + } + } + + /** + * Mark the current state as having responded to callbacks. + * Only applicable to Success and Failed states. + */ + fun markAsResponded() { + _state.update { current -> + when (current) { + is ProductLoadingState.Success -> current.copy(respondedWithCallback = true) + is ProductLoadingState.Failed -> current.copy(respondedWithCallback = true) + else -> current + } + } + } + + /** + * Rollback retry counters in Loading state. + * Used when a request didn't actually happen (e.g., APPHUD_NO_REQUEST) + * but transitionToLoading() was already called and incremented counters. + */ + fun rollbackRetryCounters() { + _state.update { current -> + when (current) { + is ProductLoadingState.Loading -> { + current.copy( + currentRetryCount = maxOf(0, current.currentRetryCount - 1), + totalRetryCount = maxOf(0, current.totalRetryCount - 1) + ) + } + else -> current + } + } + } + + /** + * Reset state to Idle. + */ + fun reset() { + _state.update { ProductLoadingState.Idle } + } +} diff --git a/sdk/src/main/java/com/apphud/sdk/internal/data/UserRepository.kt b/sdk/src/main/java/com/apphud/sdk/internal/data/UserRepository.kt index d5fb472c..1f915acb 100644 --- a/sdk/src/main/java/com/apphud/sdk/internal/data/UserRepository.kt +++ b/sdk/src/main/java/com/apphud/sdk/internal/data/UserRepository.kt @@ -22,10 +22,24 @@ internal class UserRepository( @Synchronized fun setCurrentUser(user: ApphudUser): Boolean { val userIdChanged = currentUser?.userId != user.userId - currentUser = user - if (user.isTemporary != true) { - dataSource.saveUser(user) + // Preserve paywalls/placements if server returned empty ones. + // This happens when /subscriptions endpoint is called (purchase verification) + // which doesn't return paywalls, only subscription data. + val existing = currentUser + val mergedUser = if (user.paywalls.isEmpty() && existing?.paywalls?.isNotEmpty() == true) { + user.copy( + paywalls = existing.paywalls, + placements = existing.placements + ) + } else { + user + } + + currentUser = mergedUser + + if (mergedUser.isTemporary != true) { + dataSource.saveUser(mergedUser) } return userIdChanged diff --git a/sdk/src/main/java/com/apphud/sdk/internal/data/network/PrettyHttpLoggingInterceptor.kt b/sdk/src/main/java/com/apphud/sdk/internal/data/network/PrettyHttpLoggingInterceptor.kt new file mode 100644 index 00000000..1d274774 --- /dev/null +++ b/sdk/src/main/java/com/apphud/sdk/internal/data/network/PrettyHttpLoggingInterceptor.kt @@ -0,0 +1,86 @@ +package com.apphud.sdk.internal.data.network + +import com.apphud.sdk.ApphudLog +import com.apphud.sdk.ApphudUtils +import okhttp3.Interceptor +import okhttp3.Response +import okio.Buffer +import java.nio.charset.StandardCharsets +import java.util.UUID + +/** + * HTTP logging interceptor that logs requests and responses with pretty-printed JSON bodies. + * + * Each request-response pair is tagged with a unique 8-character traceId (e.g., `[a1b2c3d4]`) + * to easily correlate requests with their corresponding responses in logs. + * + * Example output: + * ``` + * [a1b2c3d4] Start POST request https://api.apphud.com/v2/customers with params: + * { + * "device_id": "xxx" + * } + * + * [a1b2c3d4] Finished POST request https://api.apphud.com/v2/customers with response: 200 + * { + * "data": {...} + * } + * ``` + */ +internal class PrettyHttpLoggingInterceptor( + private val prettyJsonFormatter: PrettyJsonFormatter +) : Interceptor { + + override fun intercept(chain: Interceptor.Chain): Response { + val request = chain.request() + + if (!ApphudUtils.httpLogging) { + return chain.proceed(request) + } + + val traceId = UUID.randomUUID().toString().take(8) + logRequestStart(request, traceId) + val response = chain.proceed(request) + return logResponse(request, response, traceId) + } + + private fun logRequestStart(request: okhttp3.Request, traceId: String) { + val method = request.method + val url = request.url + + val requestBody = request.body + val bodyString = if (requestBody != null) { + Buffer().use { buffer -> + requestBody.writeTo(buffer) + buffer.readString(StandardCharsets.UTF_8) + } + } else { + null + } + + val prettyBody = prettyJsonFormatter.format(bodyString) + val bodyPart = if (prettyBody != null) " with params:\n$prettyBody" else "" + + ApphudLog.logI("[$traceId] Start $method request $url$bodyPart") + } + + private fun logResponse(request: okhttp3.Request, response: Response, traceId: String): Response { + val method = request.method + val url = request.url + val code = response.code + + val responseBody = response.body + val source = responseBody?.source() + source?.request(Long.MAX_VALUE) + val bodyString = source?.buffer?.clone()?.use { clonedBuffer -> + clonedBuffer.readString(StandardCharsets.UTF_8) + } + + val prettyBody = prettyJsonFormatter.format(bodyString) + val bodyPart = if (prettyBody != null) "\n$prettyBody" else "" + + ApphudLog.logI("[$traceId] Finished $method request $url with response: $code$bodyPart") + + return response + } +} diff --git a/sdk/src/main/java/com/apphud/sdk/internal/data/network/PrettyJsonFormatter.kt b/sdk/src/main/java/com/apphud/sdk/internal/data/network/PrettyJsonFormatter.kt new file mode 100644 index 00000000..4a1208da --- /dev/null +++ b/sdk/src/main/java/com/apphud/sdk/internal/data/network/PrettyJsonFormatter.kt @@ -0,0 +1,17 @@ +package com.apphud.sdk.internal.data.network + +import com.google.gson.Gson +import com.google.gson.JsonParser + +internal class PrettyJsonFormatter(private val prettyGson: Gson) { + + fun format(json: String?): String? { + if (json.isNullOrBlank()) return null + return try { + val jsonElement = JsonParser.parseString(json) + prettyGson.toJson(jsonElement) + } catch (_: Exception) { + json + } + } +} diff --git a/sdk/src/test/java/com/apphud/sdk/internal/data/ProductLoadingStateTest.kt b/sdk/src/test/java/com/apphud/sdk/internal/data/ProductLoadingStateTest.kt new file mode 100644 index 00000000..a759bb1d --- /dev/null +++ b/sdk/src/test/java/com/apphud/sdk/internal/data/ProductLoadingStateTest.kt @@ -0,0 +1,447 @@ +package com.apphud.sdk.internal.data + +import com.android.billingclient.api.BillingClient +import com.android.billingclient.api.ProductDetails +import io.mockk.every +import io.mockk.mockk +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test + +class ProductLoadingStateTest { + + private lateinit var mockProduct1: ProductDetails + private lateinit var mockProduct2: ProductDetails + + @Before + fun setup() { + mockProduct1 = mockk(relaxed = true) + mockProduct2 = mockk(relaxed = true) + + every { mockProduct1.productId } returns "product-1" + every { mockProduct2.productId } returns "product-2" + } + + // ======================================== + // Idle State Tests + // ======================================== + + @Test + fun `Idle state should have correct default values`() { + val state = ProductLoadingState.Idle + + assertFalse("Should not be finished", state.isFinished) + assertFalse("Should not have responded", state.hasRespondedWithCallback) + assertFalse("Should not be loading", state.isLoading) + assertTrue("Should have empty products", state.products.isEmpty()) + } + + // ======================================== + // Loading State Tests + // ======================================== + + @Test + fun `Loading state should have correct default values`() { + val state = ProductLoadingState.Loading() + + assertTrue("Should be loading", state.isLoading) + assertFalse("Should not be finished", state.isFinished) + assertFalse("Should not have responded", state.hasRespondedWithCallback) + assertTrue("Should have empty products", state.products.isEmpty()) + assertEquals("Should have 0 current retry count", 0, state.currentRetryCount) + assertEquals("Should have 0 total retry count", 0, state.totalRetryCount) + } + + @Test + fun `Loading state should store retry counts`() { + val state = ProductLoadingState.Loading( + currentRetryCount = 2, + totalRetryCount = 5 + ) + + assertEquals("Should have current retry count 2", 2, state.currentRetryCount) + assertEquals("Should have total retry count 5", 5, state.totalRetryCount) + } + + @Test + fun `Loading state should store previous products`() { + val previousProducts = listOf(mockProduct1, mockProduct2) + val state = ProductLoadingState.Loading( + currentRetryCount = 1, + totalRetryCount = 2, + previousProducts = previousProducts + ) + + assertEquals("Should have 2 previous products", 2, state.previousProducts.size) + assertEquals("Should have correct previous products", previousProducts, state.previousProducts) + } + + @Test + fun `Loading state products property should return previousProducts`() { + val previousProducts = listOf(mockProduct1) + val state = ProductLoadingState.Loading(previousProducts = previousProducts) + + assertEquals("products should return previousProducts", previousProducts, state.products) + } + + // ======================================== + // Success State Tests + // ======================================== + + @Test + fun `Success state should have correct values`() { + val products = listOf(mockProduct1, mockProduct2) + val state = ProductLoadingState.Success( + loadedProducts = products, + loadTimeMs = 1500L, + respondedWithCallback = false + ) + + assertTrue("Should be finished", state.isFinished) + assertFalse("Should be loading", state.isLoading) + assertFalse("Should not have responded initially", state.hasRespondedWithCallback) + assertEquals("Should have 2 products", 2, state.products.size) + assertEquals("Should have correct load time", 1500L, state.loadTimeMs) + } + + @Test + fun `Success state with respondedWithCallback true`() { + val products = listOf(mockProduct1) + val state = ProductLoadingState.Success( + loadedProducts = products, + respondedWithCallback = true + ) + + assertTrue("Should have responded", state.hasRespondedWithCallback) + } + + @Test(expected = IllegalArgumentException::class) + fun `Success state should reject empty products list`() { + ProductLoadingState.Success( + loadedProducts = emptyList(), + loadTimeMs = null, + respondedWithCallback = false + ) + } + + @Test + fun `Success state should allow copy with respondedWithCallback`() { + val products = listOf(mockProduct1) + val state1 = ProductLoadingState.Success(loadedProducts = products, respondedWithCallback = false) + val state2 = state1.copy(respondedWithCallback = true) + + assertFalse("Original should not have responded", state1.respondedWithCallback) + assertTrue("Copy should have responded", state2.respondedWithCallback) + assertEquals("Should have same products", state1.loadedProducts, state2.loadedProducts) + } + + // ======================================== + // Failed State Tests + // ======================================== + + @Test + fun `Failed state should have correct values`() { + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.SERVICE_UNAVAILABLE, + cachedProducts = emptyList(), + currentRetryCount = 1, + totalRetryCount = 3, + respondedWithCallback = false + ) + + assertTrue("Should be finished", state.isFinished) + assertFalse("Should be loading", state.isLoading) + assertFalse("Should not have responded", state.hasRespondedWithCallback) + assertTrue("Should have empty cached products", state.products.isEmpty()) + assertEquals("Should have response code SERVICE_UNAVAILABLE", + BillingClient.BillingResponseCode.SERVICE_UNAVAILABLE, state.responseCode) + assertEquals("Should have current retry count 1", 1, state.currentRetryCount) + assertEquals("Should have total retry count 3", 3, state.totalRetryCount) + } + + @Test + fun `Failed state with cached products`() { + val cachedProducts = listOf(mockProduct1) + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.SERVICE_UNAVAILABLE, + cachedProducts = cachedProducts + ) + + assertEquals("Should have 1 cached product", 1, state.products.size) + assertEquals("Should have correct product", mockProduct1, state.products[0]) + } + + // ======================================== + // Failed State isRetriable Tests + // ======================================== + + @Test + fun `Failed state isRetriable should be true for network error with no cached products`() { + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.NETWORK_ERROR, + cachedProducts = emptyList(), + currentRetryCount = 0, + totalRetryCount = 0 + ) + + assertTrue("Should be retriable for NETWORK_ERROR", state.isRetriable) + } + + @Test + fun `Failed state isRetriable should be true for service timeout`() { + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.SERVICE_TIMEOUT, + cachedProducts = emptyList(), + currentRetryCount = 0, + totalRetryCount = 0 + ) + + assertTrue("Should be retriable for SERVICE_TIMEOUT", state.isRetriable) + } + + @Test + fun `Failed state isRetriable should be true for service disconnected`() { + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.SERVICE_DISCONNECTED, + cachedProducts = emptyList(), + currentRetryCount = 0, + totalRetryCount = 0 + ) + + assertTrue("Should be retriable for SERVICE_DISCONNECTED", state.isRetriable) + } + + @Test + fun `Failed state isRetriable should be true for service unavailable`() { + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.SERVICE_UNAVAILABLE, + cachedProducts = emptyList(), + currentRetryCount = 0, + totalRetryCount = 0 + ) + + assertTrue("Should be retriable for SERVICE_UNAVAILABLE", state.isRetriable) + } + + @Test + fun `Failed state isRetriable should be true for billing unavailable`() { + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.BILLING_UNAVAILABLE, + cachedProducts = emptyList(), + currentRetryCount = 0, + totalRetryCount = 0 + ) + + assertTrue("Should be retriable for BILLING_UNAVAILABLE", state.isRetriable) + } + + @Test + fun `Failed state isRetriable should be true for generic error`() { + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.ERROR, + cachedProducts = emptyList(), + currentRetryCount = 0, + totalRetryCount = 0 + ) + + assertTrue("Should be retriable for ERROR", state.isRetriable) + } + + @Test + fun `Failed state isRetriable should be false for non-retriable error codes`() { + val nonRetriableErrors = listOf( + BillingClient.BillingResponseCode.OK, + BillingClient.BillingResponseCode.USER_CANCELED, + BillingClient.BillingResponseCode.ITEM_ALREADY_OWNED, + BillingClient.BillingResponseCode.ITEM_NOT_OWNED, + BillingClient.BillingResponseCode.DEVELOPER_ERROR, + BillingClient.BillingResponseCode.ITEM_UNAVAILABLE, + BillingClient.BillingResponseCode.FEATURE_NOT_SUPPORTED + ) + + nonRetriableErrors.forEach { errorCode -> + val state = ProductLoadingState.Failed( + responseCode = errorCode, + cachedProducts = emptyList(), + currentRetryCount = 0, + totalRetryCount = 0 + ) + + assertFalse("Should not be retriable for error code $errorCode", state.isRetriable) + } + } + + @Test + fun `Failed state isRetriable should be false when cached products exist`() { + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.NETWORK_ERROR, + cachedProducts = listOf(mockProduct1), + currentRetryCount = 0, + totalRetryCount = 0 + ) + + assertFalse("Should not be retriable when cached products exist", state.isRetriable) + } + + @Test + fun `Failed state isRetriable should be false when current retry count exceeded`() { + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.NETWORK_ERROR, + cachedProducts = emptyList(), + currentRetryCount = 3, // APPHUD_DEFAULT_RETRIES = 3 + totalRetryCount = 5 + ) + + assertFalse("Should not be retriable when current retry count >= 3", state.isRetriable) + } + + @Test + fun `Failed state isRetriable should be false when total retry count exceeded`() { + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.NETWORK_ERROR, + cachedProducts = emptyList(), + currentRetryCount = 0, + totalRetryCount = 100 // MAX_TOTAL_PRODUCTS_RETRIES = 100 + ) + + assertFalse("Should not be retriable when total retry count >= 100", state.isRetriable) + } + + @Test + fun `Failed state isRetriable should be true just below retry limits`() { + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.NETWORK_ERROR, + cachedProducts = emptyList(), + currentRetryCount = 2, // Just below APPHUD_DEFAULT_RETRIES (3) + totalRetryCount = 99 // Just below MAX_TOTAL_PRODUCTS_RETRIES (100) + ) + + assertTrue("Should be retriable just below limits", state.isRetriable) + } + + // ======================================== + // State Properties Tests + // ======================================== + + @Test + fun `isFinished should be false for Idle`() { + val state = ProductLoadingState.Idle + assertFalse("Idle should not be finished", state.isFinished) + } + + @Test + fun `isFinished should be false for Loading`() { + val state = ProductLoadingState.Loading() + assertFalse("Loading should not be finished", state.isFinished) + } + + @Test + fun `isFinished should be true for Success`() { + val state = ProductLoadingState.Success(loadedProducts = listOf(mockProduct1)) + assertTrue("Success should be finished", state.isFinished) + } + + @Test + fun `isFinished should be true for Failed`() { + val state = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.ERROR + ) + assertTrue("Failed should be finished", state.isFinished) + } + + @Test + fun `products property should return correct list for all states`() { + val successProducts = listOf(mockProduct1, mockProduct2) + val cachedProducts = listOf(mockProduct1) + val previousProducts = listOf(mockProduct2) + + val idle = ProductLoadingState.Idle + val loadingEmpty = ProductLoadingState.Loading() + val loadingWithPrevious = ProductLoadingState.Loading(previousProducts = previousProducts) + val success = ProductLoadingState.Success(loadedProducts = successProducts) + val failedWithCache = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.ERROR, + cachedProducts = cachedProducts + ) + val failedNoCache = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.ERROR, + cachedProducts = emptyList() + ) + + assertTrue("Idle should have empty products", idle.products.isEmpty()) + assertTrue("Loading without previous should have empty products", loadingEmpty.products.isEmpty()) + assertEquals("Loading with previous should have 1 product", 1, loadingWithPrevious.products.size) + assertEquals("Loading should return previousProducts", previousProducts, loadingWithPrevious.products) + assertEquals("Success should have 2 products", 2, success.products.size) + assertEquals("Failed with cache should have 1 product", 1, failedWithCache.products.size) + assertTrue("Failed no cache should have empty products", failedNoCache.products.isEmpty()) + } + + @Test + fun `hasRespondedWithCallback should return correct values`() { + val products = listOf(mockProduct1) + + val idle = ProductLoadingState.Idle + val loading = ProductLoadingState.Loading() + val successNotResponded = ProductLoadingState.Success(loadedProducts = products, respondedWithCallback = false) + val successResponded = ProductLoadingState.Success(loadedProducts = products, respondedWithCallback = true) + val failedNotResponded = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.ERROR, + respondedWithCallback = false + ) + val failedResponded = ProductLoadingState.Failed( + responseCode = BillingClient.BillingResponseCode.ERROR, + respondedWithCallback = true + ) + + assertFalse("Idle should not have responded", idle.hasRespondedWithCallback) + assertFalse("Loading should not have responded", loading.hasRespondedWithCallback) + assertFalse("Success not responded", successNotResponded.hasRespondedWithCallback) + assertTrue("Success responded", successResponded.hasRespondedWithCallback) + assertFalse("Failed not responded", failedNotResponded.hasRespondedWithCallback) + assertTrue("Failed responded", failedResponded.hasRespondedWithCallback) + } + + @Test + fun `isLoading should only be true for Loading state`() { + val products = listOf(mockProduct1) + + val idle = ProductLoadingState.Idle + val loading = ProductLoadingState.Loading() + val success = ProductLoadingState.Success(loadedProducts = products) + val failed = ProductLoadingState.Failed(responseCode = BillingClient.BillingResponseCode.ERROR) + + assertFalse("Idle should not be loading", idle.isLoading) + assertTrue("Loading should be loading", loading.isLoading) + assertFalse("Success should not be loading", success.isLoading) + assertFalse("Failed should not be loading", failed.isLoading) + } + + // ======================================== + // Type Safety Tests (Exhaustive When) + // ======================================== + + @Test + fun `when expression should be exhaustive for all states`() { + val states: List = listOf( + ProductLoadingState.Idle, + ProductLoadingState.Loading(1, 5), + ProductLoadingState.Success(loadedProducts = listOf(mockProduct1)), + ProductLoadingState.Failed(BillingClient.BillingResponseCode.ERROR) + ) + + states.forEach { state -> + // This when expression must be exhaustive + val result = when (state) { + is ProductLoadingState.Idle -> "idle" + is ProductLoadingState.Loading -> "loading" + is ProductLoadingState.Success -> "success" + is ProductLoadingState.Failed -> "failed" + } + + assertTrue("Result should be non-empty", result.isNotEmpty()) + } + } +} diff --git a/sdk/src/test/java/com/apphud/sdk/internal/data/ProductRepositoryTest.kt b/sdk/src/test/java/com/apphud/sdk/internal/data/ProductRepositoryTest.kt new file mode 100644 index 00000000..183c6cba --- /dev/null +++ b/sdk/src/test/java/com/apphud/sdk/internal/data/ProductRepositoryTest.kt @@ -0,0 +1,502 @@ +package com.apphud.sdk.internal.data + +import com.android.billingclient.api.BillingClient +import com.android.billingclient.api.ProductDetails +import io.mockk.every +import io.mockk.mockk +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test + +class ProductRepositoryTest { + + private lateinit var repository: ProductRepository + private val mockProduct1: ProductDetails = mockk(relaxed = true) + private val mockProduct2: ProductDetails = mockk(relaxed = true) + private val mockProduct3: ProductDetails = mockk(relaxed = true) + + @Before + fun setup() { + repository = ProductRepository() + + every { mockProduct1.productId } returns "product-1" + every { mockProduct2.productId } returns "product-2" + every { mockProduct3.productId } returns "product-3" + } + + @Test + fun `getProducts should return empty list initially`() { + val result = repository.state.value.products + + assertTrue("Should return empty list initially", result.isEmpty()) + } + + @Test + fun `getProducts should return products from state`() { + val products = listOf(mockProduct1, mockProduct2) + repository.transitionToSuccess(products) + + val result = repository.state.value.products + + assertEquals("Should have 2 products", 2, result.size) + assertTrue("Should contain product 1", result.any { it.productId == "product-1" }) + assertTrue("Should contain product 2", result.any { it.productId == "product-2" }) + } + + // ======================================== + // Sealed Class State Tests + // ======================================== + + @Test + fun `getState should return Idle initially`() { + val state = repository.state.value + + assertTrue("Should be Idle state", state is ProductLoadingState.Idle) + } + + @Test + fun `transitionToLoading should update state from Idle`() { + repository.transitionToLoading() + + val state = repository.state.value + assertTrue("Should be Loading state", state is ProductLoadingState.Loading) + assertEquals("Should have current retry 0", 0, (state as ProductLoadingState.Loading).currentRetryCount) + assertEquals("Should have total retry 0", 0, state.totalRetryCount) + } + + @Test + fun `transitionToLoading should increment retries from Failed state`() { + // Create a real scenario: Idle -> Loading -> Failed -> Loading (retry) + repository.transitionToLoading() + repository.transitionToFailed(BillingClient.BillingResponseCode.NETWORK_ERROR) + + val failedState = repository.state.value as ProductLoadingState.Failed + assertEquals("Should have 0 current retry", 0, failedState.currentRetryCount) + assertEquals("Should have 0 total retry", 0, failedState.totalRetryCount) + + // Now retry - should increment + repository.transitionToLoading() + + val state = repository.state.value + assertTrue("Should be Loading state", state is ProductLoadingState.Loading) + assertEquals("Should increment current retry to 1", 1, (state as ProductLoadingState.Loading).currentRetryCount) + assertEquals("Should increment total retry to 1", 1, state.totalRetryCount) + } + + @Test + fun `transitionToLoading should preserve previous products`() { + val products = listOf(mockProduct1, mockProduct2) + repository.transitionToSuccess(products) + + repository.transitionToLoading() + + val state = repository.state.value + assertTrue("Should be Loading state", state is ProductLoadingState.Loading) + assertEquals("Should have preserved 2 products", 2, (state as ProductLoadingState.Loading).previousProducts.size) + assertEquals("getProducts should return previous products during loading", 2, repository.state.value.products.size) + } + + @Test + fun `transitionToSuccess should update state and products`() { + val products = listOf(mockProduct1, mockProduct2) + repository.transitionToSuccess(products, loadTimeMs = 1500L) + + val state = repository.state.value + assertTrue("Should be Success state", state is ProductLoadingState.Success) + assertEquals("Should have 2 products", 2, (state as ProductLoadingState.Success).loadedProducts.size) + assertEquals("Should have load time 1500ms", 1500L, state.loadTimeMs) + assertFalse("Should not be responded yet", state.respondedWithCallback) + assertEquals("getProducts should return products from state", 2, repository.state.value.products.size) + } + + @Test + fun `transitionToFailed should derive cached products from state`() { + val products = listOf(mockProduct1) + repository.transitionToSuccess(products) + + repository.transitionToFailed(BillingClient.BillingResponseCode.SERVICE_UNAVAILABLE) + + val state = repository.state.value + assertTrue("Should be Failed state", state is ProductLoadingState.Failed) + + val failedState = state as ProductLoadingState.Failed + assertEquals("Should have SERVICE_UNAVAILABLE code", BillingClient.BillingResponseCode.SERVICE_UNAVAILABLE, failedState.responseCode) + assertEquals("Should have 1 cached product from previous state", 1, failedState.cachedProducts.size) + assertEquals("Should have current retry 0", 0, failedState.currentRetryCount) + assertEquals("Should have total retry 0", 0, failedState.totalRetryCount) + assertFalse("Should not be responded yet", failedState.respondedWithCallback) + } + + @Test + fun `transitionToFailed should preserve retry counts from Loading state`() { + // Create a real scenario: Idle -> Loading -> Failed -> Loading (retry) + repository.transitionToLoading() + repository.transitionToFailed(BillingClient.BillingResponseCode.NETWORK_ERROR) + repository.transitionToLoading() + + val loadingState = repository.state.value as ProductLoadingState.Loading + assertEquals("Should have retry 1", 1, loadingState.currentRetryCount) + assertEquals("Should have total 1", 1, loadingState.totalRetryCount) + + // Now fail while in Loading - should preserve retry counts + repository.transitionToFailed(BillingClient.BillingResponseCode.NETWORK_ERROR) + + val state = repository.state.value + assertTrue("Should be Failed state", state is ProductLoadingState.Failed) + + val failedState = state as ProductLoadingState.Failed + assertEquals("Should preserve current retry 1", 1, failedState.currentRetryCount) + assertEquals("Should preserve total retry 1", 1, failedState.totalRetryCount) + } + + @Test + fun `markAsResponded should update Success state`() { + val products = listOf(mockProduct1) + repository.transitionToSuccess(products) + + repository.markAsResponded() + + val state = repository.state.value + assertTrue("Should be Success state", state is ProductLoadingState.Success) + assertTrue("Should be responded", (state as ProductLoadingState.Success).respondedWithCallback) + } + + @Test + fun `markAsResponded should update Failed state`() { + repository.transitionToFailed(BillingClient.BillingResponseCode.ERROR) + + repository.markAsResponded() + + val state = repository.state.value + assertTrue("Should be Failed state", state is ProductLoadingState.Failed) + assertTrue("Should be responded", (state as ProductLoadingState.Failed).respondedWithCallback) + } + + @Test + fun `markAsResponded should not affect Loading state`() { + repository.transitionToLoading() + + repository.markAsResponded() + + val state = repository.state.value + assertTrue("Should still be Loading state", state is ProductLoadingState.Loading) + assertFalse("Should not have responded", state.hasRespondedWithCallback) + } + + @Test + fun `reset should transition to Idle`() { + repository.transitionToSuccess(listOf(mockProduct1)) + + repository.reset() + + val state = repository.state.value + assertTrue("Should be Idle state", state is ProductLoadingState.Idle) + } + + @Test + fun `complete product loading lifecycle with MVI pattern`() { + // Initial state + var state = repository.state.value + assertTrue("Should start as Idle", state is ProductLoadingState.Idle) + + // Start loading (first attempt) - derives 0,0 from Idle + repository.transitionToLoading() + state = repository.state.value + assertTrue("Should be Loading", state is ProductLoadingState.Loading) + assertEquals("Should have 0 retries", 0, (state as ProductLoadingState.Loading).currentRetryCount) + assertEquals("Should have 0 total retries", 0, state.totalRetryCount) + + // Simulate failure - derives retry counts from Loading + repository.transitionToFailed(BillingClient.BillingResponseCode.NETWORK_ERROR) + state = repository.state.value + assertTrue("Should be Failed", state is ProductLoadingState.Failed) + assertTrue("Should be retriable", (state as ProductLoadingState.Failed).isRetriable) + assertEquals("Should have 0 retries from Loading", 0, state.currentRetryCount) + assertEquals("Should have 0 total retries from Loading", 0, state.totalRetryCount) + + // Retry loading - increments retry counts from Failed + repository.transitionToLoading() + state = repository.state.value + assertTrue("Should be Loading again", state is ProductLoadingState.Loading) + assertEquals("Should increment to retry 1", 1, (state as ProductLoadingState.Loading).currentRetryCount) + assertEquals("Should increment to total 1", 1, state.totalRetryCount) + + // Success + val products = listOf(mockProduct1, mockProduct2) + repository.transitionToSuccess(products, loadTimeMs = 2000L) + state = repository.state.value + assertTrue("Should be Success", state is ProductLoadingState.Success) + assertEquals("Should have 2 products", 2, (state as ProductLoadingState.Success).loadedProducts.size) + + // Mark as responded + repository.markAsResponded() + state = repository.state.value + assertTrue("Should still be Success", state is ProductLoadingState.Success) + assertTrue("Should be responded", (state as ProductLoadingState.Success).respondedWithCallback) + } + + @Test + fun `transitionToSuccess should preserve existing products when adding new ones`() { + // First load: [product1, product2] + val firstBatch = listOf(mockProduct1, mockProduct2) + repository.transitionToSuccess(firstBatch) + + var state = repository.state.value + assertEquals("Should have 2 products", 2, state.products.size) + + // Second load: add product3, should keep product1 and product2 + // This simulates incremental loading in fetchDetails + val existingProducts = repository.state.value.products + val newProducts = listOf(mockProduct3) + val allProducts = (existingProducts + newProducts).distinctBy { it.productId } + + repository.transitionToSuccess(allProducts) + + state = repository.state.value + assertEquals("Should have 3 products total", 3, state.products.size) + assertTrue("Should contain product 1", state.products.any { it.productId == "product-1" }) + assertTrue("Should contain product 2", state.products.any { it.productId == "product-2" }) + assertTrue("Should contain product 3", state.products.any { it.productId == "product-3" }) + } + + @Test + fun `state should preserve all products when already loaded products are requested again`() { + // БАГ #5: Regression test for partial reload not losing existing products + // Setup: state has [product1, product2, product3] + val allProducts = listOf(mockProduct1, mockProduct2, mockProduct3) + repository.transitionToSuccess(allProducts) + + val stateBefore = repository.state.value + assertEquals("Should have 3 products", 3, stateBefore.products.size) + + // In fetchDetails: when existingIds=[product1,product2,product3] and idsToFetch=[] + // (i.e., all requested IDs are already loaded) + // The function should return early WITHOUT calling transitionToSuccess + // Because calling transitionToSuccess with subset would replace all products with just that subset + + // Verify state is unchanged (this simulates the early return in fetchDetails) + val stateAfter = repository.state.value + assertEquals("State should not change when products already loaded", 3, stateAfter.products.size) + assertTrue("Should contain product1", stateAfter.products.any { it.productId == "product-1" }) + assertTrue("Should contain product2", stateAfter.products.any { it.productId == "product-2" }) + assertTrue("Should contain product3", stateAfter.products.any { it.productId == "product-3" }) + } + + @Test + fun `transitionToSuccess should always merge new products with existing ones`() { + // Initial load: [product1, product2] + repository.transitionToSuccess(listOf(mockProduct1, mockProduct2)) + + var state = repository.state.value + assertEquals("Should have 2 products initially", 2, state.products.size) + + // Add product3 - should automatically merge with existing products + repository.transitionToSuccess(listOf(mockProduct3), loadTimeMs = 1000L) + + state = repository.state.value + assertTrue("Should be Success state", state is ProductLoadingState.Success) + assertEquals("Should have 3 products after merge", 3, state.products.size) + assertEquals("Should have load time", 1000L, (state as ProductLoadingState.Success).loadTimeMs) + assertTrue("Should contain product 1", state.products.any { it.productId == "product-1" }) + assertTrue("Should contain product 2", state.products.any { it.productId == "product-2" }) + assertTrue("Should contain product 3", state.products.any { it.productId == "product-3" }) + } + + @Test + fun `transitionToSuccess should not duplicate existing products`() { + // Initial load: [product1, product2] + repository.transitionToSuccess(listOf(mockProduct1, mockProduct2)) + + // Try to add product1 again (duplicate) + product3 + repository.transitionToSuccess(listOf(mockProduct1, mockProduct3)) + + val state = repository.state.value + assertEquals("Should have 3 unique products (no duplicates)", 3, state.products.size) + assertTrue("Should contain product 1", state.products.any { it.productId == "product-1" }) + assertTrue("Should contain product 2", state.products.any { it.productId == "product-2" }) + assertTrue("Should contain product 3", state.products.any { it.productId == "product-3" }) + } + + @Test + fun `transitionToSuccess after reset should replace all products`() { + // Initial load: [product1, product2] + repository.transitionToSuccess(listOf(mockProduct1, mockProduct2)) + + var state = repository.state.value + assertEquals("Should have 2 products initially", 2, state.products.size) + + // Reset and load only product3 + repository.reset() + repository.transitionToSuccess(listOf(mockProduct3), loadTimeMs = 2000L) + + state = repository.state.value + assertTrue("Should be Success state", state is ProductLoadingState.Success) + assertEquals("Should have 1 product after reset+load", 1, state.products.size) + assertEquals("Should have load time", 2000L, (state as ProductLoadingState.Success).loadTimeMs) + assertTrue("Should contain only product 3", state.products.any { it.productId == "product-3" }) + } + + @Test + fun `transitionToSuccess should work from Idle state`() { + // Start from Idle - no existing products + val state = repository.state.value + assertTrue("Should be Idle initially", state is ProductLoadingState.Idle) + + // Add products (automatic merge with empty state) + repository.transitionToSuccess(listOf(mockProduct1, mockProduct2), loadTimeMs = 500L) + + val newState = repository.state.value + assertTrue("Should be Success state", newState is ProductLoadingState.Success) + assertEquals("Should have 2 products", 2, newState.products.size) + assertEquals("Should have load time", 500L, (newState as ProductLoadingState.Success).loadTimeMs) + } + + @Test + fun `transitionToSuccess should preserve existing products from Failed state`() { + // Setup: Success with products, then transition to Failed + repository.transitionToSuccess(listOf(mockProduct1, mockProduct2)) + repository.transitionToFailed(BillingClient.BillingResponseCode.NETWORK_ERROR) + + val failedState = repository.state.value + assertTrue("Should be Failed state", failedState is ProductLoadingState.Failed) + assertEquals("Should have 2 cached products", 2, failedState.products.size) + + // Add new product - should automatically merge with cached products + repository.transitionToSuccess(listOf(mockProduct3)) + + val state = repository.state.value + assertTrue("Should be Success state", state is ProductLoadingState.Success) + assertEquals("Should have 3 products", 3, state.products.size) + assertTrue("Should contain product 1", state.products.any { it.productId == "product-1" }) + assertTrue("Should contain product 2", state.products.any { it.productId == "product-2" }) + assertTrue("Should contain product 3", state.products.any { it.productId == "product-3" }) + } + + @Test + fun `transitionToSuccess should replace old product with new one when productId matches`() { + val oldProduct: ProductDetails = mockk(relaxed = true) + val newProduct: ProductDetails = mockk(relaxed = true) + + // Same productId but different instances + every { oldProduct.productId } returns "product-1" + every { newProduct.productId } returns "product-1" + + // Load old version + repository.transitionToSuccess(listOf(oldProduct)) + + var state = repository.state.value + assertEquals("Should have 1 product", 1, state.products.size) + assertTrue("Should be old product instance", state.products[0] === oldProduct) + + // Load new version (same productId) - should replace old one + repository.transitionToSuccess(listOf(newProduct)) + + state = repository.state.value + assertEquals("Should still have 1 product (no duplicate)", 1, state.products.size) + assertTrue("Should be NEW product instance", state.products[0] === newProduct) + assertFalse("Should NOT be old product instance", state.products[0] === oldProduct) + } + + @Test + fun `transitionToSuccess should update product and preserve others`() { + val oldProduct1: ProductDetails = mockk(relaxed = true) + val newProduct1: ProductDetails = mockk(relaxed = true) + + every { oldProduct1.productId } returns "product-1" + every { newProduct1.productId } returns "product-1" + + // Load [oldProduct1, product2] + repository.transitionToSuccess(listOf(oldProduct1, mockProduct2)) + + var state = repository.state.value + assertEquals("Should have 2 products", 2, state.products.size) + assertTrue("Should have old product1", state.products.any { it === oldProduct1 }) + + // Update product1, keep product2 + repository.transitionToSuccess(listOf(newProduct1)) + + state = repository.state.value + assertEquals("Should still have 2 products", 2, state.products.size) + assertTrue("Should have NEW product1", state.products.any { it === newProduct1 }) + assertTrue("Should preserve product2", state.products.any { it === mockProduct2 }) + assertFalse("Should NOT have old product1", state.products.any { it === oldProduct1 }) + } + + // ======================================== + // Rollback Retry Counters Tests + // ======================================== + + @Test + fun `rollbackRetryCounters should decrement counters in Loading state`() { + // Setup: Idle -> Loading -> Failed -> Loading (retry with incremented counters) + repository.transitionToLoading() + repository.transitionToFailed(BillingClient.BillingResponseCode.NETWORK_ERROR) + repository.transitionToLoading() + + val loadingState = repository.state.value as ProductLoadingState.Loading + assertEquals("Should have retry 1", 1, loadingState.currentRetryCount) + assertEquals("Should have total 1", 1, loadingState.totalRetryCount) + + // Rollback + repository.rollbackRetryCounters() + + val state = repository.state.value + assertTrue("Should still be Loading", state is ProductLoadingState.Loading) + assertEquals("Should decrement current retry to 0", 0, (state as ProductLoadingState.Loading).currentRetryCount) + assertEquals("Should decrement total retry to 0", 0, state.totalRetryCount) + } + + @Test + fun `rollbackRetryCounters should not decrement below zero`() { + // First attempt: Loading with 0,0 counters + repository.transitionToLoading() + + val loadingState = repository.state.value as ProductLoadingState.Loading + assertEquals("Should have 0 retries", 0, loadingState.currentRetryCount) + assertEquals("Should have 0 total retries", 0, loadingState.totalRetryCount) + + // Rollback should not go below 0 + repository.rollbackRetryCounters() + + val state = repository.state.value + assertTrue("Should still be Loading", state is ProductLoadingState.Loading) + assertEquals("Should stay at 0", 0, (state as ProductLoadingState.Loading).currentRetryCount) + assertEquals("Should stay at 0 total", 0, state.totalRetryCount) + } + + @Test + fun `rollbackRetryCounters should not affect other states`() { + // Test with Idle state + repository.rollbackRetryCounters() + assertTrue("Should still be Idle", repository.state.value is ProductLoadingState.Idle) + + // Test with Success state + repository.transitionToSuccess(listOf(mockProduct1)) + repository.rollbackRetryCounters() + assertTrue("Should still be Success", repository.state.value is ProductLoadingState.Success) + + // Test with Failed state + repository.transitionToFailed(BillingClient.BillingResponseCode.ERROR) + val failedStateBefore = repository.state.value as ProductLoadingState.Failed + repository.rollbackRetryCounters() + val failedStateAfter = repository.state.value as ProductLoadingState.Failed + assertEquals("Should not change Failed state counters", failedStateBefore.currentRetryCount, failedStateAfter.currentRetryCount) + } + + @Test + fun `rollbackRetryCounters should preserve previous products in Loading state`() { + // Setup with products + repository.transitionToSuccess(listOf(mockProduct1, mockProduct2)) + repository.transitionToFailed(BillingClient.BillingResponseCode.NETWORK_ERROR) + repository.transitionToLoading() + + // Rollback + repository.rollbackRetryCounters() + + val state = repository.state.value + assertTrue("Should still be Loading", state is ProductLoadingState.Loading) + assertEquals("Should preserve products", 2, (state as ProductLoadingState.Loading).previousProducts.size) + } +} diff --git a/sdk/src/test/java/com/apphud/sdk/internal/data/UserRepositoryTest.kt b/sdk/src/test/java/com/apphud/sdk/internal/data/UserRepositoryTest.kt index 09243929..fcce7246 100644 --- a/sdk/src/test/java/com/apphud/sdk/internal/data/UserRepositoryTest.kt +++ b/sdk/src/test/java/com/apphud/sdk/internal/data/UserRepositoryTest.kt @@ -1,5 +1,7 @@ package com.apphud.sdk.internal.data +import com.apphud.sdk.domain.ApphudPaywall +import com.apphud.sdk.domain.ApphudPlacement import com.apphud.sdk.domain.ApphudUser import io.mockk.every import io.mockk.mockk @@ -19,6 +21,43 @@ class UserRepositoryTest { private val mockUser2: ApphudUser = mockk(relaxed = true) private val temporaryUser: ApphudUser = mockk(relaxed = true) + private fun createTestPaywall(id: String = "paywall-1") = ApphudPaywall( + id = id, + name = "Test Paywall", + identifier = "test_paywall", + default = false, + json = null, + products = null, + screen = null, + experimentName = null, + variationName = null, + parentPaywallIdentifier = null, + placementIdentifier = null, + placementId = null + ) + + private fun createTestPlacement(id: String = "placement-1", paywall: ApphudPaywall? = null) = ApphudPlacement( + identifier = "test_placement", + paywall = paywall, + id = id + ) + + private fun createTestUser( + userId: String, + paywalls: List = emptyList(), + placements: List = emptyList(), + isTemporary: Boolean = false + ) = ApphudUser( + userId = userId, + currencyCode = null, + countryCode = null, + subscriptions = emptyList(), + purchases = emptyList(), + paywalls = paywalls, + placements = placements, + isTemporary = isTemporary + ) + @Before fun setup() { dataSource = mockk(relaxed = true) @@ -114,4 +153,92 @@ class UserRepositoryTest { every { dataSource.getCachedUser() } returns null assertNull("Should return null after clear", repository.getCurrentUser()) } + + @Test + fun `setCurrentUser should preserve paywalls when new user has empty paywalls`() { + val paywall = createTestPaywall() + val placement = createTestPlacement(paywall = paywall) + val userWithPaywalls = createTestUser( + userId = "user-1", + paywalls = listOf(paywall), + placements = listOf(placement) + ) + val userWithEmptyPaywalls = createTestUser( + userId = "user-1", + paywalls = emptyList(), + placements = emptyList() + ) + + repository.setCurrentUser(userWithPaywalls) + repository.setCurrentUser(userWithEmptyPaywalls) + val result = repository.getCurrentUser() + + assertEquals("Paywalls should be preserved", 1, result?.paywalls?.size) + assertEquals("Placements should be preserved", 1, result?.placements?.size) + assertEquals("Paywall id should match", paywall.id, result?.paywalls?.first()?.id) + } + + @Test + fun `setCurrentUser should replace paywalls when new user has non-empty paywalls`() { + val oldPaywall = createTestPaywall(id = "old-paywall") + val newPaywall = createTestPaywall(id = "new-paywall") + val userWithOldPaywalls = createTestUser( + userId = "user-1", + paywalls = listOf(oldPaywall), + placements = emptyList() + ) + val userWithNewPaywalls = createTestUser( + userId = "user-1", + paywalls = listOf(newPaywall), + placements = emptyList() + ) + + repository.setCurrentUser(userWithOldPaywalls) + repository.setCurrentUser(userWithNewPaywalls) + val result = repository.getCurrentUser() + + assertEquals("Should have 1 paywall", 1, result?.paywalls?.size) + assertEquals("Paywall should be replaced with new one", "new-paywall", result?.paywalls?.first()?.id) + } + + @Test + fun `setCurrentUser should not preserve paywalls when current user has none`() { + val userWithEmptyPaywalls1 = createTestUser( + userId = "user-1", + paywalls = emptyList(), + placements = emptyList() + ) + val userWithEmptyPaywalls2 = createTestUser( + userId = "user-1", + paywalls = emptyList(), + placements = emptyList() + ) + + repository.setCurrentUser(userWithEmptyPaywalls1) + repository.setCurrentUser(userWithEmptyPaywalls2) + val result = repository.getCurrentUser() + + assertTrue("Paywalls should remain empty", result?.paywalls?.isEmpty() == true) + } + + @Test + fun `setCurrentUser should save merged user to dataSource`() { + val paywall = createTestPaywall() + val placement = createTestPlacement(paywall = paywall) + val userWithPaywalls = createTestUser( + userId = "user-1", + paywalls = listOf(paywall), + placements = listOf(placement) + ) + val userWithEmptyPaywalls = createTestUser( + userId = "user-1", + paywalls = emptyList(), + placements = emptyList() + ) + + repository.setCurrentUser(userWithPaywalls) + repository.setCurrentUser(userWithEmptyPaywalls) + + verify(exactly = 2) { dataSource.saveUser(match { it.paywalls.size == 1 }) } + } }