Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright © 2018-2026 Commonwealth Scientific and Industrial Research
* Organisation (CSIRO) ABN 41 687 119 230.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package au.csiro.pathling.encoders;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;

/**
* A thread-safe counter for tracking element positions within recursive tree traversals. Each
* thread gets its own independent counter via {@link ThreadLocal}, ensuring that Spark tasks
* running in parallel on different partitions do not interfere with each other.
*
* <p>This class is shared across partitions via Spark's {@code addReferenceObj()} mechanism in
* codegen mode. Since reference objects are shared within an executor, {@link ThreadLocal} is
* required to isolate mutable state per task thread.
*
* <p>This class is {@link Serializable} so that it survives Spark plan serialization to executors.
* The {@link ThreadLocal} is eagerly initialized and re-initialized after deserialization via
* {@link #readObject(ObjectInputStream)}.
*
* <p>Note: {@link ThreadLocal#remove()} is intentionally not called. The stored value is a single
* {@code int[1]} (16 bytes) that is reset to zero each row via {@link #reset()}. When this object
* becomes unreachable, the {@link ThreadLocal}'s weak-reference key is collected and the stale
* entry is cleaned up lazily by subsequent {@link ThreadLocal} operations on the same thread.
*
* @author Piotr Szul
*/
@SuppressWarnings("java:S5164") // ThreadLocal.remove() not needed — see class Javadoc.
public class RowIndexCounter implements Serializable {

private static final long serialVersionUID = 1L;

private transient ThreadLocal<int[]> counter = ThreadLocal.withInitial(() -> new int[] {0});

private void readObject(final ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject();
counter = ThreadLocal.withInitial(() -> new int[] {0});
}

/**
* Returns the current counter value without modifying it. Multiple calls between increments
* return the same value, making this safe to use when the counter is referenced more than once
* per element.
*
* @return the current counter value
*/
public int get() {
return counter.get()[0];
}

/**
* Increments the counter without returning a value. This is used to advance the counter after all
* references to the current value have been evaluated.
*/
public void increment() {
counter.get()[0]++;
}

/**
* Resets the counter to zero for the current thread. This should be called before evaluating each
* top-level row to ensure the index sequence starts fresh.
*/
public void reset() {
counter.get()[0] = 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -429,4 +429,50 @@ public static Column variantUnwrap(
public static Column pruneAnnotations(@Nonnull final Column col) {
return column(new PruneSyntheticFields(expression(col)));
}

/**
* Creates a read-only view of a shared {@link RowIndexCounter}. Each evaluation returns the
* current counter value without incrementing it, so multiple references within the same element
* evaluation all see the same value.
*
* <p>The counter must be advanced separately via {@link #rowCounterIncrement(Column,
* RowIndexCounter)} after all references for a given element have been evaluated.
*
* @param state the shared counter instance
* @return a Column that reads the current counter value without incrementing
*/
@Nonnull
public static Column rowCounterGet(@Nonnull final RowIndexCounter state) {
return column(new RowCounterGet(state));
}

/**
* Wraps a column expression so that the shared row counter is incremented after evaluating the
* expression. This should be applied to the extractor result in a repeat projection to ensure the
* counter advances exactly once per element.
*
* @param child the expression to evaluate before incrementing
* @param state the shared counter instance to increment
* @return a Column that evaluates the child and then increments the counter
*/
@Nonnull
public static Column rowCounterIncrement(
@Nonnull final Column child, @Nonnull final RowIndexCounter state) {
return column(new RowCounterIncrement(expression(child), state));
}

/**
* Wraps a column expression so that the shared row counter is reset to zero before evaluating the
* expression. This should be applied at the outermost level of a repeat projection to ensure the
* counter starts fresh for each resource row.
*
* @param child the expression to evaluate after resetting
* @param state the shared counter instance to reset
* @return a Column that resets the counter and then evaluates the child
*/
@Nonnull
public static Column resetCounter(
@Nonnull final Column child, @Nonnull final RowIndexCounter state) {
return column(new ResetCounter(expression(child), state));
}
}
120 changes: 120 additions & 0 deletions encoders/src/main/scala/au/csiro/pathling/encoders/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -947,4 +947,124 @@ case class UnresolvedVariantUnwrap(inner: Expression, schemaRef: Expression,
override def toString: String = s"VariantUnwrap($inner)"
}

/**
* A leaf expression that reads the current value of a [[RowIndexCounter]] without incrementing it.
* Multiple references to this expression within the same element evaluation all return the same
* value, making it safe for use when `%rowIndex` is referenced more than once.
*
* The counter must be advanced separately via [[RowCounterIncrement]] after all references for a
* given element have been evaluated.
*
* @param state the shared thread-safe counter
*/
case class RowCounterGet(state: RowIndexCounter)
extends LeafExpression with Nondeterministic {

override def stateful: Boolean = true

override def nullable: Boolean = false

override def dataType: DataType = IntegerType

override protected def initializeInternal(partitionIndex: Int): Unit = {
// No-op: reset is handled by ResetCounter at the per-row level, not per-partition.
}

override protected def evalInternal(input: InternalRow): Int = {
state.get()
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val counterRef = ctx.addReferenceObj("rowCounter", state, classOf[RowIndexCounter].getName)
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $counterRef.get();""",
isNull = FalseLiteral)
}

override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
RowCounterGet(state)
}
}

/**
* A unary expression that increments a [[RowIndexCounter]] after evaluating its child expression.
* This is used to advance the counter exactly once per element, after all `%rowIndex` references
* (via [[RowCounterGet]]) have been read.
*
* @param child the expression to evaluate before incrementing
* @param state the shared thread-safe counter to increment
*/
case class RowCounterIncrement(child: Expression, state: RowIndexCounter)
extends UnaryExpression with NonSQLExpression {

override def dataType: DataType = child.dataType

override def nullable: Boolean = child.nullable

override protected def nullSafeEval(input: Any): Any = {
// This should not be called — we override eval directly.
throw new UnsupportedOperationException(ExpressionConstants.CODEGEN_ONLY_MSG)
}

override def eval(input: InternalRow): Any = {
val result = child.eval(input)
state.increment()
result
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val counterRef = ctx.addReferenceObj("rowCounter", state, classOf[RowIndexCounter].getName)
val childEval = child.genCode(ctx)
ev.copy(code = code"""
${childEval.code}
final boolean ${ev.isNull} = ${childEval.isNull};
final ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childEval.value};
$counterRef.increment();""")
}

override protected def withNewChildInternal(newChild: Expression): Expression = {
RowCounterIncrement(newChild, state)
}
}

/**
* A unary expression that resets a [[RowIndexCounter]]'s shared state to zero before evaluating its
* child expression. This ensures the counter starts fresh for each row when used inside
* per-row array transformations.
*
* @param child the expression to evaluate after resetting
* @param state the shared thread-safe counter to reset
*/
case class ResetCounter(child: Expression, state: RowIndexCounter)
extends UnaryExpression with NonSQLExpression {

override def dataType: DataType = child.dataType

override def nullable: Boolean = child.nullable

override protected def nullSafeEval(input: Any): Any = {
// This should not be called — we override eval directly.
throw new UnsupportedOperationException(ExpressionConstants.CODEGEN_ONLY_MSG)
}

override def eval(input: InternalRow): Any = {
state.reset()
child.eval(input)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val counterRef = ctx.addReferenceObj("rowCounter", state, classOf[RowIndexCounter].getName)
val childEval = child.genCode(ctx)
ev.copy(code = code"""
$counterRef.reset();
${childEval.code}
final boolean ${ev.isNull} = ${childEval.isNull};
final ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childEval.value};""")
}

override protected def withNewChildInternal(newChild: Expression): Expression = {
ResetCounter(newChild, state)
}
}

// ColumnFunctions has been moved to a Java class to access package-private Spark methods
Original file line number Diff line number Diff line change
Expand Up @@ -292,4 +292,69 @@ void testStructProductInlineWithUnsafeRowData() {
assertEquals(expected.get(i), actual.get(i), "Row " + i + " mismatch");
}
}

/**
* Tests that {@link ValueFunctions#rowCounterGet}, {@link ValueFunctions#rowCounterIncrement},
* and {@link ValueFunctions#resetCounter} work together to assign sequential indices within an
* array transform and reset between rows.
*/
@Test
void testRowCounterExpressions() {
// Create a dataset with two rows, each containing an array of structs.
final StructType itemType =
DataTypes.createStructType(
new StructField[] {
new StructField("value", DataTypes.StringType, true, Metadata.empty())
});
final StructType schema =
DataTypes.createStructType(
new StructField[] {
new StructField("id", DataTypes.StringType, true, Metadata.empty()),
new StructField("items", DataTypes.createArrayType(itemType), true, Metadata.empty())
});

final List<Row> data =
Arrays.asList(
RowFactory.create(
"r1",
Arrays.asList(
RowFactory.create("a"), RowFactory.create("b"), RowFactory.create("c"))),
RowFactory.create("r2", Arrays.asList(RowFactory.create("x"), RowFactory.create("y"))));

final Dataset<Row> ds = spark.createDataFrame(data, schema).repartition(1);

// Build a transform that assigns a row index to each array element using the counter
// expressions.
final RowIndexCounter counter = new RowIndexCounter();
final Column indexCol = ValueFunctions.rowCounterGet(counter);

// Transform each item: struct(value, index), then increment the counter.
final Column transformed =
functions.transform(
col("items"),
item ->
ValueFunctions.rowCounterIncrement(
struct(item.getField("value").alias("value"), indexCol.alias("idx")), counter));

// Wrap with resetCounter so the index restarts at zero for each row.
final Column withReset = ValueFunctions.resetCounter(transformed, counter);

final Dataset<Row> result = ds.select(col("id"), withReset.alias("indexed_items"));
final List<Row> rows = result.collectAsList();

assertEquals(2, rows.size());

// Row 1: three items with indices 0, 1, 2.
final List<Row> items1 = rows.get(0).getList(1);
assertEquals(3, items1.size());
assertEquals(0, items1.get(0).getInt(1));
assertEquals(1, items1.get(1).getInt(1));
assertEquals(2, items1.get(2).getInt(1));

// Row 2: two items with indices 0, 1 (counter was reset).
final List<Row> items2 = rows.get(1).getList(1);
assertEquals(2, items2.size());
assertEquals(0, items2.get(0).getInt(1));
assertEquals(1, items2.get(1).getInt(1));
}
}
Loading
Loading