From 569b51213ce326c3dc7d0b97bf174c97fb9e41e2 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Thu, 15 Jan 2026 16:50:12 -0500 Subject: [PATCH 1/9] feat(isthmus): add struct-based user-defined type literal support --- .../io/substrait/expression/Expression.java | 7 + .../isthmus/SubstraitRelNodeConverter.java | 2 +- .../substrait/isthmus/SubstraitToCalcite.java | 18 +- .../isthmus/expression/CallConverters.java | 78 +++- .../expression/ExpressionRexConverter.java | 47 ++- .../substrait/isthmus/CalciteLiteralTest.java | 72 ++++ .../UserDefinedLiteralRoundtripTest.java | 363 ++++++++++++++++++ .../isthmus/utils/UserTypeFactory.java | 65 +++- 8 files changed, 628 insertions(+), 24 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index d52abff15..450eac7ca 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -30,6 +30,13 @@ interface Literal extends Expression { default boolean nullable() { return false; } + + /** + * Returns a copy of this literal with the specified nullability. + * + *

This method is implemented by all concrete Literal classes via Immutables code generation. + */ + Literal withNullable(boolean nullable); } interface Nested extends Expression { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 47daf97e2..5a223276a 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -162,7 +162,7 @@ public SubstraitRelNodeConverter( this.expressionRexConverter.setRelNodeConverter(this); } - private static ScalarFunctionConverter createScalarFunctionConverter( + static ScalarFunctionConverter createScalarFunctionConverter( SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory, boolean allowDynamicUdfs) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index a0c5132e4..024301bca 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -2,6 +2,9 @@ import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.SubstraitRelNodeConverter.Context; +import io.substrait.isthmus.expression.AggregateFunctionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.plan.Plan; import io.substrait.relation.Rel; import io.substrait.util.EmptyVisitationContext; @@ -104,7 +107,20 @@ protected RelBuilder createRelBuilder(CalciteSchema schema) { *

Override this method to customize the {@link SubstraitRelNodeConverter}. */ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { - return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder, featureBoard); + ScalarFunctionConverter scalarFunctionConverter = + SubstraitRelNodeConverter.createScalarFunctionConverter( + extensions, typeFactory, featureBoard.allowDynamicUdfs()); + AggregateFunctionConverter aggregateFunctionConverter = + new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory); + WindowFunctionConverter windowFunctionConverter = + new WindowFunctionConverter(extensions.windowFunctions(), typeFactory); + return new SubstraitRelNodeConverter( + typeFactory, + relBuilder, + scalarFunctionConverter, + aggregateFunctionConverter, + windowFunctionConverter, + typeConverter); } /** diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 17db2143f..f6c5c5d97 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -44,16 +44,18 @@ public class CallConverters { * {@link SqlKind#REINTERPRET} is utilized by Isthmus to represent and store {@link * Expression.UserDefinedLiteral}s within Calcite. * - *

When converting from Substrait to Calcite, the {@link - * Expression.UserDefinedAnyLiteral#value()} is stored within a {@link - * org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link org.apache.calcite.rex.RexLiteral} and - * then re-interpreted to have the correct type. + *

When converting from Substrait to Calcite, the user-defined literal value is stored either + * as a {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link + * org.apache.calcite.rex.RexLiteral} (for ANY-encoded values) or a {@link SqlKind#ROW} (for + * struct-encoded values) and then re-interpreted to have the correct user-defined type. * *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedAnyLiteral, + * SubstraitRelNodeConverter.Context)} and {@link + * ExpressionRexConverter#visit(Expression.UserDefinedStructLiteral, * SubstraitRelNodeConverter.Context)} for this conversion. * - *

When converting from Calcite to Substrait, this call converter extracts the {@link - * Expression.UserDefinedAnyLiteral} that was stored. + *

When converting from Calcite to Substrait, this call converter extracts the stored {@link + * Expression.UserDefinedLiteral}. */ public static Function REINTERPRET = typeConverter -> @@ -86,8 +88,24 @@ public class CallConverters { } catch (com.google.protobuf.InvalidProtocolBufferException e) { throw new IllegalStateException("Failed to parse UserDefinedAnyLiteral value", e); } + } else if (operand instanceof Expression.StructLiteral + && type instanceof Type.UserDefined) { + Expression.StructLiteral structLiteral = (Expression.StructLiteral) operand; + Type.UserDefined t = (Type.UserDefined) type; + + return Expression.UserDefinedStructLiteral.builder() + .nullable(t.nullable()) + .urn(t.urn()) + .name(t.name()) + .addAllTypeParameters(t.typeParameters()) + .addAllFields(structLiteral.fields()) + .build(); } - return null; + throw new IllegalStateException( + "Unexpected REINTERPRET operand type: " + + operand.getClass().getSimpleName() + + " with target type: " + + type.getClass().getSimpleName()); }; // public static SimpleCallConverter OrAnd(FunctionConverter c) { @@ -100,6 +118,51 @@ public class CallConverters { // return null; // }; // } + /** + * Converts Calcite ROW constructors into Substrait struct literals. + * + *

