beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From al...@apache.org
Subject [beam] branch master updated: [BEAM-7195] BQ BatchLoads doesn't always create new tables (#14238)
Date Tue, 27 Jul 2021 17:20:19 GMT
This is an automated email from the ASF dual-hosted git repository.

altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 69822e4  [BEAM-7195] BQ BatchLoads doesn't always create new tables (#14238)
69822e4 is described below

commit 69822e417fe4b582a73c45c5780f4ff69841d5db
Author: reuvenlax <relax@google.com>
AuthorDate: Tue Jul 27 10:19:30 2021 -0700

    [BEAM-7195] BQ BatchLoads doesn't always create new tables (#14238)
---
 .../beam/sdk/io/gcp/bigquery/BatchLoads.java       | 67 +++++++++--------
 .../io/gcp/bigquery/BigQueryResourceNaming.java    |  1 +
 .../beam/sdk/io/gcp/bigquery/WritePartition.java   | 58 ++++++++++++---
 .../beam/sdk/io/gcp/bigquery/WriteRename.java      | 34 ++++++---
 .../beam/sdk/io/gcp/bigquery/WriteTables.java      | 87 +++++++++++++++++-----
 .../sdk/io/gcp/bigquery/BigQueryIOWriteTest.java   | 60 ++++++++++-----
 6 files changed, 222 insertions(+), 85 deletions(-)

diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java
index 15ea5c0..16b96bf 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java
@@ -30,10 +30,8 @@ import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.coders.ListCoder;
 import org.apache.beam.sdk.coders.NullableCoder;
 import org.apache.beam.sdk.coders.ShardedKeyCoder;
-import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition;
@@ -147,8 +145,8 @@ class BatchLoads<DestinationT, ElementT>
   private ValueProvider<String> loadJobProjectId;
   private final Coder<ElementT> elementCoder;
   private final RowWriterFactory<ElementT, DestinationT> rowWriterFactory;
-  private String kmsKey;
-  private boolean clusteringEnabled;
+  private final String kmsKey;
+  private final boolean clusteringEnabled;
 
   // The maximum number of times to retry failed load or copy jobs.
   private int maxRetryJobs = DEFAULT_MAX_RETRY_JOBS;
@@ -274,6 +272,8 @@ class BatchLoads<DestinationT, ElementT>
   private WriteResult expandTriggered(PCollection<KV<DestinationT, ElementT>>
input) {
     Pipeline p = input.getPipeline();
     final PCollectionView<String> loadJobIdPrefixView = createJobIdPrefixView(p, JobType.LOAD);
+    final PCollectionView<String> tempLoadJobIdPrefixView =
+        createJobIdPrefixView(p, JobType.TEMP_TABLE_LOAD);
     final PCollectionView<String> copyJobIdPrefixView = createJobIdPrefixView(p, JobType.COPY);
     final PCollectionView<String> tempFilePrefixView =
         createTempFilePrefixView(p, loadJobIdPrefixView);
@@ -321,9 +321,9 @@ class BatchLoads<DestinationT, ElementT>
                             .plusDelayOf(triggeringFrequency)))
                 .discardingFiredPanes());
 
-    TupleTag<KV<ShardedKey<DestinationT>, List<String>>> multiPartitionsTag
=
+    TupleTag<KV<ShardedKey<DestinationT>, WritePartition.Result>> multiPartitionsTag
=
         new TupleTag<>("multiPartitionsTag");
-    TupleTag<KV<ShardedKey<DestinationT>, List<String>>> singlePartitionTag
=
+    TupleTag<KV<ShardedKey<DestinationT>, WritePartition.Result>> singlePartitionTag
=
         new TupleTag<>("singlePartitionTag");
 
     // If we have non-default triggered output, we can't use the side-input technique used
in
@@ -331,10 +331,10 @@ class BatchLoads<DestinationT, ElementT>
     // determinism.
     PCollectionTuple partitions =
         results
-            .apply("AttachSingletonKey", WithKeys.of((Void) null))
+            .apply("AttachDestinationKey", WithKeys.of(result -> result.destination))
             .setCoder(
-                KvCoder.of(VoidCoder.of(), WriteBundlesToFiles.ResultCoder.of(destinationCoder)))
-            .apply("GroupOntoSingleton", GroupByKey.create())
+                KvCoder.of(destinationCoder, WriteBundlesToFiles.ResultCoder.of(destinationCoder)))
+            .apply("GroupFilesByDestination", GroupByKey.create())
             .apply("ExtractResultValues", Values.create())
             .apply(
                 "WritePartitionTriggered",
@@ -350,14 +350,14 @@ class BatchLoads<DestinationT, ElementT>
                             rowWriterFactory))
                     .withSideInputs(tempFilePrefixView)
                     .withOutputTags(multiPartitionsTag, TupleTagList.of(singlePartitionTag)));
