diff --git a/stream-android-core/src/main/java/io/getstream/android/core/api/StreamClient.kt b/stream-android-core/src/main/java/io/getstream/android/core/api/StreamClient.kt index 6f55e41..bb8290e 100644 --- a/stream-android-core/src/main/java/io/getstream/android/core/api/StreamClient.kt +++ b/stream-android-core/src/main/java/io/getstream/android/core/api/StreamClient.kt @@ -35,7 +35,8 @@ import io.getstream.android.core.api.model.connection.lifecycle.StreamLifecycleS import io.getstream.android.core.api.model.connection.network.StreamNetworkState import io.getstream.android.core.api.observers.lifecycle.StreamLifecycleMonitor import io.getstream.android.core.api.observers.network.StreamNetworkMonitor -import io.getstream.android.core.api.processing.StreamBatcher +import io.getstream.android.core.api.processing.StreamEventAggregationPolicy +import io.getstream.android.core.api.processing.StreamEventAggregator import io.getstream.android.core.api.processing.StreamSerialProcessingQueue import io.getstream.android.core.api.processing.StreamSingleFlightProcessor import io.getstream.android.core.api.recovery.StreamConnectionRecoveryEvaluator @@ -249,14 +250,7 @@ public fun StreamClient( socketFactory = components.socketFactory ?: StreamWebSocketFactory(logger = logProvider.taggedLogger("SCWebSocketFactory")), - batcher = - components.batcher - ?: StreamBatcher( - scope = scope, - batchSize = socketConfig.batchSize, - initialDelayMs = socketConfig.batchInitialDelayMs, - maxDelayMs = socketConfig.batchMaxDelayMs, - ), + eventAggregator = components.eventAggregator, healthMonitor = components.healthMonitor ?: StreamHealthMonitor( @@ -320,13 +314,7 @@ internal fun createStreamClientInternal( connectionIdHolder: StreamConnectionIdHolder = StreamConnectionIdHolder(), socketFactory: StreamWebSocketFactory = StreamWebSocketFactory(logger = logProvider.taggedLogger("SCWebSocketFactory")), - batcher: StreamBatcher = - StreamBatcher( - scope = scope, - batchSize = socketConfig.batchSize, - initialDelayMs = socketConfig.batchInitialDelayMs, - maxDelayMs = socketConfig.batchMaxDelayMs, - ), + eventAggregator: StreamEventAggregator<*>? = null, // Monitoring healthMonitor: StreamHealthMonitor = @@ -427,6 +415,28 @@ internal fun createStreamClientInternal( ), ) + val eventParser = + StreamCompositeEventSerializationImpl( + internal = + serializationConfig.eventParser ?: StreamEventSerialization(compositeSerialization), + external = serializationConfig.productEventSerializers, + ) + + val resolvedAggregator = + eventAggregator + ?: StreamEventAggregator( + scope = clientScope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = { raw -> eventParser.peekType(raw) }, + deserializer = { raw -> eventParser.deserialize(raw) }, + aggregationThreshold = socketConfig.aggregationThreshold, + maxWindowMs = socketConfig.aggregationMaxWindowMs, + dispatchQueueCapacity = socketConfig.aggregationDispatchQueueCapacity, + ), + logger = logProvider.taggedLogger("SCEventAggregator"), + ) + val mutableConnectionState = MutableStateFlow(StreamConnectionState.Idle) return StreamClientImpl( user = user, @@ -446,15 +456,9 @@ internal fun createStreamClientInternal( products = products, config = socketConfig, jsonSerialization = compositeSerialization, - eventParser = - StreamCompositeEventSerializationImpl( - internal = - serializationConfig.eventParser - ?: StreamEventSerialization(compositeSerialization), - external = serializationConfig.productEventSerializers, - ), + eventParser = eventParser, healthMonitor = healthMonitor, - batcher = batcher, + aggregator = resolvedAggregator, internalSocket = socket, subscriptionManager = StreamSubscriptionManager( diff --git a/stream-android-core/src/main/java/io/getstream/android/core/api/model/config/StreamComponentProvider.kt b/stream-android-core/src/main/java/io/getstream/android/core/api/model/config/StreamComponentProvider.kt index b1c18b9..2872a8b 100644 --- a/stream-android-core/src/main/java/io/getstream/android/core/api/model/config/StreamComponentProvider.kt +++ b/stream-android-core/src/main/java/io/getstream/android/core/api/model/config/StreamComponentProvider.kt @@ -22,7 +22,7 @@ import io.getstream.android.core.api.components.StreamAndroidComponentsProvider import io.getstream.android.core.api.log.StreamLoggerProvider import io.getstream.android.core.api.observers.lifecycle.StreamLifecycleMonitor import io.getstream.android.core.api.observers.network.StreamNetworkMonitor -import io.getstream.android.core.api.processing.StreamBatcher +import io.getstream.android.core.api.processing.StreamEventAggregator import io.getstream.android.core.api.processing.StreamSerialProcessingQueue import io.getstream.android.core.api.processing.StreamSingleFlightProcessor import io.getstream.android.core.api.recovery.StreamConnectionRecoveryEvaluator @@ -63,7 +63,7 @@ import io.getstream.android.core.api.subscribe.StreamSubscriptionManager * @param tokenManager Token lifecycle manager. * @param connectionIdHolder Connection ID storage. * @param socketFactory WebSocket factory. - * @param batcher WebSocket message batcher. + * @param eventAggregator WebSocket event aggregator. * @param healthMonitor Connection health monitor. * @param networkMonitor Network connectivity monitor. * @param lifecycleMonitor App lifecycle monitor. @@ -80,7 +80,7 @@ public data class StreamComponentProvider( val tokenManager: StreamTokenManager? = null, val connectionIdHolder: StreamConnectionIdHolder? = null, val socketFactory: StreamWebSocketFactory? = null, - val batcher: StreamBatcher? = null, + val eventAggregator: StreamEventAggregator<*>? = null, val healthMonitor: StreamHealthMonitor? = null, val networkMonitor: StreamNetworkMonitor? = null, val lifecycleMonitor: StreamLifecycleMonitor? = null, diff --git a/stream-android-core/src/main/java/io/getstream/android/core/api/model/config/StreamSocketConfig.kt b/stream-android-core/src/main/java/io/getstream/android/core/api/model/config/StreamSocketConfig.kt index 211a190..64a9215 100644 --- a/stream-android-core/src/main/java/io/getstream/android/core/api/model/config/StreamSocketConfig.kt +++ b/stream-android-core/src/main/java/io/getstream/android/core/api/model/config/StreamSocketConfig.kt @@ -25,8 +25,8 @@ import io.getstream.android.core.api.model.value.StreamWsUrl * Configuration for the Stream WebSocket connection. * * Holds both **identity** (URL, API key, auth type) and **operational** tunables (health check - * timing, batching, connection timeout). Products pass this to the [StreamClient] factory to - * describe their socket. + * timing, event aggregation, connection timeout). Products pass this to the [StreamClient] factory + * to describe their socket. * * ### Usage * @@ -38,7 +38,7 @@ import io.getstream.android.core.api.model.value.StreamWsUrl * clientInfoHeader = clientInfo, * ) * - * // SFU socket — aggressive timing, no batching + * // SFU socket — aggressive timing, low aggregation threshold * val sfuSocket = StreamSocketConfig.jwt( * url = StreamWsUrl.fromString("wss://sfu.stream-io-api.com"), * apiKey = apiKey, @@ -46,7 +46,8 @@ import io.getstream.android.core.api.model.value.StreamWsUrl * healthCheckIntervalMs = 5_000, * livenessThresholdMs = 15_000, * connectionTimeoutMs = 2_000, - * batchSize = 1, + * aggregationThreshold = 10, + * aggregationMaxWindowMs = 200, * ) * ``` * @@ -58,9 +59,12 @@ import io.getstream.android.core.api.model.value.StreamWsUrl * @param livenessThresholdMs Time without a health check ack before the connection is considered * unhealthy in milliseconds. * @param connectionTimeoutMs WebSocket connection timeout in milliseconds. - * @param batchSize Maximum number of WebSocket messages to batch before flushing. - * @param batchInitialDelayMs Initial debounce window for batching in milliseconds. - * @param batchMaxDelayMs Maximum debounce window for batching in milliseconds. + * @param aggregationThreshold Number of accumulated events that triggers aggregated delivery + * instead of individual dispatch. + * @param aggregationMaxWindowMs Maximum time the aggregator collects events before delivering. This + * is the latency ceiling in milliseconds. + * @param aggregationDispatchQueueCapacity Bounded capacity of the dispatch queue between the + * aggregator's collector and dispatcher coroutines. */ @Suppress("LongParameterList") @StreamInternalApi @@ -74,9 +78,9 @@ private constructor( val healthCheckIntervalMs: Long = DEFAULT_HEALTH_INTERVAL_MS, val livenessThresholdMs: Long = DEFAULT_LIVENESS_MS, val connectionTimeoutMs: Long = DEFAULT_CONNECTION_TIMEOUT_MS, - val batchSize: Int = DEFAULT_BATCH_SIZE, - val batchInitialDelayMs: Long = DEFAULT_BATCH_INIT_DELAY_MS, - val batchMaxDelayMs: Long = DEFAULT_BATCH_MAX_DELAY_MS, + val aggregationThreshold: Int = DEFAULT_AGGREGATION_THRESHOLD, + val aggregationMaxWindowMs: Long = DEFAULT_AGGREGATION_MAX_WINDOW_MS, + val aggregationDispatchQueueCapacity: Int = DEFAULT_AGGREGATION_DISPATCH_QUEUE_CAPACITY, ) { /** Default values for [StreamSocketConfig] fields. */ public companion object { @@ -92,14 +96,14 @@ private constructor( /** Default connection timeout: 10 seconds. */ public const val DEFAULT_CONNECTION_TIMEOUT_MS: Long = 10_000L - /** Default batch size: 10 messages. */ - public const val DEFAULT_BATCH_SIZE: Int = 10 + /** Default aggregation threshold: 50 events trigger aggregated delivery. */ + public const val DEFAULT_AGGREGATION_THRESHOLD: Int = 50 - /** Default initial batch delay: 100ms. */ - public const val DEFAULT_BATCH_INIT_DELAY_MS: Long = 100L + /** Default aggregation max window: 500ms latency ceiling. */ + public const val DEFAULT_AGGREGATION_MAX_WINDOW_MS: Long = 500L - /** Default max batch delay: 1 second. */ - public const val DEFAULT_BATCH_MAX_DELAY_MS: Long = 1_000L + /** Default dispatch queue capacity: 16 items. */ + public const val DEFAULT_AGGREGATION_DISPATCH_QUEUE_CAPACITY: Int = 16 /** * Creates a JWT-based [StreamSocketConfig]. @@ -110,9 +114,9 @@ private constructor( * @param healthCheckIntervalMs Interval between health check pings in milliseconds. * @param livenessThresholdMs Liveness threshold in milliseconds. * @param connectionTimeoutMs WebSocket connection timeout in milliseconds. - * @param batchSize Maximum batch size before flush. - * @param batchInitialDelayMs Initial debounce window in milliseconds. - * @param batchMaxDelayMs Maximum debounce window in milliseconds. + * @param aggregationThreshold Events before aggregated delivery triggers. + * @param aggregationMaxWindowMs Maximum collection window in milliseconds. + * @param aggregationDispatchQueueCapacity Dispatch queue capacity. * @return A JWT-based [StreamSocketConfig]. */ @Suppress("LongParameterList") @@ -123,9 +127,9 @@ private constructor( healthCheckIntervalMs: Long = DEFAULT_HEALTH_INTERVAL_MS, livenessThresholdMs: Long = DEFAULT_LIVENESS_MS, connectionTimeoutMs: Long = DEFAULT_CONNECTION_TIMEOUT_MS, - batchSize: Int = DEFAULT_BATCH_SIZE, - batchInitialDelayMs: Long = DEFAULT_BATCH_INIT_DELAY_MS, - batchMaxDelayMs: Long = DEFAULT_BATCH_MAX_DELAY_MS, + aggregationThreshold: Int = DEFAULT_AGGREGATION_THRESHOLD, + aggregationMaxWindowMs: Long = DEFAULT_AGGREGATION_MAX_WINDOW_MS, + aggregationDispatchQueueCapacity: Int = DEFAULT_AGGREGATION_DISPATCH_QUEUE_CAPACITY, ): StreamSocketConfig = StreamSocketConfig( url = url, @@ -135,9 +139,9 @@ private constructor( healthCheckIntervalMs = healthCheckIntervalMs, livenessThresholdMs = livenessThresholdMs, connectionTimeoutMs = connectionTimeoutMs, - batchSize = batchSize, - batchInitialDelayMs = batchInitialDelayMs, - batchMaxDelayMs = batchMaxDelayMs, + aggregationThreshold = aggregationThreshold, + aggregationMaxWindowMs = aggregationMaxWindowMs, + aggregationDispatchQueueCapacity = aggregationDispatchQueueCapacity, ) /** @@ -149,9 +153,9 @@ private constructor( * @param healthCheckIntervalMs Interval between health check pings in milliseconds. * @param livenessThresholdMs Liveness threshold in milliseconds. * @param connectionTimeoutMs WebSocket connection timeout in milliseconds. - * @param batchSize Maximum batch size before flush. - * @param batchInitialDelayMs Initial debounce window in milliseconds. - * @param batchMaxDelayMs Maximum debounce window in milliseconds. + * @param aggregationThreshold Events before aggregated delivery triggers. + * @param aggregationMaxWindowMs Maximum collection window in milliseconds. + * @param aggregationDispatchQueueCapacity Dispatch queue capacity. * @return An anonymous [StreamSocketConfig]. */ @Suppress("LongParameterList") @@ -162,9 +166,9 @@ private constructor( healthCheckIntervalMs: Long = DEFAULT_HEALTH_INTERVAL_MS, livenessThresholdMs: Long = DEFAULT_LIVENESS_MS, connectionTimeoutMs: Long = DEFAULT_CONNECTION_TIMEOUT_MS, - batchSize: Int = DEFAULT_BATCH_SIZE, - batchInitialDelayMs: Long = DEFAULT_BATCH_INIT_DELAY_MS, - batchMaxDelayMs: Long = DEFAULT_BATCH_MAX_DELAY_MS, + aggregationThreshold: Int = DEFAULT_AGGREGATION_THRESHOLD, + aggregationMaxWindowMs: Long = DEFAULT_AGGREGATION_MAX_WINDOW_MS, + aggregationDispatchQueueCapacity: Int = DEFAULT_AGGREGATION_DISPATCH_QUEUE_CAPACITY, ): StreamSocketConfig = StreamSocketConfig( url = url, @@ -174,9 +178,9 @@ private constructor( healthCheckIntervalMs = healthCheckIntervalMs, livenessThresholdMs = livenessThresholdMs, connectionTimeoutMs = connectionTimeoutMs, - batchSize = batchSize, - batchInitialDelayMs = batchInitialDelayMs, - batchMaxDelayMs = batchMaxDelayMs, + aggregationThreshold = aggregationThreshold, + aggregationMaxWindowMs = aggregationMaxWindowMs, + aggregationDispatchQueueCapacity = aggregationDispatchQueueCapacity, ) /** @@ -189,9 +193,9 @@ private constructor( * @param healthCheckIntervalMs Interval between health check pings in milliseconds. * @param livenessThresholdMs Liveness threshold in milliseconds. * @param connectionTimeoutMs WebSocket connection timeout in milliseconds. - * @param batchSize Maximum batch size before flush. - * @param batchInitialDelayMs Initial debounce window in milliseconds. - * @param batchMaxDelayMs Maximum debounce window in milliseconds. + * @param aggregationThreshold Events before aggregated delivery triggers. + * @param aggregationMaxWindowMs Maximum collection window in milliseconds. + * @param aggregationDispatchQueueCapacity Dispatch queue capacity. * @return A custom [StreamSocketConfig]. */ @Suppress("LongParameterList") @@ -203,9 +207,9 @@ private constructor( healthCheckIntervalMs: Long = DEFAULT_HEALTH_INTERVAL_MS, livenessThresholdMs: Long = DEFAULT_LIVENESS_MS, connectionTimeoutMs: Long = DEFAULT_CONNECTION_TIMEOUT_MS, - batchSize: Int = DEFAULT_BATCH_SIZE, - batchInitialDelayMs: Long = DEFAULT_BATCH_INIT_DELAY_MS, - batchMaxDelayMs: Long = DEFAULT_BATCH_MAX_DELAY_MS, + aggregationThreshold: Int = DEFAULT_AGGREGATION_THRESHOLD, + aggregationMaxWindowMs: Long = DEFAULT_AGGREGATION_MAX_WINDOW_MS, + aggregationDispatchQueueCapacity: Int = DEFAULT_AGGREGATION_DISPATCH_QUEUE_CAPACITY, ): StreamSocketConfig { require(authType.isNotBlank()) { "Auth type must not be blank" } return StreamSocketConfig( @@ -216,9 +220,9 @@ private constructor( healthCheckIntervalMs = healthCheckIntervalMs, livenessThresholdMs = livenessThresholdMs, connectionTimeoutMs = connectionTimeoutMs, - batchSize = batchSize, - batchInitialDelayMs = batchInitialDelayMs, - batchMaxDelayMs = batchMaxDelayMs, + aggregationThreshold = aggregationThreshold, + aggregationMaxWindowMs = aggregationMaxWindowMs, + aggregationDispatchQueueCapacity = aggregationDispatchQueueCapacity, ) } } diff --git a/stream-android-core/src/main/java/io/getstream/android/core/api/processing/StreamAggregatedEvent.kt b/stream-android-core/src/main/java/io/getstream/android/core/api/processing/StreamAggregatedEvent.kt new file mode 100644 index 0000000..aa594dd --- /dev/null +++ b/stream-android-core/src/main/java/io/getstream/android/core/api/processing/StreamAggregatedEvent.kt @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2014-2026 Stream.io Inc. All rights reserved. + * + * Licensed under the Stream License; + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://github.com/GetStream/stream-core-android/blob/main/LICENSE + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.getstream.android.core.api.processing + +import io.getstream.android.core.annotations.StreamInternalApi + +/** + * A collection of events grouped by their type, produced by [StreamEventAggregator] during a + * traffic spike. + * + * When the aggregator detects high event throughput, it collects events within a time window and + * delivers them as a single [StreamAggregatedEvent] instead of dispatching each individually. This + * allows product SDKs to apply one state update and one UI recomposition per window instead of one + * per event. + * + * During normal (low) traffic, events flow through individually — this class is only used during + * spikes. + * + * ### Usage + * + * ```kotlin + * when (event) { + * is StreamAggregatedEvent<*> -> { + * event.events.forEach { (type, eventsOfType) -> + * when (type) { + * "channel.updated" -> applyLatest(eventsOfType) + * "message.new" -> processAll(eventsOfType) + * } + * } + * } + * else -> handleSingleEvent(event) + * } + * ``` + * + * @param T The type of the individual events. + * @property events Events grouped by type string. Each list preserves arrival order. + */ +@StreamInternalApi +public class StreamAggregatedEvent(events: Map>) { + public val events: Map> = events.mapValues { (_, v) -> v.toList() } +} diff --git a/stream-android-core/src/main/java/io/getstream/android/core/api/processing/StreamEventAggregationPolicy.kt b/stream-android-core/src/main/java/io/getstream/android/core/api/processing/StreamEventAggregationPolicy.kt new file mode 100644 index 0000000..b5947c7 --- /dev/null +++ b/stream-android-core/src/main/java/io/getstream/android/core/api/processing/StreamEventAggregationPolicy.kt @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2014-2026 Stream.io Inc. All rights reserved. + * + * Licensed under the Stream License; + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://github.com/GetStream/stream-core-android/blob/main/LICENSE + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.getstream.android.core.api.processing + +import io.getstream.android.core.annotations.StreamInternalApi + +/** + * Complete aggregation configuration for [StreamEventAggregator]. + * + * Combines **behavior** (how to extract types and deserialize) with **tuning** (thresholds, + * windows, queue capacity). Once created via a factory method, the policy is guaranteed valid — all + * invariants are checked at construction time. + * + * The primary constructor is **private**. Create instances through the [from] factory method. + * + * ### Invariants + * * `aggregationThreshold > 0` + * * `maxWindowMs > 0` + * * `dispatchQueueCapacity > 0` + * + * @param T The deserialized event type. + * @param extractType Extracts the event type string from a raw message (typically JSON). Returns + * `null` if the type cannot be determined — events with unknown type are grouped under an empty + * key. + * @param deserialize Deserializes a raw message into `T`. Returns `Result.failure` on parse errors. + * @param aggregationThreshold Number of accumulated events that triggers aggregated delivery + * instead of individual dispatch. + * @param maxWindowMs Maximum time (milliseconds) the collector will wait before packaging and + * delivering whatever has accumulated. This is the latency ceiling. + * @param dispatchQueueCapacity Bounded capacity of the dispatch queue between collector and + * dispatcher. When full, the collector logs a warning. + */ +@StreamInternalApi +@ConsistentCopyVisibility +public data class StreamEventAggregationPolicy +private constructor( + val extractType: (raw: String) -> String?, + val deserialize: (raw: String) -> Result, + val aggregationThreshold: Int, + val maxWindowMs: Long, + val dispatchQueueCapacity: Int, +) { + public companion object { + + /** Default aggregation threshold: 50 events trigger aggregated delivery. */ + public const val DEFAULT_AGGREGATION_THRESHOLD: Int = 50 + + /** Default aggregation max window: 500ms latency ceiling. */ + public const val DEFAULT_MAX_WINDOW_MS: Long = 500L + + /** Default dispatch queue capacity: 16 items. */ + public const val DEFAULT_DISPATCH_QUEUE_CAPACITY: Int = 16 + + /** + * Creates a validated policy from the given parameters. + * + * ### Usage + * + * ```kotlin + * val policy = StreamEventAggregationPolicy.from( + * typeExtractor = { raw -> eventParser.peekType(raw) }, + * deserializer = { raw -> eventParser.deserialize(raw) }, + * aggregationThreshold = 100, + * maxWindowMs = 300, + * ) + * ``` + * + * @param T The deserialized event type. + * @param typeExtractor Extracts the event type from a raw message. + * @param deserializer Deserializes a raw message into `T`. + * @param aggregationThreshold Events before aggregated delivery triggers. + * @param maxWindowMs Maximum collection window in milliseconds. + * @param dispatchQueueCapacity Dispatch queue capacity. + * @return A validated [StreamEventAggregationPolicy]. + * @throws IllegalArgumentException if any numeric parameter is ≤ 0. + */ + @Suppress("LongParameterList") + public fun from( + typeExtractor: (raw: String) -> String?, + deserializer: (raw: String) -> Result, + aggregationThreshold: Int = DEFAULT_AGGREGATION_THRESHOLD, + maxWindowMs: Long = DEFAULT_MAX_WINDOW_MS, + dispatchQueueCapacity: Int = DEFAULT_DISPATCH_QUEUE_CAPACITY, + ): StreamEventAggregationPolicy { + validate(aggregationThreshold, maxWindowMs, dispatchQueueCapacity) + return StreamEventAggregationPolicy( + extractType = typeExtractor, + deserialize = deserializer, + aggregationThreshold = aggregationThreshold, + maxWindowMs = maxWindowMs, + dispatchQueueCapacity = dispatchQueueCapacity, + ) + } + + private fun validate( + aggregationThreshold: Int, + maxWindowMs: Long, + dispatchQueueCapacity: Int, + ) { + require(aggregationThreshold > 0) { + "aggregationThreshold must be > 0, was $aggregationThreshold" + } + require(maxWindowMs > 0) { "maxWindowMs must be > 0, was $maxWindowMs" } + require(dispatchQueueCapacity > 0) { + "dispatchQueueCapacity must be > 0, was $dispatchQueueCapacity" + } + } + } +} diff --git a/stream-android-core/src/main/java/io/getstream/android/core/api/processing/StreamEventAggregator.kt b/stream-android-core/src/main/java/io/getstream/android/core/api/processing/StreamEventAggregator.kt new file mode 100644 index 0000000..a8155c2 --- /dev/null +++ b/stream-android-core/src/main/java/io/getstream/android/core/api/processing/StreamEventAggregator.kt @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2014-2026 Stream.io Inc. All rights reserved. + * + * Licensed under the Stream License; + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://github.com/GetStream/stream-core-android/blob/main/LICENSE + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.getstream.android.core.api.processing + +import io.getstream.android.core.annotations.StreamInternalApi +import io.getstream.android.core.api.log.StreamLogger +import io.getstream.android.core.internal.processing.StreamEventAggregatorImpl +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.channels.Channel + +/** + * Adaptive event aggregator that switches between individual and aggregated event delivery based on + * traffic volume. + * + * ### Architecture + * + * Two decoupled coroutines: + * - **Collector:** Drains the inbox channel, groups events by type, and packages them into the + * dispatch queue. Collection ends when either [aggregationThreshold] is reached or [maxWindowMs] + * elapses — whichever comes first. This guarantees a latency ceiling regardless of downstream + * processing speed. + * - **Dispatcher:** Takes packaged work from the dispatch queue and invokes the registered handler. + * Runs sequentially — one delivery at a time, preserving order. + * + * ### Adaptive behavior + * - **Low traffic:** When the collector drains the channel and finds few events (below + * [aggregationThreshold]), each event is dispatched individually via the handler. + * - **Spike:** When accumulated events reach or exceed [aggregationThreshold], they are grouped by + * type into a [StreamAggregatedEvent] and dispatched as a single call. + * + * ### Backpressure + * - Natural: While the dispatcher is busy, events accumulate in the inbox channel. The next + * collection cycle finds them ready — larger accumulation triggers aggregation automatically. + * - The dispatch queue has a bounded capacity. If the dispatcher can't keep up, the collector logs + * a warning when the queue is full. + * + * @param T The type of the deserialized event. + */ +@StreamInternalApi +public interface StreamEventAggregator { + + /** + * Starts the collector and dispatcher coroutines. + * + * @return `Result.success(Unit)` if started successfully. + */ + public fun start(): Result + + /** + * Enqueues a raw event for processing. Non-suspending, suitable for WebSocket callbacks. + * + * Returns `false` if the aggregator has not been [start]ed or if the inbox is full/closed. + * + * @param raw The raw event (typically a JSON string from the WebSocket). + * @return `true` if accepted into the inbox, `false` otherwise. + */ + public fun offer(raw: String): Boolean + + /** + * Registers the handler invoked for each delivery. + * + * The handler receives **either**: + * - An individual event `T` (low traffic), or + * - A [StreamAggregatedEvent]`` wrapping grouped events (spike). + * + * Product SDKs distinguish using `is StreamAggregatedEvent<*>`. + * + * @param handler A suspending function receiving the event or aggregated event. + */ + public fun onEvent(handler: suspend (Any) -> Unit) + + /** + * Stops both coroutines and releases resources. Idempotent. + * + * @return `Result.success(Unit)` on clean shutdown. + */ + public fun stop(): Result +} + +/** + * Creates a new [StreamEventAggregator] instance. + * + * @param T The deserialized event type. + * @param scope Coroutine scope for the collector and dispatcher coroutines. + * @param policy Aggregation policy defining type extraction, deserialization, and tuning + * (thresholds, windows, queue capacity). Validated at construction — once you have a policy, it + * is guaranteed valid. + * @param inboxCapacity Capacity of the raw event inbox channel. + * @param logger Optional tagged logger for diagnostics. + */ +@StreamInternalApi +public fun StreamEventAggregator( + scope: CoroutineScope, + policy: StreamEventAggregationPolicy, + inboxCapacity: Int = Channel.UNLIMITED, + logger: StreamLogger? = null, +): StreamEventAggregator = + StreamEventAggregatorImpl( + scope = scope, + policy = policy, + inboxCapacity = inboxCapacity, + logger = logger, + ) diff --git a/stream-android-core/src/main/java/io/getstream/android/core/internal/processing/StreamEventAggregatorImpl.kt b/stream-android-core/src/main/java/io/getstream/android/core/internal/processing/StreamEventAggregatorImpl.kt new file mode 100644 index 0000000..d9564e7 --- /dev/null +++ b/stream-android-core/src/main/java/io/getstream/android/core/internal/processing/StreamEventAggregatorImpl.kt @@ -0,0 +1,220 @@ +/* + * Copyright (c) 2014-2026 Stream.io Inc. All rights reserved. + * + * Licensed under the Stream License; + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://github.com/GetStream/stream-core-android/blob/main/LICENSE + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.getstream.android.core.internal.processing + +import io.getstream.android.core.api.log.StreamLogger +import io.getstream.android.core.api.processing.StreamAggregatedEvent +import io.getstream.android.core.api.processing.StreamEventAggregationPolicy +import io.getstream.android.core.api.processing.StreamEventAggregator +import io.getstream.android.core.api.utils.runCatchingCancellable +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicReference +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ClosedReceiveChannelException +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch + +/** + * Internal implementation of [StreamEventAggregator]. + * + * Uses two decoupled coroutines: + * - **Collector:** Drains [inbox], groups by type, packages into [dispatchQueue]. + * - **Dispatcher:** Takes from [dispatchQueue], calls the registered handler. + * + * Collection stops when [policy.aggregationThreshold] items accumulate or [policy.maxWindowMs] + * elapses. The dispatch queue is bounded by [dispatchQueueCapacity] — if full, the collector drops + * the delivery and logs a warning. + */ +internal class StreamEventAggregatorImpl( + private val scope: CoroutineScope, + private val policy: StreamEventAggregationPolicy, + inboxCapacity: Int = Channel.UNLIMITED, + internal var logger: StreamLogger? = null, +) : StreamEventAggregator { + + private val inbox = StreamRestartableChannel(inboxCapacity) + private val dispatchQueue = + StreamRestartableChannel>(policy.dispatchQueueCapacity) + private val started = AtomicBoolean(false) + private var collectorJob: Job? = null + private var dispatcherJob: Job? = null + + private val eventHandler = AtomicReference Unit> { /* no-op until set */ } + + override fun onEvent(handler: suspend (Any) -> Unit) { + eventHandler.set(handler) + } + + override fun start(): Result = runCatching { + if (!started.compareAndSet(false, true)) { + return Result.success(Unit) + } + inbox.start() + dispatchQueue.start() + collectorJob = scope.launch { runCollector() } + dispatcherJob = scope.launch { runDispatcher() } + } + + override fun offer(raw: String): Boolean { + if (!started.get()) { + return false + } + return inbox.trySend(raw).isSuccess + } + + override fun stop(): Result = runCatching { + if (!started.compareAndSet(true, false)) { + return Result.success(Unit) + } + collectorJob?.cancel() + dispatcherJob?.cancel() + collectorJob = null + dispatcherJob = null + inbox.close() + dispatchQueue.close() + } + + private suspend fun runCollector() { + try { + while (scope.isActive) { + // Wait for the first event — suspends until something arrives + val first = inbox.receive() + val buffer = mutableListOf(first) + + // Collect more until threshold or maxWindow + collectWindow(buffer) + + // Package and send to dispatch queue + val item = packageForDispatch(buffer) + if (item != null) { + val sent = dispatchQueue.trySend(item) + if (sent.isFailure) { + logger?.w { + "[collector] Dispatch queue full (capacity=$policy.dispatchQueueCapacity). " + + "Dropping ${buffer.size} events. Dispatcher may be too slow." + } + } + } + } + } catch (_: ClosedReceiveChannelException) { + logger?.d { "[collector] Inbox closed, shutting down" } + } + } + + /** + * Collects events from the inbox until [policy.aggregationThreshold] is reached or + * [policy.maxWindowMs] elapses. Uses [kotlinx.coroutines.withTimeoutOrNull] so virtual-time + * test dispatchers can advance the clock correctly. + */ + private suspend fun collectWindow(buffer: MutableList) { + kotlinx.coroutines.withTimeoutOrNull(policy.maxWindowMs) { + while (buffer.size < policy.aggregationThreshold) { + buffer += inbox.receive() + } + } + } + + /** Deserializes raw messages and decides: individual dispatch or aggregated. */ + private fun packageForDispatch(rawMessages: List): DispatchItem? { + if (rawMessages.isEmpty()) { + return null + } + + if (rawMessages.size < policy.aggregationThreshold) { + // Low traffic — deserialize and dispatch individually + val events = mutableListOf>() + for (raw in rawMessages) { + val parsed = safeDeserialize(raw) ?: continue + events += DeserializedEvent(raw, parsed) + } + return if (events.isEmpty()) null else DispatchItem.Individual(events) + } + + // Spike — group by type + val grouped = LinkedHashMap>() + for (raw in rawMessages) { + val type = safeExtractType(raw) + val event = safeDeserialize(raw) ?: continue + grouped.getOrPut(type) { mutableListOf() }.add(event) + } + return if (grouped.isEmpty()) null + else DispatchItem.Aggregated(StreamAggregatedEvent(grouped)) + } + + /** Calls [deserializer], catching both Result.failure and thrown exceptions. */ + private fun safeDeserialize(raw: String): T? = + runCatchingCancellable { + policy + .deserialize(raw) + .onFailure { e -> + logger?.e(e) { "[collector] Failed to deserialize event. ${e.message}" } + } + .getOrNull() + } + .onFailure { e -> logger?.e(e) { "[collector] Deserializer threw. ${e.message}" } } + .getOrNull() + + /** Calls [typeExtractor], catching thrown exceptions. Returns empty string on failure. */ + private fun safeExtractType(raw: String): String = + runCatchingCancellable { policy.extractType(raw) ?: "" } + .onFailure { e -> logger?.e(e) { "[collector] Type extractor threw. ${e.message}" } } + .getOrDefault("") + + private suspend fun runDispatcher() { + try { + for (item in dispatchQueue) { + when (item) { + is DispatchItem.Individual -> { + for (event in item.events) { + runCatchingCancellable { + eventHandler.get().invoke(event.parsed as Any) + } + .onFailure { e -> + logger?.e(e) { + "[dispatcher] Handler threw on individual event. ${e.message}" + } + } + } + } + + is DispatchItem.Aggregated -> { + runCatchingCancellable { eventHandler.get().invoke(item.aggregated) } + .onFailure { e -> + logger?.e(e) { + "[dispatcher] Handler threw on aggregated event. ${e.message}" + } + } + } + } + } + } catch (_: ClosedReceiveChannelException) { + logger?.d { "[dispatcher] Dispatch queue closed, shutting down" } + } + } + + /** A single deserialized event with its raw source preserved for logging. */ + private data class DeserializedEvent(val raw: String, val parsed: T) + + /** Work item in the dispatch queue. */ + private sealed class DispatchItem { + data class Individual(val events: List>) : DispatchItem() + + data class Aggregated(val aggregated: StreamAggregatedEvent) : DispatchItem() + } +} diff --git a/stream-android-core/src/main/java/io/getstream/android/core/internal/serialization/StreamCompositeEventSerializationImpl.kt b/stream-android-core/src/main/java/io/getstream/android/core/internal/serialization/StreamCompositeEventSerializationImpl.kt index 352fec4..40ba04f 100644 --- a/stream-android-core/src/main/java/io/getstream/android/core/internal/serialization/StreamCompositeEventSerializationImpl.kt +++ b/stream-android-core/src/main/java/io/getstream/android/core/internal/serialization/StreamCompositeEventSerializationImpl.kt @@ -124,7 +124,7 @@ internal class StreamCompositeEventSerializationImpl( }.getOrThrow() } - private fun peekType(raw: String): String? { + internal fun peekType(raw: String): String? { val reader = JsonReader.of(Buffer().writeUtf8(raw)) reader.isLenient = true return try { diff --git a/stream-android-core/src/main/java/io/getstream/android/core/internal/socket/StreamSocketSession.kt b/stream-android-core/src/main/java/io/getstream/android/core/internal/socket/StreamSocketSession.kt index 6022236..094a62a 100644 --- a/stream-android-core/src/main/java/io/getstream/android/core/internal/socket/StreamSocketSession.kt +++ b/stream-android-core/src/main/java/io/getstream/android/core/internal/socket/StreamSocketSession.kt @@ -25,7 +25,8 @@ import io.getstream.android.core.api.model.connection.StreamConnectionState import io.getstream.android.core.api.model.exceptions.StreamClientException import io.getstream.android.core.api.model.exceptions.StreamEndpointErrorData import io.getstream.android.core.api.model.exceptions.StreamEndpointException -import io.getstream.android.core.api.processing.StreamBatcher +import io.getstream.android.core.api.processing.StreamAggregatedEvent +import io.getstream.android.core.api.processing.StreamEventAggregator import io.getstream.android.core.api.serialization.StreamJsonSerialization import io.getstream.android.core.api.socket.StreamWebSocket import io.getstream.android.core.api.socket.listeners.StreamClientListener @@ -57,7 +58,7 @@ internal class StreamSocketSession( private val jsonSerialization: StreamJsonSerialization, private val eventParser: StreamCompositeEventSerializationImpl, private val healthMonitor: StreamHealthMonitor, - private val batcher: StreamBatcher, + private val aggregator: StreamEventAggregator<*>, private val subscriptionManager: StreamSubscriptionManager, private val products: List, ) : StreamSubscriptionManager by subscriptionManager { @@ -84,13 +85,15 @@ internal class StreamSocketSession( override fun onMessage(text: String) { logger.v { "[onMessage] Socket message: $text" } healthMonitor.acknowledgeHeartbeat() - val accepted = batcher.offer(text) + val accepted = aggregator.offer(text) if (!accepted) { val error = IllegalStateException( - "Failed to offer message to debounce processor. Message dropped: $text" + "Failed to offer message to event aggregator. Message dropped (${text.length} bytes)" ) - logger.e(error) { "[onMessage] Message dropped: $text" } + logger.e(error) { + "[onMessage] Message dropped by event aggregator (${text.length} bytes)" + } disconnect(error) } else { logger.v { "[onMessage] Message accepted: $text" } @@ -124,7 +127,7 @@ internal class StreamSocketSession( * * The method emits a `StreamConnectionState.Disconnected` state (embedding [error] when * provided), cancels the active socket subscription, closes the underlying [StreamWebSocket], - * and stops the health monitor and batch processor. Subsequent invocations are idempotent: + * and stops the health monitor and event aggregator. Subsequent invocations are idempotent: * listeners are only notified on the first call while the socket close is still attempted every * time. * @@ -231,57 +234,23 @@ internal class StreamSocketSession( healthMonitor.onUnhealthy { disconnect(StreamClientException("Socket did not receive any events.")) } - // Batch processing of incoming messages - batcher.onBatch { batch, delay, count -> - logger.v { - "[onBatch] Socket batch (delay: $delay ms, buffer size: $count): $batch" - } - - batch.forEach { message -> - eventParser - .deserialize(message) - .onSuccess { event -> - logger.v { "[onBatch] Deserialized event: $event" } - val coreEvent = event.core - val productEvent = event.product - - if ( - coreEvent != null && coreEvent is StreamClientConnectionErrorEvent - ) { - notifyState( - StreamConnectionState.Disconnected( - StreamEndpointException("Connection error", coreEvent.error) - ) - ) - } - subscriptionManager.forEach { listener -> - coreEvent - ?.takeUnless { it is StreamHealthCheckEvent } - ?.let { listener.onEvent(it) } - productEvent?.let { listener.onEvent(it) } - } - } - .onFailure { - logger.e(it) { - "[onBatch] Failed to deserialize socket message. ${it.message}" - } - // Attempt to parse as API error - jsonSerialization - .fromJson(message, StreamEndpointErrorData::class.java) - .onSuccess { apiError -> - logger.e { "[onBatch] Parsed error event: $apiError" } - notifyState( - StreamConnectionState.Disconnected( - StreamEndpointException("Connection error", apiError) - ) - ) - } - .onFailure { logger.i { "[onBatch] Failed to parse $message" } } - } + // Event aggregator handler — receives individual or aggregated events + aggregator.onEvent { event -> + when (event) { + is StreamAggregatedEvent<*> -> handleAggregatedEvent(event) + is StreamCompositeSerializationEvent<*> -> { + @Suppress("UNCHECKED_CAST") + handleSingleCompositeEvent(event as StreamCompositeSerializationEvent) + } } } + aggregator.start().onFailure { throwable -> + completeFailure(throwable) + return@suspendCancellableCoroutine + } + // Success/Failure continuations val success: (StreamConnectedUser, String) -> Unit = success@{ user, connectionId -> handshakeSubscription?.cancel() @@ -300,7 +269,7 @@ internal class StreamSocketSession( // Replay messages buffered during handshake for (message in pendingMessages) { - if (!batcher.offer(message)) { + if (!aggregator.offer(message)) { val err = IllegalStateException( "Failed to replay buffered message during handshake transition" @@ -496,13 +465,96 @@ internal class StreamSocketSession( } } + /** Handles a single deserialized composite event (individual dispatch path — low traffic). */ + private fun handleSingleCompositeEvent(event: StreamCompositeSerializationEvent) { + logger.v { "[onEvent] Individual event: $event" } + val coreEvent = event.core + val productEvent = event.product + + if (coreEvent != null && coreEvent is StreamClientConnectionErrorEvent) { + notifyState( + StreamConnectionState.Disconnected( + StreamEndpointException("Connection error", coreEvent.error) + ) + ) + } + subscriptionManager + .forEach { listener -> + coreEvent?.takeUnless { it is StreamHealthCheckEvent }?.let { listener.onEvent(it) } + + productEvent?.let { listener.onEvent(it) } + } + .onFailure { e -> logger.e(e) { "[onEvent] Listener dispatch failed. ${e.message}" } } + } + + /** + * Handles an aggregated event (spike dispatch path — high traffic). + * + * Core events are processed individually (connection errors need immediate handling). Product + * events are re-grouped into a [StreamAggregatedEvent] keyed by event type and dispatched as a + * single call to listeners. + */ + private fun handleAggregatedEvent(aggregated: StreamAggregatedEvent<*>) { + @Suppress("UNCHECKED_CAST") + val typed = aggregated as StreamAggregatedEvent> + val productEvents = LinkedHashMap>() + + for ((type, compositeEvents) in typed.events) { + for (composite in compositeEvents) { + val coreEvent = composite.core + val productEvent = composite.product + + // Handle core events individually — they're rare and need immediate processing + if (coreEvent != null && coreEvent is StreamClientConnectionErrorEvent) { + notifyState( + StreamConnectionState.Disconnected( + StreamEndpointException("Connection error", coreEvent.error) + ) + ) + } + coreEvent + ?.takeUnless { it is StreamHealthCheckEvent } + ?.let { core -> + subscriptionManager + .forEach { listener -> listener.onEvent(core) } + .onFailure { e -> + logger.e(e) { "[onEvent] Listener dispatch failed. ${e.message}" } + } + } + + // Collect product events for aggregated dispatch + if (productEvent != null) { + productEvents.getOrPut(type) { mutableListOf() }.add(productEvent) + } + } + } + + // Dispatch aggregated product events as a single call + if (productEvents.isNotEmpty()) { + logger.v { + "[onEvent] Aggregated: ${productEvents.size} types, " + + "${productEvents.values.sumOf { it.size }} total events" + } + val productAggregated = StreamAggregatedEvent(productEvents.toMap()) + subscriptionManager + .forEach { listener -> listener.onEvent(productAggregated) } + .onFailure { e -> + logger.e(e) { "[onEvent] Listener dispatch failed. ${e.message}" } + } + } + } + private fun cleanup() { if (!cleaned.compareAndSet(false, true)) { return } logger.d { "[cleanup] Cleaning up socket" } healthMonitor.stop() - batcher.stop() + aggregator.stop().onFailure { throwable -> + logger.e(throwable) { + "[cleanup] Failed to stop event aggregator. ${throwable.message}" + } + } socketSubscription?.cancel() socketSubscription = null streamClientConnectedEvent = null diff --git a/stream-android-core/src/test/java/io/getstream/android/core/api/StreamClientConfigFactoryTest.kt b/stream-android-core/src/test/java/io/getstream/android/core/api/StreamClientConfigFactoryTest.kt index 51e34ed..a0bd217 100644 --- a/stream-android-core/src/test/java/io/getstream/android/core/api/StreamClientConfigFactoryTest.kt +++ b/stream-android-core/src/test/java/io/getstream/android/core/api/StreamClientConfigFactoryTest.kt @@ -39,7 +39,7 @@ import io.getstream.android.core.api.model.value.StreamHttpClientInfoHeader import io.getstream.android.core.api.model.value.StreamToken import io.getstream.android.core.api.model.value.StreamUserId import io.getstream.android.core.api.model.value.StreamWsUrl -import io.getstream.android.core.api.processing.StreamBatcher +import io.getstream.android.core.api.processing.StreamEventAggregator import io.getstream.android.core.api.processing.StreamSerialProcessingQueue import io.getstream.android.core.api.processing.StreamSingleFlightProcessor import io.getstream.android.core.api.serialization.StreamEventSerialization @@ -189,26 +189,28 @@ internal class StreamClientConfigFactoryTest { } @Test - fun `factory wires custom batch parameters from socketConfig`() { + fun `factory wires custom aggregation parameters from socketConfig`() { val customConfig = StreamSocketConfig.jwt( url = defaultSocketConfig.url, apiKey = defaultSocketConfig.apiKey, clientInfoHeader = defaultSocketConfig.clientInfoHeader, - batchSize = 20, - batchInitialDelayMs = 50L, - batchMaxDelayMs = 500L, + aggregationThreshold = 20, + aggregationMaxWindowMs = 200L, + aggregationDispatchQueueCapacity = 8, ) val client = buildClient(socketConfig = customConfig) val socketSession = (client as StreamClientImpl<*>).readPrivateField("socketSession") as StreamSocketSession<*> - val batcher = socketSession.readPrivateField("batcher") as StreamBatcher<*> - assertNotNull(batcher) - batcher.assertFieldEquals("batchSize", 20) - batcher.assertFieldEquals("initialDelayMs", 50L) - batcher.assertFieldEquals("maxDelayMs", 500L) + val aggregator = socketSession.readPrivateField("aggregator") + assertNotNull(aggregator) + val policy = aggregator!!.readPrivateField("policy") + assertNotNull(policy) + policy!!.assertFieldEquals("aggregationThreshold", 20) + policy.assertFieldEquals("maxWindowMs", 200L) + policy.assertFieldEquals("dispatchQueueCapacity", 8) } // ── StreamComponentProvider overrides ──────────────────────────────────── @@ -318,14 +320,14 @@ internal class StreamClientConfigFactoryTest { } @Test - fun `factory wires injected batcher from components`() { - val batcher = mockk>(relaxed = true) + fun `factory wires injected aggregator from components`() { + val eventAggregator = mockk>(relaxed = true) val client = buildClient( components = StreamComponentProvider( logProvider = logProvider, - batcher = batcher, + eventAggregator = eventAggregator, androidComponentsProvider = fakeAndroidComponents, ) ) @@ -333,7 +335,7 @@ internal class StreamClientConfigFactoryTest { val socketSession = (client as StreamClientImpl<*>).readPrivateField("socketSession") as StreamSocketSession<*> - socketSession.assertFieldEquals("batcher", batcher) + socketSession.assertFieldEquals("aggregator", eventAggregator) } @Test @@ -376,7 +378,7 @@ internal class StreamClientConfigFactoryTest { val socketSession = impl.readPrivateField("socketSession") as StreamSocketSession<*> assertNotNull(socketSession.readPrivateField("healthMonitor")) - assertNotNull(socketSession.readPrivateField("batcher")) + assertNotNull(socketSession.readPrivateField("aggregator")) assertNotNull(socketSession.readPrivateField("internalSocket")) } @@ -391,9 +393,9 @@ internal class StreamClientConfigFactoryTest { url = StreamWsUrl.fromString("wss://custom.stream.io"), apiKey = defaultSocketConfig.apiKey, clientInfoHeader = defaultSocketConfig.clientInfoHeader, - batchSize = 5, - batchInitialDelayMs = 25L, - batchMaxDelayMs = 250L, + aggregationThreshold = 10, + aggregationMaxWindowMs = 200L, + aggregationDispatchQueueCapacity = 8, ) val client = @@ -417,10 +419,8 @@ internal class StreamClientConfigFactoryTest { // Injected health monitor takes precedence over socketConfig timing socketSession.assertFieldEquals("healthMonitor", healthMonitor) - // Batcher still created from socketConfig since not injected - val batcher = socketSession.readPrivateField("batcher") as StreamBatcher<*> - batcher.assertFieldEquals("batchSize", 5) - batcher.assertFieldEquals("initialDelayMs", 25L) - batcher.assertFieldEquals("maxDelayMs", 250L) + // Aggregator still created from socketConfig since not injected + val aggregator = socketSession.readPrivateField("aggregator") + assertNotNull(aggregator) } } diff --git a/stream-android-core/src/test/java/io/getstream/android/core/api/StreamClientFactoryTest.kt b/stream-android-core/src/test/java/io/getstream/android/core/api/StreamClientFactoryTest.kt index b6c617f..2bd94c2 100644 --- a/stream-android-core/src/test/java/io/getstream/android/core/api/StreamClientFactoryTest.kt +++ b/stream-android-core/src/test/java/io/getstream/android/core/api/StreamClientFactoryTest.kt @@ -37,7 +37,7 @@ import io.getstream.android.core.api.model.value.StreamUserId import io.getstream.android.core.api.model.value.StreamWsUrl import io.getstream.android.core.api.observers.lifecycle.StreamLifecycleMonitor import io.getstream.android.core.api.observers.network.StreamNetworkMonitor -import io.getstream.android.core.api.processing.StreamBatcher +import io.getstream.android.core.api.processing.StreamEventAggregator import io.getstream.android.core.api.processing.StreamSerialProcessingQueue import io.getstream.android.core.api.processing.StreamSingleFlightProcessor import io.getstream.android.core.api.recovery.StreamConnectionRecoveryEvaluator @@ -101,7 +101,7 @@ internal class StreamClientFactoryTest { val connectionIdHolder: StreamConnectionIdHolder, val socketFactory: StreamWebSocketFactory, val healthMonitor: StreamHealthMonitor, - val batcher: StreamBatcher, + val eventAggregator: StreamEventAggregator, val lifecycleMonitor: StreamLifecycleMonitor, val networkMonitor: StreamNetworkMonitor, val connectionRecoveryEvaluator: StreamConnectionRecoveryEvaluator, @@ -133,7 +133,7 @@ internal class StreamClientFactoryTest { connectionIdHolder = mockk(relaxed = true), socketFactory = mockk(relaxed = true), healthMonitor = mockk(relaxed = true), - batcher = mockk(relaxed = true), + eventAggregator = mockk(relaxed = true), lifecycleMonitor = mockk(relaxed = true), networkMonitor = mockk(relaxed = true), connectionRecoveryEvaluator = mockk(relaxed = true), @@ -157,7 +157,7 @@ internal class StreamClientFactoryTest { connectionIdHolder = deps.connectionIdHolder, socketFactory = deps.socketFactory, healthMonitor = deps.healthMonitor, - batcher = deps.batcher, + eventAggregator = deps.eventAggregator, httpConfig = httpConfig, serializationConfig = serializationConfig, logProvider = logProvider, @@ -201,7 +201,7 @@ internal class StreamClientFactoryTest { val socketSession = client.readPrivateField("socketSession") as StreamSocketSession<*> socketSession.assertFieldEquals("config", deps.socketConfig) socketSession.assertFieldEquals("healthMonitor", deps.healthMonitor) - socketSession.assertFieldEquals("batcher", deps.batcher) + socketSession.assertFieldEquals("aggregator", deps.eventAggregator) socketSession.assertFieldEquals("products", listOf("feeds")) val internalSocket = socketSession.readPrivateField("internalSocket") diff --git a/stream-android-core/src/test/java/io/getstream/android/core/api/model/config/StreamSocketConfigTest.kt b/stream-android-core/src/test/java/io/getstream/android/core/api/model/config/StreamSocketConfigTest.kt index 7fd1b15..7d495e0 100644 --- a/stream-android-core/src/test/java/io/getstream/android/core/api/model/config/StreamSocketConfigTest.kt +++ b/stream-android-core/src/test/java/io/getstream/android/core/api/model/config/StreamSocketConfigTest.kt @@ -66,9 +66,15 @@ class StreamSocketConfigTest { assertEquals(StreamSocketConfig.DEFAULT_HEALTH_INTERVAL_MS, config.healthCheckIntervalMs) assertEquals(StreamSocketConfig.DEFAULT_LIVENESS_MS, config.livenessThresholdMs) assertEquals(StreamSocketConfig.DEFAULT_CONNECTION_TIMEOUT_MS, config.connectionTimeoutMs) - assertEquals(StreamSocketConfig.DEFAULT_BATCH_SIZE, config.batchSize) - assertEquals(StreamSocketConfig.DEFAULT_BATCH_INIT_DELAY_MS, config.batchInitialDelayMs) - assertEquals(StreamSocketConfig.DEFAULT_BATCH_MAX_DELAY_MS, config.batchMaxDelayMs) + assertEquals(StreamSocketConfig.DEFAULT_AGGREGATION_THRESHOLD, config.aggregationThreshold) + assertEquals( + StreamSocketConfig.DEFAULT_AGGREGATION_MAX_WINDOW_MS, + config.aggregationMaxWindowMs, + ) + assertEquals( + StreamSocketConfig.DEFAULT_AGGREGATION_DISPATCH_QUEUE_CAPACITY, + config.aggregationDispatchQueueCapacity, + ) } @Test @@ -81,16 +87,16 @@ class StreamSocketConfigTest { healthCheckIntervalMs = 5_000L, livenessThresholdMs = 15_000L, connectionTimeoutMs = 2_000L, - batchSize = 1, - batchInitialDelayMs = 0L, - batchMaxDelayMs = 0L, + aggregationThreshold = 10, + aggregationMaxWindowMs = 200L, + aggregationDispatchQueueCapacity = 8, ) assertEquals(5_000L, config.healthCheckIntervalMs) assertEquals(15_000L, config.livenessThresholdMs) assertEquals(2_000L, config.connectionTimeoutMs) - assertEquals(1, config.batchSize) - assertEquals(0L, config.batchInitialDelayMs) - assertEquals(0L, config.batchMaxDelayMs) + assertEquals(10, config.aggregationThreshold) + assertEquals(200L, config.aggregationMaxWindowMs) + assertEquals(8, config.aggregationDispatchQueueCapacity) } } diff --git a/stream-android-core/src/test/java/io/getstream/android/core/internal/processing/StreamEventAggregatorImplTest.kt b/stream-android-core/src/test/java/io/getstream/android/core/internal/processing/StreamEventAggregatorImplTest.kt new file mode 100644 index 0000000..7b1b4cb --- /dev/null +++ b/stream-android-core/src/test/java/io/getstream/android/core/internal/processing/StreamEventAggregatorImplTest.kt @@ -0,0 +1,925 @@ +/* + * Copyright (c) 2014-2026 Stream.io Inc. All rights reserved. + * + * Licensed under the Stream License; + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://github.com/GetStream/stream-core-android/blob/main/LICENSE + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@file:OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) + +package io.getstream.android.core.internal.processing + +import io.getstream.android.core.api.log.StreamLogger +import io.getstream.android.core.api.processing.StreamAggregatedEvent +import io.getstream.android.core.api.processing.StreamEventAggregationPolicy +import io.getstream.android.core.api.processing.StreamEventAggregator +import java.util.concurrent.CopyOnWriteArrayList +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.test.StandardTestDispatcher +import kotlinx.coroutines.test.advanceTimeBy +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Test + +class StreamEventAggregatorImplTest { + + /** Simple type extractor: pulls the first token after "type:" prefix. */ + private val typeExtractor: (String) -> String? = { raw -> + val idx = raw.indexOf("type:") + if (idx >= 0) raw.substring(idx + 5).trim().substringBefore(' ') else null + } + + /** Simple deserializer: wraps raw string in a TestEvent. */ + private val deserializer: (String) -> Result = { raw -> + if (raw.startsWith("BAD")) Result.failure(IllegalStateException("bad event")) + else Result.success(TestEvent(raw)) + } + + data class TestEvent(val raw: String) + + private fun testScope() = CoroutineScope(SupervisorJob() + StandardTestDispatcher()) + + // ── Factory validation ─────────────────────────────────────────────────── + + @Test(expected = IllegalArgumentException::class) + fun `factory rejects aggregationThreshold = 0`() { + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = 0, + ) + } + + @Test(expected = IllegalArgumentException::class) + fun `factory rejects negative aggregationThreshold`() { + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = -1, + ) + } + + @Test(expected = IllegalArgumentException::class) + fun `factory rejects maxWindowMs = 0`() { + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + maxWindowMs = 0, + ) + } + + @Test(expected = IllegalArgumentException::class) + fun `factory rejects negative maxWindowMs`() { + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + maxWindowMs = -100, + ) + } + + @Test(expected = IllegalArgumentException::class) + fun `factory rejects dispatchQueueCapacity = 0`() { + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + dispatchQueueCapacity = 0, + ) + } + + @Test(expected = IllegalArgumentException::class) + fun `factory rejects negative dispatchQueueCapacity`() { + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + dispatchQueueCapacity = -1, + ) + } + + // ── Lifecycle edge cases ──────────────────────────────────────────────── + + @Test + fun `offer before start returns false`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + ), + ) + aggregator.onEvent {} + + // Not started — offer should fail + val accepted = aggregator.offer("type:a test") + assertTrue("offer should return false before start", !accepted) + + aggregator.stop() + } + + @Test + fun `double start is idempotent`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + val received = CopyOnWriteArrayList() + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = 50, + maxWindowMs = 200, + ), + ) + aggregator.onEvent { received += it } + + // Start twice — should not create duplicate workers + val first = aggregator.start() + val second = aggregator.start() + + assertTrue(first.isSuccess) + assertTrue(second.isSuccess) + + aggregator.offer("type:a event1") + advanceTimeBy(500) + advanceUntilIdle() + + // Should receive exactly 1 event, not duplicates + assertEquals(1, received.size) + + aggregator.stop() + } + + @Test + fun `handler exception on aggregated event does not break dispatcher`() { + val scope = CoroutineScope(SupervisorJob() + kotlinx.coroutines.Dispatchers.Default) + val received = CopyOnWriteArrayList() + val latch = java.util.concurrent.CountDownLatch(1) + var throwOnFirst = true + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = 3, + maxWindowMs = 100, + ), + ) + aggregator.onEvent { event -> + if (event is StreamAggregatedEvent<*> && throwOnFirst) { + throwOnFirst = false + throw RuntimeException("aggregated handler boom") + } + received += event + latch.countDown() + } + aggregator.start() + Thread.sleep(50) + + // First batch — triggers aggregation, handler throws + repeat(5) { aggregator.offer("type:a event$it") } + Thread.sleep(300) + + // Second batch — handler should still work + aggregator.offer("type:b after_error") + assertTrue( + "Events should still be delivered after handler error", + latch.await(5, java.util.concurrent.TimeUnit.SECONDS), + ) + assertTrue(received.isNotEmpty()) + + aggregator.stop() + scope.cancel() + } + + // ── Behavior tests ─────────────────────────────────────────────────────── + + @Test + fun `low traffic events are dispatched individually`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + val received = CopyOnWriteArrayList() + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = 50, + maxWindowMs = 200, + ), + ) + aggregator.onEvent { received += it } + aggregator.start() + + // Send 3 events — well below threshold + aggregator.offer("type:channel.updated event1") + aggregator.offer("type:message.new event2") + aggregator.offer("type:channel.updated event3") + + advanceTimeBy(500) + advanceUntilIdle() + + // Should receive individual TestEvent instances, not aggregated + assertEquals(3, received.size) + assertTrue(received.all { it is TestEvent }) + assertEquals("type:channel.updated event1", (received[0] as TestEvent).raw) + assertEquals("type:message.new event2", (received[1] as TestEvent).raw) + assertEquals("type:channel.updated event3", (received[2] as TestEvent).raw) + + aggregator.stop() + } + + @Test + fun `spike triggers aggregated delivery grouped by type`() { + // Use real dispatchers — TestDispatcher doesn't play well with Channel + withTimeoutOrNull + val scope = CoroutineScope(SupervisorJob() + kotlinx.coroutines.Dispatchers.Default) + val received = CopyOnWriteArrayList() + val latch = java.util.concurrent.CountDownLatch(1) + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = 5, // low threshold for testing + maxWindowMs = 500, + ), + ) + aggregator.onEvent { event -> + received += event + // Signal once we've received all events (individual or aggregated, total >= 10) + var count = 0 + for (item in received) { + when (item) { + is StreamAggregatedEvent<*> -> count += item.events.values.sumOf { it.size } + is TestEvent -> count++ + } + } + if (count >= 10) latch.countDown() + } + aggregator.start() + Thread.sleep(50) // let collector suspend on inbox.receive() + + // Buffer all 10 events into the inbox at once + repeat(10) { i -> + val type = if (i % 2 == 0) "channel.updated" else "message.new" + assertTrue(aggregator.offer("type:$type event$i")) + } + Thread.sleep(50) // let collector process + + // Wait for delivery (max 5s) + assertTrue( + "Events not delivered in time", + latch.await(5, java.util.concurrent.TimeUnit.SECONDS), + ) + + // Count total events delivered (individual + aggregated) + var totalEvents = 0 + val aggregated = mutableListOf>() + for (item in received) { + when (item) { + is StreamAggregatedEvent<*> -> { + aggregated += item + totalEvents += item.events.values.sumOf { it.size } + } + is TestEvent -> totalEvents++ + } + } + + // All 10 events should be delivered + assertEquals("All 10 events should be delivered", 10, totalEvents) + + // At least one aggregated event should exist (threshold = 5, we sent 10) + assertTrue( + "Expected aggregated event but got ${received.size} individual. " + + "Types: ${received.map { it::class.simpleName }}", + aggregated.isNotEmpty(), + ) + + // Both types should appear somewhere across all deliveries + val allTypes = aggregated.flatMap { it.events.keys } + assertTrue("channel.updated" in allTypes) + assertTrue("message.new" in allTypes) + + aggregator.stop() + scope.cancel() + } + + @Test + fun `maxWindow caps collection time`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + val received = CopyOnWriteArrayList() + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = 1000, // high threshold — won't be reached + maxWindowMs = 100, // short window + ), + ) + aggregator.onEvent { received += it } + aggregator.start() + + // Send a few events + aggregator.offer("type:a ev1") + aggregator.offer("type:a ev2") + + // After maxWindow, events should be delivered even though threshold not reached + advanceTimeBy(300) + advanceUntilIdle() + + assertTrue("Events should be delivered after maxWindow", received.isNotEmpty()) + + aggregator.stop() + } + + @Test + fun `deserialization failures are skipped and logged`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + val received = CopyOnWriteArrayList() + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = 50, + maxWindowMs = 200, + ), + ) + aggregator.onEvent { received += it } + aggregator.start() + + aggregator.offer("type:ok good_event") + aggregator.offer("BAD_EVENT") // will fail deserialization + aggregator.offer("type:ok another_good") + + advanceTimeBy(500) + advanceUntilIdle() + + // Only the 2 good events should be delivered + assertEquals(2, received.size) + assertTrue(received.all { it is TestEvent }) + + aggregator.stop() + } + + @Test + fun `stop is idempotent`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + ), + ) + aggregator.onEvent {} + aggregator.start() + + val first = aggregator.stop() + val second = aggregator.stop() + + assertTrue(first.isSuccess) + assertTrue(second.isSuccess) + } + + @Test + fun `stop then start resumes processing (restartable)`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + val received = CopyOnWriteArrayList() + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = 50, + maxWindowMs = 200, + ), + ) + aggregator.onEvent { received += it } + aggregator.start() + + aggregator.offer("type:a first") + advanceTimeBy(500) + advanceUntilIdle() + assertEquals(1, received.size) + + // Stop and restart + aggregator.stop() + aggregator.start() + + aggregator.offer("type:a second") + advanceTimeBy(500) + advanceUntilIdle() + + // Second event delivered after restart + assertEquals(2, received.size) + assertEquals("type:a second", (received[1] as TestEvent).raw) + + aggregator.stop() + } + + @Test + fun `all events delivered even when many arrive at once`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + val received = CopyOnWriteArrayList() + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = 5, + maxWindowMs = 100, + ), + ) + aggregator.onEvent { event -> received += event } + aggregator.start() + + // Pump 11 events rapidly + aggregator.offer("type:a first") + repeat(10) { i -> aggregator.offer("type:b spike$i") } + + // Let everything settle + advanceTimeBy(2000) + advanceUntilIdle() + + // Count total events delivered (individual + aggregated) + var totalEvents = 0 + for (item in received) { + when (item) { + is StreamAggregatedEvent<*> -> totalEvents += item.events.values.sumOf { it.size } + is TestEvent -> totalEvents++ + } + } + assertEquals("All 11 events should be delivered", 11, totalEvents) + + aggregator.stop() + } + + @Test + fun `null type from extractor uses empty string key`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + val received = CopyOnWriteArrayList() + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = { null }, // always returns null + deserializer = deserializer, + aggregationThreshold = 3, + maxWindowMs = 200, + ), + ) + aggregator.onEvent { received += it } + aggregator.start() + + repeat(5) { aggregator.offer("event$it") } + + advanceTimeBy(500) + advanceUntilIdle() + + val aggregated = received.filterIsInstance>() + assertTrue(aggregated.isNotEmpty()) + // All events grouped under empty key + assertTrue(aggregated.first().events.containsKey("")) + + aggregator.stop() + } + + // ── Edge cases: error resilience ───────────────────────────────────────── + + @Test + fun `typeExtractor that throws does not kill collector`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + val received = CopyOnWriteArrayList() + var callCount = 0 + + val throwingExtractor: (String) -> String? = { + callCount++ + if (callCount == 2) throw RuntimeException("extractor boom") + "safe" + } + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = throwingExtractor, + deserializer = deserializer, + aggregationThreshold = 3, + maxWindowMs = 200, + ), + ) + aggregator.onEvent { received += it } + aggregator.start() + + // Send enough to trigger aggregation — extractor throws on 2nd call + repeat(5) { aggregator.offer("type:a event$it") } + + advanceTimeBy(500) + advanceUntilIdle() + + // All 5 events should still be delivered (throwing one grouped under "") + var total = 0 + for (item in received) { + when (item) { + is StreamAggregatedEvent<*> -> total += item.events.values.sumOf { it.size } + is TestEvent -> total++ + } + } + assertEquals("All 5 events should be delivered despite extractor throw", 5, total) + + aggregator.stop() + } + + @Test + fun `deserializer that throws (not Result failure) does not kill collector`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + val received = CopyOnWriteArrayList() + + val throwingDeserializer: (String) -> Result = { raw -> + if (raw.contains("THROW")) throw RuntimeException("deserializer exploded") + Result.success(TestEvent(raw)) + } + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = throwingDeserializer, + aggregationThreshold = 50, + maxWindowMs = 200, + ), + ) + aggregator.onEvent { received += it } + aggregator.start() + + aggregator.offer("type:a good1") + aggregator.offer("type:a THROW") // throws, not Result.failure + aggregator.offer("type:a good2") + + advanceTimeBy(500) + advanceUntilIdle() + + // 2 good events delivered, throwing one skipped + assertEquals(2, received.size) + assertTrue(received.all { it is TestEvent }) + + // Verify collector still alive — send more events + aggregator.offer("type:a good3") + advanceTimeBy(500) + advanceUntilIdle() + + assertEquals(3, received.size) + + aggregator.stop() + } + + @Test + fun `handler exception does not break subsequent event delivery`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + val received = CopyOnWriteArrayList() + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = 50, + maxWindowMs = 100, + ), + ) + aggregator.onEvent { event -> + if (event is TestEvent && event.raw.contains("EXPLODE")) { + throw RuntimeException("handler boom") + } + received += event + } + aggregator.start() + + aggregator.offer("type:a good1") + advanceTimeBy(200) + advanceUntilIdle() + + aggregator.offer("type:a EXPLODE") // handler throws + advanceTimeBy(200) + advanceUntilIdle() + + aggregator.offer("type:a good2") // should still be delivered + advanceTimeBy(200) + advanceUntilIdle() + + assertEquals(2, received.size) + assertEquals("type:a good1", (received[0] as TestEvent).raw) + assertEquals("type:a good2", (received[1] as TestEvent).raw) + + aggregator.stop() + } + + // ── Edge cases: boundary conditions ────────────────────────────────────── + + @Test + fun `exactly threshold events triggers aggregation not individual`() { + val scope = CoroutineScope(SupervisorJob() + kotlinx.coroutines.Dispatchers.Default) + val received = CopyOnWriteArrayList() + val latch = java.util.concurrent.CountDownLatch(1) + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = 5, + maxWindowMs = 500, + ), + ) + aggregator.onEvent { event -> + received += event + var count = 0 + for (item in received) { + when (item) { + is StreamAggregatedEvent<*> -> count += item.events.values.sumOf { it.size } + is TestEvent -> count++ + } + } + if (count >= 5) latch.countDown() + } + aggregator.start() + Thread.sleep(50) + + // Exactly 5 events = threshold + repeat(5) { aggregator.offer("type:a event$it") } + + assertTrue( + "Events not delivered in time", + latch.await(5, java.util.concurrent.TimeUnit.SECONDS), + ) + + // With exactly threshold items, buffer.size == threshold, condition is + // `rawMessages.size < aggregationThreshold` = false → aggregated path + val aggregated = received.filterIsInstance>() + assertTrue( + "Exactly threshold should trigger aggregation, got: ${received.map { it::class.simpleName }}", + aggregated.isNotEmpty(), + ) + + aggregator.stop() + scope.cancel() + } + + @Test + fun `all events fail deserialization - batch silently dropped`() = runTest { + val scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)) + val received = CopyOnWriteArrayList() + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = { Result.failure(IllegalStateException("always fails")) }, + aggregationThreshold = 50, + maxWindowMs = 200, + ), + ) + aggregator.onEvent { received += it } + aggregator.start() + + repeat(5) { aggregator.offer("type:a event$it") } + + advanceTimeBy(500) + advanceUntilIdle() + + // Nothing delivered — all failed + assertTrue("No events should be delivered", received.isEmpty()) + + // Collector should still be alive — send a good event via a new aggregator + // (can't change deserializer on existing one, so just verify no crash) + aggregator.stop() + } + + @Test + fun `dispatch queue full logs warning and drops events`() { + val scope = CoroutineScope(SupervisorJob() + kotlinx.coroutines.Dispatchers.Default) + val received = CopyOnWriteArrayList() + val latch = java.util.concurrent.CountDownLatch(1) + + val warnings = CopyOnWriteArrayList() + val testLogger = + object : StreamLogger { + override fun log( + level: StreamLogger.LogLevel, + throwable: Throwable?, + message: () -> String, + ) { + if (level is StreamLogger.LogLevel.Warning) { + warnings.add(message()) + } + } + } + + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = 50, + maxWindowMs = 50, + dispatchQueueCapacity = 1, // tiny queue + ), + ) + (aggregator as StreamEventAggregatorImpl<*>).logger = testLogger + aggregator.onEvent { event -> + // Slow handler — holds dispatch queue slot + Thread.sleep(200) + received += event + latch.countDown() + } + aggregator.start() + Thread.sleep(50) + + // Flood events in rapid bursts — dispatcher is slow, queue will fill + repeat(5) { burst -> + repeat(3) { aggregator.offer("type:a burst${burst}_event$it") } + Thread.sleep(100) // let maxWindow fire between bursts + } + + // Wait for at least one delivery + assertTrue( + "At least one event should be delivered", + latch.await(5, java.util.concurrent.TimeUnit.SECONDS), + ) + + // Some events were delivered, some may have been dropped (queue full) + // The key assertion: no crash, aggregator survived + assertTrue("At least one event delivered", received.isNotEmpty()) + + // Verify that a warning was logged when the dispatch queue was full + assertTrue("Expected warning about full dispatch queue", warnings.isNotEmpty()) + + aggregator.stop() + scope.cancel() + } + + // ── Stress test ────────────────────────────────────────────────────────── + + @Test + fun `stress - 10K events with realistic type distribution`() { + val scope = CoroutineScope(SupervisorJob() + kotlinx.coroutines.Dispatchers.Default) + val totalEvents = 10_000 + val received = CopyOnWriteArrayList() + val latch = java.util.concurrent.CountDownLatch(1) + val eventTypes = + listOf( + "channel.updated", + "message.new", + "user.watching.start", + "user.presence_changed", + "typing.start", + ) + + val threshold = 50 + val aggregator = + StreamEventAggregator( + scope = scope, + policy = + StreamEventAggregationPolicy.from( + typeExtractor = typeExtractor, + deserializer = deserializer, + aggregationThreshold = threshold, + maxWindowMs = 500, + ), + ) + + val deliveredCount = java.util.concurrent.atomic.AtomicInteger(0) + aggregator.onEvent { event -> + received += event + val count = + when (event) { + is StreamAggregatedEvent<*> -> event.events.values.sumOf { it.size } + is TestEvent -> 1 + else -> 0 + } + if (deliveredCount.addAndGet(count) >= totalEvents) { + latch.countDown() + } + } + + val startNs = System.nanoTime() + aggregator.start() + // No sleep needed — UNLIMITED inbox buffers everything regardless of collector state + + // Flood 10K events with realistic type distribution + repeat(totalEvents) { i -> + val type = eventTypes[i % eventTypes.size] + assertTrue("offer failed for event $i", aggregator.offer("type:$type event$i")) + } + + assertTrue( + "10K events not delivered in time", + latch.await(60, java.util.concurrent.TimeUnit.SECONDS), + ) + val elapsedMs = (System.nanoTime() - startNs) / 1_000_000L + + // Count batches and individual events + var individualCount = 0 + var aggregatedBatches = 0 + var aggregatedEventCount = 0 + for (item in received) { + when (item) { + is StreamAggregatedEvent<*> -> { + aggregatedBatches++ + aggregatedEventCount += item.events.values.sumOf { it.size } + } + is TestEvent -> individualCount++ + } + } + val totalDelivered = individualCount + aggregatedEventCount + + // Print results for visibility + println("=== 10K Event Stress Test Results ===") + println("Total events offered: $totalEvents") + println("Total events delivered: $totalDelivered") + println("Individual dispatches: $individualCount") + println("Aggregated batches: $aggregatedBatches") + println("Events in aggregated: $aggregatedEventCount") + println("Total handler calls: ${received.size}") + println("Elapsed time: ${elapsedMs}ms") + println("====================================") + + // Guarantee 1: all events delivered + assertEquals("All events must be delivered", totalEvents, totalDelivered) + + // Guarantee 2: each aggregated batch has at most `threshold` events + for ((i, item) in received.withIndex()) { + if (item is StreamAggregatedEvent<*>) { + val batchSize = item.events.values.sumOf { it.size } + assertTrue( + "Batch $i has $batchSize events, exceeds threshold $threshold", + batchSize <= threshold, + ) + } + } + + // Guarantee 3: fewer handler calls than raw events (aggregation happened) + assertTrue( + "Expected fewer handler calls (${received.size}) than raw events ($totalEvents)", + received.size < totalEvents, + ) + + aggregator.stop() + scope.cancel() + } +} diff --git a/stream-android-core/src/test/java/io/getstream/android/core/internal/socket/StreamSocketSessionTest.kt b/stream-android-core/src/test/java/io/getstream/android/core/internal/socket/StreamSocketSessionTest.kt index 73668fd..963fe8e 100644 --- a/stream-android-core/src/test/java/io/getstream/android/core/internal/socket/StreamSocketSessionTest.kt +++ b/stream-android-core/src/test/java/io/getstream/android/core/internal/socket/StreamSocketSessionTest.kt @@ -25,7 +25,8 @@ import io.getstream.android.core.api.model.event.StreamClientWsEvent import io.getstream.android.core.api.model.exceptions.StreamEndpointErrorData import io.getstream.android.core.api.model.exceptions.StreamEndpointException import io.getstream.android.core.api.model.value.StreamWsUrl -import io.getstream.android.core.api.processing.StreamBatcher +import io.getstream.android.core.api.processing.StreamAggregatedEvent +import io.getstream.android.core.api.processing.StreamEventAggregator import io.getstream.android.core.api.serialization.StreamJsonSerialization import io.getstream.android.core.api.socket.StreamWebSocket import io.getstream.android.core.api.socket.listeners.StreamClientListener @@ -70,7 +71,7 @@ class StreamSocketSessionTest { private lateinit var json: StreamJsonSerialization private lateinit var parser: StreamCompositeEventSerializationImpl private lateinit var health: StreamHealthMonitor - private lateinit var debounce: StreamBatcher + private lateinit var aggregator: StreamEventAggregator private lateinit var subs: StreamSubscriptionManager private lateinit var session: StreamSocketSession @@ -91,7 +92,7 @@ class StreamSocketSessionTest { json = mockk(relaxed = true) parser = mockk(relaxed = true) health = mockk(relaxed = true) - debounce = mockk(relaxed = true) + aggregator = mockk(relaxed = true) subs = mockk(relaxed = true) // default: route notifications to a listener so we can assert state @@ -104,7 +105,7 @@ class StreamSocketSessionTest { } every { socket.close(any(), any()) } returns Result.success(Unit) - every { debounce.stop() } returns Result.success(Unit) + every { aggregator.stop() } returns Result.success(Unit) every { health.stop() } returns Result.success(Unit) session = @@ -115,7 +116,7 @@ class StreamSocketSessionTest { jsonSerialization = json, eventParser = parser, healthMonitor = health, - batcher = debounce, + aggregator = aggregator, subscriptionManager = subs, products = listOf("feeds"), ) @@ -140,7 +141,7 @@ class StreamSocketSessionTest { verify { socket.close(SocketConstants.CLOSE_SOCKET_CODE, SocketConstants.CLOSE_SOCKET_REASON) } - verify { debounce.stop() } + verify { aggregator.stop() } verify { health.stop() } } @@ -164,7 +165,7 @@ class StreamSocketSessionTest { verify { socket.close(SocketConstants.CLOSE_SOCKET_CODE, SocketConstants.CLOSE_SOCKET_REASON) } - verify { debounce.stop() } + verify { aggregator.stop() } verify { health.stop() } } @@ -177,7 +178,7 @@ class StreamSocketSessionTest { assertTrue(result.isSuccess) verify { socket.close(customCode, customReason) } - verify { debounce.stop() } + verify { aggregator.stop() } verify { health.stop() } } @@ -199,7 +200,7 @@ class StreamSocketSessionTest { verify(exactly = 2) { socket.close(SocketConstants.CLOSE_SOCKET_CODE, SocketConstants.CLOSE_SOCKET_REASON) } - verify(atLeast = 1) { debounce.stop() } + verify(atLeast = 1) { aggregator.stop() } verify(atLeast = 1) { health.stop() } } @@ -212,7 +213,7 @@ class StreamSocketSessionTest { assertTrue(res.isSuccess) verify { socket.close(any(), any()) } - verify { debounce.stop() } + verify { aggregator.stop() } verify { health.stop() } verify { logger.e(notifyFailure, any()) } } @@ -238,7 +239,7 @@ class StreamSocketSessionTest { verify { clientListener.onState(StreamConnectionState.Disconnected()) } verify { health.stop() } - verify { debounce.stop() } + verify { aggregator.stop() } } @Test @@ -261,7 +262,7 @@ class StreamSocketSessionTest { verify { client.onState(StreamConnectionState.Disconnected()) } verify { health.stop() } - verify { debounce.stop() } + verify { aggregator.stop() } } @Test @@ -273,7 +274,7 @@ class StreamSocketSessionTest { Result.success(Unit) } - every { debounce.offer(any()) } returns false + every { aggregator.offer(any()) } returns false every { socket.close(any(), any()) } returns Result.success(Unit) val f = @@ -287,7 +288,7 @@ class StreamSocketSessionTest { verify { health.acknowledgeHeartbeat() } verify { client.onState(match { it is StreamConnectionState.Disconnected }) } verify { health.stop() } - verify { debounce.stop() } + verify { aggregator.stop() } verify { socket.close(any(), any()) } } @@ -299,7 +300,7 @@ class StreamSocketSessionTest { firstArg<(StreamClientListener) -> Unit>().invoke(client) Result.success(Unit) } - every { debounce.offer(any()) } returns true + every { aggregator.offer(any()) } returns true every { socket.close(any(), any()) } returns Result.success(Unit) val f = @@ -323,7 +324,7 @@ class StreamSocketSessionTest { firstArg<(StreamClientListener) -> Unit>().invoke(client) Result.success(Unit) } - every { debounce.offer(any()) } returns false + every { aggregator.offer(any()) } returns false every { socket.close(any(), any()) } returns Result.success(Unit) val f = @@ -337,7 +338,7 @@ class StreamSocketSessionTest { verify { socket.close(any(), any()) } verify { client.onState(any()) } verify { health.stop() } - verify { debounce.stop() } + verify { aggregator.stop() } } @Test @@ -360,7 +361,7 @@ class StreamSocketSessionTest { verify { client.onState(StreamConnectionState.Disconnected()) } verify { health.stop() } - verify { debounce.stop() } + verify { aggregator.stop() } } @Test @@ -382,7 +383,7 @@ class StreamSocketSessionTest { verify { client.onState(ofType()) } verify { health.stop() } - verify { debounce.stop() } + verify { aggregator.stop() } } @Test @@ -410,7 +411,7 @@ class StreamSocketSessionTest { verify(exactly = 0) { client.onState(any()) } verify { health.stop() } - verify { debounce.stop() } + verify { aggregator.stop() } } @Test @@ -430,7 +431,7 @@ class StreamSocketSessionTest { verify { client.onState(ofType()) } verify { socket.close(any(), any()) } verify { health.stop() } - verify { debounce.stop() } + verify { aggregator.stop() } } @Test @@ -448,7 +449,7 @@ class StreamSocketSessionTest { assertTrue(res.isFailure) verify { client.onState(StreamConnectionState.Disconnected()) } verify { health.stop() } - verify { debounce.stop() } + verify { aggregator.stop() } } @Test @@ -616,7 +617,7 @@ class StreamSocketSessionTest { verify(atLeast = 1) { sub.cancel() } verify { socket.close(any(), any()) } verify { health.stop() } - verify { debounce.stop() } + verify { aggregator.stop() } } @Test @@ -763,17 +764,17 @@ class StreamSocketSessionTest { verify(exactly = 1) { socket.close(any(), any()) } verify(exactly = 2) { hsSub.cancel() } verify(exactly = 1) { health.stop() } - verify(exactly = 1) { debounce.stop() } + verify(exactly = 1) { aggregator.stop() } assertTrue(emittedStates.any { it is StreamConnectionState.Disconnected }) job.cancelAndJoin() } @Test - fun `onBatch forwards non-health events, ignores health, and emits Disconnected_Error on connection error`() = + fun `onEvent forwards non-health events, ignores health, and emits Disconnected_Error on connection error`() = runTest { - var onBatchCb: (suspend (List, Long, Int) -> Unit)? = null - every { debounce.onBatch(any()) } answers { onBatchCb = arg(0) } + var onEventCb: (suspend (Any) -> Unit)? = null + every { aggregator.onEvent(any()) } answers { onEventCb = arg(0) } every { health.onHeartbeat(any()) } just Runs every { health.onUnhealthy(any()) } just Runs @@ -809,12 +810,6 @@ class StreamSocketSessionTest { mockk(relaxed = true).also { every { it.error } returns mockk(relaxed = true) } - every { parser.deserialize("E1") } returns - Result.success(StreamCompositeSerializationEvent.internal(normalEvent)) - every { parser.deserialize("H") } returns - Result.success(StreamCompositeSerializationEvent.internal(healthEvent)) - every { parser.deserialize("ERR") } returns - Result.success(StreamCompositeSerializationEvent.internal(errorEvent)) val job = async { session.connect( @@ -831,10 +826,14 @@ class StreamSocketSessionTest { } advanceUntilIdle() - val cb = requireNotNull(onBatchCb) { "onBatch not registered" } - cb.invoke(listOf("E1", "H", "ERR"), 100L, 3) + val cb = requireNotNull(onEventCb) { "onEvent not registered" } + // Simulate individual event delivery (low traffic path) + cb.invoke(StreamCompositeSerializationEvent.internal(normalEvent)) + cb.invoke(StreamCompositeSerializationEvent.internal(healthEvent)) + cb.invoke(StreamCompositeSerializationEvent.internal(errorEvent)) advanceUntilIdle() + // normalEvent + errorEvent dispatched; healthEvent filtered assertEquals(2, seenEvents.size) assertTrue(seenEvents.contains(normalEvent)) assertTrue(seenEvents.contains(errorEvent)) @@ -844,15 +843,15 @@ class StreamSocketSessionTest { } @Test - fun `onBatch - deserialize fails then fallback parses api error and emits Disconnected_Error`() = + fun `onEvent aggregated - core events handled individually, product events grouped`() = runTest { - var onBatchCb: (suspend (List, Long, Int) -> Unit)? = null - every { debounce.onBatch(any()) } answers { onBatchCb = arg(0) } + var onEventCb: (suspend (Any) -> Unit)? = null + every { aggregator.onEvent(any()) } answers { onEventCb = arg(0) } every { health.onHeartbeat(any()) } just Runs every { health.onUnhealthy(any()) } just Runs - val seenStates = mutableListOf() val seenEvents = mutableListOf() + val seenStates = mutableListOf() every { subs.forEach(any()) } answers { val consumer = arg<(StreamClientListener) -> Unit>(0) @@ -877,17 +876,16 @@ class StreamSocketSessionTest { every { socket.open(config) } returns Result.success(Unit) every { socket.close(any(), any()) } returns Result.success(Unit) - val apiError = mockk(relaxed = true) - every { parser.deserialize("BAD_JSON") } returns - Result.failure(IllegalStateException("boom")) - every { json.fromJson("BAD_JSON", StreamEndpointErrorData::class.java) } returns - Result.success(apiError) + val errorEvent = + mockk(relaxed = true).also { + every { it.error } returns mockk(relaxed = true) + } val job = async { session.connect( ConnectUserData( - userId = "u", - token = "t", + userId = "user-1", + token = "tok", image = null, invisible = false, language = null, @@ -898,13 +896,34 @@ class StreamSocketSessionTest { } advanceUntilIdle() - val cb = requireNotNull(onBatchCb) { "onBatch not registered" } - cb.invoke(listOf("BAD_JSON"), 100L, 1) + val cb = requireNotNull(onEventCb) { "onEvent not registered" } + // Simulate aggregated event delivery (spike path) + val aggregatedEvent = + StreamAggregatedEvent( + mapOf( + "connection.error" to + listOf(StreamCompositeSerializationEvent.internal(errorEvent)), + "channel.updated" to + listOf( + StreamCompositeSerializationEvent.external("product1"), + StreamCompositeSerializationEvent.external("product2"), + ), + ) + ) + cb.invoke(aggregatedEvent) advanceUntilIdle() - assertTrue(seenEvents.isEmpty()) + // Connection error causes Disconnected state assertTrue(seenStates.any { it is StreamConnectionState.Disconnected }) - verify { json.fromJson("BAD_JSON", StreamEndpointErrorData::class.java) } + // Core error event dispatched individually + product aggregated event dispatched + assertTrue(seenEvents.any { it is StreamClientConnectionErrorEvent }) + val forwardedAggregate = + seenEvents.filterIsInstance>().single() + assertFalse(forwardedAggregate.events.containsKey("connection.error")) + assertEquals( + listOf("product1", "product2"), + forwardedAggregate.events["channel.updated"], + ) job.cancelAndJoin() } @@ -1423,7 +1442,7 @@ class StreamSocketSessionTest { fun `handshake buffers non-auth message and replays it exactly once after eventListener installed`() = runTest { val offeredMessages = mutableListOf() - every { debounce.offer(any()) } answers + every { aggregator.offer(any()) } answers { offeredMessages.add(firstArg()) true @@ -1431,7 +1450,7 @@ class StreamSocketSessionTest { every { health.onHeartbeat(any()) } just Runs every { health.onUnhealthy(any()) } just Runs - every { debounce.onBatch(any()) } just Runs + every { aggregator.onEvent(any()) } just Runs val hsSub = mockk(relaxed = true) val eventSub = mockk(relaxed = true) @@ -1500,7 +1519,7 @@ class StreamSocketSessionTest { fun `handshake acknowledges heartbeat for all messages including non-auth`() = runTest { every { health.onHeartbeat(any()) } just Runs every { health.onUnhealthy(any()) } just Runs - every { debounce.onBatch(any()) } just Runs + every { aggregator.onEvent(any()) } just Runs var hsListener: StreamWebSocketListener? = null every { socket.subscribe(any(), any()) } answers @@ -1532,7 +1551,7 @@ class StreamSocketSessionTest { fun `connect fails when eventListener subscribe fails after handshake`() = runTest { every { health.onHeartbeat(any()) } just Runs every { health.onUnhealthy(any()) } just Runs - every { debounce.onBatch(any()) } just Runs + every { aggregator.onEvent(any()) } just Runs val hsSub = mockk(relaxed = true) var hsListener: StreamWebSocketListener? = null @@ -1572,8 +1591,8 @@ class StreamSocketSessionTest { fun `connect fails when buffered message replay fails`() = runTest { every { health.onHeartbeat(any()) } just Runs every { health.onUnhealthy(any()) } just Runs - every { debounce.onBatch(any()) } just Runs - every { debounce.offer(any()) } returns false // replay will fail + every { aggregator.onEvent(any()) } just Runs + every { aggregator.offer(any()) } returns false // replay will fail val hsSub = mockk(relaxed = true) val eventSub = mockk(relaxed = true)