ROW values are always concrete (never null themselves) - if a value is actually null, use + * NullLiteral instead of StructLiteral. Therefore, the resulting StructLiteral always has + * nullable=false. The ROW's type may be nullable (for regular structs) or non-nullable (for UDT + * struct encoding), but the value itself is always concrete. + * + *

Field nullability comes from individual field types in the ROW's type definition. When a + * field's type is nullable but the literal operand is not, we update the literal's nullability to + * match. + */ + public static SimpleCallConverter ROW = + (call, visitor) -> { + if (call.getKind() != SqlKind.ROW) { + return null; + } + + List operands = + call.getOperands().stream().map(visitor).collect(java.util.stream.Collectors.toList()); + if (!operands.stream().allMatch(expr -> expr instanceof Expression.Literal)) { + throw new IllegalArgumentException("ROW operands must be literals."); + } + + java.util.List fieldTypes = + call.getType().getFieldList(); + List literals = new java.util.ArrayList<>(); + + for (int i = 0; i < operands.size(); i++) { + Expression.Literal lit = (Expression.Literal) operands.get(i); + boolean fieldIsNullable = fieldTypes.get(i).getType().isNullable(); + + // ROW types are never nullable (struct literals are always concrete values). + // Field nullability comes from individual field types. + if (fieldIsNullable && !lit.nullable()) { + lit = lit.withNullable(true); + } + literals.add(lit); + } + + // Struct literals are always concrete values (never null). + // For UDT struct literals, struct-level nullability is in the REINTERPRET target type. + return ExpressionCreator.struct(false, literals); + }; + /** */ public static SimpleCallConverter CASE = (call, visitor) -> { @@ -150,6 +213,7 @@ public static List defaults(TypeConverter typeConverter) { return ImmutableList.of( new FieldSelectionConverter(typeConverter), CallConverters.CASE, + CallConverters.ROW, CallConverters.CAST.apply(typeConverter), CallConverters.REINTERPRET.apply(typeConverter), new SqlArrayValueConstructorCallConverter(typeConverter), diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index e06f66e8b..3deff1f81 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -120,8 +120,12 @@ public RexNode visit(Expression.UserDefinedAnyLiteral expr, Context context) @Override public RexNode visit(Expression.UserDefinedStructLiteral expr, Context context) throws RuntimeException { - throw new UnsupportedOperationException( - "UserDefinedStructLiteral representation is not yet supported in Isthmus"); + // UserDefinedStructLiteral: Struct is just the ENCODING/REPRESENTATION of a UDT value. + // The ROW is never nullable (it's just encoding). UDT nullability is carried by the + // REINTERPRET target type: REINTERPRET(ROW(...), udt{nullable=true/false}). + RelDataType type = typeConverter.toCalcite(typeFactory, expr.getType()); + RexNode structValue = toStructEncoding(expr.fields(), context); + return rexBuilder.makeReinterpretCast(type, structValue, rexBuilder.makeLiteral(false)); } @Override @@ -320,6 +324,14 @@ public RexNode visit(Expression.DecimalLiteral expr, Context context) throws Run return rexBuilder.makeLiteral(decimal, typeConverter.toCalcite(typeFactory, expr.getType())); } + @Override + public RexNode visit(Expression.StructLiteral expr, Context context) throws RuntimeException { + List fieldNodes = + expr.fields().stream().map(f -> f.accept(this, context)).collect(Collectors.toList()); + RelDataType structType = typeConverter.toCalcite(typeFactory, expr.getType()); + return rexBuilder.makeCall(structType, SqlStdOperatorTable.ROW, fieldNodes); + } + @Override public RexNode visit(Expression.ListLiteral expr, Context context) throws RuntimeException { List args = @@ -723,4 +735,35 @@ public RexNode visit(SetPredicate expr, Context context) throws RuntimeException "Cannot handle SetPredicate when PredicateOp is %s.", expr.predicateOp().name())); } } + + /** + * Helper method to create a Calcite ROW expression for encoding UDT struct literals. + * + *

Used specifically for {@link Expression.UserDefinedStructLiteral} where the struct is just + * the encoding representation of the UDT value. The ROW is never nullable because it's just the + * encoding - nullability is carried by the REINTERPRET target UDT type. + * + *