-    PCollection<KV<TableDestination, String>> tempTables =
-        writeTempTables(partitions.get(multiPartitionsTag), loadJobIdPrefixView);
+    PCollection<KV<TableDestination, WriteTables.Result>> tempTables =
+        writeTempTables(partitions.get(multiPartitionsTag), tempLoadJobIdPrefixView);
 
     tempTables
         // Now that the load job has happened, we want the rename to happen immediately.
         .apply(
             "Window Into Global Windows",
-            Window.<KV<TableDestination, String>>into(new GlobalWindows())
+            Window.<KV<TableDestination, WriteTables.Result>>into(new GlobalWindows())
                 .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1))))
         .apply("Add Void Key", WithKeys.of((Void) null))
         .setCoder(KvCoder.of(VoidCoder.of(), tempTables.getCoder()))
@@ -382,6 +382,9 @@ class BatchLoads<DestinationT, ElementT>
   public WriteResult expandUntriggered(PCollection<KV<DestinationT, ElementT>>
input) {
     Pipeline p = input.getPipeline();
     final PCollectionView<String> loadJobIdPrefixView = createJobIdPrefixView(p, JobType.LOAD);
+    final PCollectionView<String> tempLoadJobIdPrefixView =
+        createJobIdPrefixView(p, JobType.TEMP_TABLE_LOAD);
+    final PCollectionView<String> copyJobIdPrefixView = createJobIdPrefixView(p, JobType.COPY);
     final PCollectionView<String> tempFilePrefixView =
         createTempFilePrefixView(p, loadJobIdPrefixView);
     PCollection<KV<DestinationT, ElementT>> inputInGlobalWindow =
@@ -395,10 +398,10 @@ class BatchLoads<DestinationT, ElementT>
             ? writeDynamicallyShardedFilesUntriggered(inputInGlobalWindow, tempFilePrefixView)
             : writeStaticallyShardedFiles(inputInGlobalWindow, tempFilePrefixView);
 
-    TupleTag<KV<ShardedKey<DestinationT>, List<String>>> multiPartitionsTag
=
-        new TupleTag<KV<ShardedKey<DestinationT>, List<String>>>("multiPartitionsTag")
{};
-    TupleTag<KV<ShardedKey<DestinationT>, List<String>>> singlePartitionTag
=
-        new TupleTag<KV<ShardedKey<DestinationT>, List<String>>>("singlePartitionTag")
{};
+    TupleTag<KV<ShardedKey<DestinationT>, WritePartition.Result>> multiPartitionsTag
=
+        new TupleTag<KV<ShardedKey<DestinationT>, WritePartition.Result>>("multiPartitionsTag")
{};
+    TupleTag<KV<ShardedKey<DestinationT>, WritePartition.Result>> singlePartitionTag
=
+        new TupleTag<KV<ShardedKey<DestinationT>, WritePartition.Result>>("singlePartitionTag")
{};
 
     // This transform will look at the set of files written for each table, and if any table
has
     // too many files or bytes, will partition that table's files into multiple partitions
for
@@ -421,8 +424,8 @@ class BatchLoads<DestinationT, ElementT>
                             rowWriterFactory))
                     .withSideInputs(tempFilePrefixView)
                     .withOutputTags(multiPartitionsTag, TupleTagList.of(singlePartitionTag)));
-    PCollection<KV<TableDestination, String>> tempTables =
-        writeTempTables(partitions.get(multiPartitionsTag), loadJobIdPrefixView);
+    PCollection<KV<TableDestination, WriteTables.Result>> tempTables =
+        writeTempTables(partitions.get(multiPartitionsTag), tempLoadJobIdPrefixView);
 
     tempTables
         .apply("ReifyRenameInput", new ReifyAsIterable<>())
@@ -431,7 +434,7 @@ class BatchLoads<DestinationT, ElementT>
             ParDo.of(
                     new WriteRename(
                         bigQueryServices,
-                        loadJobIdPrefixView,
+                        copyJobIdPrefixView,
                         writeDisposition,
                         createDisposition,
                         maxRetryJobs,
@@ -637,23 +640,22 @@ class BatchLoads<DestinationT, ElementT>
         .apply(
             "WriteGroupedRecords",
             ParDo.of(
-                    new WriteGroupedRecordsToFiles<DestinationT, ElementT>(
-                        tempFilePrefix, maxFileSize, rowWriterFactory))
+                    new WriteGroupedRecordsToFiles<>(tempFilePrefix, maxFileSize, rowWriterFactory))
                 .withSideInputs(tempFilePrefix))
         .setCoder(WriteBundlesToFiles.ResultCoder.of(destinationCoder));
   }
 
   // Take in a list of files and write them to temporary tables.
