Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8993daa

Browse filesBrowse files
committed
CSHARP-4453: Support Bucket and BucketAuto stages in LINQ3.
1 parent ec46c34 commit 8993daa
Copy full SHA for 8993daa

20 files changed

+741
-514
lines changed
+103Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using MongoDB.Bson;
17+
using MongoDB.Bson.IO;
18+
using MongoDB.Bson.Serialization;
19+
using MongoDB.Bson.Serialization.Serializers;
20+
using MongoDB.Driver.Core.Misc;
21+
22+
namespace MongoDB.Driver
23+
{
24+
/// <summary>
25+
/// Static factory class for AggregateBucketAutoResultIdSerializer.
26+
/// </summary>
27+
public static class AggregateBucketAutoResultIdSerializer
28+
{
29+
/// <summary>
30+
/// Creates an instance of AggregateBucketAutoResultIdSerializer.
31+
/// </summary>
32+
/// <typeparam name="TValue">The value type.</typeparam>
33+
/// <param name="valueSerializer">The value serializer.</param>
34+
/// <returns>A AggregateBucketAutoResultIdSerializer.</returns>
35+
public static IBsonSerializer<AggregateBucketAutoResultId<TValue>> Create<TValue>(IBsonSerializer<TValue> valueSerializer)
36+
{
37+
return new AggregateBucketAutoResultIdSerializer<TValue>(valueSerializer);
38+
}
39+
}
40+
41+
/// <summary>
42+
/// A serializer for AggregateBucketAutoResultId.
43+
/// </summary>
44+
/// <typeparam name="TValue">The type of the values.</typeparam>
45+
public class AggregateBucketAutoResultIdSerializer<TValue> : ClassSerializerBase<AggregateBucketAutoResultId<TValue>>, IBsonDocumentSerializer
46+
{
47+
private readonly IBsonSerializer<TValue> _valueSerializer;
48+
49+
/// <summary>
50+
/// Initializes a new instance of the <see cref="AggregateBucketAutoResultIdSerializer{TValue}"/> class.
51+
/// </summary>
52+
/// <param name="valueSerializer">The value serializer.</param>
53+
public AggregateBucketAutoResultIdSerializer(IBsonSerializer<TValue> valueSerializer)
54+
{
55+
_valueSerializer = Ensure.IsNotNull(valueSerializer, nameof(valueSerializer));
56+
}
57+
58+
/// <inheritdoc/>
59+
protected override AggregateBucketAutoResultId<TValue> DeserializeValue(BsonDeserializationContext context, BsonDeserializationArgs args)
60+
{
61+
var reader = context.Reader;
62+
reader.ReadStartDocument();
63+
TValue min = default;
64+
TValue max = default;
65+
while (reader.ReadBsonType() != 0)
66+
{
67+
var name = reader.ReadName();
68+
switch (name)
69+
{
70+
case "min": min = _valueSerializer.Deserialize(context); break;
71+
case "max": max = _valueSerializer.Deserialize(context); break;
72+
default: throw new BsonSerializationException($"Invalid element name for AggregateBucketAutoResultId: {name}.");
73+
}
74+
}
75+
reader.ReadEndDocument();
76+
return new AggregateBucketAutoResultId<TValue>(min, max);
77+
}
78+
79+
/// <inheritdoc/>
80+
protected override void SerializeValue(BsonSerializationContext context, BsonSerializationArgs args, AggregateBucketAutoResultId<TValue> value)
81+
{
82+
var writer = context.Writer;
83+
writer.WriteStartDocument();
84+
writer.WriteName("min");
85+
_valueSerializer.Serialize(context, value.Min);
86+
writer.WriteName("max");
87+
_valueSerializer.Serialize(context, value.Max);
88+
writer.WriteEndDocument();
89+
}
90+
91+
/// <inheritdoc/>
92+
public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo)
93+
{
94+
serializationInfo = memberName switch
95+
{
96+
"Min" => new BsonSerializationInfo("min", _valueSerializer, _valueSerializer.ValueType),
97+
"Max" => new BsonSerializationInfo("max", _valueSerializer, _valueSerializer.ValueType),
98+
_ => null
99+
};
100+
return serializationInfo != null;
101+
}
102+
}
103+
}