For regular {@link Expression.StructLiteral}, use the struct's own type via {@code + * expr.getType()} instead. + */ + private RexNode toStructEncoding(List fields, Context context) { + List fieldNodes = + fields.stream().map(f -> f.accept(this, context)).collect(Collectors.toList()); + + // Note: Field names ("field0", "field1", etc.) are dummy values required by Calcite's ROW + // type. These names are discarded during roundtrip conversion back to Substrait, as Substrait + // struct literals are position-based and only the field values are preserved. + // + // The ROW type is never nullable because it's just encoding for the UDT. Field nullability + // comes from individual field types. + RelDataTypeFactory.Builder rowBuilder = typeFactory.builder(); + IntStream.range(0, fields.size()) + .forEach( + i -> { + RelDataType fieldType = typeConverter.toCalcite(typeFactory, fields.get(i).getType()); + rowBuilder.add("field" + i, fieldType); + }); + + return rexBuilder.makeCall(rowBuilder.build(), SqlStdOperatorTable.ROW, fieldNodes); + } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java index 94f77c6a9..f731621d2 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java @@ -22,6 +22,7 @@ import java.nio.charset.StandardCharsets; import java.time.LocalDate; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.concurrent.TimeUnit; import org.apache.calcite.rex.RexLiteral; @@ -388,6 +389,77 @@ void tStruct() { false)); } + @Test + void tStructRoundtripNullableFields() { + // Test regular struct with nullable fields roundtrips correctly + Expression.StructLiteral struct = + ExpressionCreator.struct( + false, ExpressionCreator.i32(true, 4), ExpressionCreator.i32(true, -1)); + + RexNode rex = struct.accept(expressionRexConverter, Context.newContext()); + Expression roundtrip = rex.accept(rexExpressionConverter); + + assertEquals(struct, roundtrip); + } + + @Test + void tStructRoundtripMixedFieldNullability() { + // Test regular struct with mixed field nullability roundtrips correctly + Expression.StructLiteral struct = + ExpressionCreator.struct( + false, ExpressionCreator.i32(true, 4), ExpressionCreator.i32(false, -1)); + + RexNode rex = struct.accept(expressionRexConverter, Context.newContext()); + Expression roundtrip = rex.accept(rexExpressionConverter); + + assertEquals(struct, roundtrip); + } + + @Test + void tStructRoundtripWithNullFieldValues() { + // Test struct with actual NULL field values roundtrips correctly + Expression.NullLiteral nullField = + Expression.NullLiteral.builder() + .nullable(true) + .type(io.substrait.type.Type.I32.builder().nullable(true).build()) + .build(); + + Expression.StructLiteral struct = + ExpressionCreator.struct(false, nullField, ExpressionCreator.i32(false, 100)); + + RexNode rex = struct.accept(expressionRexConverter, Context.newContext()); + Expression roundtrip = rex.accept(rexExpressionConverter); + + assertEquals(struct, roundtrip); + } + + @Test + void tStructRoundtripNested() { + // Test nested regular structs roundtrip correctly + Expression.StructLiteral innerStruct = + ExpressionCreator.struct( + false, ExpressionCreator.i32(false, 1), ExpressionCreator.i32(false, 2)); + + Expression.StructLiteral outerStruct = + ExpressionCreator.struct(false, innerStruct, ExpressionCreator.i32(false, 3)); + + RexNode rex = outerStruct.accept(expressionRexConverter, Context.newContext()); + Expression roundtrip = rex.accept(rexExpressionConverter); + + assertEquals(outerStruct, roundtrip); + } + + @Test + void tStructRoundtripEmpty() { + // Test empty struct roundtrips correctly + Expression.StructLiteral struct = ExpressionCreator.struct(false, Collections.emptyList()); + + RexNode rex = struct.accept(expressionRexConverter, Context.newContext()); + Expression roundtrip = rex.accept(rexExpressionConverter); + + assertEquals(struct, roundtrip); + } + @Test void tFixedBinary() { byte[] val = "my test".getBytes(StandardCharsets.UTF_8); diff --git a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java new file mode 100644 index 000000000..1aee28cb3 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java @@ -0,0 +1,363 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.protobuf.Any; +import com.google.protobuf.StringValue; +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.expression.AggregateFunctionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.isthmus.expression.WindowFunctionConverter; +import io.substrait.isthmus.utils.UserTypeFactory; +import io.substrait.relation.Rel; +import io.substrait.type.Type; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.jspecify.annotations.Nullable; +import org.junit.jupiter.api.Test; + +class UserDefinedLiteralRoundtripTest extends PlanTestBase { + + private static final String NESTED_TYPES_URN = "extension:io.substrait:test_nested_types"; + + private static final String NESTED_TYPES_YAML = + "---\n" + + "urn: " + + NESTED_TYPES_URN + + "\n" + + "types:\n" + + " - name: point\n" + + " structure:\n" + + " latitude: i32\n" + + " longitude: i32\n" + + " - name: triangle\n" + + " structure:\n" + + " p1: point\n" + + " p2: point\n" + + " p3: point\n" + + " - name: vector\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " structure:\n" + + " x: T\n" + + " y: T\n" + + " z: T\n" + + " - name: multi_param\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " - name: size\n" + + " type: integer\n" + + " - name: nullable\n" + + " type: boolean\n" + + " - name: encoding\n" + + " type: string\n" + + " - name: precision\n" + + " type: dataType\n" + + " - name: mode\n" + + " type: enum\n" + + " structure:\n" + + " value: T\n"; + + private static final SimpleExtension.ExtensionCollection NESTED_TYPES_EXTENSIONS = + SimpleExtension.load("nested_types.yaml", NESTED_TYPES_YAML); + + private final SubstraitBuilder builder = new SubstraitBuilder(NESTED_TYPES_EXTENSIONS); + + private final Map userTypeFactories = + Map.of( + "point", new UserTypeFactory(NESTED_TYPES_URN, "point"), + "triangle", new UserTypeFactory(NESTED_TYPES_URN, "triangle"), + "vector", new UserTypeFactory(NESTED_TYPES_URN, "vector"), + "multi_param", new UserTypeFactory(NESTED_TYPES_URN, "multi_param")); + + private final UserTypeMapper userTypeMapper = + new UserTypeMapper() { + @Override + public @Nullable Type toSubstrait(RelDataType relDataType) { + return userTypeFactories.values().stream() + .filter(factory -> factory.isTypeFromFactory(relDataType)) + .findFirst() + .map( + factory -> + factory.createSubstrait( + relDataType.isNullable(), factory.getTypeParameters(relDataType))) + .orElse(null); + } + + @Override + public @Nullable RelDataType toCalcite(Type.UserDefined type) { + if (!type.urn().equals(NESTED_TYPES_URN)) { + return null; + } + UserTypeFactory factory = userTypeFactories.get(type.name()); + if (factory == null) { + return null; + } + + return factory.createCalcite(type.nullable(), type.typeParameters()); + } + }; + + private final TypeConverter typeConverter = new TypeConverter(userTypeMapper); + + private final ScalarFunctionConverter scalarFunctionConverter = + new ScalarFunctionConverter( + NESTED_TYPES_EXTENSIONS.scalarFunctions(), + Collections.emptyList(), + typeFactory, + typeConverter); + + private final AggregateFunctionConverter aggregateFunctionConverter = + new AggregateFunctionConverter( + NESTED_TYPES_EXTENSIONS.aggregateFunctions(), + Collections.emptyList(), + typeFactory, + typeConverter); + + private final WindowFunctionConverter windowFunctionConverter = + new WindowFunctionConverter(NESTED_TYPES_EXTENSIONS.windowFunctions(), typeFactory); + + private final SubstraitToCalcite substraitToCalcite = + new SubstraitToCalcite(NESTED_TYPES_EXTENSIONS, typeFactory, typeConverter); + + private final SubstraitRelVisitor calciteToSubstrait = + new SubstraitRelVisitor( + typeFactory, + scalarFunctionConverter, + aggregateFunctionConverter, + windowFunctionConverter, + typeConverter, + ImmutableFeatureBoard.builder().build()); + + private void assertRoundTrip(Expression.UserDefinedLiteral literal) { + Rel rel = + builder.project( + input -> List.of(literal), + builder.remap(1), + builder.namedScan(List.of("example"), List.of("udt_col"), List.of(literal.getType()))); + + RelNode calciteRel = substraitToCalcite.convert(rel); + Rel relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } + + private Expression.UserDefinedStructLiteral pointStructLiteral(int latitude, int longitude) { + return ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(false, latitude), ExpressionCreator.i32(false, longitude))); + } + + private Expression.UserDefinedStructLiteral triangleStruct( + Expression.UserDefinedLiteral p1, + Expression.UserDefinedLiteral p2, + Expression.UserDefinedLiteral p3) { + return ExpressionCreator.userDefinedLiteralStruct( + false, NESTED_TYPES_URN, "triangle", Collections.emptyList(), Arrays.asList(p1, p2, p3)); + } + + private Expression.UserDefinedAnyLiteral pointAnyLiteral(String value) { + return ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), Any.pack(StringValue.of(value))); + } + + private Expression.UserDefinedStructLiteral vectorStructLiteral( + List params, + Expression.Literal x, + Expression.Literal y, + Expression.Literal z) { + return ExpressionCreator.userDefinedLiteralStruct( + false, NESTED_TYPES_URN, "vector", params, Arrays.asList(x, y, z)); + } + + @Test + void anyEncodedUdtRoundTrip() { + Expression.UserDefinedLiteral literal = + ExpressionCreator.userDefinedLiteralAny( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + Any.pack(StringValue.of(""))); + + assertRoundTrip(literal); + } + + @Test + void structEncodedUdtRoundTrip() { + assertRoundTrip(pointStructLiteral(42, 100)); + } + + @Test + void nestedStructEncodedUdtRoundTrip() { + assertRoundTrip( + triangleStruct( + pointStructLiteral(0, 0), pointStructLiteral(10, 0), pointStructLiteral(5, 10))); + } + + @Test + void nestedMixedEncodingsRoundTrip() { + // Mix encodings: struct, any, struct. + assertRoundTrip( + triangleStruct( + pointStructLiteral(1, 2), pointAnyLiteral("p2-any"), pointStructLiteral(3, 4))); + } + + @Test + void parameterizedUdtRoundTrip() { + Type.Parameter typeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.I32.builder().nullable(false).build()) + .build(); + + Expression.UserDefinedLiteral literal = + vectorStructLiteral( + Collections.singletonList(typeParam), + ExpressionCreator.i32(false, 1), + ExpressionCreator.i32(false, 2), + ExpressionCreator.i32(false, 3)); + + assertRoundTrip(literal); + } + + @Test + void parameterizedUdtAllParamKindsRoundTrip() { + Type.Parameter typeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.I32.builder().nullable(false).build()) + .build(); + + Type.Parameter intParam = + io.substrait.type.ImmutableType.ParameterIntegerValue.builder().value(100L).build(); + + Type.Parameter boolParam = + io.substrait.type.ImmutableType.ParameterBooleanValue.builder().value(true).build(); + + Type.Parameter stringParam = + io.substrait.type.ImmutableType.ParameterStringValue.builder().value("utf8").build(); + + Type.Parameter nullParam = io.substrait.type.Type.ParameterNull.INSTANCE; + + Type.Parameter enumParam = + io.substrait.type.ImmutableType.ParameterEnumValue.builder().value("FAST").build(); + + Expression.UserDefinedLiteral literal = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "multi_param", + Arrays.asList(typeParam, intParam, boolParam, stringParam, nullParam, enumParam), + Arrays.asList(ExpressionCreator.i32(false, 42))); + + assertRoundTrip(literal); + } + + @Test + void nullableFieldsInStructUdtRoundTrip() { + // Test field-level nullability: struct is non-nullable, but fields are nullable + Expression.UserDefinedStructLiteral literal = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(true, 42), // nullable field + ExpressionCreator.i32(true, 100))); // nullable field + + assertRoundTrip(literal); + } + + @Test + void mixedFieldNullabilityInStructUdtRoundTrip() { + // Test mixed field nullability: struct is non-nullable, first field nullable, second + // non-nullable + Expression.UserDefinedStructLiteral literal = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(true, 42), // nullable field + ExpressionCreator.i32(false, 100))); // non-nullable field + + assertRoundTrip(literal); + } + + @Test + void nullableStructEncodedUdtRoundTrip() { + // Test struct-level nullability: struct is nullable, fields are non-nullable + Expression.UserDefinedStructLiteral literal = + ExpressionCreator.userDefinedLiteralStruct( + true, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(false, 42), // non-nullable field + ExpressionCreator.i32(false, 100))); // non-nullable field + + assertRoundTrip(literal); + } + + @Test + void nullableStructWithMixedFieldNullabilityRoundTrip() { + // Test the critical case: nullable struct with mixed field nullability + Expression.UserDefinedStructLiteral literal = + ExpressionCreator.userDefinedLiteralStruct( + true, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(true, 42), // nullable field + ExpressionCreator.i32(false, 100))); // non-nullable field + + assertRoundTrip(literal); + } + + @Test + void multipleParameterizedUdtInstancesRoundTrip() { + Type.Parameter i32Param = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.I32.builder().nullable(false).build()) + .build(); + Type.Parameter fp64Param = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.FP64.builder().nullable(false).build()) + .build(); + + Expression.UserDefinedLiteral vecI32 = + vectorStructLiteral( + Collections.singletonList(i32Param), + ExpressionCreator.i32(false, 1), + ExpressionCreator.i32(false, 2), + ExpressionCreator.i32(false, 3)); + + Expression.UserDefinedLiteral vecFp64 = + vectorStructLiteral( + Collections.singletonList(fp64Param), + ExpressionCreator.fp64(false, 1.1), + ExpressionCreator.fp64(false, 2.2), + ExpressionCreator.fp64(false, 3.3)); + + Rel rel = builder.project(input -> Arrays.asList(vecI32, vecFp64), builder.emptyScan()); + + RelNode calciteRel = substraitToCalcite.convert(rel); + Rel relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java index 2c90f133d..2a70b87f7 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java @@ -1,7 +1,9 @@ package io.substrait.isthmus.utils; import io.substrait.type.Type; -import io.substrait.type.TypeCreator; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeImpl; import org.apache.calcite.sql.type.SqlTypeName; @@ -19,34 +21,65 @@ public class UserTypeFactory { public UserTypeFactory(String urn, String name) { this.urn = urn; this.name = name; - this.N = new InnerType(true, name); - this.R = new InnerType(false, name); + this.N = new InnerType(urn, name, true, Collections.emptyList()); + this.R = new InnerType(urn, name, false, Collections.emptyList()); } public RelDataType createCalcite(boolean nullable) { - if (nullable) { - return N; - } else { - return R; + return createCalcite(nullable, Collections.emptyList()); + } + + public RelDataType createCalcite(boolean nullable, List typeParameters) { + if (typeParameters.isEmpty()) { + return nullable ? N : R; } + + return new InnerType(urn, name, nullable, typeParameters); } public Type createSubstrait(boolean nullable) { - return TypeCreator.of(nullable).userDefined(urn, name); + return createSubstrait(nullable, Collections.emptyList()); + } + + public Type createSubstrait(boolean nullable, List typeParameters) { + return Type.UserDefined.builder() + .nullable(nullable) + .urn(urn) + .name(name) + .addAllTypeParameters(typeParameters) + .build(); } public boolean isTypeFromFactory(RelDataType type) { - return type == N || type == R; + // We may return cached instances (N/R) or fresh InnerType instances with parameters. + // Use instanceof to recognize any of them and match by urn/name so custom UDT mappings work. + if (type instanceof InnerType) { + InnerType inner = (InnerType) type; + return urn.equals(inner.urn) && name.equals(inner.name); + } + return false; + } + + public List getTypeParameters(RelDataType type) { + if (type instanceof InnerType) { + return ((InnerType) type).typeParameters; + } + return Collections.emptyList(); } private static class InnerType extends RelDataTypeImpl { private final boolean nullable; + private final String urn; private final String name; + private final List typeParameters; - private InnerType(boolean nullable, String name) { - computeDigest(); - this.nullable = nullable; + private InnerType( + String urn, String name, boolean nullable, List typeParameters) { + this.urn = urn; this.name = name; + this.nullable = nullable; + this.typeParameters = Collections.unmodifiableList(typeParameters); + computeDigest(); } @Override @@ -61,7 +94,13 @@ public SqlTypeName getSqlTypeName() { @Override protected void generateTypeString(StringBuilder sb, boolean withDetail) { - sb.append(name); + sb.append(urn).append(":").append(name); + + if (!typeParameters.isEmpty()) { + sb.append("<"); + sb.append(typeParameters.stream().map(Object::toString).collect(Collectors.joining(","))); + sb.append(">"); + } } } } From b7949f116641f1c34c993da06002d9cecd6a9b0f Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Fri, 16 Jan 2026 15:57:33 -0500 Subject: [PATCH 2/9] test(isthmus): add UDT roundtrip test for list and map typed fields Adds test coverage for struct-encoded UDT literals that contain list and map typed fields to ensure proper roundtrip conversion. --- .../UserDefinedLiteralRoundtripTest.java | 40 ++++++++++++++++++- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java index 1aee28cb3..10cfffffe 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java @@ -65,7 +65,12 @@ class UserDefinedLiteralRoundtripTest extends PlanTestBase { + " - name: mode\n" + " type: enum\n" + " structure:\n" - + " value: T\n"; + + " value: T\n" + + " - name: complex_record\n" + + " structure:\n" + + " id: i32\n" + + " tags: list\n" + + " attributes: map\n"; private static final SimpleExtension.ExtensionCollection NESTED_TYPES_EXTENSIONS = SimpleExtension.load("nested_types.yaml", NESTED_TYPES_YAML); @@ -77,7 +82,8 @@ class UserDefinedLiteralRoundtripTest extends PlanTestBase { "point", new UserTypeFactory(NESTED_TYPES_URN, "point"), "triangle", new UserTypeFactory(NESTED_TYPES_URN, "triangle"), "vector", new UserTypeFactory(NESTED_TYPES_URN, "vector"), - "multi_param", new UserTypeFactory(NESTED_TYPES_URN, "multi_param")); + "multi_param", new UserTypeFactory(NESTED_TYPES_URN, "multi_param"), + "complex_record", new UserTypeFactory(NESTED_TYPES_URN, "complex_record")); private final UserTypeMapper userTypeMapper = new UserTypeMapper() { @@ -360,4 +366,34 @@ void multipleParameterizedUdtInstancesRoundTrip() { Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } + + @Test + void listAndMapFieldsInStructUdtRoundTrip() { + // Test UDT with list and map typed fields + Expression.Literal idField = ExpressionCreator.i32(false, 42); + + Expression.Literal tagsField = + ExpressionCreator.list( + false, + ExpressionCreator.string(false, "tag1"), + ExpressionCreator.string(false, "tag2"), + ExpressionCreator.string(false, "tag3")); + + Expression.Literal attributesField = + ExpressionCreator.map( + false, + com.google.common.collect.ImmutableMap.of( + ExpressionCreator.string(false, "key1"), ExpressionCreator.i32(false, 100), + ExpressionCreator.string(false, "key2"), ExpressionCreator.i32(false, 200))); + + Expression.UserDefinedStructLiteral literal = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "complex_record", + Collections.emptyList(), + Arrays.asList(idField, tagsField, attributesField)); + + assertRoundTrip(literal); + } } From 114ca404be70037c125da0c66a9fe512469caf6b Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 21 Jan 2026 10:13:19 -0500 Subject: [PATCH 3/9] tweak: use imports over FQN --- .../io/substrait/isthmus/expression/CallConverters.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index f6c5c5d97..b5df611bc 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.Optional; import java.util.function.Function; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; @@ -142,9 +143,8 @@ public class CallConverters { throw new IllegalArgumentException("ROW operands must be literals."); } - java.util.List fieldTypes = - call.getType().getFieldList(); - List literals = new java.util.ArrayList<>(); + List fieldTypes = call.getType().getFieldList(); + List literals = new ArrayList<>(); for (int i = 0; i < operands.size(); i++) { Expression.Literal lit = (Expression.Literal) operands.get(i); From ee8c121e71dae3f650a467362cd38f6df6b0960b Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 21 Jan 2026 10:27:07 -0500 Subject: [PATCH 4/9] fix: correctly return null on callconverter that cannot handle --- .../io/substrait/isthmus/expression/CallConverters.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index b5df611bc..5716e6aa0 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -102,11 +102,7 @@ public class CallConverters { .addAllFields(structLiteral.fields()) .build(); } - throw new IllegalStateException( - "Unexpected REINTERPRET operand type: " - + operand.getClass().getSimpleName() - + " with target type: " - + type.getClass().getSimpleName()); + return null; }; // public static SimpleCallConverter OrAnd(FunctionConverter c) { From 1a849cf8494aee996288f13b462f1968e1b878e7 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Thu, 22 Jan 2026 09:49:14 -0500 Subject: [PATCH 5/9] tweak: enforce withNullable always makes null literal nullable --- .../src/main/java/io/substrait/expression/Expression.java | 8 ++++++++ .../java/io/substrait/isthmus/CalciteLiteralTest.java | 5 +---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 1a2a745eb..f8a5c1d3a 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -59,6 +59,14 @@ public boolean nullable() { return true; } + @Override + public NullLiteral withNullable(boolean nullable) { + if (!nullable) { + throw new IllegalArgumentException("NullLiteral cannot be made non-nullable"); + } + return this; + } + @Value.Check protected void check() { if (!type().nullable()) { diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java index f731621d2..2b29d89d0 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java @@ -419,10 +419,7 @@ void tStructRoundtripMixedFieldNullability() { void tStructRoundtripWithNullFieldValues() { // Test struct with actual NULL field values roundtrips correctly Expression.NullLiteral nullField = - Expression.NullLiteral.builder() - .nullable(true) - .type(io.substrait.type.Type.I32.builder().nullable(true).build()) - .build(); + ExpressionCreator.typedNull(io.substrait.type.Type.I32.builder().nullable(true).build()); Expression.StructLiteral struct = ExpressionCreator.struct(false, nullField, ExpressionCreator.i32(false, 100)); From f29a88d217fa42e90ee329529cb49863ab6af123 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Thu, 22 Jan 2026 16:47:16 -0500 Subject: [PATCH 6/9] simplify lit nullability cast in CallConverters --- .../substrait/isthmus/expression/CallConverters.java | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 5716e6aa0..62a95017f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -123,9 +123,7 @@ public class CallConverters { * nullable=false. The ROW's type may be nullable (for regular structs) or non-nullable (for UDT * struct encoding), but the value itself is always concrete. * - *