-  private PCollection<KV<TableDestination, String>> writeTempTables(
-      PCollection<KV<ShardedKey<DestinationT>, List<String>>> input,
+  private PCollection<KV<TableDestination, WriteTables.Result>> writeTempTables(
+      PCollection<KV<ShardedKey<DestinationT>, WritePartition.Result>>
input,
       PCollectionView<String> jobIdTokenView) {
     List<PCollectionView<?>> sideInputs = Lists.newArrayList(jobIdTokenView);
     sideInputs.addAll(dynamicDestinations.getSideInputs());
 
-    Coder<KV<ShardedKey<DestinationT>, List<String>>> partitionsCoder
=
+    Coder<KV<ShardedKey<DestinationT>, WritePartition.Result>> partitionsCoder
=
         KvCoder.of(
             ShardedKeyCoder.of(NullableCoder.of(destinationCoder)),
-            ListCoder.of(StringUtf8Coder.of()));
+            WritePartition.ResultCoder.INSTANCE);
 
     // If the final destination table exists already (and we're appending to it), then the
temp
     // tables must exactly match schema, partitioning, etc. Wrap the DynamicDestinations
object
@@ -695,20 +697,24 @@ class BatchLoads<DestinationT, ElementT>
                 rowWriterFactory.getSourceFormat(),
                 useAvroLogicalTypes,
                 schemaUpdateOptions))
-        .setCoder(KvCoder.of(tableDestinationCoder, StringUtf8Coder.of()));
+        .setCoder(KvCoder.of(tableDestinationCoder, WriteTables.ResultCoder.INSTANCE));
   }
 
   // In the case where the files fit into a single load job, there's no need to write temporary
   // tables and rename. We can load these files directly into the target BigQuery table.
   void writeSinglePartition(
-      PCollection<KV<ShardedKey<DestinationT>, List<String>>> input,
+      PCollection<KV<ShardedKey<DestinationT>, WritePartition.Result>>
input,
       PCollectionView<String> loadJobIdPrefixView) {
     List<PCollectionView<?>> sideInputs = Lists.newArrayList(loadJobIdPrefixView);
     sideInputs.addAll(dynamicDestinations.getSideInputs());
-    Coder<KV<ShardedKey<DestinationT>, List<String>>> partitionsCoder
=
+
+    Coder<TableDestination> tableDestinationCoder =
+        clusteringEnabled ? TableDestinationCoderV3.of() : TableDestinationCoderV2.of();
+
+    Coder<KV<ShardedKey<DestinationT>, WritePartition.Result>> partitionsCoder
=
         KvCoder.of(
             ShardedKeyCoder.of(NullableCoder.of(destinationCoder)),
-            ListCoder.of(StringUtf8Coder.of()));
+            WritePartition.ResultCoder.INSTANCE);
     // Write single partition to final table
     input
         .setCoder(partitionsCoder)
@@ -731,7 +737,8 @@ class BatchLoads<DestinationT, ElementT>
                 kmsKey,
                 rowWriterFactory.getSourceFormat(),
                 useAvroLogicalTypes,
