Skip to content

Commit 227b818

Browse files
authored
feat: SUM / AVG (#1218)
* feat: SUM / AVG (real files) * Remove aggregate field duplicates (if any). * Clean up and fixes. * Clean up comments, and add Nonnull where possible. * Add more public docs. * More cleanup. * Update hashCode and equals for AggregateQuery. * Address code review comments. more to come. * fix test name. * Better comment. * Fix alias encoding. * Remove TODO. * Revert the way alias is constructed. * Backport test updates.
1 parent 6190d4a commit 227b818

File tree

10 files changed

+1364
-120
lines changed

10 files changed

+1364
-120
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright 2023 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.firestore;
18+
19+
import java.util.Objects;
20+
import javax.annotation.Nonnull;
21+
import javax.annotation.Nullable;
22+
23+
public abstract class AggregateField {
24+
@Nonnull
25+
public static CountAggregateField count() {
26+
return new CountAggregateField();
27+
}
28+
29+
@Nonnull
30+
public static SumAggregateField sum(@Nonnull String field) {
31+
return new SumAggregateField(FieldPath.fromDotSeparatedString(field));
32+
}
33+
34+
@Nonnull
35+
public static SumAggregateField sum(@Nonnull FieldPath fieldPath) {
36+
return new SumAggregateField(fieldPath);
37+
}
38+
39+
@Nonnull
40+
public static AverageAggregateField average(@Nonnull String field) {
41+
return new AverageAggregateField(FieldPath.fromDotSeparatedString(field));
42+
}
43+
44+
@Nonnull
45+
public static AverageAggregateField average(@Nonnull FieldPath fieldPath) {
46+
return new AverageAggregateField(fieldPath);
47+
}
48+
49+
@Nullable FieldPath fieldPath;
50+
51+
/** Returns the alias used internally for this aggregate field. */
52+
@Nonnull
53+
String getAlias() {
54+
// Use $operator_$field format if it's an aggregation of a specific field. For example: sum_foo.
55+
// Use $operator format if there's no field. For example: count.
56+
return getOperator() + (fieldPath == null ? "" : "_" + fieldPath.getEncodedPath());
57+
}
58+
59+
/**
60+
* Returns the field on which the aggregation takes place. Returns an empty string if there's no
61+
* field (e.g. for count).
62+
*/
63+
@Nonnull
64+
String getFieldPath() {
65+
return fieldPath == null ? "" : fieldPath.getEncodedPath();
66+
}
67+
68+
/** Returns a string representation of this aggregation's operator. For example: "sum" */
69+
abstract @Nonnull String getOperator();
70+
71+
/**
72+
* Returns true if the given object is equal to this object. Two `AggregateField` objects are
73+
* considered equal if they have the same operator and operate on the same field.
74+
*/
75+
@Override
76+
public boolean equals(Object other) {
77+
if (this == other) {
78+
return true;
79+
}
80+
if (!(other instanceof AggregateField)) {
81+
return false;
82+
}
83+
AggregateField otherAggregateField = (AggregateField) other;
84+
return getOperator().equals(otherAggregateField.getOperator())
85+
&& getFieldPath().equals(otherAggregateField.getFieldPath());
86+
}
87+
88+
/** Calculates and returns the hash code for this object. */
89+
@Override
90+
public int hashCode() {
91+
return Objects.hash(getOperator(), getFieldPath());
92+
}
93+
94+
public static class SumAggregateField extends AggregateField {
95+
private SumAggregateField(@Nonnull FieldPath field) {
96+
fieldPath = field;
97+
}
98+
99+
@Override
100+
@Nonnull
101+
public String getOperator() {
102+
return "sum";
103+
}
104+
}
105+
106+
public static class AverageAggregateField extends AggregateField {
107+
private AverageAggregateField(@Nonnull FieldPath field) {
108+
fieldPath = field;
109+
}
110+
111+
@Override
112+
@Nonnull
113+
public String getOperator() {
114+
return "average";
115+
}
116+
}
117+
118+
public static class CountAggregateField extends AggregateField {
119+
private CountAggregateField() {}
120+
121+
@Override
122+
@Nonnull
123+
public String getOperator() {
124+
return "count";
125+
}
126+
}
127+
}

google-cloud-firestore/src/main/java/com/google/cloud/firestore/AggregateQuery.java

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,10 @@
2525
import com.google.api.gax.rpc.StreamController;
2626
import com.google.cloud.Timestamp;
2727
import com.google.cloud.firestore.v1.FirestoreSettings;
28-
import com.google.firestore.v1.RunAggregationQueryRequest;
29-
import com.google.firestore.v1.RunAggregationQueryResponse;
30-
import com.google.firestore.v1.RunQueryRequest;
31-
import com.google.firestore.v1.StructuredAggregationQuery;
32-
import com.google.firestore.v1.Value;
28+
import com.google.firestore.v1.*;
29+
import com.google.firestore.v1.StructuredAggregationQuery.Aggregation;
3330
import com.google.protobuf.ByteString;
34-
import java.util.Set;
31+
import java.util.*;
3532
import java.util.concurrent.atomic.AtomicBoolean;
3633
import javax.annotation.Nonnull;
3734
import javax.annotation.Nullable;
@@ -49,8 +46,11 @@ public class AggregateQuery {
4946

5047
@Nonnull private final Query query;
5148

52-
AggregateQuery(@Nonnull Query query) {
49+
@Nonnull private List<AggregateField> aggregateFieldList;
50+
51+
AggregateQuery(@Nonnull Query query, @Nonnull List<AggregateField> aggregateFields) {
5352
this.query = query;
53+
this.aggregateFieldList = aggregateFields;
5454
}
5555

5656
/** Returns the query whose aggregations will be calculated by this object. */
@@ -112,9 +112,9 @@ long getStartTimeNanos() {
112112
return startTimeNanos;
113113
}
114114

115-
void deliverResult(long count, Timestamp readTime) {
115+
void deliverResult(@Nonnull Map<String, Value> data, Timestamp readTime) {
116116
if (isFutureCompleted.compareAndSet(false, true)) {
117-
future.set(new AggregateQuerySnapshot(AggregateQuery.this, readTime, count));
117+
future.set(new AggregateQuerySnapshot(AggregateQuery.this, readTime, data));
118118
}
119119
}
120120

@@ -145,26 +145,13 @@ public void onResponse(RunAggregationQueryResponse response) {
145145
// Close the stream to avoid it dangling, since we're not expecting any more responses.
146146
streamController.cancel();
147147

148-
// Extract the count and read time from the RunAggregationQueryResponse.
148+
// Extract the aggregations and read time from the RunAggregationQueryResponse.
149149
Timestamp readTime = Timestamp.fromProto(response.getReadTime());
150-
Value value = response.getResult().getAggregateFieldsMap().get(ALIAS_COUNT);
151-
if (value == null) {
152-
throw new IllegalArgumentException(
153-
"RunAggregationQueryResponse is missing required alias: " + ALIAS_COUNT);
154-
} else if (value.getValueTypeCase() != Value.ValueTypeCase.INTEGER_VALUE) {
155-
throw new IllegalArgumentException(
156-
"RunAggregationQueryResponse alias "
157-
+ ALIAS_COUNT
158-
+ " has incorrect type: "
159-
+ value.getValueTypeCase());
160-
}
161-
long count = value.getIntegerValue();
162150

163151
// Deliver the result; even though the `RunAggregationQuery` RPC is a "streaming" RPC, meaning
164-
// that `onResponse()` can be called multiple times, it _should_ only be called once for count
165-
// queries. But even if it is called more than once, `responseDeliverer` will drop superfluous
166-
// results.
167-
responseDeliverer.deliverResult(count, readTime);
152+
// that `onResponse()` can be called multiple times, it _should_ only be called once. But even
153+
// if it is called more than once, `responseDeliverer` will drop superfluous results.
154+
responseDeliverer.deliverResult(response.getResult().getAggregateFieldsMap(), readTime);
168155
}
169156

170157
@Override
@@ -215,12 +202,32 @@ RunAggregationQueryRequest toProto(@Nullable final ByteString transactionId) {
215202
request.getStructuredAggregationQueryBuilder();
216203
structuredAggregationQuery.setStructuredQuery(runQueryRequest.getStructuredQuery());
217204

218-
StructuredAggregationQuery.Aggregation.Builder aggregation =
219-
StructuredAggregationQuery.Aggregation.newBuilder();
220-
aggregation.setCount(StructuredAggregationQuery.Aggregation.Count.getDefaultInstance());
221-
aggregation.setAlias(ALIAS_COUNT);
222-
structuredAggregationQuery.addAggregations(aggregation);
223-
205+
// We use a Set here to automatically remove duplicates.
206+
Set<StructuredAggregationQuery.Aggregation> aggregations = new HashSet<>();
207+
for (AggregateField aggregateField : aggregateFieldList) {
208+
// If there's a field for this aggregation, build its proto.
209+
StructuredQuery.FieldReference field = null;
210+
if (!aggregateField.getFieldPath().isEmpty()) {
211+
field =
212+
StructuredQuery.FieldReference.newBuilder()
213+
.setFieldPath(aggregateField.getFieldPath())
214+
.build();
215+
}
216+
// Build the aggregation proto.
217+
Aggregation.Builder aggregation = Aggregation.newBuilder();
218+
if (aggregateField instanceof AggregateField.CountAggregateField) {
219+
aggregation.setCount(Aggregation.Count.getDefaultInstance());
220+
} else if (aggregateField instanceof AggregateField.SumAggregateField) {
221+
aggregation.setSum(Aggregation.Sum.newBuilder().setField(field).build());
222+
} else if (aggregateField instanceof AggregateField.AverageAggregateField) {
223+
aggregation.setAvg(Aggregation.Avg.newBuilder().setField(field).build());
224+
} else {
225+
throw new RuntimeException("Unsupported aggregation");
226+
}
227+
aggregation.setAlias(aggregateField.getAlias());
228+
aggregations.add(aggregation.build());
229+
}
230+
structuredAggregationQuery.addAllAggregations(aggregations);
224231
return request.build();
225232
}
226233

@@ -243,7 +250,23 @@ public static AggregateQuery fromProto(Firestore firestore, RunAggregationQueryR
243250
.setStructuredQuery(proto.getStructuredAggregationQuery().getStructuredQuery())
244251
.build();
245252
Query query = Query.fromProto(firestore, runQueryRequest);
246-
return new AggregateQuery(query);
253+
254+
List<AggregateField> aggregateFields = new ArrayList<>();
255+
List<Aggregation> aggregations = proto.getStructuredAggregationQuery().getAggregationsList();
256+
aggregations.forEach(
257+
aggregation -> {
258+
if (aggregation.hasCount()) {
259+
aggregateFields.add(AggregateField.count());
260+
} else if (aggregation.hasAvg()) {
261+
aggregateFields.add(
262+
AggregateField.average(aggregation.getAvg().getField().getFieldPath()));
263+
} else if (aggregation.hasSum()) {
264+
aggregateFields.add(AggregateField.sum(aggregation.getSum().getField().getFieldPath()));
265+
} else {
266+
throw new RuntimeException("Unsupported aggregation.");
267+
}
268+
});
269+
return new AggregateQuery(query, aggregateFields);
247270
}
248271

249272
/**
@@ -253,7 +276,7 @@ public static AggregateQuery fromProto(Firestore firestore, RunAggregationQueryR
253276
*/
254277
@Override
255278
public int hashCode() {
256-
return query.hashCode();
279+
return Objects.hash(query, aggregateFieldList);
257280
}
258281

259282
/**
@@ -280,6 +303,6 @@ public boolean equals(Object object) {
280303
return false;
281304
}
282305
AggregateQuery other = (AggregateQuery) object;
283-
return query.equals(other.query);
306+
return query.equals(other.query) && aggregateFieldList.equals(other.aggregateFieldList);
284307
}
285308
}

0 commit comments

Comments
 (0)