diff --git a/core/src/main/scala/org/typelevel/keypool/KeyPool.scala b/core/src/main/scala/org/typelevel/keypool/KeyPool.scala index fcad9842..aa550f27 100644 --- a/core/src/main/scala/org/typelevel/keypool/KeyPool.scala +++ b/core/src/main/scala/org/typelevel/keypool/KeyPool.scala @@ -95,13 +95,14 @@ object KeyPool { m <- kpVar.getAndSet(PoolMap.closed[A, (B, F[Unit])]) _ <- m match { case PoolClosed() => Applicative[F].unit - case PoolOpen(_, m2) => - m2.toList.traverse_ { case (_, pl) => - pl.toList - .traverse_ { case (_, r) => - r._2.attempt.void - } - } + case PoolOpen(_, borrowed, m2) => + borrowed.toList.traverse { case (_, (_, destroy)) => destroy.attempt.void } >> + m2.toList.traverse_ { case (_, pl) => + pl.toList + .traverse_ { case (_, r) => + r._2.attempt.void + } + } } } yield () @@ -120,7 +121,7 @@ object KeyPool { now: FiniteDuration, idleCount: Int, m: Map[A, PoolList[(B, F[Unit])]] - ): (PoolMap[A, (B, F[Unit])], List[(A, (B, F[Unit]))]) = { + ): (Int, Map[A, PoolList[(B, F[Unit])]], List[(A, (B, F[Unit]))]) = { val isNotStale: FiniteDuration => Boolean = time => time + idleTimeAllowedInPoolNanos >= now // Time value is alright inside the KeyPool in nanos. @@ -155,7 +156,7 @@ object KeyPool { // May be able to use Span eventually val (toKeep, toDestroy) = findStale_(identity, identity, m.toList) val idleCount_ = idleCount - toDestroy.length - (PoolMap.open(idleCount_, toKeep), toDestroy) + (idleCount_, toKeep, toDestroy) } val sleep = Temporal[F].sleep(5.seconds).void @@ -166,13 +167,13 @@ object KeyPool { _ <- { kpVar.tryModify { case p @ PoolClosed() => (p, F.unit) - case p @ PoolOpen(idleCount, m) => + case p @ PoolOpen(idleCount, borrowed, m) => if (m.isEmpty) (p, F.unit) // Not worth it to introduce deadlock concerns when hot loop is 5 seconds else { - val (m_, toDestroy) = findStale(now, idleCount, m) + val (idleCount_, toKeep, toDestroy) = findStale(now, idleCount, m) ( - m_, + PoolMap.open(idleCount_, borrowed, toKeep), toDestroy.traverse_(_._2._2).attempt.flatMap { case Left(t) => onReaperException(t) // .handleErrorWith(t => F.delay(t.printStackTrace())) // CHEATING? @@ -196,7 +197,7 @@ object KeyPool { kpVar.get.map(pm => pm match { case PoolClosed() => (0, Map.empty) - case PoolOpen(idleCount, m) => + case PoolOpen(idleCount, _, m) => val modified = m.map { case (k, pl) => pl match { case One(_, _) => (k, 1) @@ -231,18 +232,18 @@ object KeyPool { def go(now: FiniteDuration, pc: PoolMap[A, (B, F[Unit])]): (PoolMap[A, (B, F[Unit])], F[Unit]) = pc match { case p @ PoolClosed() => (p, destroy) - case p @ PoolOpen(idleCount, m) => + case p @ PoolOpen(idleCount, borrowed, m) => if (idleCount > kp.kpMaxTotal) (p, destroy) else m.get(k) match { case None => val cnt_ = idleCount + 1 - val m_ = PoolMap.open(cnt_, m + (k -> One((r, destroy), now))) + val m_ = PoolMap.open(cnt_, borrowed, m + (k -> One((r, destroy), now))) (m_, Applicative[F].pure(())) case Some(l) => val (l_, mx) = addToList(now, kp.kpMaxPerKey(k), (r, destroy), l) val cnt_ = idleCount + mx.fold(1)(_ => 0) - val m_ = PoolMap.open(cnt_, m + (k -> l_)) + val m_ = PoolMap.open(cnt_, borrowed, m + (k -> l_)) (m_, mx.fold(Applicative[F].unit)(_ => destroy)) } } @@ -259,28 +260,43 @@ object KeyPool { def go(pm: PoolMap[A, (B, F[Unit])]): (PoolMap[A, (B, F[Unit])], Option[(B, F[Unit])]) = pm match { case p @ PoolClosed() => (p, None) - case pOrig @ PoolOpen(idleCount, m) => + case pOrig @ PoolOpen(idleCount, borrowed, m) => m.get(k) match { case None => (pOrig, None) case Some(One(a, _)) => - (PoolMap.open(idleCount - 1, m - (k)), Some(a)) + (PoolMap.open(idleCount - 1, borrowed, m - (k)), Some(a)) case Some(Cons(a, _, _, rest)) => - (PoolMap.open(idleCount - 1, m + (k -> rest)), Some(a)) + (PoolMap.open(idleCount - 1, borrowed, m + (k -> rest)), Some(a)) } } + def updateBorrowed( + pm: PoolMap[A, (B, F[Unit])], + update: Map[Unique.Token, (B, F[Unit])] => Map[Unique.Token, (B, F[Unit])] + ): PoolMap[A, (B, F[Unit])] = + pm match { + case p @ PoolClosed() => p + case PoolOpen(idleCount, borrowed, m) => + PoolMap.open(idleCount, update(borrowed), m) + } + for { optR <- Resource.eval(kp.kpVar.modify(go)) releasedState <- Resource.eval(Ref[F].of[Reusable](kp.kpDefaultReuseState)) - resource <- Resource.make(optR.fold(kp.kpRes(k).allocated)(r => Applicative[F].pure(r))) { - resource => - for { - reusable <- releasedState.get - out <- reusable match { - case Reusable.Reuse => put(kp, k, resource._1, resource._2).attempt.void - case Reusable.DontReuse => resource._2.attempt.void - } - } yield out + token <- Resource.eval(Temporal[F].unique) + resource <- Resource.make { + optR + .fold(kp.kpRes(k).allocated)(Applicative[F].pure) + .flatTap(r => kp.kpVar.update(pm => updateBorrowed(pm, _ + (token -> r)))) + } { resource => + for { + reusable <- releasedState.get + out <- reusable match { + case Reusable.Reuse => put(kp, k, resource._1, resource._2).attempt.void + case Reusable.DontReuse => resource._2.attempt.void + } + _ <- kp.kpVar.update(pm => updateBorrowed(pm, _ - token)) + } yield out } } yield new Managed(resource._1, optR.isDefined, releasedState) } @@ -337,7 +353,9 @@ object KeyPool { fa.onError { case e => onReaperException(e) }.attempt >> keepRunning(fa) for { kpVar <- Resource.make( - Ref[F].of[PoolMap[A, (B, F[Unit])]](PoolMap.open(0, Map.empty[A, PoolList[(B, F[Unit])]])) + Ref[F].of[PoolMap[A, (B, F[Unit])]]( + PoolMap.open(0, Map.empty, Map.empty[A, PoolList[(B, F[Unit])]]) + ) )(kpVar => KeyPool.destroy(kpVar)) _ <- idleTimeAllowedInPool match { case fd: FiniteDuration => diff --git a/core/src/main/scala/org/typelevel/keypool/KeyPoolBuilder.scala b/core/src/main/scala/org/typelevel/keypool/KeyPoolBuilder.scala index 2607c47d..4b94b1fa 100644 --- a/core/src/main/scala/org/typelevel/keypool/KeyPoolBuilder.scala +++ b/core/src/main/scala/org/typelevel/keypool/KeyPoolBuilder.scala @@ -60,7 +60,9 @@ final class KeyPoolBuilder[F[_]: Temporal, A, B] private ( fa.onError { case e => onReaperException(e) }.attempt >> keepRunning(fa) for { kpVar <- Resource.make( - Ref[F].of[PoolMap[A, (B, F[Unit])]](PoolMap.open(0, Map.empty[A, PoolList[(B, F[Unit])]])) + Ref[F].of[PoolMap[A, (B, F[Unit])]]( + PoolMap.open(0, Map.empty, Map.empty[A, PoolList[(B, F[Unit])]]) + ) )(kpVar => KeyPool.destroy(kpVar)) _ <- idleTimeAllowedInPool match { case fd: FiniteDuration => diff --git a/core/src/main/scala/org/typelevel/keypool/internal/PoolMap.scala b/core/src/main/scala/org/typelevel/keypool/internal/PoolMap.scala index 1f330a46..690c4795 100644 --- a/core/src/main/scala/org/typelevel/keypool/internal/PoolMap.scala +++ b/core/src/main/scala/org/typelevel/keypool/internal/PoolMap.scala @@ -1,16 +1,17 @@ package org.typelevel.keypool.internal import cats._ +import cats.effect.kernel.Unique import cats.syntax.all._ private[keypool] sealed trait PoolMap[Key, Rezource] extends Product with Serializable { def foldLeft[B](b: B)(f: (B, Rezource) => B): B = this match { case PoolClosed() => b - case PoolOpen(_, m) => m.foldLeft(b) { case (b, (_, pl)) => pl.foldLeft(b)(f) } + case PoolOpen(_, _, m) => m.foldLeft(b) { case (b, (_, pl)) => pl.foldLeft(b)(f) } } def foldRight[B](lb: Eval[B])(f: (Rezource, Eval[B]) => Eval[B]): Eval[B] = this match { case PoolClosed() => lb - case PoolOpen(_, m) => + case PoolOpen(_, _, m) => Foldable.iterateRight(m.values, lb) { case (pl, b) => pl.foldRight(b)(f) } } } @@ -21,11 +22,13 @@ private[keypool] object PoolMap { fa.foldRight(lb)(f) } def closed[K, R]: PoolMap[K, R] = PoolClosed() - def open[K, R](n: Int, m: Map[K, PoolList[R]]): PoolMap[K, R] = PoolOpen(n, m) + def open[K, R](n: Int, borrowed: Map[Unique.Token, R], m: Map[K, PoolList[R]]): PoolMap[K, R] = + PoolOpen(n, borrowed, m) } private[keypool] final case class PoolClosed[Key, Rezource]() extends PoolMap[Key, Rezource] private[keypool] final case class PoolOpen[Key, Rezource]( idleCount: Int, + borrowed: Map[Unique.Token, Rezource], m: Map[Key, PoolList[Rezource]] ) extends PoolMap[Key, Rezource] diff --git a/core/src/test/scala/org/typelevel/keypool/KeyPoolSpec.scala b/core/src/test/scala/org/typelevel/keypool/KeyPoolSpec.scala index 7cb97046..1883430f 100644 --- a/core/src/test/scala/org/typelevel/keypool/KeyPoolSpec.scala +++ b/core/src/test/scala/org/typelevel/keypool/KeyPoolSpec.scala @@ -133,4 +133,44 @@ class KeypoolSpec extends CatsEffectSuite { } yield assert(init === 1 && later === 1) } } + + // see https://github.com/typelevel/keypool/issues/291 + test("Borrowed Resource destroyed during cleanup") { + val pool = KeyPool + .Builder( + (_: Int) => IO.ref(true), + (r: Ref[IO, Boolean]) => r.set(false) + ) + .build + + def escapedResource(outerAwait: Deferred[IO, Unit]): IO[FiberIO[(Boolean, Boolean)]] = + pool.use { p => + def job(innerAwait: Deferred[IO, Unit]): IO[(Boolean, Boolean)] = + p.take(1).use { managed => + val value = managed.value + + for { + status1 <- value.get + _ <- innerAwait.complete(()) + _ <- outerAwait.get + status2 <- value.get + } yield (status1, status2) + } + + for { + await <- IO.deferred[Unit] + fiber <- job(await).start + _ <- await.get // make sure the first part happens inside of the open pool + } yield fiber + } + + for { + await <- IO.deferred[Unit] + fiber <- escapedResource(await) + _ <- await.complete(()) // pool is closed and a fiber can continue its execution + outcome <- fiber.join + result <- outcome.embedNever + (status1, status2) = result + } yield assert(status1 === true && status2 === false) + } } diff --git a/core/src/test/scala/org/typelevel/keypool/PoolSpec.scala b/core/src/test/scala/org/typelevel/keypool/PoolSpec.scala index ba6c5a76..3d41a7cb 100644 --- a/core/src/test/scala/org/typelevel/keypool/PoolSpec.scala +++ b/core/src/test/scala/org/typelevel/keypool/PoolSpec.scala @@ -115,6 +115,46 @@ class PoolSpec extends CatsEffectSuite { } } + // see https://github.com/typelevel/keypool/issues/291 + test("Borrowed Resource destroyed during cleanup") { + val pool = Pool + .Builder( + IO.ref(true), + (r: Ref[IO, Boolean]) => r.set(false) + ) + .build + + def escapedResource(outerAwait: Deferred[IO, Unit]): IO[FiberIO[(Boolean, Boolean)]] = + pool.use { p => + def job(innerAwait: Deferred[IO, Unit]): IO[(Boolean, Boolean)] = + p.take.use { managed => + val value = managed.value + + for { + status1 <- value.get + _ <- innerAwait.complete(()) + _ <- outerAwait.get + status2 <- value.get + } yield (status1, status2) + } + + for { + await <- IO.deferred[Unit] + fiber <- job(await).start + _ <- await.get // make sure the first part happens inside of the open pool + } yield fiber + } + + for { + await <- IO.deferred[Unit] + fiber <- escapedResource(await) + _ <- await.complete(()) // pool is closed and a fiber can continue its execution + outcome <- fiber.join + result <- outcome.embedNever + (status1, status2) = result + } yield assert(status1 === true && status2 === false) + } + private def nothing(ref: Ref[IO, Int]): IO[Unit] = ref.get.void }