diff --git a/src/Microsoft.Data.Analysis/GroupBy.cs b/src/Microsoft.Data.Analysis/GroupBy.cs index 357fa80a63..93aec05b49 100644 --- a/src/Microsoft.Data.Analysis/GroupBy.cs +++ b/src/Microsoft.Data.Analysis/GroupBy.cs @@ -9,6 +9,27 @@ namespace Microsoft.Data.Analysis { + /// + /// A record to identify the row that is being aggregated that can be used to decide whether or not to include it in the aggregation. + /// + public record GroupByPredicateInput + { + /// + /// The name of the column that is being aggregated + /// + public string ColumnName { get; set; } + + /// + /// The value from the GroupBy column that this group is grouped on + /// + public object GroupKey { get; set; } + + /// + /// The value of this row within the column that is being aggregated + /// + public object RowValue { get; set; } + } + /// /// A GroupBy class that is typically the result of a DataFrame.GroupBy call. /// It holds information to perform typical aggregation ops on it. @@ -16,14 +37,31 @@ namespace Microsoft.Data.Analysis public abstract class GroupBy { /// - /// Compute the number of non-null values in each group + /// Compute the number of non-null values in each group /// + /// The columns within which to compute the number of non-null values in each group. A default value includes all columns. /// public abstract DataFrame Count(params string[] columnNames); + /// + /// Compute the number of values in each group that match a custom predicate + /// + /// A function that takes in the column name, group key, and row value and returns true to include that row in the group count or false to exclude it. + /// The columns within which to compute the number of values in each group that match the predicate. A default value includes all columns. + /// + public abstract DataFrame CountIf(Func predicate, params string[] columnNames); + + /// + /// Compute the number of distinct non-null values in each group + /// + /// The columns within which to compute the number of distinct non-null values in each group. A default value includes all columns. + /// + public abstract DataFrame CountDistinct(params string[] columnNames); + /// /// Return the first value in each group /// + /// Names of the columns to aggregate /// public abstract DataFrame First(params string[] columnNames); @@ -140,6 +178,11 @@ private void EnumerateColumnsWithRows(GroupByColumnDelegate groupByColumnDelegat } public override DataFrame Count(params string[] columnNames) + { + return CountIf(input => input.RowValue != null, columnNames); + } + + public override DataFrame CountIf(Func predicate, params string[] columnNames) { DataFrame ret = new DataFrame(); PrimitiveDataFrameColumn empty = new PrimitiveDataFrameColumn("Empty"); @@ -156,10 +199,19 @@ public override DataFrame Count(params string[] columnNames) return; DataFrameColumn column = _dataFrame.Columns[columnIndex]; long count = 0; + var groupByPredicateInput = new GroupByPredicateInput + { + ColumnName = column.Name, + GroupKey = firstColumn[rowIndex] + }; foreach (long row in rowEnumerable) { - if (column[row] != null) + groupByPredicateInput.RowValue = column[row]; + + if (predicate(groupByPredicateInput)) + { count++; + } } DataFrameColumn retColumn; if (firstGroup) @@ -182,6 +234,26 @@ public override DataFrame Count(params string[] columnNames) return ret; } + public override DataFrame CountDistinct(params string[] columnNames) + { + HashSet seenValues = []; + + return CountIf( + input => + { + if (input.RowValue == null || seenValues.Contains(input)) + { + return false; + } + + seenValues.Add(input); + + return true; + }, + columnNames + ); + } + public override DataFrame First(params string[] columnNames) { DataFrame ret = new DataFrame(); diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index 2d75caef72..6320d15aba 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -429,6 +429,50 @@ public void TestGroupBy() Assert.Equal(2, firstDecimalColumn.Rows.Count); Assert.Equal((decimal)0, firstDecimalColumn.Columns["Decimal"][0]); Assert.Equal((decimal)1, firstDecimalColumn.Columns["Decimal"][1]); + + var dfWithDuplicates = new DataFrame( + new Int32DataFrameColumn("Group", [1, 1, 1, 1, 1, 2, 2, 2, 2, 2]), + new Int32DataFrameColumn("Int", [1, 2, 3, 4, null, 1, 1, 2, 3, 4]), + new DoubleDataFrameColumn("Double", [1, 2, 3, 4, null, 1, 1, 2, 3, 4]), + new StringDataFrameColumn("String", ["1", "2", "3", "4", null, "1", "1", "2", "3", "4"]), + new DateTimeDataFrameColumn("DateTime", [ + new DateTime(2026, 1, 1, 0, 0, 0), + new DateTime(2026, 1, 1, 0, 0, 1), + new DateTime(2026, 1, 1, 0, 0, 2), + new DateTime(2026, 1, 1, 0, 0, 3), + null, + new DateTime(2026, 1, 1, 0, 0, 0), + new DateTime(2026, 1, 1, 0, 0, 0), + new DateTime(2026, 1, 1, 0, 0, 1), + new DateTime(2026, 1, 1, 0, 0, 2), + new DateTime(2026, 1, 1, 0, 0, 3) + ]) + ); + + DataFrame countDistinct = dfWithDuplicates.GroupBy("Group").CountDistinct(); + Assert.Equal(5, countDistinct.Columns.Count); + Assert.Equal(2, countDistinct.Rows.Count); + + foreach (var columnName in countDistinct.Columns.Select(c => c.Name)) + { + if (columnName == "Group") + { + continue; + } + + var column = (PrimitiveDataFrameColumn)countDistinct[columnName]; + + for (int row = 0; row < countDistinct.Rows.Count; row++) + { + Assert.Equal(4, column[row]); + } + } + + DataFrame countIf = dfWithDuplicates.GroupBy("Group").CountIf((GroupByPredicateInput input) => input.RowValue is int and < 3, "Int"); + Assert.Equal(2, countIf.Columns.Count); + Assert.Equal(2, countIf.Rows.Count); + Assert.Equal(2L, countIf["Int"][0]); + Assert.Equal(3L, countIf["Int"][1]); } [Fact]