-                schemaUpdateOptions));
+                schemaUpdateOptions))
+        .setCoder(KvCoder.of(tableDestinationCoder, WriteTables.ResultCoder.INSTANCE));
   }
 
   private WriteResult writeResult(Pipeline p) {
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryResourceNaming.java
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryResourceNaming.java
index 7e800fd..7eae6fe 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryResourceNaming.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryResourceNaming.java
@@ -69,6 +69,7 @@ class BigQueryResourceNaming {
 
   public enum JobType {
     LOAD,
+    TEMP_TABLE_LOAD,
     COPY,
     EXPORT,
     QUERY,
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WritePartition.java
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WritePartition.java
index cd4f163..e1e0566 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WritePartition.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WritePartition.java
@@ -17,8 +17,17 @@
  */
 package org.apache.beam.sdk.io.gcp.bigquery;
 
+import com.google.auto.value.AutoValue;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
 import java.util.List;
 import java.util.Map;
+import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.BooleanCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.gcp.bigquery.WriteBundlesToFiles.Result;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.values.KV;
@@ -39,7 +48,32 @@ import org.checkerframework.checker.nullness.qual.Nullable;
 class WritePartition<DestinationT>
     extends DoFn<
         Iterable<WriteBundlesToFiles.Result<DestinationT>>,
-        KV<ShardedKey<DestinationT>, List<String>>> {
+        KV<ShardedKey<DestinationT>, WritePartition.Result>> {
+  @AutoValue
+  abstract static class Result {
+    public abstract List<String> getFilenames();
+
+    abstract Boolean isFirstPane();
+  }
+
+  static class ResultCoder extends AtomicCoder<Result> {
+    private static final Coder<List<String>> FILENAMES_CODER = ListCoder.of(StringUtf8Coder.of());
+    private static final Coder<Boolean> FIRST_PANE_CODER = BooleanCoder.of();
+    static final ResultCoder INSTANCE = new ResultCoder();
+
+    @Override
+    public void encode(Result value, OutputStream outStream) throws IOException {
+      FILENAMES_CODER.encode(value.getFilenames(), outStream);
+      FIRST_PANE_CODER.encode(value.isFirstPane(), outStream);
+    }
+
+    @Override
+    public Result decode(InputStream inStream) throws IOException {
+      return new AutoValue_WritePartition_Result(
+          FILENAMES_CODER.decode(inStream), FIRST_PANE_CODER.decode(inStream));
+    }
+  }
+
   private final boolean singletonTable;
   private final DynamicDestinations<?, DestinationT> dynamicDestinations;
   private final PCollectionView<String> tempFilePrefix;
@@ -47,8 +81,9 @@ class WritePartition<DestinationT>
   private final long maxSizeBytes;
   private final RowWriterFactory<?, DestinationT> rowWriterFactory;
 
-  private @Nullable TupleTag<KV<ShardedKey<DestinationT>, List<String>>>
multiPartitionsTag;
-  private TupleTag<KV<ShardedKey<DestinationT>, List<String>>> singlePartitionTag;
+  private @Nullable TupleTag<KV<ShardedKey<DestinationT>, WritePartition.Result>>
+      multiPartitionsTag;
+  private TupleTag<KV<ShardedKey<DestinationT>, WritePartition.Result>>
singlePartitionTag;
 
   private static class PartitionData {
     private int numFiles = 0;
@@ -131,8 +166,8 @@ class WritePartition<DestinationT>
       PCollectionView<String> tempFilePrefix,
       int maxNumFiles,
       long maxSizeBytes,
-      TupleTag<KV<ShardedKey<DestinationT>, List<String>>> multiPartitionsTag,
-      TupleTag<KV<ShardedKey<DestinationT>, List<String>>> singlePartitionTag,
+      TupleTag<KV<ShardedKey<DestinationT>, WritePartition.Result>> multiPartitionsTag,
+      TupleTag<KV<ShardedKey<DestinationT>, WritePartition.Result>> singlePartitionTag,
       RowWriterFactory<?, DestinationT> rowWriterFactory) {
     this.singletonTable = singletonTable;
     this.dynamicDestinations = dynamicDestinations;
@@ -147,7 +182,6 @@ class WritePartition<DestinationT>
   @ProcessElement
   public void processElement(ProcessContext c) throws Exception {
     List<WriteBundlesToFiles.Result<DestinationT>> results = Lists.newArrayList(c.element());
-
     // If there are no elements to write _and_ the user specified a constant output table,
then
     // generate an empty table of that name.
     if (results.isEmpty() && singletonTable) {
@@ -161,7 +195,8 @@ class WritePartition<DestinationT>
       BigQueryRowWriter.Result writerResult = writer.getResult();
 
       results.add(
-          new Result<>(writerResult.resourceId.toString(), writerResult.byteSize, destination));
+          new WriteBundlesToFiles.Result<>(
+              writerResult.resourceId.toString(), writerResult.byteSize, destination));
     }
 
     Map<DestinationT, DestinationData> currentResults = Maps.newHashMap();
@@ -190,11 +225,16 @@ class WritePartition<DestinationT>
       // In the fast-path case where we only output one table, the transform loads it directly
       // to the final table. In this case, we output on a special TupleTag so the enclosing
       // transform knows to skip the rename step.
-      TupleTag<KV<ShardedKey<DestinationT>, List<String>>> outputTag
=
+      TupleTag<KV<ShardedKey<DestinationT>, WritePartition.Result>> outputTag
=
           (destinationData.getPartitions().size() == 1) ? singlePartitionTag : multiPartitionsTag;
       for (int i = 0; i < destinationData.getPartitions().size(); ++i) {
         PartitionData partitionData = destinationData.getPartitions().get(i);
-        c.output(outputTag, KV.of(ShardedKey.of(destination, i + 1), partitionData.getFilenames()));
+        c.output(
+            outputTag,
+            KV.of(
+                ShardedKey.of(destination, i + 1),
+                new AutoValue_WritePartition_Result(
+                    partitionData.getFilenames(), c.pane().isFirst())));
       }
     }
   }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteRename.java
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteRename.java
index a45f6f8..80201ff 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteRename.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteRename.java
@@ -39,6 +39,7 @@ import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
 import org.slf4j.Logger;
@@ -51,7 +52,7 @@ import org.slf4j.LoggerFactory;
 @SuppressWarnings({
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
-class WriteRename extends DoFn<Iterable<KV<TableDestination, String>>, Void>
{
+class WriteRename extends DoFn<Iterable<KV<TableDestination, WriteTables.Result>>,
Void> {
   private static final Logger LOG = LoggerFactory.getLogger(WriteRename.class);
 
   private final BigQueryServices bqServices;
@@ -116,12 +117,15 @@ class WriteRename extends DoFn<Iterable<KV<TableDestination,
String>>, Void> {
   }
 
   @ProcessElement
-  public void processElement(ProcessContext c) throws Exception {
-    Multimap<TableDestination, String> tempTables = ArrayListMultimap.create();
-    for (KV<TableDestination, String> entry : c.element()) {
+  public void processElement(
+      @Element Iterable<KV<TableDestination, WriteTables.Result>> element, ProcessContext
c)
+      throws Exception {
+    Multimap<TableDestination, WriteTables.Result> tempTables = ArrayListMultimap.create();
+    for (KV<TableDestination, WriteTables.Result> entry : element) {
       tempTables.put(entry.getKey(), entry.getValue());
     }
-    for (Map.Entry<TableDestination, Collection<String>> entry : tempTables.asMap().entrySet())
{
+    for (Map.Entry<TableDestination, Collection<WriteTables.Result>> entry :
+        tempTables.asMap().entrySet()) {
       // Process each destination table.
       // Do not copy if no temp tables are provided.
       if (!entry.getValue().isEmpty()) {
@@ -165,17 +169,27 @@ class WriteRename extends DoFn<Iterable<KV<TableDestination,
String>>, Void> {
   }
 
   private PendingJobData startWriteRename(
-      TableDestination finalTableDestination, Iterable<String> tempTableNames, ProcessContext
c)
+      TableDestination finalTableDestination,
+      Iterable<WriteTables.Result> tempTableNames,
+      ProcessContext c)
       throws Exception {
+    // The pane may have advanced either here due to triggering or due to an upstream trigger.
We
+    // check the upstream
+    // trigger to handle the case where an earlier pane triggered the single-partition path.
If this
+    // happened, then the
+    // table will already exist so we want to append to the table.
+    boolean isFirstPane =
+        Iterables.getFirst(tempTableNames, null).isFirstPane() && c.pane().isFirst();
     WriteDisposition writeDisposition =
-        (c.pane().getIndex() == 0) ? firstPaneWriteDisposition : WriteDisposition.WRITE_APPEND;
+        isFirstPane ? firstPaneWriteDisposition : WriteDisposition.WRITE_APPEND;
     CreateDisposition createDisposition =
-        (c.pane().getIndex() == 0) ? firstPaneCreateDisposition : CreateDisposition.CREATE_NEVER;
+        isFirstPane ? firstPaneCreateDisposition : CreateDisposition.CREATE_NEVER;
     List<TableReference> tempTables =
         StreamSupport.stream(tempTableNames.spliterator(), false)
-            .map(table -> BigQueryHelpers.fromJsonString(table, TableReference.class))
+            .map(
+                result ->
+                    BigQueryHelpers.fromJsonString(result.getTableName(), TableReference.class))
             .collect(Collectors.toList());
-    ;
 
     // Make sure each destination table gets a unique job id.
     String jobIdPrefix =
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java
index 465c924..32ed1fe 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java
@@ -26,11 +26,17 @@ import com.google.api.services.bigquery.model.JobReference;
 import com.google.api.services.bigquery.model.TableReference;
 import com.google.api.services.bigquery.model.TableSchema;
 import com.google.api.services.bigquery.model.TimePartitioning;
+import com.google.auto.value.AutoValue;
 import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
+import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.BooleanCoder;
+import org.apache.beam.sdk.coders.CoderException;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VoidCoder;
@@ -68,7 +74,10 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
+import org.checkerframework.checker.initialization.qual.Initialized;
+import org.checkerframework.checker.nullness.qual.NonNull;
 import org.checkerframework.checker.nullness.qual.Nullable;
+import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -89,8 +98,35 @@ import org.slf4j.LoggerFactory;
 })
 class WriteTables<DestinationT>
     extends PTransform<
-        PCollection<KV<ShardedKey<DestinationT>, List<String>>>,
-        PCollection<KV<TableDestination, String>>> {
+        PCollection<KV<ShardedKey<DestinationT>, WritePartition.Result>>,
+        PCollection<KV<TableDestination, WriteTables.Result>>> {
+  @AutoValue
+  abstract static class Result {
+    abstract String getTableName();
+
+    abstract Boolean isFirstPane();
+  }
+
+  static class ResultCoder extends AtomicCoder<WriteTables.Result> {
+    static final ResultCoder INSTANCE = new ResultCoder();
+
+    @Override
+    public void encode(Result value, @UnknownKeyFor @NonNull @Initialized OutputStream outStream)
+        throws @UnknownKeyFor @NonNull @Initialized CoderException, @UnknownKeyFor @NonNull
+            @Initialized IOException {
+      StringUtf8Coder.of().encode(value.getTableName(), outStream);
+      BooleanCoder.of().encode(value.isFirstPane(), outStream);
+    }
+
+    @Override
+    public Result decode(@UnknownKeyFor @NonNull @Initialized InputStream inStream)
+        throws @UnknownKeyFor @NonNull @Initialized CoderException, @UnknownKeyFor @NonNull
+            @Initialized IOException {
+      return new AutoValue_WriteTables_Result(
+          StringUtf8Coder.of().decode(inStream), BooleanCoder.of().decode(inStream));
+    }
+  }
+
   private static final Logger LOG = LoggerFactory.getLogger(WriteTables.class);
 
   private final boolean tempTable;
@@ -101,7 +137,7 @@ class WriteTables<DestinationT>
   private final Set<SchemaUpdateOption> schemaUpdateOptions;
   private final DynamicDestinations<?, DestinationT> dynamicDestinations;
   private final List<PCollectionView<?>> sideInputs;
-  private final TupleTag<KV<TableDestination, String>> mainOutputTag;
+  private final TupleTag<KV<TableDestination, WriteTables.Result>> mainOutputTag;
   private final TupleTag<String> temporaryFilesTag;
   private final ValueProvider<String> loadJobProjectId;
   private final int maxRetryJobs;
@@ -113,7 +149,9 @@ class WriteTables<DestinationT>
   private @Nullable JobService jobService;
 
   private class WriteTablesDoFn
-      extends DoFn<KV<ShardedKey<DestinationT>, List<String>>, KV<TableDestination,
String>> {
+      extends DoFn<
+          KV<ShardedKey<DestinationT>, WritePartition.Result>, KV<TableDestination,
Result>> {
+
     private Map<DestinationT, String> jsonSchemas = Maps.newHashMap();
 
     // Represents a pending BigQuery load job.
@@ -123,18 +161,21 @@ class WriteTables<DestinationT>
       final List<String> partitionFiles;
       final TableDestination tableDestination;
       final TableReference tableReference;
+      final boolean isFirstPane;
 
       public PendingJobData(
           BoundedWindow window,
           BigQueryHelpers.PendingJob retryJob,
           List<String> partitionFiles,
           TableDestination tableDestination,
-          TableReference tableReference) {
+          TableReference tableReference,
+          boolean isFirstPane) {
         this.window = window;
         this.retryJob = retryJob;
         this.partitionFiles = partitionFiles;
         this.tableDestination = tableDestination;
         this.tableReference = tableReference;
+        this.isFirstPane = isFirstPane;
       }
     }
     // All pending load jobs.
@@ -149,7 +190,11 @@ class WriteTables<DestinationT>
     }
 
     @ProcessElement
-    public void processElement(ProcessContext c, BoundedWindow window) throws Exception {
+    public void processElement(
+        @Element KV<ShardedKey<DestinationT>, WritePartition.Result> element,
+        ProcessContext c,
+        BoundedWindow window)
+        throws Exception {
       dynamicDestinations.setSideInputAccessorFromProcessContext(c);
       DestinationT destination = c.element().getKey().getKey();
       TableSchema tableSchema;
@@ -199,8 +244,8 @@ class WriteTables<DestinationT>
         tableDestination = tableDestination.withTableReference(tableReference);
       }
 
-      Integer partition = c.element().getKey().getShardNumber();
-      List<String> partitionFiles = Lists.newArrayList(c.element().getValue());
+      Integer partition = element.getKey().getShardNumber();
+      List<String> partitionFiles = Lists.newArrayList(element.getValue().getFilenames());
       String jobIdPrefix =
           BigQueryResourceNaming.createJobIdWithDestination(
               c.sideInput(loadJobIdPrefixView), tableDestination, partition, c.pane().getIndex());
@@ -212,7 +257,7 @@ class WriteTables<DestinationT>
 
       WriteDisposition writeDisposition = firstPaneWriteDisposition;
       CreateDisposition createDisposition = firstPaneCreateDisposition;
-      if (c.pane().getIndex() > 0 && !tempTable) {
+      if (!element.getValue().isFirstPane() && !tempTable) {
         // If writing directly to the destination, then the table is created on the first
write
         // and we should change the disposition for subsequent writes.
         writeDisposition = WriteDisposition.WRITE_APPEND;
@@ -238,7 +283,13 @@ class WriteTables<DestinationT>
               createDisposition,
               schemaUpdateOptions);
       pendingJobs.add(
-          new PendingJobData(window, retryJob, partitionFiles, tableDestination, tableReference));
+          new PendingJobData(
+              window,
+              retryJob,
+              partitionFiles,
+              tableDestination,
+              tableReference,
+              element.getValue().isFirstPane()));
     }
 
     @Teardown
@@ -284,7 +335,7 @@ class WriteTables<DestinationT>
           bqServices.getDatasetService(c.getPipelineOptions().as(BigQueryOptions.class));
 
       PendingJobManager jobManager = new PendingJobManager();
-      for (PendingJobData pendingJob : pendingJobs) {
+      for (final PendingJobData pendingJob : pendingJobs) {
         jobManager =
             jobManager.addPendingJob(
                 pendingJob.retryJob,
@@ -299,11 +350,14 @@ class WriteTables<DestinationT>
                                   BigQueryHelpers.stripPartitionDecorator(ref.getTableId())),
                           pendingJob.tableDestination.getTableDescription());
                     }
+
+                    Result result =
+                        new AutoValue_WriteTables_Result(
+                            BigQueryHelpers.toJsonString(pendingJob.tableReference),
+                            pendingJob.isFirstPane);
                     c.output(
                         mainOutputTag,
-                        KV.of(
-                            pendingJob.tableDestination,
-                            BigQueryHelpers.toJsonString(pendingJob.tableReference)),
+                        KV.of(pendingJob.tableDestination, result),
                         pendingJob.window.maxTimestamp(),
                         pendingJob.window);
                     for (String file : pendingJob.partitionFiles) {
@@ -365,8 +419,8 @@ class WriteTables<DestinationT>
   }
 
   @Override
-  public PCollection<KV<TableDestination, String>> expand(
-      PCollection<KV<ShardedKey<DestinationT>, List<String>>> input)
{
+  public PCollection<KV<TableDestination, Result>> expand(
+      PCollection<KV<ShardedKey<DestinationT>, WritePartition.Result>>
input) {
     PCollectionTuple writeTablesOutputs =
         input.apply(
             ParDo.of(new WriteTablesDoFn())
@@ -391,7 +445,6 @@ class WriteTables<DestinationT>
         .apply(GroupByKey.create())
         .apply(Values.create())
         .apply(ParDo.of(new GarbageCollectTemporaryFiles()));
-
     return writeTablesOutputs.get(mainOutputTag);
   }
 
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
index bde803d..6799c67 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
@@ -65,6 +65,7 @@ import java.util.concurrent.ThreadLocalRandom;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
 import org.apache.avro.generic.GenericData;
 import org.apache.avro.generic.GenericDatumWriter;
 import org.apache.avro.generic.GenericRecord;
@@ -72,13 +73,17 @@ import org.apache.avro.io.DatumWriter;
 import org.apache.avro.io.Encoder;
 import org.apache.beam.sdk.coders.AtomicCoder;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.SerializableCoder;
+import org.apache.beam.sdk.coders.ShardedKeyCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.GenerateSequence;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.Method;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.SchemaUpdateOption;
+import org.apache.beam.sdk.io.gcp.bigquery.WritePartition.ResultCoder;
+import org.apache.beam.sdk.io.gcp.bigquery.WriteTables.Result;
 import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices;
 import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService;
 import org.apache.beam.sdk.io.gcp.testing.FakeJobService;
@@ -1758,10 +1763,12 @@ public class BigQueryIOWriteTest implements Serializable {
       }
     }
 
-    TupleTag<KV<ShardedKey<TableDestination>, List<String>>> multiPartitionsTag
=
-        new TupleTag<KV<ShardedKey<TableDestination>, List<String>>>("multiPartitionsTag")
{};
-    TupleTag<KV<ShardedKey<TableDestination>, List<String>>> singlePartitionTag
=
-        new TupleTag<KV<ShardedKey<TableDestination>, List<String>>>("singlePartitionTag")
{};
+    TupleTag<KV<ShardedKey<TableDestination>, WritePartition.Result>> multiPartitionsTag
=
+        new TupleTag<KV<ShardedKey<TableDestination>, WritePartition.Result>>(
+            "multiPartitionsTag") {};
+    TupleTag<KV<ShardedKey<TableDestination>, WritePartition.Result>> singlePartitionTag
=
+        new TupleTag<KV<ShardedKey<TableDestination>, WritePartition.Result>>(
+            "singlePartitionTag") {};
 
     String tempFilePrefix = testFolder.newFolder("BigQueryIOTest").getAbsolutePath();
     PCollectionView<String> tempFilePrefixView =
@@ -1781,12 +1788,12 @@ public class BigQueryIOWriteTest implements Serializable {
 
     DoFnTester<
             Iterable<WriteBundlesToFiles.Result<TableDestination>>,
-            KV<ShardedKey<TableDestination>, List<String>>>
+            KV<ShardedKey<TableDestination>, WritePartition.Result>>
         tester = DoFnTester.of(writePartition);
     tester.setSideInput(tempFilePrefixView, GlobalWindow.INSTANCE, tempFilePrefix);
     tester.processElement(files);
 
-    List<KV<ShardedKey<TableDestination>, List<String>>> partitions;
+    List<KV<ShardedKey<TableDestination>, WritePartition.Result>> partitions;
     if (expectedNumPartitionsPerTable > 1) {
       partitions = tester.takeOutputElements(multiPartitionsTag);
     } else {
@@ -1795,12 +1802,12 @@ public class BigQueryIOWriteTest implements Serializable {
 
     List<ShardedKey<TableDestination>> partitionsResult = Lists.newArrayList();
     Map<String, List<String>> filesPerTableResult = Maps.newHashMap();
-    for (KV<ShardedKey<TableDestination>, List<String>> partition : partitions)
{
+    for (KV<ShardedKey<TableDestination>, WritePartition.Result> partition :
partitions) {
       String table = partition.getKey().getKey().getTableSpec();
       partitionsResult.add(partition.getKey());
       List<String> tableFilesResult =
           filesPerTableResult.computeIfAbsent(table, k -> Lists.newArrayList());
-      tableFilesResult.addAll(partition.getValue());
+      tableFilesResult.addAll(partition.getValue().getFilenames());
     }
 
     assertThat(
@@ -1847,7 +1854,7 @@ public class BigQueryIOWriteTest implements Serializable {
     String jobIdToken = "jobId";
     final Multimap<TableDestination, String> expectedTempTables = ArrayListMultimap.create();
 
-    List<KV<ShardedKey<String>, List<String>>> partitions = Lists.newArrayList();
+    List<KV<ShardedKey<String>, WritePartition.Result>> partitions = Lists.newArrayList();
     for (int i = 0; i < numTables; ++i) {
       String tableName = String.format("project-id:dataset-id.table%05d", i);
       TableDestination tableDestination = new TableDestination(tableName, tableName);
@@ -1869,7 +1876,10 @@ public class BigQueryIOWriteTest implements Serializable {
           }
           filesPerPartition.add(writer.getResult().resourceId.toString());
         }
-        partitions.add(KV.of(ShardedKey.of(tableDestination.getTableSpec(), j), filesPerPartition));
+        partitions.add(
+            KV.of(
+                ShardedKey.of(tableDestination.getTableSpec(), j),
+                new AutoValue_WritePartition_Result(filesPerPartition, true)));
 
         String json =
             String.format(
@@ -1879,8 +1889,11 @@ public class BigQueryIOWriteTest implements Serializable {
       }
     }
 
-    PCollection<KV<ShardedKey<String>, List<String>>> writeTablesInput
=
-        p.apply(Create.of(partitions));
+    PCollection<KV<ShardedKey<String>, WritePartition.Result>> writeTablesInput
=
+        p.apply(
+            Create.of(partitions)
+                .withCoder(
+                    KvCoder.of(ShardedKeyCoder.of(StringUtf8Coder.of()), ResultCoder.INSTANCE)));
     PCollectionView<String> jobIdTokenView =
         p.apply("CreateJobId", Create.of("jobId")).apply(View.asSingleton());
     List<PCollectionView<?>> sideInputs = ImmutableList.of(jobIdTokenView);
@@ -1903,18 +1916,25 @@ public class BigQueryIOWriteTest implements Serializable {
             false,
             Collections.emptySet());
 
-    PCollection<KV<TableDestination, String>> writeTablesOutput =
-        writeTablesInput.apply(writeTables);
+    PCollection<KV<TableDestination, WriteTables.Result>> writeTablesOutput =
+        writeTablesInput
+            .apply(writeTables)
+            .setCoder(KvCoder.of(TableDestinationCoderV3.of(), WriteTables.ResultCoder.INSTANCE));
 
     PAssert.thatMultimap(writeTablesOutput)
         .satisfies(
             input -> {
               assertEquals(input.keySet(), expectedTempTables.keySet());
-              for (Map.Entry<TableDestination, Iterable<String>> entry : input.entrySet())
{
+              for (Map.Entry<TableDestination, Iterable<WriteTables.Result>>
entry :
+                  input.entrySet()) {
+                Iterable<String> tableNames =
+                    StreamSupport.stream(entry.getValue().spliterator(), false)
+                        .map(Result::getTableName)
+                        .collect(Collectors.toList());
                 @SuppressWarnings("unchecked")
                 String[] expectedValues =
                     Iterables.toArray(expectedTempTables.get(entry.getKey()), String.class);
-                assertThat(entry.getValue(), containsInAnyOrder(expectedValues));
+                assertThat(tableNames, containsInAnyOrder(expectedValues));
               }
               return null;
             });
@@ -1951,7 +1971,7 @@ public class BigQueryIOWriteTest implements Serializable {
     Multimap<TableDestination, TableRow> expectedRowsPerTable = ArrayListMultimap.create();
     String jobIdToken = "jobIdToken";
     Multimap<TableDestination, String> tempTables = ArrayListMultimap.create();
-    List<KV<TableDestination, String>> tempTablesElement = Lists.newArrayList();
+    List<KV<TableDestination, WriteTables.Result>> tempTablesElement = Lists.newArrayList();
     for (int i = 0; i < numFinalTables; ++i) {
       String tableName = "project-id:dataset-id.table_" + i;
       TableDestination tableDestination = new TableDestination(tableName, "table_" + i +
"_desc");
@@ -1971,7 +1991,8 @@ public class BigQueryIOWriteTest implements Serializable {
         expectedRowsPerTable.putAll(tableDestination, rows);
         String tableJson = toJsonString(tempTable);
         tempTables.put(tableDestination, tableJson);
-        tempTablesElement.add(KV.of(tableDestination, tableJson));
+        tempTablesElement.add(
+            KV.of(tableDestination, new AutoValue_WriteTables_Result(tableJson, true)));
       }
     }
 
@@ -1987,7 +2008,8 @@ public class BigQueryIOWriteTest implements Serializable {
             3,
             "kms_key");
 
-    DoFnTester<Iterable<KV<TableDestination, String>>, Void> tester = DoFnTester.of(writeRename);
+    DoFnTester<Iterable<KV<TableDestination, WriteTables.Result>>, Void>
tester =
+        DoFnTester.of(writeRename);
     tester.setSideInput(jobIdTokenView, GlobalWindow.INSTANCE, jobIdToken);
     tester.processElement(tempTablesElement);
     tester.finishBundle();

Mime
View raw message