Field nullability comes from individual field types in the ROW's type definition. When a - * field's type is nullable but the literal operand is not, we update the literal's nullability to - * match. + *

Each literal's nullability is set to match its field type's nullability. */ public static SimpleCallConverter ROW = (call, visitor) -> { @@ -147,10 +145,9 @@ public class CallConverters { boolean fieldIsNullable = fieldTypes.get(i).getType().isNullable(); // ROW types are never nullable (struct literals are always concrete values). - // Field nullability comes from individual field types. - if (fieldIsNullable && !lit.nullable()) { - lit = lit.withNullable(true); - } + // Field nullability comes from individual field types, so match literal nullability + // to field type nullability. + lit = lit.withNullable(fieldIsNullable); literals.add(lit); } From 41d6f1cefcbb5bec3a2403dcaa9184e2b6fc68bd Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Thu, 22 Jan 2026 16:59:02 -0500 Subject: [PATCH 7/9] codegolf --- .../isthmus/expression/CallConverters.java | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 62a95017f..432002c06 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -137,19 +137,19 @@ public class CallConverters { throw new IllegalArgumentException("ROW operands must be literals."); } + // ROW types are never nullable (struct literals are always concrete values). + // Field nullability comes from individual field types, so match literal nullability + // to field type nullability. List fieldTypes = call.getType().getFieldList(); - List literals = new ArrayList<>(); - - for (int i = 0; i < operands.size(); i++) { - Expression.Literal lit = (Expression.Literal) operands.get(i); - boolean fieldIsNullable = fieldTypes.get(i).getType().isNullable(); - - // ROW types are never nullable (struct literals are always concrete values). - // Field nullability comes from individual field types, so match literal nullability - // to field type nullability. - lit = lit.withNullable(fieldIsNullable); - literals.add(lit); - } + List literals = + java.util.stream.IntStream.range(0, operands.size()) + .mapToObj( + i -> { + Expression.Literal lit = (Expression.Literal) operands.get(i); + boolean fieldIsNullable = fieldTypes.get(i).getType().isNullable(); + return lit.withNullable(fieldIsNullable); + }) + .collect(java.util.stream.Collectors.toList()); // Struct literals are always concrete values (never null). // For UDT struct literals, struct-level nullability is in the REINTERPRET target type. From 70579c6fcf3c2935b7c20173819b4a37bbd07cd4 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Fri, 23 Jan 2026 17:08:24 -0500 Subject: [PATCH 8/9] drop old builder.emptyScan() --- .../io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java index 10cfffffe..e381a14a2 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java @@ -360,7 +360,8 @@ void multipleParameterizedUdtInstancesRoundTrip() { ExpressionCreator.fp64(false, 2.2), ExpressionCreator.fp64(false, 3.3)); - Rel rel = builder.project(input -> Arrays.asList(vecI32, vecFp64), builder.emptyScan()); + Rel rel = + builder.project(input -> Arrays.asList(vecI32, vecFp64), builder.emptyVirtualTableScan()); RelNode calciteRel = substraitToCalcite.convert(rel); Rel relReturned = calciteToSubstrait.apply(calciteRel); From d6337ce4de612c8de1c96908dac79be1e4cd7d5c Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Fri, 23 Jan 2026 17:11:12 -0500 Subject: [PATCH 9/9] drop FQN --- .../io/substrait/isthmus/expression/CallConverters.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 432002c06..23612fce5 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.Optional; import java.util.function.Function; +import java.util.stream.Collectors; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; @@ -132,7 +133,7 @@ public class CallConverters { } List operands = - call.getOperands().stream().map(visitor).collect(java.util.stream.Collectors.toList()); + call.getOperands().stream().map(visitor).collect(Collectors.toList()); if (!operands.stream().allMatch(expr -> expr instanceof Expression.Literal)) { throw new IllegalArgumentException("ROW operands must be literals."); } @@ -149,7 +150,7 @@ public class CallConverters { boolean fieldIsNullable = fieldTypes.get(i).getType().isNullable(); return lit.withNullable(fieldIsNullable); }) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); // Struct literals are always concrete values (never null). // For UDT struct literals, struct-level nullability is in the REINTERPRET target type. @@ -168,7 +169,7 @@ public class CallConverters { assert call.getOperands().size() % 2 == 1; List caseArgs = - call.getOperands().stream().map(visitor).collect(java.util.stream.Collectors.toList()); + call.getOperands().stream().map(visitor).collect(Collectors.toList()); int last = caseArgs.size() - 1; // for if/else, process in reverse to maintain query order