‎src/MongoDB.Driver/GroupForLinq3Result.cs

Copy file name to clipboardExpand all lines: src/MongoDB.Driver/GroupForLinq3Result.cs
-57Lines changed: 0 additions & 57 deletions
This file was deleted.

‎src/MongoDB.Driver/IAggregateFluentExtensions.cs

Copy file name to clipboardExpand all lines: src/MongoDB.Driver/IAggregateFluentExtensions.cs
+36-11Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public static IAggregateFluent<AggregateBucketAutoResult<TValue>> BucketAuto<TRe
9595
}
9696

9797
/// <summary>
98-
/// Appends a $bucketAuto stage to the pipeline.
98+
/// Appends a $bucketAuto stage to the pipeline (this overload can only be used with LINQ3).
9999
/// </summary>
100100
/// <typeparam name="TResult">The type of the result.</typeparam>
101101
/// <typeparam name="TValue">The type of the value.</typeparam>
@@ -110,13 +110,46 @@ public static IAggregateFluent<TNewResult> BucketAuto<TResult, TValue, TNewResul
110110
this IAggregateFluent<TResult> aggregate,
111111
Expression<Func<TResult, TValue>> groupBy,
112112
int buckets,
113-
Expression<Func<IGrouping<TValue, TResult>, TNewResult>> output,
113+
Expression<Func<IGrouping<AggregateBucketAutoResultId<TValue>, TResult>, TNewResult>> output,
114114
AggregateBucketAutoOptions options = null)
115115
{
116116
Ensure.IsNotNull(aggregate, nameof(aggregate));
117+
if (aggregate.Database.Client.Settings.LinqProvider != LinqProvider.V3)
118+
{
119+
throw new InvalidOperationException("This overload of BucketAuto can only be used with LINQ3.");
120+
}
121+
117122
return aggregate.AppendStage(PipelineStageDefinitionBuilder.BucketAuto(groupBy, buckets, output, options));
118123
}
119124

125+
/// <summary>
126+
/// Appends a $bucketAuto stage to the pipeline (this method can only be used with LINQ2).
127+
/// </summary>
128+
/// <typeparam name="TResult">The type of the result.</typeparam>
129+
/// <typeparam name="TValue">The type of the value.</typeparam>
130+
/// <typeparam name="TNewResult">The type of the new result.</typeparam>
131+
/// <param name="aggregate">The aggregate.</param>
132+
/// <param name="groupBy">The expression providing the value to group by.</param>
133+
/// <param name="buckets">The number of buckets.</param>
134+
/// <param name="output">The output projection.</param>
135+
/// <param name="options">The options (optional).</param>
136+
/// <returns>The fluent aggregate interface.</returns>
137+
public static IAggregateFluent<TNewResult> BucketAutoForLinq2<TResult, TValue, TNewResult>(
138+
this IAggregateFluent<TResult> aggregate,
139+
Expression<Func<TResult, TValue>> groupBy,
140+
int buckets,
141+
Expression<Func<IGrouping<TValue, TResult>, TNewResult>> output, // the IGrouping for BucketAuto has been wrong all along, only fixing it for LINQ3
142+
AggregateBucketAutoOptions options = null)
143+
{
144+
Ensure.IsNotNull(aggregate, nameof(aggregate));
145+
if (aggregate.Database.Client.Settings.LinqProvider != LinqProvider.V2)
146+
{
147+
throw new InvalidOperationException("The BucketAutoForLinq2 method can only be used with LINQ2.");
148+
}
149+
150+
return aggregate.AppendStage(PipelineStageDefinitionBuilder.BucketAutoForLinq2(groupBy, buckets, output, options));
151+
}
152+
120153
/// <summary>
121154
/// Appends a $densify stage to the pipeline.
122155
/// </summary>
@@ -396,15 +429,7 @@ public static IAggregateFluent<BsonDocument> Group<TResult>(this IAggregateFluen
396429
public static IAggregateFluent<TNewResult> Group<TResult, TKey, TNewResult>(this IAggregateFluent<TResult> aggregate, Expression<Func<TResult, TKey>> id, Expression<Func<IGrouping<TKey, TResult>, TNewResult>> group)
397430
{
398431
Ensure.IsNotNull(aggregate, nameof(aggregate));
399-
if (aggregate.Database.Client.Settings.LinqProvider == LinqProvider.V2)
400-
{
401-
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Group(id, group));
402-
}
403-
else
404-
{
405-
var (groupStage, projectStage) = PipelineStageDefinitionBuilder.GroupForLinq3(id, group);
406-
return aggregate.AppendStage(groupStage).AppendStage(projectStage);
407-
}
432+
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Group(id, group));
408433
}
409434

