Return-Path: X-Original-To: apmail-tez-commits-archive@minotaur.apache.org Delivered-To: apmail-tez-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id E684B10B37 for ; Wed, 25 Sep 2013 07:35:56 +0000 (UTC) Received: (qmail 42157 invoked by uid 500); 25 Sep 2013 07:34:23 -0000 Delivered-To: apmail-tez-commits-archive@tez.apache.org Received: (qmail 41250 invoked by uid 500); 25 Sep 2013 07:33:08 -0000 Mailing-List: contact commits-help@tez.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@tez.incubator.apache.org Delivered-To: mailing list commits@tez.incubator.apache.org Received: (qmail 40327 invoked by uid 99); 25 Sep 2013 07:32:14 -0000 Received: from athena.apache.org (HELO athena.apache.org) (140.211.11.136) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 25 Sep 2013 07:32:14 +0000 X-ASF-Spam-Status: No, hits=-2002.3 required=5.0 tests=ALL_TRUSTED,RP_MATCHES_RCVD,T_FILL_THIS_FORM_SHORT X-Spam-Check-By: apache.org Received: from [140.211.11.3] (HELO mail.apache.org) (140.211.11.3) by apache.org (qpsmtpd/0.29) with SMTP; Wed, 25 Sep 2013 07:31:44 +0000 Received: (qmail 39436 invoked by uid 99); 25 Sep 2013 07:31:11 -0000 Received: from tyr.zones.apache.org (HELO tyr.zones.apache.org) (140.211.11.114) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 25 Sep 2013 07:31:11 +0000 Received: by tyr.zones.apache.org (Postfix, from userid 65534) id 08D7F9095F6; Wed, 25 Sep 2013 07:31:10 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: sseth@apache.org To: commits@tez.incubator.apache.org Date: Wed, 25 Sep 2013 07:31:40 -0000 Message-Id: <3b6bdae789d74d03b1b367c080c26121@git.apache.org> In-Reply-To: <951a7e7fa257470e83418fce839114b5@git.apache.org> References: <951a7e7fa257470e83418fce839114b5@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: [33/50] [abbrv] Rename tez-engine-api to tez-runtime-api and tez-engine is split into 2: - tez-engine-library for user-visible Input/Output/Processor implementations - tez-engine-internals for framework internals X-Virus-Checked: Checked by ClamAV on apache.org http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/Fetcher.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/Fetcher.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/Fetcher.java new file mode 100644 index 0000000..f5d1802 --- /dev/null +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/Fetcher.java @@ -0,0 +1,624 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tez.runtime.library.common.shuffle.impl; + +import java.io.DataInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLConnection; +import java.security.GeneralSecurityException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import javax.crypto.SecretKey; +import javax.net.ssl.HttpsURLConnection; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IOUtils; +import org.apache.hadoop.io.compress.CodecPool; +import org.apache.hadoop.io.compress.CompressionCodec; +import org.apache.hadoop.io.compress.Decompressor; +import org.apache.hadoop.io.compress.DefaultCodec; +import org.apache.hadoop.security.ssl.SSLFactory; +import org.apache.hadoop.util.ReflectionUtils; +import org.apache.tez.common.TezJobConfig; +import org.apache.tez.common.counters.TezCounter; +import org.apache.tez.runtime.api.TezInputContext; +import org.apache.tez.runtime.library.common.ConfigUtils; +import org.apache.tez.runtime.library.common.InputAttemptIdentifier; +import org.apache.tez.runtime.library.common.security.SecureShuffleUtils; +import org.apache.tez.runtime.library.common.shuffle.impl.MapOutput.Type; +import org.apache.tez.runtime.library.common.sort.impl.IFileInputStream; + +import com.google.common.annotations.VisibleForTesting; + +class Fetcher extends Thread { + + private static final Log LOG = LogFactory.getLog(Fetcher.class); + + /** Basic/unit connection timeout (in milliseconds) */ + private final static int UNIT_CONNECT_TIMEOUT = 60 * 1000; + + private static enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP, + CONNECTION, WRONG_REDUCE} + + private final static String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors"; + private final TezCounter connectionErrs; + private final TezCounter ioErrs; + private final TezCounter wrongLengthErrs; + private final TezCounter badIdErrs; + private final TezCounter wrongMapErrs; + private final TezCounter wrongReduceErrs; + private final MergeManager merger; + private final ShuffleScheduler scheduler; + private final ShuffleClientMetrics metrics; + private final Shuffle shuffle; + private final int id; + private static int nextId = 0; + + private final int connectionTimeout; + private final int readTimeout; + + // Decompression of map-outputs + private final CompressionCodec codec; + private final Decompressor decompressor; + private final SecretKey jobTokenSecret; + + private volatile boolean stopped = false; + + private Configuration job; + + private static boolean sslShuffle; + private static SSLFactory sslFactory; + + public Fetcher(Configuration job, + ShuffleScheduler scheduler, MergeManager merger, + ShuffleClientMetrics metrics, + Shuffle shuffle, SecretKey jobTokenSecret, TezInputContext inputContext) throws IOException { + this.job = job; + this.scheduler = scheduler; + this.merger = merger; + this.metrics = metrics; + this.shuffle = shuffle; + this.id = ++nextId; + this.jobTokenSecret = jobTokenSecret; + ioErrs = inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME, + ShuffleErrors.IO_ERROR.toString()); + wrongLengthErrs = inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME, + ShuffleErrors.WRONG_LENGTH.toString()); + badIdErrs = inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME, + ShuffleErrors.BAD_ID.toString()); + wrongMapErrs = inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME, + ShuffleErrors.WRONG_MAP.toString()); + connectionErrs = inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME, + ShuffleErrors.CONNECTION.toString()); + wrongReduceErrs = inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME, + ShuffleErrors.WRONG_REDUCE.toString()); + + if (ConfigUtils.isIntermediateInputCompressed(job)) { + Class codecClass = + ConfigUtils.getIntermediateInputCompressorClass(job, DefaultCodec.class); + codec = ReflectionUtils.newInstance(codecClass, job); + decompressor = CodecPool.getDecompressor(codec); + } else { + codec = null; + decompressor = null; + } + + this.connectionTimeout = + job.getInt(TezJobConfig.TEZ_RUNTIME_SHUFFLE_CONNECT_TIMEOUT, + TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_STALLED_COPY_TIMEOUT); + this.readTimeout = + job.getInt(TezJobConfig.TEZ_RUNTIME_SHUFFLE_READ_TIMEOUT, + TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_READ_TIMEOUT); + + setName("fetcher#" + id); + setDaemon(true); + + synchronized (Fetcher.class) { + sslShuffle = job.getBoolean(TezJobConfig.TEZ_RUNTIME_SHUFFLE_ENABLE_SSL, + TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_ENABLE_SSL); + if (sslShuffle && sslFactory == null) { + sslFactory = new SSLFactory(SSLFactory.Mode.CLIENT, job); + try { + sslFactory.init(); + } catch (Exception ex) { + sslFactory.destroy(); + throw new RuntimeException(ex); + } + } + } + } + + public void run() { + try { + while (!stopped && !Thread.currentThread().isInterrupted()) { + MapHost host = null; + try { + // If merge is on, block + merger.waitForInMemoryMerge(); + + // Get a host to shuffle from + host = scheduler.getHost(); + metrics.threadBusy(); + + // Shuffle + copyFromHost(host); + } finally { + if (host != null) { + scheduler.freeHost(host); + metrics.threadFree(); + } + } + } + } catch (InterruptedException ie) { + return; + } catch (Throwable t) { + shuffle.reportException(t); + } + } + + public void shutDown() throws InterruptedException { + this.stopped = true; + interrupt(); + try { + join(5000); + } catch (InterruptedException ie) { + LOG.warn("Got interrupt while joining " + getName(), ie); + } + if (sslFactory != null) { + sslFactory.destroy(); + } + } + + @VisibleForTesting + protected HttpURLConnection openConnection(URL url) throws IOException { + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + if (sslShuffle) { + HttpsURLConnection httpsConn = (HttpsURLConnection) conn; + try { + httpsConn.setSSLSocketFactory(sslFactory.createSSLSocketFactory()); + } catch (GeneralSecurityException ex) { + throw new IOException(ex); + } + httpsConn.setHostnameVerifier(sslFactory.getHostnameVerifier()); + } + return conn; + } + + /** + * The crux of the matter... + * + * @param host {@link MapHost} from which we need to + * shuffle available map-outputs. + */ + @VisibleForTesting + protected void copyFromHost(MapHost host) throws IOException { + // Get completed maps on 'host' + List srcAttempts = scheduler.getMapsForHost(host); + + // Sanity check to catch hosts with only 'OBSOLETE' maps, + // especially at the tail of large jobs + if (srcAttempts.size() == 0) { + return; + } + + if(LOG.isDebugEnabled()) { + LOG.debug("Fetcher " + id + " going to fetch from " + host + " for: " + + srcAttempts); + } + + // List of maps to be fetched yet + Set remaining = new HashSet(srcAttempts); + + // Construct the url and connect + DataInputStream input; + boolean connectSucceeded = false; + + try { + URL url = getMapOutputURL(host, srcAttempts); + HttpURLConnection connection = openConnection(url); + + // generate hash of the url + String msgToEncode = SecureShuffleUtils.buildMsgFrom(url); + String encHash = SecureShuffleUtils.hashFromString(msgToEncode, jobTokenSecret); + + // put url hash into http header + connection.addRequestProperty( + SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash); + // set the read timeout + connection.setReadTimeout(readTimeout); + // put shuffle version into http header + connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_NAME, + ShuffleHeader.DEFAULT_HTTP_HEADER_NAME); + connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION, + ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION); + connect(connection, connectionTimeout); + connectSucceeded = true; + input = new DataInputStream(connection.getInputStream()); + + // Validate response code + int rc = connection.getResponseCode(); + if (rc != HttpURLConnection.HTTP_OK) { + throw new IOException( + "Got invalid response code " + rc + " from " + url + + ": " + connection.getResponseMessage()); + } + // get the shuffle version + if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals( + connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME)) + || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals( + connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION))) { + throw new IOException("Incompatible shuffle response version"); + } + // get the replyHash which is HMac of the encHash we sent to the server + String replyHash = connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH); + if(replyHash==null) { + throw new IOException("security validation of TT Map output failed"); + } + LOG.debug("url="+msgToEncode+";encHash="+encHash+";replyHash="+replyHash); + // verify that replyHash is HMac of encHash + SecureShuffleUtils.verifyReply(replyHash, encHash, jobTokenSecret); + LOG.info("for url="+msgToEncode+" sent hash and receievd reply"); + } catch (IOException ie) { + ioErrs.increment(1); + LOG.warn("Failed to connect to " + host + " with " + remaining.size() + + " map outputs", ie); + + // If connect did not succeed, just mark all the maps as failed, + // indirectly penalizing the host + if (!connectSucceeded) { + for(InputAttemptIdentifier left: remaining) { + scheduler.copyFailed(left, host, connectSucceeded); + } + } else { + // If we got a read error at this stage, it implies there was a problem + // with the first map, typically lost map. So, penalize only that map + // and add the rest + InputAttemptIdentifier firstMap = srcAttempts.get(0); + scheduler.copyFailed(firstMap, host, connectSucceeded); + } + + // Add back all the remaining maps, WITHOUT marking them as failed + for(InputAttemptIdentifier left: remaining) { + // TODO Should the first one be skipped ? + scheduler.putBackKnownMapOutput(host, left); + } + + return; + } + + try { + // Loop through available map-outputs and fetch them + // On any error, faildTasks is not null and we exit + // after putting back the remaining maps to the + // yet_to_be_fetched list and marking the failed tasks. + InputAttemptIdentifier[] failedTasks = null; + while (!remaining.isEmpty() && failedTasks == null) { + failedTasks = copyMapOutput(host, input, remaining); + } + + if(failedTasks != null && failedTasks.length > 0) { + LOG.warn("copyMapOutput failed for tasks "+Arrays.toString(failedTasks)); + for(InputAttemptIdentifier left: failedTasks) { + scheduler.copyFailed(left, host, true); + } + } + + IOUtils.cleanup(LOG, input); + + // Sanity check + if (failedTasks == null && !remaining.isEmpty()) { + throw new IOException("server didn't return all expected map outputs: " + + remaining.size() + " left."); + } + } finally { + for (InputAttemptIdentifier left : remaining) { + scheduler.putBackKnownMapOutput(host, left); + } + } + } + + private static InputAttemptIdentifier[] EMPTY_ATTEMPT_ID_ARRAY = new InputAttemptIdentifier[0]; + + private InputAttemptIdentifier[] copyMapOutput(MapHost host, + DataInputStream input, + Set remaining) { + MapOutput mapOutput = null; + InputAttemptIdentifier srcAttemptId = null; + long decompressedLength = -1; + long compressedLength = -1; + + try { + long startTime = System.currentTimeMillis(); + int forReduce = -1; + //Read the shuffle header + try { + ShuffleHeader header = new ShuffleHeader(); + header.readFields(input); + String pathComponent = header.mapId; + srcAttemptId = scheduler.getIdentifierForPathComponent(pathComponent); + compressedLength = header.compressedLength; + decompressedLength = header.uncompressedLength; + forReduce = header.forReduce; + } catch (IllegalArgumentException e) { + badIdErrs.increment(1); + LOG.warn("Invalid map id ", e); + //Don't know which one was bad, so consider all of them as bad + return remaining.toArray(new InputAttemptIdentifier[remaining.size()]); + } + + + // Do some basic sanity verification + if (!verifySanity(compressedLength, decompressedLength, forReduce, + remaining, srcAttemptId)) { + return new InputAttemptIdentifier[] {srcAttemptId}; + } + + if(LOG.isDebugEnabled()) { + LOG.debug("header: " + srcAttemptId + ", len: " + compressedLength + + ", decomp len: " + decompressedLength); + } + + // Get the location for the map output - either in-memory or on-disk + mapOutput = merger.reserve(srcAttemptId, decompressedLength, id); + + // Check if we can shuffle *now* ... + if (mapOutput.getType() == Type.WAIT) { + LOG.info("fetcher#" + id + " - MergerManager returned Status.WAIT ..."); + //Not an error but wait to process data. + return EMPTY_ATTEMPT_ID_ARRAY; + } + + // Go! + LOG.info("fetcher#" + id + " about to shuffle output of map " + + mapOutput.getAttemptIdentifier() + " decomp: " + + decompressedLength + " len: " + compressedLength + " to " + + mapOutput.getType()); + if (mapOutput.getType() == Type.MEMORY) { + shuffleToMemory(host, mapOutput, input, + (int) decompressedLength, (int) compressedLength); + } else { + shuffleToDisk(host, mapOutput, input, compressedLength); + } + + // Inform the shuffle scheduler + long endTime = System.currentTimeMillis(); + scheduler.copySucceeded(srcAttemptId, host, compressedLength, + endTime - startTime, mapOutput); + // Note successful shuffle + remaining.remove(srcAttemptId); + metrics.successFetch(); + return null; + } catch (IOException ioe) { + ioErrs.increment(1); + if (srcAttemptId == null || mapOutput == null) { + LOG.info("fetcher#" + id + " failed to read map header" + + srcAttemptId + " decomp: " + + decompressedLength + ", " + compressedLength, ioe); + if(srcAttemptId == null) { + return remaining.toArray(new InputAttemptIdentifier[remaining.size()]); + } else { + return new InputAttemptIdentifier[] {srcAttemptId}; + } + } + + LOG.warn("Failed to shuffle output of " + srcAttemptId + + " from " + host.getHostName(), ioe); + + // Inform the shuffle-scheduler + mapOutput.abort(); + metrics.failedFetch(); + return new InputAttemptIdentifier[] {srcAttemptId}; + } + + } + + /** + * Do some basic verification on the input received -- Being defensive + * @param compressedLength + * @param decompressedLength + * @param forReduce + * @param remaining + * @param mapId + * @return true/false, based on if the verification succeeded or not + */ + private boolean verifySanity(long compressedLength, long decompressedLength, + int forReduce, Set remaining, InputAttemptIdentifier srcAttemptId) { + if (compressedLength < 0 || decompressedLength < 0) { + wrongLengthErrs.increment(1); + LOG.warn(getName() + " invalid lengths in map output header: id: " + + srcAttemptId + " len: " + compressedLength + ", decomp len: " + + decompressedLength); + return false; + } + + int reduceStartId = shuffle.getReduceStartId(); + int reduceRange = shuffle.getReduceRange(); + if (forReduce < reduceStartId || forReduce >= reduceStartId+reduceRange) { + wrongReduceErrs.increment(1); + LOG.warn(getName() + " data for the wrong reduce map: " + + srcAttemptId + " len: " + compressedLength + " decomp len: " + + decompressedLength + " for reduce " + forReduce); + return false; + } + + // Sanity check + if (!remaining.contains(srcAttemptId)) { + wrongMapErrs.increment(1); + LOG.warn("Invalid map-output! Received output for " + srcAttemptId); + return false; + } + + return true; + } + + /** + * Create the map-output-url. This will contain all the map ids + * separated by commas + * @param host + * @param maps + * @return + * @throws MalformedURLException + */ + private URL getMapOutputURL(MapHost host, List srcAttempts + ) throws MalformedURLException { + // Get the base url + StringBuffer url = new StringBuffer(host.getBaseUrl()); + + boolean first = true; + for (InputAttemptIdentifier mapId : srcAttempts) { + if (!first) { + url.append(","); + } + url.append(mapId.getPathComponent()); + first = false; + } + + if (LOG.isDebugEnabled()) { + LOG.debug("MapOutput URL for " + host + " -> " + url.toString()); + } + return new URL(url.toString()); + } + + /** + * The connection establishment is attempted multiple times and is given up + * only on the last failure. Instead of connecting with a timeout of + * X, we try connecting with a timeout of x < X but multiple times. + */ + private void connect(URLConnection connection, int connectionTimeout) + throws IOException { + int unit = 0; + if (connectionTimeout < 0) { + throw new IOException("Invalid timeout " + + "[timeout = " + connectionTimeout + " ms]"); + } else if (connectionTimeout > 0) { + unit = Math.min(UNIT_CONNECT_TIMEOUT, connectionTimeout); + } + // set the connect timeout to the unit-connect-timeout + connection.setConnectTimeout(unit); + while (true) { + try { + connection.connect(); + break; + } catch (IOException ioe) { + // update the total remaining connect-timeout + connectionTimeout -= unit; + + // throw an exception if we have waited for timeout amount of time + // note that the updated value if timeout is used here + if (connectionTimeout == 0) { + throw ioe; + } + + // reset the connect timeout for the last try + if (connectionTimeout < unit) { + unit = connectionTimeout; + // reset the connect time out for the final connect + connection.setConnectTimeout(unit); + } + } + } + } + + private void shuffleToMemory(MapHost host, MapOutput mapOutput, + InputStream input, + int decompressedLength, + int compressedLength) throws IOException { + IFileInputStream checksumIn = + new IFileInputStream(input, compressedLength, job); + + input = checksumIn; + + // Are map-outputs compressed? + if (codec != null) { + decompressor.reset(); + input = codec.createInputStream(input, decompressor); + } + + // Copy map-output into an in-memory buffer + byte[] shuffleData = mapOutput.getMemory(); + + try { + IOUtils.readFully(input, shuffleData, 0, shuffleData.length); + metrics.inputBytes(shuffleData.length); + LOG.info("Read " + shuffleData.length + " bytes from map-output for " + + mapOutput.getAttemptIdentifier()); + } catch (IOException ioe) { + // Close the streams + IOUtils.cleanup(LOG, input); + + // Re-throw + throw ioe; + } + + } + + private void shuffleToDisk(MapHost host, MapOutput mapOutput, + InputStream input, + long compressedLength) + throws IOException { + // Copy data to local-disk + OutputStream output = mapOutput.getDisk(); + long bytesLeft = compressedLength; + try { + final int BYTES_TO_READ = 64 * 1024; + byte[] buf = new byte[BYTES_TO_READ]; + while (bytesLeft > 0) { + int n = input.read(buf, 0, (int) Math.min(bytesLeft, BYTES_TO_READ)); + if (n < 0) { + throw new IOException("read past end of stream reading " + + mapOutput.getAttemptIdentifier()); + } + output.write(buf, 0, n); + bytesLeft -= n; + metrics.inputBytes(n); + } + + LOG.info("Read " + (compressedLength - bytesLeft) + + " bytes from map-output for " + + mapOutput.getAttemptIdentifier()); + + output.close(); + } catch (IOException ioe) { + // Close the streams + IOUtils.cleanup(LOG, input, output); + + // Re-throw + throw ioe; + } + + // Sanity check + if (bytesLeft != 0) { + throw new IOException("Incomplete map output received for " + + mapOutput.getAttemptIdentifier() + " from " + + host.getHostName() + " (" + + bytesLeft + " bytes missing of " + + compressedLength + ")" + ); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryReader.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryReader.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryReader.java new file mode 100644 index 0000000..ae95268 --- /dev/null +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryReader.java @@ -0,0 +1,156 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.runtime.library.common.shuffle.impl; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.io.DataInputBuffer; +import org.apache.tez.runtime.library.common.InputAttemptIdentifier; +import org.apache.tez.runtime.library.common.sort.impl.IFile; +import org.apache.tez.runtime.library.common.sort.impl.IFile.Reader; + +/** + * IFile.InMemoryReader to read map-outputs present in-memory. + */ +@InterfaceAudience.Private +@InterfaceStability.Unstable +public class InMemoryReader extends Reader { + private final InputAttemptIdentifier taskAttemptId; + private final MergeManager merger; + DataInputBuffer memDataIn = new DataInputBuffer(); + private int start; + private int length; + private int prevKeyPos; + + public InMemoryReader(MergeManager merger, InputAttemptIdentifier taskAttemptId, + byte[] data, int start, int length) + throws IOException { + super(null, null, length - start, null, null); + this.merger = merger; + this.taskAttemptId = taskAttemptId; + + buffer = data; + bufferSize = (int)fileLength; + memDataIn.reset(buffer, start, length); + this.start = start; + this.length = length; + } + + @Override + public void reset(int offset) { + memDataIn.reset(buffer, start + offset, length); + bytesRead = offset; + eof = false; + } + + @Override + public long getPosition() throws IOException { + // InMemoryReader does not initialize streams like Reader, so in.getPos() + // would not work. Instead, return the number of uncompressed bytes read, + // which will be correct since in-memory data is not compressed. + return bytesRead; + } + + @Override + public long getLength() { + return fileLength; + } + + private void dumpOnError() { + File dumpFile = new File("../output/" + taskAttemptId + ".dump"); + System.err.println("Dumping corrupt map-output of " + taskAttemptId + + " to " + dumpFile.getAbsolutePath()); + try { + FileOutputStream fos = new FileOutputStream(dumpFile); + fos.write(buffer, 0, bufferSize); + fos.close(); + } catch (IOException ioe) { + System.err.println("Failed to dump map-output of " + taskAttemptId); + } + } + + public KeyState readRawKey(DataInputBuffer key) throws IOException { + try { + if (!positionToNextRecord(memDataIn)) { + return KeyState.NO_KEY; + } + // Setup the key + int pos = memDataIn.getPosition(); + byte[] data = memDataIn.getData(); + if(currentKeyLength == IFile.RLE_MARKER) { + key.reset(data, prevKeyPos, prevKeyLength); + currentKeyLength = prevKeyLength; + return KeyState.SAME_KEY; + } + key.reset(data, pos, currentKeyLength); + prevKeyPos = pos; + // Position for the next value + long skipped = memDataIn.skip(currentKeyLength); + if (skipped != currentKeyLength) { + throw new IOException("Rec# " + recNo + + ": Failed to skip past key of length: " + + currentKeyLength); + } + + // Record the byte + bytesRead += currentKeyLength; + return KeyState.NEW_KEY; + } catch (IOException ioe) { + dumpOnError(); + throw ioe; + } + } + + public void nextRawValue(DataInputBuffer value) throws IOException { + try { + int pos = memDataIn.getPosition(); + byte[] data = memDataIn.getData(); + value.reset(data, pos, currentValueLength); + + // Position for the next record + long skipped = memDataIn.skip(currentValueLength); + if (skipped != currentValueLength) { + throw new IOException("Rec# " + recNo + + ": Failed to skip past value of length: " + + currentValueLength); + } + // Record the byte + bytesRead += currentValueLength; + + ++recNo; + } catch (IOException ioe) { + dumpOnError(); + throw ioe; + } + } + + public void close() { + // Release + dataIn = null; + buffer = null; + // Inform the MergeManager + if (merger != null) { + merger.unreserve(bufferSize); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryWriter.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryWriter.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryWriter.java new file mode 100644 index 0000000..f81b28e --- /dev/null +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryWriter.java @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tez.runtime.library.common.shuffle.impl; + +import java.io.DataOutputStream; +import java.io.IOException; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.io.BoundedByteArrayOutputStream; +import org.apache.hadoop.io.DataInputBuffer; +import org.apache.hadoop.io.WritableUtils; +import org.apache.tez.runtime.library.common.sort.impl.IFile; +import org.apache.tez.runtime.library.common.sort.impl.IFileOutputStream; +import org.apache.tez.runtime.library.common.sort.impl.IFile.Writer; + +@InterfaceAudience.Private +@InterfaceStability.Unstable +public class InMemoryWriter extends Writer { + private static final Log LOG = LogFactory.getLog(InMemoryWriter.class); + + private DataOutputStream out; + + public InMemoryWriter(BoundedByteArrayOutputStream arrayStream) { + super(null); + this.out = + new DataOutputStream(new IFileOutputStream(arrayStream)); + } + + public void append(Object key, Object value) throws IOException { + throw new UnsupportedOperationException + ("InMemoryWriter.append(K key, V value"); + } + + public void append(DataInputBuffer key, DataInputBuffer value) + throws IOException { + int keyLength = key.getLength() - key.getPosition(); + if (keyLength < 0) { + throw new IOException("Negative key-length not allowed: " + keyLength + + " for " + key); + } + + boolean sameKey = (key == IFile.REPEAT_KEY); + + int valueLength = value.getLength() - value.getPosition(); + if (valueLength < 0) { + throw new IOException("Negative value-length not allowed: " + + valueLength + " for " + value); + } + + if(sameKey) { + WritableUtils.writeVInt(out, IFile.RLE_MARKER); + WritableUtils.writeVInt(out, valueLength); + out.write(value.getData(), value.getPosition(), valueLength); + } else { + if (LOG.isDebugEnabled()) { + LOG.debug("InMemWriter.append" + + " key.data=" + key.getData() + + " key.pos=" + key.getPosition() + + " key.len=" +key.getLength() + + " val.data=" + value.getData() + + " val.pos=" + value.getPosition() + + " val.len=" + value.getLength()); + } + WritableUtils.writeVInt(out, keyLength); + WritableUtils.writeVInt(out, valueLength); + out.write(key.getData(), key.getPosition(), keyLength); + out.write(value.getData(), value.getPosition(), valueLength); + } + + } + + public void close() throws IOException { + // Write EOF_MARKER for key/value length + WritableUtils.writeVInt(out, IFile.EOF_MARKER); + WritableUtils.writeVInt(out, IFile.EOF_MARKER); + + // Close the stream + out.close(); + out = null; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapHost.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapHost.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapHost.java new file mode 100644 index 0000000..b8be657 --- /dev/null +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapHost.java @@ -0,0 +1,124 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tez.runtime.library.common.shuffle.impl; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.classification.InterfaceAudience.Private; +import org.apache.tez.runtime.library.common.InputAttemptIdentifier; + +@Private +class MapHost { + + public static enum State { + IDLE, // No map outputs available + BUSY, // Map outputs are being fetched + PENDING, // Known map outputs which need to be fetched + PENALIZED // Host penalized due to shuffle failures + } + + private State state = State.IDLE; + private final String hostName; + private final int partitionId; + private final String baseUrl; + private final String identifier; + // Tracks attempt IDs + private List maps = new ArrayList(); + + public MapHost(int partitionId, String hostName, String baseUrl) { + this.partitionId = partitionId; + this.hostName = hostName; + this.baseUrl = baseUrl; + this.identifier = createIdentifier(hostName, partitionId); + } + + public static String createIdentifier(String hostName, int partitionId) { + return hostName + ":" + Integer.toString(partitionId); + } + + public String getIdentifier() { + return identifier; + } + + public int getPartitionId() { + return partitionId; + } + + public State getState() { + return state; + } + + public String getHostName() { + return hostName; + } + + public String getBaseUrl() { + return baseUrl; + } + + public synchronized void addKnownMap(InputAttemptIdentifier srcAttempt) { + maps.add(srcAttempt); + if (state == State.IDLE) { + state = State.PENDING; + } + } + + public synchronized List getAndClearKnownMaps() { + List currentKnownMaps = maps; + maps = new ArrayList(); + return currentKnownMaps; + } + + public synchronized void markBusy() { + state = State.BUSY; + } + + public synchronized void markPenalized() { + state = State.PENALIZED; + } + + public synchronized int getNumKnownMapOutputs() { + return maps.size(); + } + + /** + * Called when the node is done with its penalty or done copying. + * @return the host's new state + */ + public synchronized State markAvailable() { + if (maps.isEmpty()) { + state = State.IDLE; + } else { + state = State.PENDING; + } + return state; + } + + @Override + public String toString() { + return hostName; + } + + /** + * Mark the host as penalized + */ + public synchronized void penalize() { + state = State.PENALIZED; + } +} http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapOutput.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapOutput.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapOutput.java new file mode 100644 index 0000000..9f673a0 --- /dev/null +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapOutput.java @@ -0,0 +1,227 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tez.runtime.library.common.shuffle.impl; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.Comparator; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.LocalDirAllocator; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.BoundedByteArrayOutputStream; +import org.apache.tez.runtime.library.common.InputAttemptIdentifier; +import org.apache.tez.runtime.library.common.task.local.output.TezTaskOutputFiles; + + +class MapOutput { + private static final Log LOG = LogFactory.getLog(MapOutput.class); + private static AtomicInteger ID = new AtomicInteger(0); + + public static enum Type { + WAIT, + MEMORY, + DISK + } + + private InputAttemptIdentifier attemptIdentifier; + private final int id; + + private final MergeManager merger; + + private final long size; + + private final byte[] memory; + private BoundedByteArrayOutputStream byteStream; + + private final FileSystem localFS; + private final Path tmpOutputPath; + private final Path outputPath; + private final OutputStream disk; + + private final Type type; + + private final boolean primaryMapOutput; + + MapOutput(InputAttemptIdentifier attemptIdentifier, MergeManager merger, long size, + Configuration conf, LocalDirAllocator localDirAllocator, + int fetcher, boolean primaryMapOutput, + TezTaskOutputFiles mapOutputFile) + throws IOException { + this.id = ID.incrementAndGet(); + this.attemptIdentifier = attemptIdentifier; + this.merger = merger; + + type = Type.DISK; + + memory = null; + byteStream = null; + + this.size = size; + + this.localFS = FileSystem.getLocal(conf); + outputPath = + mapOutputFile.getInputFileForWrite(this.attemptIdentifier.getInputIdentifier().getSrcTaskIndex(), size); + tmpOutputPath = outputPath.suffix(String.valueOf(fetcher)); + + disk = localFS.create(tmpOutputPath); + + this.primaryMapOutput = primaryMapOutput; + } + + MapOutput(InputAttemptIdentifier attemptIdentifier, MergeManager merger, int size, + boolean primaryMapOutput) { + this.id = ID.incrementAndGet(); + this.attemptIdentifier = attemptIdentifier; + this.merger = merger; + + type = Type.MEMORY; + byteStream = new BoundedByteArrayOutputStream(size); + memory = byteStream.getBuffer(); + + this.size = size; + + localFS = null; + disk = null; + outputPath = null; + tmpOutputPath = null; + + this.primaryMapOutput = primaryMapOutput; + } + + public MapOutput(InputAttemptIdentifier attemptIdentifier) { + this.id = ID.incrementAndGet(); + this.attemptIdentifier = attemptIdentifier; + + type = Type.WAIT; + merger = null; + memory = null; + byteStream = null; + + size = -1; + + localFS = null; + disk = null; + outputPath = null; + tmpOutputPath = null; + + this.primaryMapOutput = false; +} + + public boolean isPrimaryMapOutput() { + return primaryMapOutput; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof MapOutput) { + return id == ((MapOutput)obj).id; + } + return false; + } + + @Override + public int hashCode() { + return id; + } + + public Path getOutputPath() { + return outputPath; + } + + public byte[] getMemory() { + return memory; + } + + public BoundedByteArrayOutputStream getArrayStream() { + return byteStream; + } + + public OutputStream getDisk() { + return disk; + } + + public InputAttemptIdentifier getAttemptIdentifier() { + return this.attemptIdentifier; + } + + public Type getType() { + return type; + } + + public long getSize() { + return size; + } + + public void commit() throws IOException { + if (type == Type.MEMORY) { + merger.closeInMemoryFile(this); + } else if (type == Type.DISK) { + localFS.rename(tmpOutputPath, outputPath); + merger.closeOnDiskFile(outputPath); + } else { + throw new IOException("Cannot commit MapOutput of type WAIT!"); + } + } + + public void abort() { + if (type == Type.MEMORY) { + merger.unreserve(memory.length); + } else if (type == Type.DISK) { + try { + localFS.delete(tmpOutputPath, false); + } catch (IOException ie) { + LOG.info("failure to clean up " + tmpOutputPath, ie); + } + } else { + throw new IllegalArgumentException + ("Cannot commit MapOutput with of type WAIT!"); + } + } + + public String toString() { + return "MapOutput( AttemptIdentifier: " + attemptIdentifier + ", Type: " + type + ")"; + } + + public static class MapOutputComparator + implements Comparator { + public int compare(MapOutput o1, MapOutput o2) { + if (o1.id == o2.id) { + return 0; + } + + if (o1.size < o2.size) { + return -1; + } else if (o1.size > o2.size) { + return 1; + } + + if (o1.id < o2.id) { + return -1; + } else { + return 1; + + } + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeManager.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeManager.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeManager.java new file mode 100644 index 0000000..0abe530 --- /dev/null +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeManager.java @@ -0,0 +1,782 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tez.runtime.library.common.shuffle.impl; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Set; +import java.util.TreeSet; + + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.classification.InterfaceAudience.Private; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.ChecksumFileSystem; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.LocalDirAllocator; +import org.apache.hadoop.fs.LocalFileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DataInputBuffer; +import org.apache.hadoop.io.RawComparator; +import org.apache.hadoop.io.compress.CompressionCodec; +import org.apache.hadoop.io.compress.DefaultCodec; +import org.apache.hadoop.util.Progressable; +import org.apache.hadoop.util.ReflectionUtils; +import org.apache.tez.common.TezJobConfig; +import org.apache.tez.common.counters.TezCounter; +import org.apache.tez.runtime.api.TezInputContext; +import org.apache.tez.runtime.library.common.ConfigUtils; +import org.apache.tez.runtime.library.common.Constants; +import org.apache.tez.runtime.library.common.InputAttemptIdentifier; +import org.apache.tez.runtime.library.common.combine.Combiner; +import org.apache.tez.runtime.library.common.sort.impl.IFile; +import org.apache.tez.runtime.library.common.sort.impl.TezMerger; +import org.apache.tez.runtime.library.common.sort.impl.TezRawKeyValueIterator; +import org.apache.tez.runtime.library.common.sort.impl.IFile.Writer; +import org.apache.tez.runtime.library.common.sort.impl.TezMerger.Segment; +import org.apache.tez.runtime.library.common.task.local.output.TezTaskOutputFiles; +import org.apache.tez.runtime.library.hadoop.compat.NullProgressable; + +@InterfaceAudience.Private +@InterfaceStability.Unstable +@SuppressWarnings(value={"rawtypes"}) +public class MergeManager { + + private static final Log LOG = LogFactory.getLog(MergeManager.class); + + private final Configuration conf; + private final FileSystem localFS; + private final FileSystem rfs; + private final LocalDirAllocator localDirAllocator; + + private final TezTaskOutputFiles mapOutputFile; + private final Progressable nullProgressable = new NullProgressable(); + private final Combiner combiner; + + Set inMemoryMergedMapOutputs = + new TreeSet(new MapOutput.MapOutputComparator()); + private final IntermediateMemoryToMemoryMerger memToMemMerger; + + Set inMemoryMapOutputs = + new TreeSet(new MapOutput.MapOutputComparator()); + private final InMemoryMerger inMemoryMerger; + + Set onDiskMapOutputs = new TreeSet(); + private final OnDiskMerger onDiskMerger; + + private final long memoryLimit; + private long usedMemory; + private long commitMemory; + private final long maxSingleShuffleLimit; + + private final int memToMemMergeOutputsThreshold; + private final long mergeThreshold; + + private final int ioSortFactor; + + private final ExceptionReporter exceptionReporter; + + private final TezInputContext inputContext; + + private final TezCounter spilledRecordsCounter; + + private final TezCounter reduceCombineInputCounter; + + private final TezCounter mergedMapOutputsCounter; + + private final CompressionCodec codec; + + private volatile boolean finalMergeComplete = false; + + public MergeManager(Configuration conf, + FileSystem localFS, + LocalDirAllocator localDirAllocator, + TezInputContext inputContext, + Combiner combiner, + TezCounter spilledRecordsCounter, + TezCounter reduceCombineInputCounter, + TezCounter mergedMapOutputsCounter, + ExceptionReporter exceptionReporter) { + this.inputContext = inputContext; + this.conf = conf; + this.localDirAllocator = localDirAllocator; + this.exceptionReporter = exceptionReporter; + + this.combiner = combiner; + + this.reduceCombineInputCounter = reduceCombineInputCounter; + this.spilledRecordsCounter = spilledRecordsCounter; + this.mergedMapOutputsCounter = mergedMapOutputsCounter; + this.mapOutputFile = new TezTaskOutputFiles(conf, inputContext.getUniqueIdentifier()); + + this.localFS = localFS; + this.rfs = ((LocalFileSystem)localFS).getRaw(); + + if (ConfigUtils.isIntermediateInputCompressed(conf)) { + Class codecClass = + ConfigUtils.getIntermediateInputCompressorClass(conf, DefaultCodec.class); + codec = ReflectionUtils.newInstance(codecClass, conf); + } else { + codec = null; + } + + final float maxInMemCopyUse = + conf.getFloat( + TezJobConfig.TEZ_RUNTIME_SHUFFLE_INPUT_BUFFER_PERCENT, + TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_INPUT_BUFFER_PERCENT); + if (maxInMemCopyUse > 1.0 || maxInMemCopyUse < 0.0) { + throw new IllegalArgumentException("Invalid value for " + + TezJobConfig.TEZ_RUNTIME_SHUFFLE_INPUT_BUFFER_PERCENT + ": " + + maxInMemCopyUse); + } + + // Allow unit tests to fix Runtime memory + this.memoryLimit = + (long)(conf.getLong(Constants.TEZ_RUNTIME_TASK_MEMORY, + Math.min(Runtime.getRuntime().maxMemory(), Integer.MAX_VALUE)) + * maxInMemCopyUse); + + this.ioSortFactor = + conf.getInt( + TezJobConfig.TEZ_RUNTIME_IO_SORT_FACTOR, + TezJobConfig.DEFAULT_TEZ_RUNTIME_IO_SORT_FACTOR); + + final float singleShuffleMemoryLimitPercent = + conf.getFloat( + TezJobConfig.TEZ_RUNTIME_SHUFFLE_MEMORY_LIMIT_PERCENT, + TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_MEMORY_LIMIT_PERCENT); + if (singleShuffleMemoryLimitPercent <= 0.0f + || singleShuffleMemoryLimitPercent > 1.0f) { + throw new IllegalArgumentException("Invalid value for " + + TezJobConfig.TEZ_RUNTIME_SHUFFLE_MEMORY_LIMIT_PERCENT + ": " + + singleShuffleMemoryLimitPercent); + } + + this.maxSingleShuffleLimit = + (long)(memoryLimit * singleShuffleMemoryLimitPercent); + this.memToMemMergeOutputsThreshold = + conf.getInt( + TezJobConfig.TEZ_RUNTIME_SHUFFLE_MEMTOMEM_SEGMENTS, + ioSortFactor); + this.mergeThreshold = + (long)(this.memoryLimit * + conf.getFloat( + TezJobConfig.TEZ_RUNTIME_SHUFFLE_MERGE_PERCENT, + TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_MERGE_PERCENT)); + LOG.info("MergerManager: memoryLimit=" + memoryLimit + ", " + + "maxSingleShuffleLimit=" + maxSingleShuffleLimit + ", " + + "mergeThreshold=" + mergeThreshold + ", " + + "ioSortFactor=" + ioSortFactor + ", " + + "memToMemMergeOutputsThreshold=" + memToMemMergeOutputsThreshold); + + if (this.maxSingleShuffleLimit >= this.mergeThreshold) { + throw new RuntimeException("Invlaid configuration: " + + "maxSingleShuffleLimit should be less than mergeThreshold" + + "maxSingleShuffleLimit: " + this.maxSingleShuffleLimit + + "mergeThreshold: " + this.mergeThreshold); + } + + boolean allowMemToMemMerge = + conf.getBoolean( + TezJobConfig.TEZ_RUNTIME_SHUFFLE_ENABLE_MEMTOMEM, + TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_ENABLE_MEMTOMEM); + if (allowMemToMemMerge) { + this.memToMemMerger = + new IntermediateMemoryToMemoryMerger(this, + memToMemMergeOutputsThreshold); + this.memToMemMerger.start(); + } else { + this.memToMemMerger = null; + } + + this.inMemoryMerger = new InMemoryMerger(this); + this.inMemoryMerger.start(); + + this.onDiskMerger = new OnDiskMerger(this); + this.onDiskMerger.start(); + } + + public void waitForInMemoryMerge() throws InterruptedException { + inMemoryMerger.waitForMerge(); + } + + private boolean canShuffleToMemory(long requestedSize) { + return (requestedSize < maxSingleShuffleLimit); + } + + final private MapOutput stallShuffle = new MapOutput(null); + + public synchronized MapOutput reserve(InputAttemptIdentifier srcAttemptIdentifier, + long requestedSize, + int fetcher + ) throws IOException { + if (!canShuffleToMemory(requestedSize)) { + LOG.info(srcAttemptIdentifier + ": Shuffling to disk since " + requestedSize + + " is greater than maxSingleShuffleLimit (" + + maxSingleShuffleLimit + ")"); + return new MapOutput(srcAttemptIdentifier, this, requestedSize, conf, + localDirAllocator, fetcher, true, + mapOutputFile); + } + + // Stall shuffle if we are above the memory limit + + // It is possible that all threads could just be stalling and not make + // progress at all. This could happen when: + // + // requested size is causing the used memory to go above limit && + // requested size < singleShuffleLimit && + // current used size < mergeThreshold (merge will not get triggered) + // + // To avoid this from happening, we allow exactly one thread to go past + // the memory limit. We check (usedMemory > memoryLimit) and not + // (usedMemory + requestedSize > memoryLimit). When this thread is done + // fetching, this will automatically trigger a merge thereby unlocking + // all the stalled threads + + if (usedMemory > memoryLimit) { + LOG.debug(srcAttemptIdentifier + ": Stalling shuffle since usedMemory (" + usedMemory + + ") is greater than memoryLimit (" + memoryLimit + ")." + + " CommitMemory is (" + commitMemory + ")"); + return stallShuffle; + } + + // Allow the in-memory shuffle to progress + LOG.debug(srcAttemptIdentifier + ": Proceeding with shuffle since usedMemory (" + + usedMemory + ") is lesser than memoryLimit (" + memoryLimit + ")." + + "CommitMemory is (" + commitMemory + ")"); + return unconditionalReserve(srcAttemptIdentifier, requestedSize, true); + } + + /** + * Unconditional Reserve is used by the Memory-to-Memory thread + * @return + */ + private synchronized MapOutput unconditionalReserve( + InputAttemptIdentifier srcAttemptIdentifier, long requestedSize, boolean primaryMapOutput) { + usedMemory += requestedSize; + return new MapOutput(srcAttemptIdentifier, this, (int)requestedSize, + primaryMapOutput); + } + + synchronized void unreserve(long size) { + commitMemory -= size; + usedMemory -= size; + } + + public synchronized void closeInMemoryFile(MapOutput mapOutput) { + inMemoryMapOutputs.add(mapOutput); + LOG.info("closeInMemoryFile -> map-output of size: " + mapOutput.getSize() + + ", inMemoryMapOutputs.size() -> " + inMemoryMapOutputs.size() + + ", commitMemory -> " + commitMemory + ", usedMemory ->" + usedMemory); + + commitMemory+= mapOutput.getSize(); + + synchronized (inMemoryMerger) { + // Can hang if mergeThreshold is really low. + if (!inMemoryMerger.isInProgress() && commitMemory >= mergeThreshold) { + LOG.info("Starting inMemoryMerger's merge since commitMemory=" + + commitMemory + " > mergeThreshold=" + mergeThreshold + + ". Current usedMemory=" + usedMemory); + inMemoryMapOutputs.addAll(inMemoryMergedMapOutputs); + inMemoryMergedMapOutputs.clear(); + inMemoryMerger.startMerge(inMemoryMapOutputs); + } + } + + if (memToMemMerger != null) { + synchronized (memToMemMerger) { + if (!memToMemMerger.isInProgress() && + inMemoryMapOutputs.size() >= memToMemMergeOutputsThreshold) { + memToMemMerger.startMerge(inMemoryMapOutputs); + } + } + } + } + + + public synchronized void closeInMemoryMergedFile(MapOutput mapOutput) { + inMemoryMergedMapOutputs.add(mapOutput); + LOG.info("closeInMemoryMergedFile -> size: " + mapOutput.getSize() + + ", inMemoryMergedMapOutputs.size() -> " + + inMemoryMergedMapOutputs.size()); + } + + public synchronized void closeOnDiskFile(Path file) { + onDiskMapOutputs.add(file); + + synchronized (onDiskMerger) { + if (!onDiskMerger.isInProgress() && + onDiskMapOutputs.size() >= (2 * ioSortFactor - 1)) { + onDiskMerger.startMerge(onDiskMapOutputs); + } + } + } + + /** + * Should only be used after the Shuffle phaze is complete, otherwise can + * return an invalid state since a merge may not be in progress dur to + * inadequate inputs + * + * @return true if the merge process is complete, otherwise false + */ + @Private + public boolean isMergeComplete() { + return finalMergeComplete; + } + + public TezRawKeyValueIterator close() throws Throwable { + // Wait for on-going merges to complete + if (memToMemMerger != null) { + memToMemMerger.close(); + } + inMemoryMerger.close(); + onDiskMerger.close(); + + List memory = + new ArrayList(inMemoryMergedMapOutputs); + memory.addAll(inMemoryMapOutputs); + List disk = new ArrayList(onDiskMapOutputs); + TezRawKeyValueIterator kvIter = finalMerge(conf, rfs, memory, disk); + this.finalMergeComplete = true; + return kvIter; + } + + void runCombineProcessor(TezRawKeyValueIterator kvIter, Writer writer) + throws IOException, InterruptedException { + combiner.combine(kvIter, writer); + } + + private class IntermediateMemoryToMemoryMerger + extends MergeThread { + + public IntermediateMemoryToMemoryMerger(MergeManager manager, + int mergeFactor) { + super(manager, mergeFactor, exceptionReporter); + setName("InMemoryMerger - Thread to do in-memory merge of in-memory " + + "shuffled map-outputs"); + setDaemon(true); + } + + @Override + public void merge(List inputs) throws IOException { + if (inputs == null || inputs.size() == 0) { + return; + } + + InputAttemptIdentifier dummyMapId = inputs.get(0).getAttemptIdentifier(); + List inMemorySegments = new ArrayList(); + long mergeOutputSize = + createInMemorySegments(inputs, inMemorySegments, 0); + int noInMemorySegments = inMemorySegments.size(); + + MapOutput mergedMapOutputs = + unconditionalReserve(dummyMapId, mergeOutputSize, false); + + Writer writer = + new InMemoryWriter(mergedMapOutputs.getArrayStream()); + + LOG.info("Initiating Memory-to-Memory merge with " + noInMemorySegments + + " segments of total-size: " + mergeOutputSize); + + TezRawKeyValueIterator rIter = + TezMerger.merge(conf, rfs, + ConfigUtils.getIntermediateInputKeyClass(conf), + ConfigUtils.getIntermediateInputValueClass(conf), + inMemorySegments, inMemorySegments.size(), + new Path(inputContext.getUniqueIdentifier()), + (RawComparator)ConfigUtils.getIntermediateInputKeyComparator(conf), + nullProgressable, null, null, null); + TezMerger.writeFile(rIter, writer, nullProgressable, conf); + writer.close(); + + LOG.info(inputContext.getUniqueIdentifier() + + " Memory-to-Memory merge of the " + noInMemorySegments + + " files in-memory complete."); + + // Note the output of the merge + closeInMemoryMergedFile(mergedMapOutputs); + } + } + + private class InMemoryMerger extends MergeThread { + + public InMemoryMerger(MergeManager manager) { + super(manager, Integer.MAX_VALUE, exceptionReporter); + setName + ("InMemoryMerger - Thread to merge in-memory shuffled map-outputs"); + setDaemon(true); + } + + @Override + public void merge(List inputs) throws IOException, InterruptedException { + if (inputs == null || inputs.size() == 0) { + return; + } + + //name this output file same as the name of the first file that is + //there in the current list of inmem files (this is guaranteed to + //be absent on the disk currently. So we don't overwrite a prev. + //created spill). Also we need to create the output file now since + //it is not guaranteed that this file will be present after merge + //is called (we delete empty files as soon as we see them + //in the merge method) + + //figure out the mapId + InputAttemptIdentifier srcTaskIdentifier = inputs.get(0).getAttemptIdentifier(); + + List inMemorySegments = new ArrayList(); + long mergeOutputSize = + createInMemorySegments(inputs, inMemorySegments,0); + int noInMemorySegments = inMemorySegments.size(); + + Path outputPath = mapOutputFile.getInputFileForWrite( + srcTaskIdentifier.getInputIdentifier().getSrcTaskIndex(), + mergeOutputSize).suffix(Constants.MERGED_OUTPUT_PREFIX); + + Writer writer = null; + try { + writer = + new Writer(conf, rfs, outputPath, + (Class)ConfigUtils.getIntermediateInputKeyClass(conf), + (Class)ConfigUtils.getIntermediateInputValueClass(conf), + codec, null); + + TezRawKeyValueIterator rIter = null; + LOG.info("Initiating in-memory merge with " + noInMemorySegments + + " segments..."); + + rIter = TezMerger.merge(conf, rfs, + (Class)ConfigUtils.getIntermediateInputKeyClass(conf), + (Class)ConfigUtils.getIntermediateInputValueClass(conf), + inMemorySegments, inMemorySegments.size(), + new Path(inputContext.getUniqueIdentifier()), + (RawComparator)ConfigUtils.getIntermediateInputKeyComparator(conf), + nullProgressable, spilledRecordsCounter, null, null); + + if (null == combiner) { + TezMerger.writeFile(rIter, writer, nullProgressable, conf); + } else { + runCombineProcessor(rIter, writer); + } + writer.close(); + writer = null; + + LOG.info(inputContext.getUniqueIdentifier() + + " Merge of the " + noInMemorySegments + + " files in-memory complete." + + " Local file is " + outputPath + " of size " + + localFS.getFileStatus(outputPath).getLen()); + } catch (IOException e) { + //make sure that we delete the ondisk file that we created + //earlier when we invoked cloneFileAttributes + localFS.delete(outputPath, true); + throw e; + } finally { + if (writer != null) { + writer.close(); + } + } + + // Note the output of the merge + closeOnDiskFile(outputPath); + } + + } + + private class OnDiskMerger extends MergeThread { + + public OnDiskMerger(MergeManager manager) { + super(manager, Integer.MAX_VALUE, exceptionReporter); + setName("OnDiskMerger - Thread to merge on-disk map-outputs"); + setDaemon(true); + } + + @Override + public void merge(List inputs) throws IOException { + // sanity check + if (inputs == null || inputs.isEmpty()) { + LOG.info("No ondisk files to merge..."); + return; + } + + long approxOutputSize = 0; + int bytesPerSum = + conf.getInt("io.bytes.per.checksum", 512); + + LOG.info("OnDiskMerger: We have " + inputs.size() + + " map outputs on disk. Triggering merge..."); + + // 1. Prepare the list of files to be merged. + for (Path file : inputs) { + approxOutputSize += localFS.getFileStatus(file).getLen(); + } + + // add the checksum length + approxOutputSize += + ChecksumFileSystem.getChecksumLength(approxOutputSize, bytesPerSum); + + // 2. Start the on-disk merge process + Path outputPath = + localDirAllocator.getLocalPathForWrite(inputs.get(0).toString(), + approxOutputSize, conf).suffix(Constants.MERGED_OUTPUT_PREFIX); + Writer writer = + new Writer(conf, rfs, outputPath, + (Class)ConfigUtils.getIntermediateInputKeyClass(conf), + (Class)ConfigUtils.getIntermediateInputValueClass(conf), + codec, null); + TezRawKeyValueIterator iter = null; + Path tmpDir = new Path(inputContext.getUniqueIdentifier()); + try { + iter = TezMerger.merge(conf, rfs, + (Class)ConfigUtils.getIntermediateInputKeyClass(conf), + (Class)ConfigUtils.getIntermediateInputValueClass(conf), + codec, inputs.toArray(new Path[inputs.size()]), + true, ioSortFactor, tmpDir, + (RawComparator)ConfigUtils.getIntermediateInputKeyComparator(conf), + nullProgressable, spilledRecordsCounter, null, + mergedMapOutputsCounter, null); + + TezMerger.writeFile(iter, writer, nullProgressable, conf); + writer.close(); + } catch (IOException e) { + localFS.delete(outputPath, true); + throw e; + } + + closeOnDiskFile(outputPath); + + LOG.info(inputContext.getUniqueIdentifier() + + " Finished merging " + inputs.size() + + " map output files on disk of total-size " + + approxOutputSize + "." + + " Local output file is " + outputPath + " of size " + + localFS.getFileStatus(outputPath).getLen()); + } + } + + private long createInMemorySegments(List inMemoryMapOutputs, + List inMemorySegments, + long leaveBytes + ) throws IOException { + long totalSize = 0L; + // We could use fullSize could come from the RamManager, but files can be + // closed but not yet present in inMemoryMapOutputs + long fullSize = 0L; + for (MapOutput mo : inMemoryMapOutputs) { + fullSize += mo.getMemory().length; + } + while(fullSize > leaveBytes) { + MapOutput mo = inMemoryMapOutputs.remove(0); + byte[] data = mo.getMemory(); + long size = data.length; + totalSize += size; + fullSize -= size; + IFile.Reader reader = new InMemoryReader(MergeManager.this, + mo.getAttemptIdentifier(), + data, 0, (int)size); + inMemorySegments.add(new Segment(reader, true, + (mo.isPrimaryMapOutput() ? + mergedMapOutputsCounter : null))); + } + return totalSize; + } + + class RawKVIteratorReader extends IFile.Reader { + + private final TezRawKeyValueIterator kvIter; + + public RawKVIteratorReader(TezRawKeyValueIterator kvIter, long size) + throws IOException { + super(null, null, size, null, spilledRecordsCounter); + this.kvIter = kvIter; + } + public boolean nextRawKey(DataInputBuffer key) throws IOException { + if (kvIter.next()) { + final DataInputBuffer kb = kvIter.getKey(); + final int kp = kb.getPosition(); + final int klen = kb.getLength() - kp; + key.reset(kb.getData(), kp, klen); + bytesRead += klen; + return true; + } + return false; + } + public void nextRawValue(DataInputBuffer value) throws IOException { + final DataInputBuffer vb = kvIter.getValue(); + final int vp = vb.getPosition(); + final int vlen = vb.getLength() - vp; + value.reset(vb.getData(), vp, vlen); + bytesRead += vlen; + } + public long getPosition() throws IOException { + return bytesRead; + } + + public void close() throws IOException { + kvIter.close(); + } + } + + private TezRawKeyValueIterator finalMerge(Configuration job, FileSystem fs, + List inMemoryMapOutputs, + List onDiskMapOutputs + ) throws IOException { + LOG.info("finalMerge called with " + + inMemoryMapOutputs.size() + " in-memory map-outputs and " + + onDiskMapOutputs.size() + " on-disk map-outputs"); + + final float maxRedPer = + job.getFloat( + TezJobConfig.TEZ_RUNTIME_INPUT_BUFFER_PERCENT, + TezJobConfig.DEFAULT_TEZ_RUNTIME_INPUT_BUFFER_PERCENT); + if (maxRedPer > 1.0 || maxRedPer < 0.0) { + throw new IOException(TezJobConfig.TEZ_RUNTIME_INPUT_BUFFER_PERCENT + + maxRedPer); + } + int maxInMemReduce = (int)Math.min( + Runtime.getRuntime().maxMemory() * maxRedPer, Integer.MAX_VALUE); + + + // merge config params + Class keyClass = (Class)ConfigUtils.getIntermediateInputKeyClass(job); + Class valueClass = (Class)ConfigUtils.getIntermediateInputValueClass(job); + final Path tmpDir = new Path(inputContext.getUniqueIdentifier()); + final RawComparator comparator = + (RawComparator)ConfigUtils.getIntermediateInputKeyComparator(job); + + // segments required to vacate memory + List memDiskSegments = new ArrayList(); + long inMemToDiskBytes = 0; + boolean mergePhaseFinished = false; + if (inMemoryMapOutputs.size() > 0) { + int srcTaskId = inMemoryMapOutputs.get(0).getAttemptIdentifier().getInputIdentifier().getSrcTaskIndex(); + inMemToDiskBytes = createInMemorySegments(inMemoryMapOutputs, + memDiskSegments, + maxInMemReduce); + final int numMemDiskSegments = memDiskSegments.size(); + if (numMemDiskSegments > 0 && + ioSortFactor > onDiskMapOutputs.size()) { + + // If we reach here, it implies that we have less than io.sort.factor + // disk segments and this will be incremented by 1 (result of the + // memory segments merge). Since this total would still be + // <= io.sort.factor, we will not do any more intermediate merges, + // the merge of all these disk segments would be directly fed to the + // reduce method + + mergePhaseFinished = true; + // must spill to disk, but can't retain in-mem for intermediate merge + final Path outputPath = + mapOutputFile.getInputFileForWrite(srcTaskId, + inMemToDiskBytes).suffix( + Constants.MERGED_OUTPUT_PREFIX); + final TezRawKeyValueIterator rIter = TezMerger.merge(job, fs, + keyClass, valueClass, memDiskSegments, numMemDiskSegments, + tmpDir, comparator, nullProgressable, spilledRecordsCounter, null, null); + final Writer writer = new Writer(job, fs, outputPath, + keyClass, valueClass, codec, null); + try { + TezMerger.writeFile(rIter, writer, nullProgressable, job); + // add to list of final disk outputs. + onDiskMapOutputs.add(outputPath); + } catch (IOException e) { + if (null != outputPath) { + try { + fs.delete(outputPath, true); + } catch (IOException ie) { + // NOTHING + } + } + throw e; + } finally { + if (null != writer) { + writer.close(); + } + } + LOG.info("Merged " + numMemDiskSegments + " segments, " + + inMemToDiskBytes + " bytes to disk to satisfy " + + "reduce memory limit"); + inMemToDiskBytes = 0; + memDiskSegments.clear(); + } else if (inMemToDiskBytes != 0) { + LOG.info("Keeping " + numMemDiskSegments + " segments, " + + inMemToDiskBytes + " bytes in memory for " + + "intermediate, on-disk merge"); + } + } + + // segments on disk + List diskSegments = new ArrayList(); + long onDiskBytes = inMemToDiskBytes; + Path[] onDisk = onDiskMapOutputs.toArray(new Path[onDiskMapOutputs.size()]); + for (Path file : onDisk) { + onDiskBytes += fs.getFileStatus(file).getLen(); + LOG.debug("Disk file: " + file + " Length is " + + fs.getFileStatus(file).getLen()); + diskSegments.add(new Segment(job, fs, file, codec, false, + (file.toString().endsWith( + Constants.MERGED_OUTPUT_PREFIX) ? + null : mergedMapOutputsCounter) + )); + } + LOG.info("Merging " + onDisk.length + " files, " + + onDiskBytes + " bytes from disk"); + Collections.sort(diskSegments, new Comparator() { + public int compare(Segment o1, Segment o2) { + if (o1.getLength() == o2.getLength()) { + return 0; + } + return o1.getLength() < o2.getLength() ? -1 : 1; + } + }); + + // build final list of segments from merged backed by disk + in-mem + List finalSegments = new ArrayList(); + long inMemBytes = createInMemorySegments(inMemoryMapOutputs, + finalSegments, 0); + LOG.info("Merging " + finalSegments.size() + " segments, " + + inMemBytes + " bytes from memory into reduce"); + if (0 != onDiskBytes) { + final int numInMemSegments = memDiskSegments.size(); + diskSegments.addAll(0, memDiskSegments); + memDiskSegments.clear(); + TezRawKeyValueIterator diskMerge = TezMerger.merge( + job, fs, keyClass, valueClass, diskSegments, + ioSortFactor, numInMemSegments, tmpDir, comparator, + nullProgressable, false, spilledRecordsCounter, null, null); + diskSegments.clear(); + if (0 == finalSegments.size()) { + return diskMerge; + } + finalSegments.add(new Segment( + new RawKVIteratorReader(diskMerge, onDiskBytes), true)); + } + return TezMerger.merge(job, fs, keyClass, valueClass, + finalSegments, finalSegments.size(), tmpDir, + comparator, nullProgressable, spilledRecordsCounter, null, + null); + + } +} http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeThread.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeThread.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeThread.java new file mode 100644 index 0000000..d8a7722 --- /dev/null +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeThread.java @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tez.runtime.library.common.shuffle.impl; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Set; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +abstract class MergeThread extends Thread { + + private static final Log LOG = LogFactory.getLog(MergeThread.class); + + private volatile boolean inProgress = false; + private List inputs = new ArrayList(); + protected final MergeManager manager; + private final ExceptionReporter reporter; + private boolean closed = false; + private final int mergeFactor; + + public MergeThread(MergeManager manager, int mergeFactor, + ExceptionReporter reporter) { + this.manager = manager; + this.mergeFactor = mergeFactor; + this.reporter = reporter; + } + + public synchronized void close() throws InterruptedException { + closed = true; + waitForMerge(); + interrupt(); + } + + public synchronized boolean isInProgress() { + return inProgress; + } + + public synchronized void startMerge(Set inputs) { + if (!closed) { + inProgress = true; + this.inputs = new ArrayList(); + Iterator iter=inputs.iterator(); + for (int ctr = 0; iter.hasNext() && ctr < mergeFactor; ++ctr) { + this.inputs.add(iter.next()); + iter.remove(); + } + LOG.info(getName() + ": Starting merge with " + this.inputs.size() + + " segments, while ignoring " + inputs.size() + " segments"); + notifyAll(); + } + } + + public synchronized void waitForMerge() throws InterruptedException { + while (inProgress) { + wait(); + } + } + + public void run() { + while (true) { + try { + // Wait for notification to start the merge... + synchronized (this) { + while (!inProgress) { + wait(); + } + } + + // Merge + merge(inputs); + } catch (InterruptedException ie) { + return; + } catch(Throwable t) { + reporter.reportException(t); + return; + } finally { + synchronized (this) { + // Clear inputs + inputs = null; + inProgress = false; + notifyAll(); + } + } + } + } + + public abstract void merge(List inputs) + throws IOException, InterruptedException; +}