-
Notifications
You must be signed in to change notification settings - Fork 94
fix(core): validate Plan.Root output name count #812
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0eab080
5a31567
0043ab2
b7fb3c9
c50bbfe
97ddfff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,9 +3,12 @@ | |||||||||
| import io.substrait.SubstraitVersion; | ||||||||||
| import io.substrait.extension.AdvancedExtension; | ||||||||||
| import io.substrait.relation.Rel; | ||||||||||
| import io.substrait.type.NamedFieldCountingTypeVisitor; | ||||||||||
| import java.util.List; | ||||||||||
| import java.util.Optional; | ||||||||||
| import org.immutables.value.Value; | ||||||||||
| import org.slf4j.Logger; | ||||||||||
| import org.slf4j.LoggerFactory; | ||||||||||
|
|
||||||||||
| /** A complete Substrait plan: a set of root relations together with version and metadata. */ | ||||||||||
| @Value.Immutable | ||||||||||
|
|
@@ -153,6 +156,8 @@ private static Version loadVersion() { | |||||||||
| /** A root relation of a plan together with the output field names it exposes. */ | ||||||||||
| @Value.Immutable | ||||||||||
| public abstract static class Root { | ||||||||||
| private static final Logger LOGGER = LoggerFactory.getLogger(Root.class); | ||||||||||
|
|
||||||||||
| /** | ||||||||||
| * Returns the relation producing this root's output. | ||||||||||
| * | ||||||||||
|
|
@@ -167,6 +172,26 @@ public abstract static class Root { | |||||||||
| */ | ||||||||||
| public abstract List<String> getNames(); | ||||||||||
|
|
||||||||||
| /** Validates that the root output names match the input record type. */ | ||||||||||
| @Value.Check | ||||||||||
| protected void check() { | ||||||||||
| final int actualNameCount = getNames().size(); | ||||||||||
| if (actualNameCount == 0) { | ||||||||||
| LOGGER.warn( | ||||||||||
| "Plan.Root built without output names; this will be an error in the next release"); | ||||||||||
| return; | ||||||||||
| } | ||||||||||
|
Comment on lines
+179
to
+183
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have the feeling like we need to make the spec more clear on names being required for We probably also should flesh out the Plan Root documentation a little bit more to describe the structure and reference to the What do you think?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you are right that the spec is a bit unclear. Maybe a more obvious validation per the current state of the spec is
What do you think about that?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We discussed this a little outside of this issue. Initially, I thought it was good idea to validate this but we may have wanted to sharpen the specification a bit on what the correct behavior is. After some more consideration we identified that always requiring names to be set does not work since we have exceptions like Implementing conditional logic on the type of I think what we could do is the second option you suggested:
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am happy to alter the PR. The current version has everything required. I think this is actually okay because
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. We have output schemas defined in the spec and all POJO Rels in substrait-java inherit the record type from
This even includes the extension Rels which I thought might have been another gap. Then we can just always enforce it. |
||||||||||
|
|
||||||||||
| final int expectedFieldCount = | ||||||||||
| NamedFieldCountingTypeVisitor.countNames(getInput().getRecordType()); | ||||||||||
| if (actualNameCount != expectedFieldCount) { | ||||||||||
|
benbellick marked this conversation as resolved.
|
||||||||||
| throw new IllegalArgumentException( | ||||||||||
| String.format( | ||||||||||
| "Plan.Root names count (%d) must match input record type depth-first named-field count (%d)", | ||||||||||
| actualNameCount, expectedFieldCount)); | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
| /** | ||||||||||
| * Creates a builder for {@link Root}. | ||||||||||
| * | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,8 @@ | ||
| package io.substrait.relation; | ||
|
|
||
| import io.substrait.expression.Expression; | ||
| import io.substrait.type.NamedFieldCountingTypeVisitor; | ||
| import io.substrait.type.Type; | ||
| import io.substrait.type.TypeVisitor; | ||
| import io.substrait.util.VisitationContext; | ||
| import java.util.List; | ||
| import java.util.Objects; | ||
|
|
@@ -101,147 +101,4 @@ public <O, C extends VisitationContext, E extends Exception> O accept( | |
| public static ImmutableVirtualTableScan.Builder builder() { | ||
| return ImmutableVirtualTableScan.builder(); | ||
| } | ||
|
|
||
| private static class NamedFieldCountingTypeVisitor | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice, I had no idea we already had such a visitor in the code base |
||
| implements TypeVisitor<Integer, RuntimeException> { | ||
|
|
||
| private static final NamedFieldCountingTypeVisitor VISITOR = | ||
| new NamedFieldCountingTypeVisitor(); | ||
|
|
||
| private static Integer countNames(Type type) { | ||
| return type.accept(VISITOR); | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Bool type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.I8 type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.I16 type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.I32 type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.I64 type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.FP32 type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.FP64 type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Str type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Binary type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Date type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.PrecisionTimestamp type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.PrecisionTime type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.PrecisionTimestampTZ type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.IntervalYear type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.IntervalDay type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.IntervalCompound type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.UUID type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.FixedChar type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.VarChar type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.FixedBinary type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Decimal type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Struct type) throws RuntimeException { | ||
| // Only struct fields have names - the top level column names are also | ||
| // captured by this since the whole schema is wrapped in a Struct type | ||
| return type.fields().stream().mapToInt(field -> 1 + field.accept(this)).sum(); | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.ListType type) throws RuntimeException { | ||
| return type.elementType().accept(this); | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Map type) throws RuntimeException { | ||
| return type.key().accept(this) + type.value().accept(this); | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.UserDefined type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Func type) throws RuntimeException { | ||
| return 0; | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| package io.substrait.type; | ||
|
|
||
| /** | ||
| * Counts the number of field names required for a {@link Type} using Substrait's depth-first naming | ||
| * rules. | ||
| * | ||
| * <p>This is the same counting scheme used by {@link NamedStruct#names()}: top-level struct fields | ||
| * contribute one name each, and nested struct fields inside structs, lists, and maps also | ||
| * contribute names in depth-first order. Scalar types and other non-structural types do not | ||
| * contribute additional names. | ||
| * | ||
| * <p>Examples: | ||
| * | ||
| * <ul> | ||
| * <li>{@code struct<i64, i64>} requires 2 names | ||
| * <li>{@code list<struct<i64, i64>>} requires 2 names | ||
| * <li>{@code map<struct<i64, i64>, struct<i64, i64, i64>>} requires 5 names | ||
| * </ul> | ||
| * | ||
| * <p>This utility is used anywhere the library needs to validate or reason about name counts | ||
| * without carrying the names themselves, such as {@code Plan.Root} and {@code VirtualTableScan} | ||
| * validation. | ||
| */ | ||
| public final class NamedFieldCountingTypeVisitor implements TypeVisitor<Integer, RuntimeException> { | ||
|
|
||
| private static final NamedFieldCountingTypeVisitor VISITOR = new NamedFieldCountingTypeVisitor(); | ||
|
|
||
| private NamedFieldCountingTypeVisitor() {} | ||
|
|
||
| /** | ||
| * Returns the number of names required to describe {@code type} in Substrait's depth-first naming | ||
| * order. | ||
| * | ||
| * <p>For a top-level struct, this includes both the top-level field names and any nested struct | ||
| * field names required by compound child types. | ||
| * | ||
| * @param type the type to inspect | ||
| * @return the number of required names | ||
| */ | ||
| public static int countNames(Type type) { | ||
| return type.accept(VISITOR); | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Bool type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.I8 type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.I16 type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.I32 type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.I64 type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.FP32 type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.FP64 type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Str type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Binary type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Date type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.PrecisionTime type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.PrecisionTimestamp type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.PrecisionTimestampTZ type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.IntervalYear type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.IntervalDay type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.IntervalCompound type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.UUID type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.FixedChar type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.VarChar type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.FixedBinary type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Decimal type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Func type) { | ||
| return 0; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Struct type) { | ||
| // Each struct field contributes its own name, plus any nested names required by that field's | ||
| // type. | ||
| return type.fields().stream().mapToInt(field -> 1 + countNames(field)).sum(); | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.ListType type) { | ||
| // Lists do not add a name themselves, but list elements may contain nested structs. | ||
| return countNames(type.elementType()); | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.Map type) { | ||
| // Maps do not add names themselves; any required names come from struct keys and/or values. | ||
| return countNames(type.key()) + countNames(type.value()); | ||
| } | ||
|
|
||
| @Override | ||
| public Integer visit(Type.UserDefined type) { | ||
| return 0; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is a great solution. But considering that this update will result in stricter plan enforcement, I want to figure out a way where we can just warn on incorrect behavior for now, and then make it an actual error in the future.
But I'm not sure introducing a logger here where one isn't used anywhere else makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Left a comment on whether we want to make the spec more clear first. If the spec clearly says names are required then we can fire the exception straight away and we handle the change in behavior as a breaking change commit so it gets announced to consumers of substrait-java. That would then be sufficient in my opinion. Logging a warning to communicate the upcoming change is a nice gesture but who knows if consumers notice the warning message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened a PR in upstream: substrait-io/substrait#1101