Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 74 additions & 2 deletions src/Microsoft.Data.Analysis/GroupBy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,59 @@

namespace Microsoft.Data.Analysis
{
/// <summary>
/// A record to identify the row that is being aggregated that can be used to decide whether or not to include it in the aggregation.
/// </summary>
public record GroupByPredicateInput
{
/// <summary>
/// The name of the column that is being aggregated
/// </summary>
public string ColumnName { get; set; }

/// <summary>
/// The value from the GroupBy column that this group is grouped on
/// </summary>
public object GroupKey { get; set; }

/// <summary>
/// The value of this row within the column that is being aggregated
/// </summary>
public object RowValue { get; set; }
}

/// <summary>
/// A GroupBy class that is typically the result of a DataFrame.GroupBy call.
/// It holds information to perform typical aggregation ops on it.
/// </summary>
public abstract class GroupBy
{
/// <summary>
/// Compute the number of non-null values in each group
/// Compute the number of non-null values in each group
/// </summary>
/// <param name="columnNames">The columns within which to compute the number of non-null values in each group. A default value includes all columns.</param>
/// <returns></returns>
public abstract DataFrame Count(params string[] columnNames);

/// <summary>
/// Compute the number of values in each group that match a custom predicate
/// </summary>
/// <param name="predicate">A function that takes in the column name, group key, and row value and returns true to include that row in the group count or false to exclude it.</param>
/// <param name="columnNames">The columns within which to compute the number of values in each group that match the predicate. A default value includes all columns.</param>
/// <returns></returns>
public abstract DataFrame CountIf(Func<GroupByPredicateInput, bool> predicate, params string[] columnNames);

/// <summary>
/// Compute the number of distinct non-null values in each group
/// </summary>
/// <param name="columnNames">The columns within which to compute the number of distinct non-null values in each group. A default value includes all columns.</param>
/// <returns></returns>
public abstract DataFrame CountDistinct(params string[] columnNames);

/// <summary>
/// Return the first value in each group
/// </summary>
/// <param name="columnNames">Names of the columns to aggregate</param>
/// <returns></returns>
public abstract DataFrame First(params string[] columnNames);

Expand Down Expand Up @@ -140,6 +178,11 @@ private void EnumerateColumnsWithRows(GroupByColumnDelegate groupByColumnDelegat
}

public override DataFrame Count(params string[] columnNames)
{
return CountIf(input => input.RowValue != null, columnNames);
}

public override DataFrame CountIf(Func<GroupByPredicateInput, bool> predicate, params string[] columnNames)
{
DataFrame ret = new DataFrame();
PrimitiveDataFrameColumn<long> empty = new PrimitiveDataFrameColumn<long>("Empty");
Expand All @@ -156,10 +199,19 @@ public override DataFrame Count(params string[] columnNames)
return;
DataFrameColumn column = _dataFrame.Columns[columnIndex];
long count = 0;
var groupByPredicateInput = new GroupByPredicateInput
{
ColumnName = column.Name,
GroupKey = firstColumn[rowIndex]
};
foreach (long row in rowEnumerable)
{
if (column[row] != null)
groupByPredicateInput.RowValue = column[row];

if (predicate(groupByPredicateInput))
{
count++;
}
}
DataFrameColumn retColumn;
if (firstGroup)
Expand All @@ -182,6 +234,26 @@ public override DataFrame Count(params string[] columnNames)
return ret;
}

public override DataFrame CountDistinct(params string[] columnNames)
{
HashSet<GroupByPredicateInput> seenValues = [];

return CountIf(
input =>
{
if (input.RowValue == null || seenValues.Contains(input))
{
return false;
}

seenValues.Add(input);

return true;
},
columnNames
);
}

public override DataFrame First(params string[] columnNames)
{
DataFrame ret = new DataFrame();
Expand Down
44 changes: 44 additions & 0 deletions test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,50 @@ public void TestGroupBy()
Assert.Equal(2, firstDecimalColumn.Rows.Count);
Assert.Equal((decimal)0, firstDecimalColumn.Columns["Decimal"][0]);
Assert.Equal((decimal)1, firstDecimalColumn.Columns["Decimal"][1]);

var dfWithDuplicates = new DataFrame(
new Int32DataFrameColumn("Group", [1, 1, 1, 1, 1, 2, 2, 2, 2, 2]),
new Int32DataFrameColumn("Int", [1, 2, 3, 4, null, 1, 1, 2, 3, 4]),
new DoubleDataFrameColumn("Double", [1, 2, 3, 4, null, 1, 1, 2, 3, 4]),
new StringDataFrameColumn("String", ["1", "2", "3", "4", null, "1", "1", "2", "3", "4"]),
new DateTimeDataFrameColumn("DateTime", [
new DateTime(2026, 1, 1, 0, 0, 0),
new DateTime(2026, 1, 1, 0, 0, 1),
new DateTime(2026, 1, 1, 0, 0, 2),
new DateTime(2026, 1, 1, 0, 0, 3),
null,
new DateTime(2026, 1, 1, 0, 0, 0),
new DateTime(2026, 1, 1, 0, 0, 0),
new DateTime(2026, 1, 1, 0, 0, 1),
new DateTime(2026, 1, 1, 0, 0, 2),
new DateTime(2026, 1, 1, 0, 0, 3)
])
);

DataFrame countDistinct = dfWithDuplicates.GroupBy("Group").CountDistinct();
Assert.Equal(5, countDistinct.Columns.Count);
Assert.Equal(2, countDistinct.Rows.Count);

foreach (var columnName in countDistinct.Columns.Select(c => c.Name))
{
if (columnName == "Group")
{
continue;
}

var column = (PrimitiveDataFrameColumn<long>)countDistinct[columnName];

for (int row = 0; row < countDistinct.Rows.Count; row++)
{
Assert.Equal(4, column[row]);
}
}

DataFrame countIf = dfWithDuplicates.GroupBy("Group").CountIf((GroupByPredicateInput input) => input.RowValue is int and < 3, "Int");
Assert.Equal(2, countIf.Columns.Count);
Assert.Equal(2, countIf.Rows.Count);
Assert.Equal(2L, countIf["Int"][0]);
Assert.Equal(3L, countIf["Int"][1]);
}

[Fact]
Expand Down
Loading