Skip to content

Commit 7833e2f

Browse files
aokolnychyidongjoon-hyun
authored andcommitted
[SPARK-54424][SQL] Failures during recaching must not fail operations
### What changes were proposed in this pull request? This PR prevents failures during recaching failing write/refresh operations. ### Why are the changes needed? After recent changes in SPARK-54387, we may now mark write operations as failed even though they successfully committed to the table but the cache refresh was unsuccessful. ### Does this PR introduce _any_ user-facing change? Yes, `recacheByXXX` will no longer throw an exception if recaching fails. ### How was this patch tested? This PR comes with tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #53143 from aokolnychyi/spark-54424. Authored-by: Anton Okolnychyi <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 8f6c6d6 commit 7833e2f

File tree

5 files changed

+233
-45
lines changed

5 files changed

+233
-45
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import scala.util.control.NonFatal
21+
2022
import org.apache.hadoop.fs.{FileSystem, Path}
2123

2224
import org.apache.spark.internal.{Logging, MessageWithContext}
@@ -374,25 +376,68 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
374376
}
375377
needToRecache.foreach { cd =>
376378
cd.cachedRepresentation.cacheBuilder.clearCache()
377-
val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark)
378-
val (newKey, newCache) = sessionWithConfigsOff.withActive {
379-
val refreshedPlan = V2TableRefreshUtil.refresh(sessionWithConfigsOff, cd.plan)
380-
val qe = sessionWithConfigsOff.sessionState.executePlan(refreshedPlan)
381-
qe.normalized -> InMemoryRelation(cd.cachedRepresentation.cacheBuilder, qe)
382-
}
383-
val recomputedPlan = cd.copy(plan = newKey, cachedRepresentation = newCache)
384-
this.synchronized {
385-
if (lookupCachedDataInternal(recomputedPlan.plan).nonEmpty) {
386-
logWarning("While recaching, data was already added to cache.")
387-
} else {
388-
cachedData = recomputedPlan +: cachedData
389-
CacheManager.logCacheOperation(log"Re-cached Dataframe cache entry:" +
390-
log"${MDC(DATAFRAME_CACHE_ENTRY, recomputedPlan)}")
379+
tryRebuildCacheEntry(spark, cd).foreach { entry =>
380+
this.synchronized {
381+
if (lookupCachedDataInternal(entry.plan).nonEmpty) {
382+
logWarning("While recaching, data was already added to cache.")
383+
} else {
384+
cachedData = entry +: cachedData
385+
CacheManager.logCacheOperation(log"Re-cached Dataframe cache entry:" +
386+
log"${MDC(DATAFRAME_CACHE_ENTRY, entry)}")
387+
}
391388
}
392389
}
393390
}
394391
}
395392