410435
/// <summary>

‎src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupPipelineOptimizer.cs renamed to ‎src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs

Copy file name to clipboardExpand all lines: src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs
+72-18Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,41 +24,50 @@
2424

2525
namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Optimizers
2626
{
27-
internal class AstGroupPipelineOptimizer
27+
internal class AstGroupingPipelineOptimizer
2828
{
2929
#region static
3030
public static AstPipeline Optimize(AstPipeline pipeline)
3131
{
32-
var optimizer = new AstGroupPipelineOptimizer();
32+
var optimizer = new AstGroupingPipelineOptimizer();
3333
for (var i = 0; i < pipeline.Stages.Count; i++)
3434
{
3535
var stage = pipeline.Stages[i];
36-
if (stage is AstGroupStage groupStage)
36+
if (IsGroupingStage(stage))
3737
{
38-
pipeline = optimizer.OptimizeGroupStage(pipeline, i, groupStage);
38+
pipeline = optimizer.OptimizeGroupingStage(pipeline, i, stage);
3939
}
4040
}
4141

4242
return pipeline;
43+
44+
static bool IsGroupingStage(AstStage stage)
45+
{
46+
return stage.NodeType switch
47+
{
48+
AstNodeType.GroupStage or AstNodeType.BucketStage or AstNodeType.BucketAutoStage => true,
49+
_ => false
50+
};
51+
}
4352
}
4453
#endregion
4554

4655
private readonly AccumulatorSet _accumulators = new AccumulatorSet();
4756
private AstExpression _element; // normally either "$$ROOT" or "$_v"
4857

49-
private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStage groupStage)
58+
private AstPipeline OptimizeGroupingStage(AstPipeline pipeline, int i, AstStage groupingStage)
5059
{
5160
try
5261
{
53-
if (IsOptimizableGroupStage(groupStage, out _element))
62+
if (IsOptimizableGroupingStage(groupingStage, out _element))
5463
{
5564
var followingStages = GetFollowingStagesToOptimize(pipeline, i + 1);
5665
if (followingStages == null)
5766
{
5867
return pipeline;
5968
}
6069

61-
var mappings = OptimizeGroupAndFollowingStages(groupStage, followingStages);
70+
var mappings = OptimizeGroupingAndFollowingStages(groupingStage, followingStages);
6271
if (mappings.Length > 0)
6372
{
6473
return (AstPipeline)AstNodeReplacer.Replace(pipeline, mappings);
@@ -72,23 +81,57 @@ private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStag
7281

7382
return pipeline;
7483

75-
static bool IsOptimizableGroupStage(AstGroupStage groupStage, out AstExpression element)
84+
static bool IsOptimizableGroupingStage(AstStage groupingStage, out AstExpression element)
7685
{
77-
// { $group : { _id : ?, _elements : { $push : element } } }
78-
if (groupStage.Fields.Count == 1)
86+
if (groupingStage is AstGroupStage groupStage)
87+
{
88+
// { $group : { _id : ?, _elements : { $push : element } } }
89+
if (groupStage.Fields.Count == 1)
90+
{
91+
var field = groupStage.Fields[0];
92+
return IsElementsPush(field, out element);
93+
}
94+
}
95+
96+
if (groupingStage is AstBucketStage bucketStage)
97+
{
98+
// { $bucket : { groupBy : ?, boundaries : ?, default : ?, output : { _elements : { $push : element } } } }
99+
if (bucketStage.Output.Count == 1)
100+
{
101+
var output = bucketStage.Output[0];
102+
return IsElementsPush(output, out element);
103+
}
104+
}
105+
106+
if (groupingStage is AstBucketAutoStage bucketAutoStage)
79107
{
80-
var field = groupStage.Fields[0];
81-
if (field.Path == "_elements" &&
108+
// { $bucketAuto : { groupBy : ?, buckets : ?, granularity : ?, output : { _elements : { $push : element } } } }
109+
if (bucketAutoStage.Output.Count == 1)
110+
{
111+
var output = bucketAutoStage.Output[0];
112+
return IsElementsPush(output, out element);
113+
}
114+
}
115+
116+
element = null;
117+
return false;
118+
119+
static bool IsElementsPush(AstAccumulatorField field, out AstExpression element)
120+
{
121+
if (
122+
field.Path == "_elements" &&
82123
field.Value is AstUnaryAccumulatorExpression unaryAccumulatorExpression &&
83124
unaryAccumulatorExpression.Operator == AstUnaryAccumulatorOperator.Push)
84125
{
85126
element = unaryAccumulatorExpression.Arg;
86127
return true;
87128
}
129+
else
130+
{
131+
element = null;
132+
return false;
133+
}
88134
}
89-
90-
element = null;
91-
return false;
92135
}
93136

94137
static List<AstStage> GetFollowingStagesToOptimize(AstPipeline pipeline, int from)
@@ -135,7 +178,7 @@ static bool IsLastStageThatCanBeOptimized(AstStage stage)
135178
}
136179
}
137180

138-
private (AstNode, AstNode)[] OptimizeGroupAndFollowingStages(AstGroupStage groupStage, List<AstStage> followingStages)
181+
private (AstNode, AstNode)[] OptimizeGroupingAndFollowingStages(AstStage groupingStage, List<AstStage> followingStages)
139182
{
140183
var mappings = new List<(AstNode, AstNode)>();
141184

@@ -148,10 +191,21 @@ static bool IsLastStageThatCanBeOptimized(AstStage stage)
148191
}
149192
}
150193

151-
var newGroupStage = AstStage.Group(groupStage.Id, _accumulators);
152-
mappings.Add((groupStage, newGroupStage));
194+
var newGroupingStage = CreateNewGroupingStage(groupingStage, _accumulators);
195+
mappings.Add((groupingStage, newGroupingStage));
153196

154197
return mappings.ToArray();
198+
199+
static AstStage CreateNewGroupingStage(AstStage groupingStage, AccumulatorSet accumulators)
200+
{
201+
return groupingStage switch
202+
{
203+
AstGroupStage groupStage => AstStage.Group(groupStage.Id, accumulators),
204+
AstBucketStage bucketStage => AstStage.Bucket(bucketStage.GroupBy, bucketStage.Boundaries, bucketStage.Default, accumulators),
205+
AstBucketAutoStage bucketAutoStage => AstStage.BucketAuto(bucketAutoStage.GroupBy, bucketAutoStage.Buckets, bucketAutoStage.Granularity, accumulators),
206+
_ => throw new Exception($"Unexpected {nameof(groupingStage)} node type: {groupingStage.NodeType}.")
207+
};
208+
}
155209
}
156210

157211
private AstStage OptimizeFollowingStage(AstStage stage)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.