2525import com .google .api .gax .rpc .StreamController ;
2626import com .google .cloud .Timestamp ;
2727import com .google .cloud .firestore .v1 .FirestoreSettings ;
28- import com .google .firestore .v1 .RunAggregationQueryRequest ;
29- import com .google .firestore .v1 .RunAggregationQueryResponse ;
30- import com .google .firestore .v1 .RunQueryRequest ;
31- import com .google .firestore .v1 .StructuredAggregationQuery ;
32- import com .google .firestore .v1 .Value ;
28+ import com .google .firestore .v1 .*;
29+ import com .google .firestore .v1 .StructuredAggregationQuery .Aggregation ;
3330import com .google .protobuf .ByteString ;
34- import java .util .Set ;
31+ import java .util .* ;
3532import java .util .concurrent .atomic .AtomicBoolean ;
3633import javax .annotation .Nonnull ;
3734import javax .annotation .Nullable ;
@@ -49,8 +46,11 @@ public class AggregateQuery {
4946
5047 @ Nonnull private final Query query ;
5148
52- AggregateQuery (@ Nonnull Query query ) {
49+ @ Nonnull private List <AggregateField > aggregateFieldList ;
50+
51+ AggregateQuery (@ Nonnull Query query , @ Nonnull List <AggregateField > aggregateFields ) {
5352 this .query = query ;
53+ this .aggregateFieldList = aggregateFields ;
5454 }
5555
5656 /** Returns the query whose aggregations will be calculated by this object. */
@@ -112,9 +112,9 @@ long getStartTimeNanos() {
112112 return startTimeNanos ;
113113 }
114114
115- void deliverResult (long count , Timestamp readTime ) {
115+ void deliverResult (@ Nonnull Map < String , Value > data , Timestamp readTime ) {
116116 if (isFutureCompleted .compareAndSet (false , true )) {
117- future .set (new AggregateQuerySnapshot (AggregateQuery .this , readTime , count ));
117+ future .set (new AggregateQuerySnapshot (AggregateQuery .this , readTime , data ));
118118 }
119119 }
120120
@@ -145,26 +145,13 @@ public void onResponse(RunAggregationQueryResponse response) {
145145 // Close the stream to avoid it dangling, since we're not expecting any more responses.
146146 streamController .cancel ();
147147
148- // Extract the count and read time from the RunAggregationQueryResponse.
148+ // Extract the aggregations and read time from the RunAggregationQueryResponse.
149149 Timestamp readTime = Timestamp .fromProto (response .getReadTime ());
150- Value value = response .getResult ().getAggregateFieldsMap ().get (ALIAS_COUNT );
151- if (value == null ) {
152- throw new IllegalArgumentException (
153- "RunAggregationQueryResponse is missing required alias: " + ALIAS_COUNT );
154- } else if (value .getValueTypeCase () != Value .ValueTypeCase .INTEGER_VALUE ) {
155- throw new IllegalArgumentException (
156- "RunAggregationQueryResponse alias "
157- + ALIAS_COUNT
158- + " has incorrect type: "
159- + value .getValueTypeCase ());
160- }
161- long count = value .getIntegerValue ();
162150
163151 // Deliver the result; even though the `RunAggregationQuery` RPC is a "streaming" RPC, meaning
164- // that `onResponse()` can be called multiple times, it _should_ only be called once for count
165- // queries. But even if it is called more than once, `responseDeliverer` will drop superfluous
166- // results.
167- responseDeliverer .deliverResult (count , readTime );
152+ // that `onResponse()` can be called multiple times, it _should_ only be called once. But even
153+ // if it is called more than once, `responseDeliverer` will drop superfluous results.
154+ responseDeliverer .deliverResult (response .getResult ().getAggregateFieldsMap (), readTime );
168155 }
169156
170157 @ Override
@@ -215,12 +202,32 @@ RunAggregationQueryRequest toProto(@Nullable final ByteString transactionId) {
215202 request .getStructuredAggregationQueryBuilder ();
216203 structuredAggregationQuery .setStructuredQuery (runQueryRequest .getStructuredQuery ());
217204
218- StructuredAggregationQuery .Aggregation .Builder aggregation =
219- StructuredAggregationQuery .Aggregation .newBuilder ();
220- aggregation .setCount (StructuredAggregationQuery .Aggregation .Count .getDefaultInstance ());
221- aggregation .setAlias (ALIAS_COUNT );
222- structuredAggregationQuery .addAggregations (aggregation );
223-
205+ // We use a Set here to automatically remove duplicates.
206+ Set <StructuredAggregationQuery .Aggregation > aggregations = new HashSet <>();
207+ for (AggregateField aggregateField : aggregateFieldList ) {
208+ // If there's a field for this aggregation, build its proto.
209+ StructuredQuery .FieldReference field = null ;
210+ if (!aggregateField .getFieldPath ().isEmpty ()) {
211+ field =
212+ StructuredQuery .FieldReference .newBuilder ()
213+ .setFieldPath (aggregateField .getFieldPath ())
214+ .build ();
215+ }
216+ // Build the aggregation proto.
217+ Aggregation .Builder aggregation = Aggregation .newBuilder ();
218+ if (aggregateField instanceof AggregateField .CountAggregateField ) {
219+ aggregation .setCount (Aggregation .Count .getDefaultInstance ());
220+ } else if (aggregateField instanceof AggregateField .SumAggregateField ) {
221+ aggregation .setSum (Aggregation .Sum .newBuilder ().setField (field ).build ());
222+ } else if (aggregateField instanceof AggregateField .AverageAggregateField ) {
223+ aggregation .setAvg (Aggregation .Avg .newBuilder ().setField (field ).build ());
224+ } else {
225+ throw new RuntimeException ("Unsupported aggregation" );
226+ }
227+ aggregation .setAlias (aggregateField .getAlias ());
228+ aggregations .add (aggregation .build ());
229+ }
230+ structuredAggregationQuery .addAllAggregations (aggregations );
224231 return request .build ();
225232 }
226233
@@ -243,7 +250,23 @@ public static AggregateQuery fromProto(Firestore firestore, RunAggregationQueryR
243250 .setStructuredQuery (proto .getStructuredAggregationQuery ().getStructuredQuery ())
244251 .build ();
245252 Query query = Query .fromProto (firestore , runQueryRequest );
246- return new AggregateQuery (query );
253+
254+ List <AggregateField > aggregateFields = new ArrayList <>();
255+ List <Aggregation > aggregations = proto .getStructuredAggregationQuery ().getAggregationsList ();
256+ aggregations .forEach (
257+ aggregation -> {
258+ if (aggregation .hasCount ()) {
259+ aggregateFields .add (AggregateField .count ());
260+ } else if (aggregation .hasAvg ()) {
261+ aggregateFields .add (
262+ AggregateField .average (aggregation .getAvg ().getField ().getFieldPath ()));
263+ } else if (aggregation .hasSum ()) {
264+ aggregateFields .add (AggregateField .sum (aggregation .getSum ().getField ().getFieldPath ()));
265+ } else {
266+ throw new RuntimeException ("Unsupported aggregation." );
267+ }
268+ });
269+ return new AggregateQuery (query , aggregateFields );
247270 }
248271
249272 /**
@@ -253,7 +276,7 @@ public static AggregateQuery fromProto(Firestore firestore, RunAggregationQueryR
253276 */
254277 @ Override
255278 public int hashCode () {
256- return query . hashCode ( );
279+ return Objects . hash ( query , aggregateFieldList );
257280 }
258281
259282 /**
@@ -280,6 +303,6 @@ public boolean equals(Object object) {
280303 return false ;
281304 }
282305 AggregateQuery other = (AggregateQuery ) object ;
283- return query .equals (other .query );
306+ return query .equals (other .query ) && aggregateFieldList . equals ( other . aggregateFieldList ) ;
284307 }
285308}
0 commit comments