393+
private def tryRebuildCacheEntry(spark: SparkSession, cd: CachedData): Option[CachedData] = {
394+
val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark)
395+
sessionWithConfigsOff.withActive {
396+
tryRefreshPlan(sessionWithConfigsOff, cd.plan).map { refreshedPlan =>
397+
val qe = QueryExecution.create(
398+
sessionWithConfigsOff,
399+
refreshedPlan,
400+
refreshPhaseEnabled = false)
401+
val newKey = qe.normalized
402+
val newCache = InMemoryRelation(cd.cachedRepresentation.cacheBuilder, qe)
403+
cd.copy(plan = newKey, cachedRepresentation = newCache)
404+
}
405+
}
406+
}
407+
408+
/**
409+
* Attempts to refresh table metadata loaded through the catalog.
410+
*
411+
* If the table state is cached (e.g., via `CACHE TABLE t`), the relation is replaced with
412+
* updated metadata as long as the table ID still matches, ensuring that all schema changes
413+
* are reflected. Otherwise, a new plan is produced using refreshed table metadata but
414+
* retaining the original schema, provided the schema changes are still compatible with the
415+
* query (e.g., adding new columns should be acceptable).
416+
*
417+
* Note this logic applies only to V2 tables at the moment.
418+
*
419+
* @return the refreshed plan if refresh succeeds, None otherwise
420+
*/
421+
private def tryRefreshPlan(spark: SparkSession, plan: LogicalPlan): Option[LogicalPlan] = {
422+
try {
423+
EliminateSubqueryAliases(plan) match {
424+
case r @ ExtractV2CatalogAndIdentifier(catalog, ident) if r.timeTravelSpec.isEmpty =>
425+
val table = catalog.loadTable(ident)
426+
if (r.table.id == table.id) {
427+
Some(DataSourceV2Relation.create(table, Some(catalog), Some(ident)))
428+
} else {
429+
None
430+
}
431+
case _ =>
432+
Some(V2TableRefreshUtil.refresh(spark, plan))
433+
}
434+
} catch {
435+
case NonFatal(e) =>
436+
logWarning(log"Failed to refresh plan while attempting to recache", e)
437+
None
438+
}
439+
}
440+
396441
private[sql] def lookupCachedTable(
397442
name: Seq[String],
398443
resolver: Resolver): Option[LogicalPlan] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ class QueryExecution(
6666
val logical: LogicalPlan,
6767
val tracker: QueryPlanningTracker = new QueryPlanningTracker,
6868
val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL,
69-
val shuffleCleanupMode: ShuffleCleanupMode = DoNotCleanup) extends Logging {
69+
val shuffleCleanupMode: ShuffleCleanupMode = DoNotCleanup,
70+
val refreshPhaseEnabled: Boolean = true) extends Logging {
7071

7172
val id: Long = QueryExecution.nextExecutionId
7273

@@ -178,7 +179,7 @@ class QueryExecution(
178179
// for eagerly executed commands we mark this place as beginning of execution.
179180
tracker.setReadyForExecution()
180181
val qe = new QueryExecution(sparkSession, p, mode = mode,
181-
shuffleCleanupMode = shuffleCleanupMode)
182+
shuffleCleanupMode = shuffleCleanupMode, refreshPhaseEnabled = refreshPhaseEnabled)
182183
val result = QueryExecution.withInternalError(s"Eagerly executed $name failed.") {
183184
SQLExecution.withNewExecutionId(qe, Some(name)) {
184185
qe.executedPlan.executeCollect()
@@ -207,7 +208,11 @@ class QueryExecution(
207208
// there may be delay between analysis and subsequent phases
208209
// therefore, refresh captured table versions to reflect latest data
209210
private val lazyTableVersionsRefreshed = LazyTry {
210-
V2TableRefreshUtil.refresh(sparkSession, commandExecuted, versionedOnly = true)
211+
if (refreshPhaseEnabled) {
212+
V2TableRefreshUtil.refresh(sparkSession, commandExecuted, versionedOnly = true)
213+
} else {
214+
commandExecuted
215+
}
211216
}
212217

213218
private[sql] def tableVersionsRefreshed: LogicalPlan = lazyTableVersionsRefreshed.get
@@ -569,6 +574,18 @@ object QueryExecution {
569574

570575
private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
571576

577+
private[execution] def create(
578+
sparkSession: SparkSession,
579+
logical: LogicalPlan,
580+
refreshPhaseEnabled: Boolean = true): QueryExecution = {
581+
new QueryExecution(
582+
sparkSession,
583+
logical,
584+
mode = CommandExecutionMode.ALL,
585+
shuffleCleanupMode = determineShuffleCleanupMode(sparkSession.sessionState.conf),
586+
refreshPhaseEnabled = refreshPhaseEnabled)
587+
}
588+
572589
/**
573590
* Construct a sequence of rules that are used to prepare a planned [[SparkPlan]] for execution.
574591
* These rules will make sure subqueries are planned, make sure the data partitioning and ordering

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import scala.collection.mutable
2121

2222
import org.apache.spark.internal.Logging
2323
import org.apache.spark.sql.catalyst.SQLConfHelper
24-
import org.apache.spark.sql.catalyst.analysis.AsOfVersion
2524
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
2625
import org.apache.spark.sql.classic.SparkSession
2726
import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog, V2TableUtil}
@@ -32,27 +31,6 @@ import org.apache.spark.sql.util.SchemaValidationMode.ALLOW_NEW_FIELDS
3231
import org.apache.spark.sql.util.SchemaValidationMode.PROHIBIT_CHANGES
3332

3433
private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging {
35-
/**
36-
* Pins table versions for all versioned tables in the plan.
37-
*
38-
* This method captures the current version of each versioned table by adding time travel
39-
* specifications. Tables that already have time travel specifications or are not versioned
40-
* are left unchanged.
41-
*
42-
* @param plan the logical plan to pin versions for
43-
* @return plan with pinned table versions
44-
*/
45-
def pinVersions(plan: LogicalPlan): LogicalPlan = {
46-
plan transform {
47-
case r @ ExtractV2CatalogAndIdentifier(catalog, ident)
48-
if r.isVersioned && r.timeTravelSpec.isEmpty =>
49-
val tableName = V2TableUtil.toQualifiedName(catalog, ident)
50-
val version = r.table.version
51-
logDebug(s"Pinning table version for $tableName to $version")
52-
r.copy(timeTravelSpec = Some(AsOfVersion(version)))
53-
}
54-
}
55-
5634
/**
5735
* Refreshes table metadata for tables in the plan.
5836
*

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.expressions.Transform
3333
import org.apache.spark.sql.connector.metric.CustomMetric
3434
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, RowLevelOperationTable, Write, WriterCommitMessage, WriteSummary}
3535
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
36-
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryExecNode}
36+
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution, UnaryExecNode}
3737
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3838
import org.apache.spark.sql.execution.joins.BaseJoinExec
3939
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics}
@@ -177,21 +177,22 @@ case class ReplaceTableAsSelectExec(
177177
query,
178178
versionedOnly = true,
179179
schemaValidationMode = PROHIBIT_CHANGES)
180-
val pinnedQuery = V2TableRefreshUtil.pinVersions(refreshedQuery)
181180
if (catalog.tableExists(ident)) {
182181
invalidateCache(catalog, ident)
183182
catalog.dropTable(ident)
184183
} else if (!orCreate) {
185184
throw QueryCompilationErrors.cannotReplaceMissingTableError(ident)
186185
}
187186
val tableInfo = new TableInfo.Builder()
188-
.withColumns(getV2Columns(pinnedQuery.schema, catalog.useNullableQuerySchema))
187+
.withColumns(getV2Columns(refreshedQuery.schema, catalog.useNullableQuerySchema))
189188
.withPartitions(partitioning.toArray)
190189
.withProperties(properties.asJava)
191190
.build()
192191
val table = Option(catalog.createTable(ident, tableInfo))
193192
.getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
194-
writeToTable(catalog, table, writeOptions, ident, pinnedQuery, overwrite = true)
193+
writeToTable(
194+
catalog, table, writeOptions, ident, refreshedQuery,
195+
overwrite = true, refreshPhaseEnabled = false)
195196
}
196197
}
197198

@@ -764,15 +765,16 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec {
764765
writeOptions: Map[String, String],
765766
ident: Identifier,
766767
query: LogicalPlan,
767-
overwrite: Boolean): Seq[InternalRow] = {
768+
overwrite: Boolean,
769+
refreshPhaseEnabled: Boolean = true): Seq[InternalRow] = {
768770
Utils.tryWithSafeFinallyAndFailureCallbacks({
769771
val relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident))
770772
val writeCommand = if (overwrite) {
771773
OverwriteByExpression.byPosition(relation, query, Literal.TrueLiteral, writeOptions)
772774
} else {
773775
AppendData.byPosition(relation, query, writeOptions)
774776
}
775-
val qe = session.sessionState.executePlan(writeCommand)
777+
val qe = QueryExecution.create(session, writeCommand, refreshPhaseEnabled)
776778
qe.assertCommandExecuted()
777779
DataSourceV2Utils.commitStagedChanges(sparkContext, table, metrics)
778780
Nil

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSel
3030
import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, DefaultValue, Identifier, InMemoryTableCatalog, SupportsV1OverwriteWithSaveAsTable, TableInfo}
3131
import org.apache.spark.sql.connector.catalog.BasicInMemoryTableCatalog
3232
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, UpdateColumnDefaultValue}
33+
import org.apache.spark.sql.connector.catalog.TableChange
3334
import org.apache.spark.sql.connector.catalog.TableWritePrivilege
3435
import org.apache.spark.sql.connector.catalog.TruncatableTable
3536
import org.apache.spark.sql.connector.expressions.{ApplyTransform, GeneralScalarExpression, LiteralValue, Transform}
@@ -1894,6 +1895,151 @@ class DataSourceV2DataFrameSuite
18941895
}
18951896
}
18961897

1898+
test("SPARK-54424: refresh table cache on schema changes (column removed)") {
1899+
val t = "testcat.ns1.ns2.tbl"
1900+
val ident = Identifier.of(Array("ns1", "ns2"), "tbl")
1901+
withTable(t) {
1902+
sql(s"CREATE TABLE $t (id INT, value INT, category STRING) USING foo")
1903+
sql(s"INSERT INTO $t VALUES (1, 10, 'A'), (2, 20, 'B'), (3, 30, 'A')")
1904+
1905+
// cache table
1906+
spark.table(t).cache()
1907+
1908+
// verify caching works as expected
1909+
assertCached(spark.table(t))
1910+
checkAnswer(
1911+
spark.table(t),
1912+
Seq(Row(1, 10, "A"), Row(2, 20, "B"), Row(3, 30, "A")))
1913+
1914+
// evolve table directly to mimic external changes
1915+
// these external changes make cached plan invalid (column is no longer there)
1916+
val change = TableChange.deleteColumn(Array("category"), false)
1917+
catalog("testcat").alterTable(ident, change)
1918+
1919+
// refresh table is supposed to trigger recaching
1920+
spark.sql(s"REFRESH TABLE $t")
1921+
1922+
// recaching is expected to succeed
1923+
assert(spark.sharedState.cacheManager.numCachedEntries == 1)
1924+
1925+
// verify cache reflects latest schema and data
1926+
assertCached(spark.table(t))
1927+
checkAnswer(spark.table(t), Seq(Row(1, 10), Row(2, 20), Row(3, 30)))
1928+
}
1929+
}
1930+
1931+
test("SPARK-54424: refresh table cache on schema changes (column added)") {
1932+
val t = "testcat.ns1.ns2.tbl"
1933+
val ident = Identifier.of(Array("ns1", "ns2"), "tbl")
1934+
withTable(t) {
1935+
sql(s"CREATE TABLE $t (id INT, value INT) USING foo")
1936+
sql(s"INSERT INTO $t VALUES (1, 10), (2, 20), (3, 30)")
1937+
1938+
// cache table
1939+
spark.table(t).cache()
1940+
1941+
// verify caching works as expected
1942+
assertCached(spark.table(t))
1943+
checkAnswer(
1944+
spark.table(t),
1945+
Seq(Row(1, 10), Row(2, 20), Row(3, 30)))
1946+
1947+
// evolve table directly to mimic external changes
1948+
// these external changes make cached plan invalid (table state has changed)
1949+
val change = TableChange.addColumn(Array("category"), StringType, true)
1950+
catalog("testcat").alterTable(ident, change)
1951+
1952+
// refresh table is supposed to trigger recaching
1953+
spark.sql(s"REFRESH TABLE $t")
1954+
1955+
// recaching is expected to succeed
1956+
assert(spark.sharedState.cacheManager.numCachedEntries == 1)
1957+
1958+
// verify cache reflects latest schema and data
1959+
assertCached(spark.table(t))
1960+
checkAnswer(spark.table(t), Seq(Row(1, 10, null), Row(2, 20, null), Row(3, 30, null)))
1961+
}
1962+
}
1963+
1964+
test("SPARK-54424: successfully refresh cache with compatible schema changes") {
1965+
val t = "testcat.ns1.ns2.tbl"
1966+
val ident = Identifier.of(Array("ns1", "ns2"), "tbl")
1967+
withTable(t) {
1968+
sql(s"CREATE TABLE $t (id INT, value INT) USING foo")
1969+
sql(s"INSERT INTO $t VALUES (1, 10), (2, 20), (3, 30)")
1970+
1971+
// cache query
1972+
val df = spark.table(t).filter("id < 100")
1973+
df.cache()
1974+
1975+
// verify caching works as expected
1976+
assertCached(spark.table(t).filter("id < 100"))
1977+
checkAnswer(
1978+
spark.table(t).filter("id < 100"),
1979+
Seq(Row(1, 10), Row(2, 20), Row(3, 30)))
1980+
1981+
// evolve table directly to mimic external changes
1982+
// adding columns should be OK
1983+
val change = TableChange.addColumn(Array("category"), StringType, true)
1984+
catalog("testcat").alterTable(ident, change)
1985+
1986+
// refresh table is supposed to trigger recaching
1987+
spark.sql(s"REFRESH TABLE $t")
1988+
1989+
// recaching is expected to succeed
1990+
assert(spark.sharedState.cacheManager.numCachedEntries == 1)
1991+
1992+
// verify derived queries still benefit from refreshed cache
1993+
assertCached(df.filter("id > 0"))
1994+
checkAnswer(df.filter("id > 0"), Seq(Row(1, 10), Row(2, 20), Row(3, 30)))
1995+
1996+
// add more data
1997+
sql(s"INSERT INTO $t VALUES (4, 40, '40')")
1998+
1999+
// verify derived queries still benefit from refreshed cache
2000+
assertCached(df.filter("id > 0"))
2001+
checkAnswer(df.filter("id > 0"), Seq(Row(1, 10), Row(2, 20), Row(3, 30), Row(4, 40)))
2002+
2003+
// verify latest schema is propagated (new column has NULL values for existing rows)
2004+
checkAnswer(
2005+
spark.table(t),
2006+
Seq(Row(1, 10, null), Row(2, 20, null), Row(3, 30, null), Row(4, 40, "40")))
2007+
}
2008+
}
2009+
2010+
test("SPARK-54424: inability to refresh cache shouldn't fail operations") {
2011+
val t = "testcat.ns1.ns2.tbl"
2012+
val ident = Identifier.of(Array("ns1", "ns2"), "tbl")
2013+
withTable(t) {
2014+
sql(s"CREATE TABLE $t (id INT, value INT) USING foo")
2015+
sql(s"INSERT INTO $t VALUES (1, 10), (2, 20), (3, 30)")
2016+
2017+
// cache query
2018+
val df = spark.table(t).filter("id < 100")
2019+
df.cache()
2020+
2021+
// verify caching works as expected
2022+
assertCached(spark.table(t).filter("id < 100"))
2023+
checkAnswer(
2024+
spark.table(t).filter("id < 100"),
2025+
Seq(Row(1, 10), Row(2, 20), Row(3, 30)))
2026+
2027+
// evolve table directly to mimic external changes
2028+
// removing columns should be make cached plan invalid
2029+
val change = TableChange.deleteColumn(Array("value"), false)
2030+
catalog("testcat").alterTable(ident, change)
2031+
2032+
// refresh table is supposed to trigger recaching
2033+
spark.sql(s"REFRESH TABLE $t")
2034+
2035+
// recaching is expected to fail
2036+
assert(spark.sharedState.cacheManager.isEmpty)
2037+
2038+
// verify latest schema is propagated
2039+
checkAnswer(spark.table(t), Seq(Row(1), Row(2), Row(3)))
2040+
}
2041+
}
2042+
18972043
private def pinTable(catalogName: String, ident: Identifier, version: String): Unit = {
18982044
catalog(catalogName) match {
18992045
case inMemory: BasicInMemoryTableCatalog => inMemory.pinTable(ident, version)

0 commit comments

Comments
 (0)