cassandra-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From alek...@apache.org
Subject [1/6] cassandra git commit: (cqlsh) further optimise COPY FROM
Date Tue, 15 Dec 2015 21:40:42 GMT
Repository: cassandra
Updated Branches:
  refs/heads/trunk a018bcb7d -> bab66dd1a


(cqlsh) further optimise COPY FROM

patch by Stefania Alborghetti; reviewed by Adam Holmberg for
CASSANDRA-9302


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/124f1bd2
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/124f1bd2
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/124f1bd2

Branch: refs/heads/trunk
Commit: 124f1bd2613e400f69f8369ada0ad15c28738530
Parents: 994250c
Author: Stefania Alborghetti <stefania.alborghetti@datastax.com>
Authored: Thu Oct 22 17:16:50 2015 +0800
Committer: Aleksey Yeschenko <aleksey@apache.org>
Committed: Tue Dec 15 21:03:31 2015 +0000

----------------------------------------------------------------------
 CHANGES.txt                |   4 +-
 bin/cqlsh                  | 285 ++-----------
 pylib/cqlshlib/copyutil.py | 910 ++++++++++++++++++++++++++++++++++------
 pylib/cqlshlib/util.py     |  19 +
 4 files changed, 838 insertions(+), 380 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/124f1bd2/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index 8e58703..90f1bca 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,12 +1,10 @@
 2.1.13
-<<<<<<< HEAD
+ * (cqlsh) further optimise COPY FROM (CASSANDRA-9302)
  * Allow CREATE TABLE WITH ID (CASSANDRA-9179)
  * Make Stress compiles within eclipse (CASSANDRA-10807)
  * Cassandra Daemon should print JVM arguments (CASSANDRA-10764)
  * Allow cancellation of index summary redistribution (CASSANDRA-8805)
-=======
  * sstableloader will fail if there are collections in the schema tables (CASSANDRA-10700)
->>>>>>> 5377183... stableloader will fail if there are collections in the schema tables
  * Disable reloading of GossipingPropertyFileSnitch (CASSANDRA-9474)
  * Fix Stress profile parsing on Windows (CASSANDRA-10808)
 

http://git-wip-us.apache.org/repos/asf/cassandra/blob/124f1bd2/bin/cqlsh
----------------------------------------------------------------------
diff --git a/bin/cqlsh b/bin/cqlsh
index e72624a..651420d 100755
--- a/bin/cqlsh
+++ b/bin/cqlsh
@@ -37,7 +37,6 @@ import ConfigParser
 import csv
 import getpass
 import locale
-import multiprocessing as mp
 import optparse
 import os
 import platform
@@ -48,7 +47,6 @@ import warnings
 
 from StringIO import StringIO
 from contextlib import contextmanager
-from functools import partial
 from glob import glob
 from uuid import UUID
 
@@ -110,10 +108,10 @@ except ImportError, e:
 
 from cassandra.auth import PlainTextAuthProvider
 from cassandra.cluster import Cluster, PagedResult
-from cassandra.metadata import protect_name, protect_names, protect_value
+from cassandra.metadata import protect_name, protect_names
 from cassandra.policies import WhiteListRoundRobinPolicy
-from cassandra.protocol import QueryMessage, ResultMessage
-from cassandra.query import SimpleStatement, ordered_dict_factory
+from cassandra.protocol import ResultMessage
+from cassandra.query import SimpleStatement, ordered_dict_factory, tuple_factory
 
 # cqlsh should run correctly when run out of a Cassandra source tree,
 # out of an unpacked Cassandra tarball, and after a proper package install.
@@ -334,7 +332,7 @@ cqlsh_extra_syntax_rules = r'''
 
 <copyOptionVal> ::= <identifier>
                   | <reserved_identifier>
-                  | <stringLiteral>
+                  | <term>
                   ;
 
 # avoiding just "DEBUG" so that this rule doesn't get treated as a terminal
@@ -412,17 +410,20 @@ def complete_copy_column_names(ctxt, cqlsh):
     return set(colnames[1:]) - set(existcols)
 
 
-COPY_OPTIONS = ['DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'NULL', 'ENCODING',
-                'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS']
+COPY_COMMON_OPTIONS = ['DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'NULL',
+                       'MAXATTEMPTS', 'REPORTFREQUENCY']
+COPY_FROM_OPTIONS = ['CHUNKSIZE', 'INGESTRATE', 'MAXBATCHSIZE', 'MINBATCHSIZE']
+COPY_TO_OPTIONS = ['ENCODING', 'TIMEFORMAT', 'PAGESIZE', 'PAGETIMEOUT', 'MAXREQUESTS']
 
 
 @cqlsh_syntax_completer('copyOption', 'optnames')
 def complete_copy_options(ctxt, cqlsh):
     optnames = map(str.upper, ctxt.get_binding('optnames', ()))
     direction = ctxt.get_binding('dir').upper()
-    opts = set(COPY_OPTIONS) - set(optnames)
     if direction == 'FROM':
-        opts -= set(['ENCODING', 'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS'])
+        opts = set(COPY_COMMON_OPTIONS + COPY_FROM_OPTIONS) - set(optnames)
+    elif direction == 'TO':
+        opts = set(COPY_COMMON_OPTIONS + COPY_TO_OPTIONS) - set(optnames)
     return opts
 
 
@@ -1520,12 +1521,18 @@ class Shell(cmd.Cmd):
           ESCAPE='\'              - character to appear before the QUOTE char when quoted
           HEADER=false            - whether to ignore the first line
           NULL=''                 - string that represents a null value
-          ENCODING='utf8'         - encoding for CSV output (COPY TO only)
-          TIMEFORMAT=             - timestamp strftime format (COPY TO only)
+          ENCODING='utf8'         - encoding for CSV output (COPY TO)
+          TIMEFORMAT=             - timestamp strftime format (COPY TO)
             '%Y-%m-%d %H:%M:%S%z'   defaults to time_format value in cqlshrc
-          PAGESIZE='1000'         - the page size for fetching results (COPY TO only)
-          PAGETIMEOUT=10          - the page timeout for fetching results (COPY TO only)
-          MAXATTEMPTS='5'         - the maximum number of attempts for errors (COPY TO only)
+          MAXREQUESTS=6           - the maximum number of requests each worker process can work on in parallel (COPY TO)
+          PAGESIZE=1000           - the page size for fetching results (COPY TO)
+          PAGETIMEOUT=10          - the page timeout for fetching results (COPY TO)
+          MAXATTEMPTS=5           - the maximum number of attempts for errors
+          CHUNKSIZE=1000          - the size of chunks passed to worker processes (COPY FROM)
+          INGESTRATE=100000       - an approximate ingest rate in rows per second (COPY FROM)
+          MAXBATCHSIZE=20         - the maximum size of an import batch (COPY FROM)
+          MINBATCHSIZE=2          - the minimum size of an import batch (COPY FROM)
+          REPORTFREQUENCY=0.25    - the frequency with which we display status updates in seconds
 
         When entering CSV data on STDIN, you can use the sequence "\."
         on a line by itself to end the data input.
@@ -1571,253 +1578,11 @@ class Shell(cmd.Cmd):
     def perform_csv_import(self, ks, cf, columns, fname, opts):
         csv_options, dialect_options, unrecognized_options = copyutil.parse_options(self, opts)
         if unrecognized_options:
-            self.printerr('Unrecognized COPY FROM options: %s'
-                          % ', '.join(unrecognized_options.keys()))
+            self.printerr('Unrecognized COPY FROM options: %s' % ', '.join(unrecognized_options.keys()))
             return 0
-        nullval, header = csv_options['nullval'], csv_options['header']
 
-        if fname is None:
-            do_close = False
-            print "[Use \. on a line by itself to end input]"
-            linesource = self.use_stdin_reader(prompt='[copy] ', until=r'\.')
-        else:
-            do_close = True
-            try:
-                linesource = open(fname, 'rb')
-            except IOError, e:
-                self.printerr("Can't open %r for reading: %s" % (fname, e))
-                return 0
-
-        current_record = None
-        processes, pipes = [], [],
-        try:
-            if header:
-                linesource.next()
-            reader = csv.reader(linesource, **dialect_options)
-
-            num_processes = copyutil.get_num_processes(cap=4)
-
-            for i in range(num_processes):
-                parent_conn, child_conn = mp.Pipe()
-                pipes.append(parent_conn)
-                proc_args = (child_conn, ks, cf, columns, nullval)
-                processes.append(mp.Process(target=self.multiproc_import, args=proc_args))
-
-            for process in processes:
-                process.start()
-
-            meter = copyutil.RateMeter(10000)
-            for current_record, row in enumerate(reader, start=1):
-                # write to the child process
-                pipes[current_record % num_processes].send((current_record, row))
-
-                # update the progress and current rate periodically
-                meter.increment()
-
-                # check for any errors reported by the children
-                if (current_record % 100) == 0:
-                    if self._check_import_processes(current_record, pipes):
-                        # no errors seen, continue with outer loop
-                        continue
-                    else:
-                        # errors seen, break out of outer loop
-                        break
-        except Exception, exc:
-            if current_record is None:
-                # we failed before we started
-                self.printerr("\nError starting import process:\n")
-                self.printerr(str(exc))
-                if self.debug:
-                    traceback.print_exc()
-            else:
-                self.printerr("\n" + str(exc))
-                self.printerr("\nAborting import at record #%d. "
-                              "Previously inserted records and some records after "
-                              "this number may be present."
-                              % (current_record,))
-                if self.debug:
-                    traceback.print_exc()
-        finally:
-            # send a message that indicates we're done
-            for pipe in pipes:
-                pipe.send((None, None))
-
-            for process in processes:
-                process.join()
-
-            self._check_import_processes(current_record, pipes)
-
-            for pipe in pipes:
-                pipe.close()
-
-            if do_close:
-                linesource.close()
-            elif self.tty:
-                print
-
-        return current_record
-
-    def _check_import_processes(self, current_record, pipes):
-        for pipe in pipes:
-            if pipe.poll():
-                try:
-                    (record_num, error) = pipe.recv()
-                    self.printerr("\n" + str(error))
-                    self.printerr(
-                        "Aborting import at record #%d. "
-                        "Previously inserted records are still present, "
-                        "and some records after that may be present as well."
-                        % (record_num,))
-                    return False
-                except EOFError:
-                    # pipe is closed, nothing to read
-                    self.printerr("\nChild process died without notification, "
-                                  "aborting import at record #%d. Previously "
-                                  "inserted records are probably still present, "
-                                  "and some records after that may be present "
-                                  "as well." % (current_record,))
-                    return False
-        return True
-
-    def multiproc_import(self, pipe, ks, cf, columns, nullval):
-        """
-        This method is where child processes start when doing a COPY FROM
-        operation.  The child process will open one connection to the node and
-        interact directly with the connection, bypassing most of the driver
-        code.  Because we don't need retries, connection pooling, thread safety,
-        and other fancy features, this is okay.
-        """
-
-        # open a new connection for this subprocess
-        new_cluster = Cluster(
-            contact_points=(self.hostname,),
-            port=self.port,
-            cql_version=self.conn.cql_version,
-            protocol_version=DEFAULT_PROTOCOL_VERSION,
-            auth_provider=self.auth_provider,
-            ssl_options=sslhandling.ssl_settings(self.hostname, CONFIG_FILE) if self.ssl else None,
-            load_balancing_policy=WhiteListRoundRobinPolicy([self.hostname]),
-            compression=None,
-            connect_timeout=self.conn.connect_timeout)
-        session = new_cluster.connect(self.keyspace)
-        conn = session._pools.values()[0]._connection
-
-        # pre-build as much of the query as we can
-        table_meta = self.get_table_meta(ks, cf)
-        pk_cols = [col.name for col in table_meta.primary_key]
-        cqltypes = [table_meta.columns[name].typestring for name in columns]
-        pk_indexes = [columns.index(col.name) for col in table_meta.primary_key]
-        is_counter_table = ("counter" in cqltypes)
-        if is_counter_table:
-            query = 'Update %s.%s SET %%s WHERE %%s' % (
-                protect_name(ks),
-                protect_name(cf))
-        else:
-            query = 'INSERT INTO %s.%s (%s) VALUES (%%s)' % (
-                protect_name(ks),
-                protect_name(cf),
-                ', '.join(protect_names(columns)))
-
-        # we need to handle some types specially
-        should_escape = [t in ('ascii', 'text', 'timestamp', 'date', 'time', 'inet') for t in cqltypes]
-
-        insert_timestamp = int(time.time() * 1e6)
-
-        def callback(record_num, response):
-            # This is the callback we register for all inserts.  Because this
-            # is run on the event-loop thread, we need to hold a lock when
-            # adjusting in_flight.
-            with conn.lock:
-                conn.in_flight -= 1
-
-            if not isinstance(response, ResultMessage):
-                # It's an error. Notify the parent process and let it send
-                # a stop signal to all child processes (including this one).
-                pipe.send((record_num, str(response)))
-                if isinstance(response, Exception) and self.debug:
-                    traceback.print_exc(response)
-
-        current_record = 0
-        insert_num = 0
-        try:
-            while True:
-                # To avoid totally maxing out the connection,
-                # defer to the reactor thread when we're close
-                # to capacity
-                if conn.in_flight > (conn.max_request_id * 0.9):
-                    conn._readable = True
-                    time.sleep(0.05)
-                    continue
-
-                try:
-                    (current_record, row) = pipe.recv()
-                except EOFError:
-                    # the pipe was closed and there's nothing to receive
-                    sys.stdout.write('Failed to read from pipe:\n\n')
-                    sys.stdout.flush()
-                    conn._writable = True
-                    conn._readable = True
-                    break
-
-                # see if the parent process has signaled that we are done
-                if (current_record, row) == (None, None):
-                    conn._writable = True
-                    conn._readable = True
-                    pipe.close()
-                    break
-
-                # format the values in the row
-                for i, value in enumerate(row):
-                    if value != nullval:
-                        if should_escape[i]:
-                            row[i] = protect_value(value)
-                    elif i in pk_indexes:
-                        # By default, nullval is an empty string. See CASSANDRA-7792 for details.
-                        message = "Cannot insert null value for primary key column '%s'." % (pk_cols[i],)
-                        if nullval == '':
-                            message += " If you want to insert empty strings, consider using " \
-                                       "the WITH NULL=<marker> option for COPY."
-                        pipe.send((current_record, message))
-                        return
-                    else:
-                        row[i] = 'null'
-                if is_counter_table:
-                    where_clause = []
-                    set_clause = []
-                    for i, value in enumerate(row):
-                        if i in pk_indexes:
-                            where_clause.append("%s=%s" % (columns[i], value))
-                        else:
-                            set_clause.append("%s=%s+%s" % (columns[i], columns[i], value))
-                    full_query = query % (','.join(set_clause), ' AND '.join(where_clause))
-                else:
-                    full_query = query % (','.join(row),)
-                query_message = QueryMessage(
-                    full_query, self.consistency_level, serial_consistency_level=None,
-                    fetch_size=None, paging_state=None, timestamp=insert_timestamp)
-
-                request_id = conn.get_request_id()
-                conn.send_msg(query_message, request_id=request_id, cb=partial(callback, current_record))
-
-                with conn.lock:
-                    conn.in_flight += 1
-
-                # every 50 records, clear the pending writes queue and read
-                # any responses we have
-                if insert_num % 50 == 0:
-                    conn._writable = True
-                    conn._readable = True
-
-                insert_num += 1
-        except Exception, exc:
-            pipe.send((current_record, exc))
-        finally:
-            # wait for any pending requests to finish
-            while conn.in_flight > 0:
-                conn._readable = True
-                time.sleep(0.01)
-
-            new_cluster.shutdown()
+        return copyutil.ImportTask(self, ks, cf, columns, fname, csv_options, dialect_options,
+                                   DEFAULT_PROTOCOL_VERSION, CONFIG_FILE).run()
 
     def perform_csv_export(self, ks, cf, columns, fname, opts):
         csv_options, dialect_options, unrecognized_options = copyutil.parse_options(self, opts)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/124f1bd2/pylib/cqlshlib/copyutil.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/copyutil.py b/pylib/cqlshlib/copyutil.py
index 8534b98..f699e64 100644
--- a/pylib/cqlshlib/copyutil.py
+++ b/pylib/cqlshlib/copyutil.py
@@ -19,23 +19,32 @@ import json
 import multiprocessing as mp
 import os
 import Queue
+import random
+import re
+import struct
 import sys
 import time
 import traceback
 
-from StringIO import StringIO
+from calendar import timegm
+from collections import defaultdict, deque, namedtuple
+from decimal import Decimal
 from random import randrange
+from StringIO import StringIO
 from threading import Lock
+from uuid import UUID
 
 from cassandra.cluster import Cluster
+from cassandra.cqltypes import ReversedType, UserType
 from cassandra.metadata import protect_name, protect_names
-from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, TokenAwarePolicy
-from cassandra.query import tuple_factory
+from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, TokenAwarePolicy, DCAwareRoundRobinPolicy
+from cassandra.query import BatchStatement, BatchType, SimpleStatement, tuple_factory
+from cassandra.util import Date, Time
 
-
-import sslhandling
+from cql3handling import CqlRuleSet
 from displaying import NO_COLOR_MAP
 from formatting import format_value_default, EMPTY, get_formatter
+from sslhandling import ssl_settings
 
 
 def parse_options(shell, opts):
@@ -60,13 +69,18 @@ def parse_options(shell, opts):
     csv_options['nullval'] = opts.pop('null', '')
     csv_options['header'] = bool(opts.pop('header', '').lower() == 'true')
     csv_options['encoding'] = opts.pop('encoding', 'utf8')
-    csv_options['jobs'] = int(opts.pop('jobs', 12))
+    csv_options['maxrequests'] = int(opts.pop('maxrequests', 6))
     csv_options['pagesize'] = int(opts.pop('pagesize', 1000))
     # by default the page timeout is 10 seconds per 1000 entries in the page size or 10 seconds if pagesize is smaller
     csv_options['pagetimeout'] = int(opts.pop('pagetimeout', max(10, 10 * (csv_options['pagesize'] / 1000))))
     csv_options['maxattempts'] = int(opts.pop('maxattempts', 5))
     csv_options['dtformats'] = opts.pop('timeformat', shell.display_time_format)
     csv_options['float_precision'] = shell.display_float_precision
+    csv_options['chunksize'] = int(opts.pop('chunksize', 1000))
+    csv_options['ingestrate'] = int(opts.pop('ingestrate', 100000))
+    csv_options['maxbatchsize'] = int(opts.pop('maxbatchsize', 20))
+    csv_options['minbatchsize'] = int(opts.pop('minbatchsize', 2))
+    csv_options['reportfrequency'] = float(opts.pop('reportfrequency', 0.25))
 
     return csv_options, dialect_options, opts
 
@@ -86,9 +100,9 @@ def get_num_processes(cap):
         return 1
 
 
-class ExportTask(object):
+class CopyTask(object):
     """
-    A class that exports data to .csv by instantiating one or more processes that work in parallel (ExportProcess).
+    A base class for ImportTask and ExportTask
     """
     def __init__(self, shell, ks, cf, columns, fname, csv_options, dialect_options, protocol_version, config_file):
         self.shell = shell
@@ -101,6 +115,55 @@ class ExportTask(object):
         self.protocol_version = protocol_version
         self.config_file = config_file
 
+        self.processes = []
+        self.inmsg = mp.Queue()
+        self.outmsg = mp.Queue()
+
+    def close(self):
+        for process in self.processes:
+            process.terminate()
+
+        self.inmsg.close()
+        self.outmsg.close()
+
+    def num_live_processes(self):
+        return sum(1 for p in self.processes if p.is_alive())
+
+    def make_params(self):
+        """
+        Return a dictionary of parameters to be used by the worker processes.
+        On Windows this dictionary must be pickle-able.
+
+        inmsg is the message queue flowing from parent to child process, so outmsg from the parent point
+        of view and, vice-versa,  outmsg is the message queue flowing from child to parent, so inmsg
+        from the parent point of view, hence the two are swapped below.
+        """
+        shell = self.shell
+        return dict(inmsg=self.outmsg,  # see comment above
+                    outmsg=self.inmsg,  # see comment above
+                    ks=self.ks,
+                    cf=self.cf,
+                    columns=self.columns,
+                    csv_options=self.csv_options,
+                    dialect_options=self.dialect_options,
+                    consistency_level=shell.consistency_level,
+                    connect_timeout=shell.conn.connect_timeout,
+                    hostname=shell.hostname,
+                    port=shell.port,
+                    ssl=shell.ssl,
+                    auth_provider=shell.auth_provider,
+                    cql_version=shell.conn.cql_version,
+                    config_file=self.config_file,
+                    protocol_version=self.protocol_version,
+                    debug=shell.debug
+                    )
+
+
+class ExportTask(CopyTask):
+    """
+    A class that exports data to .csv by instantiating one or more processes that work in parallel (ExportProcess).
+    """
+
     def run(self):
         """
         Initiates the export by creating the processes.
@@ -125,25 +188,18 @@ class ExportTask(object):
 
         ranges = self.get_ranges()
         num_processes = get_num_processes(cap=min(16, len(ranges)))
+        params = self.make_params()
 
-        inmsg = mp.Queue()
-        outmsg = mp.Queue()
-        processes = []
         for i in xrange(num_processes):
-            process = ExportProcess(outmsg, inmsg, self.ks, self.cf, self.columns, self.dialect_options,
-                                    self.csv_options, shell.debug, shell.port, shell.conn.cql_version,
-                                    shell.auth_provider, shell.ssl, self.protocol_version, self.config_file)
+            self.processes.append(ExportProcess(params))
+
+        for process in self.processes:
             process.start()
-            processes.append(process)
 
         try:
-            return self.check_processes(csvdest, ranges, inmsg, outmsg, processes)
+            return self.check_processes(csvdest, ranges)
         finally:
-            for process in processes:
-                process.terminate()
-
-            inmsg.close()
-            outmsg.close()
+            self.close()
             if do_close:
                 csvdest.close()
 
@@ -183,9 +239,9 @@ class ExportTask(object):
 
             hosts = []
             for host in replicas:
-                if host.datacenter == local_dc:
+                if host.is_up and host.datacenter == local_dc:
                     hosts.append(host.address)
-            if len(hosts) == 0:
+            if not hosts:
                 hosts.append(hostname)  # fallback to default host if no replicas in current dc
             ranges[(previous, token.value)] = make_range(hosts)
             previous_previous = previous
@@ -194,7 +250,7 @@ class ExportTask(object):
         #  If the ring is empty we get the entire ring from the
         #  host we are currently connected to, otherwise for the last ring interval
         #  we query the same replicas that hold the last token in the ring
-        if len(ranges) == 0:
+        if not ranges:
             ranges[(None, None)] = make_range([hostname])
         else:
             ranges[(previous, None)] = ranges[(previous_previous, previous)].copy()
@@ -217,32 +273,32 @@ class ExportTask(object):
         else:
             return None
 
-    @staticmethod
-    def send_work(ranges, tokens_to_send, queue):
+    def send_work(self, ranges, tokens_to_send):
         for token_range in tokens_to_send:
-            queue.put((token_range, ranges[token_range]))
+            self.outmsg.put((token_range, ranges[token_range]))
             ranges[token_range]['attempts'] += 1
 
-    def check_processes(self, csvdest, ranges, inmsg, outmsg, processes):
+    def check_processes(self, csvdest, ranges):
         """
         Here we monitor all child processes by collecting their results
         or any errors. We terminate when we have processed all the ranges or when there
         are no more processes.
         """
         shell = self.shell
-        meter = RateMeter(10000)
-        total_jobs = len(ranges)
+        processes = self.processes
+        meter = RateMeter(update_interval=self.csv_options['reportfrequency'])
+        total_requests = len(ranges)
         max_attempts = self.csv_options['maxattempts']
 
-        self.send_work(ranges, ranges.keys(), outmsg)
+        self.send_work(ranges, ranges.keys())
 
         num_processes = len(processes)
         succeeded = 0
         failed = 0
-        while (failed + succeeded) < total_jobs and self.num_live_processes(processes) == num_processes:
+        while (failed + succeeded) < total_requests and self.num_live_processes() == num_processes:
             try:
-                token_range, result = inmsg.get(timeout=1.0)
-                if token_range is None and result is None:  # a job has finished
+                token_range, result = self.inmsg.get(timeout=1.0)
+                if token_range is None and result is None:  # a request has finished
                     succeeded += 1
                 elif isinstance(result, Exception):  # an error occurred
                     if token_range is None:  # the entire process failed
@@ -253,7 +309,7 @@ class ExportTask(object):
                         if ranges[token_range]['attempts'] < max_attempts and ranges[token_range]['rows'] == 0:
                             shell.printerr('Error for %s: %s (will try again later attempt %d of %d)'
                                            % (token_range, result, ranges[token_range]['attempts'], max_attempts))
-                            self.send_work(ranges, [token_range], outmsg)
+                            self.send_work(ranges, [token_range])
                         else:
                             shell.printerr('Error for %s: %s (permanently given up after %d rows and %d attempts)'
                                            % (token_range, result, ranges[token_range]['rows'],
@@ -267,34 +323,257 @@ class ExportTask(object):
             except Queue.Empty:
                 pass
 
-        if self.num_live_processes(processes) < len(processes):
+        if self.num_live_processes() < len(processes):
             for process in processes:
                 if not process.is_alive():
                     shell.printerr('Child process %d died with exit code %d' % (process.pid, process.exitcode))
 
-        if succeeded < total_jobs:
+        if succeeded < total_requests:
             shell.printerr('Exported %d ranges out of %d total ranges, some records might be missing'
-                           % (succeeded, total_jobs))
+                           % (succeeded, total_requests))
 
         return meter.get_total_records()
 
+
+class ImportReader(object):
+    """
+    A wrapper around a csv reader to keep track of when we have
+    exhausted reading input records.
+    """
+    def __init__(self, linesource, chunksize, dialect_options):
+        self.linesource = linesource
+        self.chunksize = chunksize
+        self.reader = csv.reader(linesource, **dialect_options)
+        self.exhausted = False
+
+    def read_rows(self):
+        if self.exhausted:
+            return []
+
+        rows = list(next(self.reader) for _ in xrange(self.chunksize))
+        self.exhausted = len(rows) < self.chunksize
+        return rows
+
+
+class ImportTask(CopyTask):
+    """
+    A class to import data from .csv by instantiating one or more processes
+    that work in parallel (ImportProcess).
+    """
+    def __init__(self, shell, ks, cf, columns, fname, csv_options, dialect_options, protocol_version, config_file):
+        CopyTask.__init__(self, shell, ks, cf, columns, fname,
+                          csv_options, dialect_options, protocol_version, config_file)
+
+        self.num_processes = get_num_processes(cap=4)
+        self.chunk_size = csv_options['chunksize']
+        self.ingest_rate = csv_options['ingestrate']
+        self.max_attempts = csv_options['maxattempts']
+        self.header = self.csv_options['header']
+        self.table_meta = self.shell.get_table_meta(self.ks, self.cf)
+        self.batch_id = 0
+        self.receive_meter = RateMeter(update_interval=csv_options['reportfrequency'])
+        self.send_meter = RateMeter(update_interval=1, log=False)
+        self.retries = deque([])
+        self.failed = 0
+        self.succeeded = 0
+        self.sent = 0
+
+    def run(self):
+        shell = self.shell
+
+        if self.fname is None:
+            do_close = False
+            print "[Use \. on a line by itself to end input]"
+            linesource = shell.use_stdin_reader(prompt='[copy] ', until=r'\.')
+        else:
+            do_close = True
+            try:
+                linesource = open(self.fname, 'rb')
+            except IOError, e:
+                shell.printerr("Can't open %r for reading: %s" % (self.fname, e))
+                return 0
+
+        try:
+            if self.header:
+                linesource.next()
+
+            reader = ImportReader(linesource, self.chunk_size, self.dialect_options)
+            params = self.make_params()
+
+            for i in range(self.num_processes):
+                self.processes.append(ImportProcess(params))
+
+            for process in self.processes:
+                process.start()
+
+            return self.process_records(reader)
+
+        except Exception, exc:
+            shell.printerr(str(exc))
+            if shell.debug:
+                traceback.print_exc()
+            return 0
+        finally:
+            self.close()
+            if do_close:
+                linesource.close()
+            elif shell.tty:
+                print
+
+    def process_records(self, reader):
+        """
+        Keep on running until we have stuff to receive or send and until all processes are running.
+        Send data (batches or retries) up to the max ingest rate. If we are waiting for stuff to
+        receive check the incoming queue.
+        """
+        while (self.has_more_to_send(reader) or self.has_more_to_receive()) and self.all_processes_running():
+            if self.has_more_to_send(reader):
+                if self.send_meter.current_record <= self.ingest_rate:
+                    self.send_batches(reader)
+                else:
+                    self.send_meter.maybe_update()
+
+            if self.has_more_to_receive():
+                self.receive()
+
+        if self.succeeded < self.sent:
+            self.shell.printerr("Failed to process %d batches" % (self.sent - self.succeeded))
+
+        return self.receive_meter.get_total_records()
+
+    def has_more_to_receive(self):
+        return (self.succeeded + self.failed) < self.sent
+
+    def has_more_to_send(self, reader):
+        return (not reader.exhausted) or self.retries
+
+    def all_processes_running(self):
+        return self.num_live_processes() == self.num_processes
+
+    def receive(self):
+        shell = self.shell
+        start_time = time.time()
+
+        while time.time() - start_time < 0.01:  # 10 millis
+            try:
+                batch, err = self.inmsg.get(timeout=0.001)  # 1 millisecond
+
+                if err is None:
+                    self.succeeded += batch['imported']
+                    self.receive_meter.increment(batch['imported'])
+                else:
+                    err = str(err)
+
+                    if err.startswith('ValueError') or err.startswith('TypeError') or err.startswith('IndexError') \
+                            or batch['attempts'] >= self.max_attempts:
+                        shell.printerr("Failed to import %d rows: %s -  given up after %d attempts"
+                                       % (len(batch['rows']), err, batch['attempts']))
+                        self.failed += len(batch['rows'])
+                    else:
+                        shell.printerr("Failed to import %d rows: %s -  will retry later, attempt %d of %d"
+                                       % (len(batch['rows']), err, batch['attempts'],
+                                          self.max_attempts))
+                        self.retries.append(self.reset_batch(batch))
+            except Queue.Empty:
+                break
+
+    def send_batches(self, reader):
+        """
+        Send batches to the queue until we have exceeded the ingest rate. In the export case we queue
+        everything and let the worker processes throttle using max_requests, here we throttle
+        in the parent process because of memory usage concerns.
+
+        When we have finished reading the csv file, then send any retries.
+        """
+        while self.send_meter.current_record <= self.ingest_rate:
+            if not reader.exhausted:
+                rows = reader.read_rows()
+                if rows:
+                    self.sent += self.send_batch(self.new_batch(rows))
+            elif self.retries:
+                batch = self.retries.popleft()
+                self.send_batch(batch)
+            else:
+                break
+
+    def send_batch(self, batch):
+        batch['attempts'] += 1
+        num_rows = len(batch['rows'])
+        self.send_meter.increment(num_rows)
+        self.outmsg.put(batch)
+        return num_rows
+
+    def new_batch(self, rows):
+        self.batch_id += 1
+        return self.make_batch(self.batch_id, rows, 0)
+
+    @staticmethod
+    def reset_batch(batch):
+        batch['imported'] = 0
+        return batch
+
     @staticmethod
-    def num_live_processes(processes):
-        return sum(1 for p in processes if p.is_alive())
+    def make_batch(batch_id, rows, attempts):
+        return {'id': batch_id, 'rows': rows, 'attempts': attempts, 'imported': 0}
+
+
+class ChildProcess(mp.Process):
+    """
+    An child worker process, this is for common functionality between ImportProcess and ExportProcess.
+    """
+
+    def __init__(self, params, target):
+        mp.Process.__init__(self, target=target)
+        self.inmsg = params['inmsg']
+        self.outmsg = params['outmsg']
+        self.ks = params['ks']
+        self.cf = params['cf']
+        self.columns = params['columns']
+        self.debug = params['debug']
+        self.port = params['port']
+        self.hostname = params['hostname']
+        self.consistency_level = params['consistency_level']
+        self.connect_timeout = params['connect_timeout']
+        self.cql_version = params['cql_version']
+        self.auth_provider = params['auth_provider']
+        self.ssl = params['ssl']
+        self.protocol_version = params['protocol_version']
+        self.config_file = params['config_file']
+
+        # Here we inject some failures for testing purposes, only if this environment variable is set
+        if os.environ.get('CQLSH_COPY_TEST_FAILURES', ''):
+            self.test_failures = json.loads(os.environ.get('CQLSH_COPY_TEST_FAILURES', ''))
+        else:
+            self.test_failures = None
+
+    def printmsg(self, text):
+        if self.debug:
+            sys.stderr.write(text + os.linesep)
+
+    def close(self):
+        self.printmsg("Closing queues...")
+        self.inmsg.close()
+        self.outmsg.close()
 
 
 class ExpBackoffRetryPolicy(RetryPolicy):
     """
-    A retry policy with exponential back-off for read timeouts,
-    see ExportProcess.
+    A retry policy with exponential back-off for read timeouts and write timeouts
     """
-    def __init__(self, export_process):
+    def __init__(self, parent_process):
         RetryPolicy.__init__(self)
-        self.max_attempts = export_process.csv_options['maxattempts']
-        self.printmsg = lambda txt: export_process.printmsg(txt)
+        self.max_attempts = parent_process.max_attempts
+        self.printmsg = parent_process.printmsg
 
     def on_read_timeout(self, query, consistency, required_responses,
                         received_responses, data_retrieved, retry_num):
+        return self._handle_timeout(consistency, retry_num)
+
+    def on_write_timeout(self, query, consistency, write_type,
+                         required_responses, received_responses, retry_num):
+        return self._handle_timeout(consistency, retry_num)
+
+    def _handle_timeout(self, consistency, retry_num):
         delay = self.backoff(retry_num)
         if delay > 0:
             self.printmsg("Timeout received, retrying after %d seconds" % (delay))
@@ -327,7 +606,7 @@ class ExpBackoffRetryPolicy(RetryPolicy):
 class ExportSession(object):
     """
     A class for connecting to a cluster and storing the number
-    of jobs that this connection is processing. It wraps the methods
+    of requests that this connection is processing. It wraps the methods
     for executing a query asynchronously and for shutting down the
     connection to the cluster.
     """
@@ -342,20 +621,20 @@ class ExportSession(object):
 
         self.cluster = cluster
         self.session = session
-        self.jobs = 1
+        self.requests = 1
         self.lock = Lock()
 
-    def add_job(self):
+    def add_request(self):
         with self.lock:
-            self.jobs += 1
+            self.requests += 1
 
-    def complete_job(self):
+    def complete_request(self):
         with self.lock:
-            self.jobs -= 1
+            self.requests -= 1
 
-    def num_jobs(self):
+    def num_requests(self):
         with self.lock:
-            return self.jobs
+            return self.requests
 
     def execute_async(self, query):
         return self.session.execute_async(query)
@@ -364,48 +643,26 @@ class ExportSession(object):
         self.cluster.shutdown()
 
 
-class ExportProcess(mp.Process):
+class ExportProcess(ChildProcess):
     """
     An child worker process for the export task, ExportTask.
     """
 
-    def __init__(self, inmsg, outmsg, ks, cf, columns, dialect_options, csv_options,
-                 debug, port, cql_version, auth_provider, ssl, protocol_version, config_file):
-        mp.Process.__init__(self, target=self.run)
-        self.inmsg = inmsg
-        self.outmsg = outmsg
-        self.ks = ks
-        self.cf = cf
-        self.columns = columns
-        self.dialect_options = dialect_options
+    def __init__(self, params):
+        ChildProcess.__init__(self, params=params, target=self.run)
+        self.dialect_options = params['dialect_options']
         self.hosts_to_sessions = dict()
 
-        self.debug = debug
-        self.port = port
-        self.cql_version = cql_version
-        self.auth_provider = auth_provider
-        self.ssl = ssl
-        self.protocol_version = protocol_version
-        self.config_file = config_file
-
+        csv_options = params['csv_options']
         self.encoding = csv_options['encoding']
         self.time_format = csv_options['dtformats']
         self.float_precision = csv_options['float_precision']
         self.nullval = csv_options['nullval']
-        self.maxjobs = csv_options['jobs']
+        self.max_attempts = csv_options['maxattempts']
+        self.max_requests = csv_options['maxrequests']
         self.csv_options = csv_options
         self.formatters = dict()
 
-        # Here we inject some failures for testing purposes, only if this environment variable is set
-        if os.environ.get('CQLSH_COPY_TEST_FAILURES', ''):
-            self.test_failures = json.loads(os.environ.get('CQLSH_COPY_TEST_FAILURES', ''))
-        else:
-            self.test_failures = None
-
-    def printmsg(self, text):
-        if self.debug:
-            sys.stderr.write(text + os.linesep)
-
     def run(self):
         try:
             self.inner_run()
@@ -423,12 +680,12 @@ class ExportProcess(mp.Process):
         We terminate when the inbound queue is closed.
         """
         while True:
-            if self.num_jobs() > self.maxjobs:
+            if self.num_requests() > self.max_requests:
                 time.sleep(0.001)  # 1 millisecond
                 continue
 
             token_range, info = self.inmsg.get()
-            self.start_job(token_range, info)
+            self.start_request(token_range, info)
 
     def report_error(self, err, token_range=None):
         if isinstance(err, str):
@@ -443,7 +700,7 @@ class ExportProcess(mp.Process):
         self.printmsg(msg)
         self.outmsg.put((token_range, Exception(msg)))
 
-    def start_job(self, token_range, info):
+    def start_request(self, token_range, info):
         """
         Begin querying a range by executing an async query that
         will later on invoke the callbacks attached in attach_callbacks.
@@ -454,14 +711,14 @@ class ExportProcess(mp.Process):
         future = session.execute_async(query)
         self.attach_callbacks(token_range, future, session)
 
-    def num_jobs(self):
-        return sum(session.num_jobs() for session in self.hosts_to_sessions.values())
+    def num_requests(self):
+        return sum(session.num_requests() for session in self.hosts_to_sessions.values())
 
     def get_session(self, hosts):
         """
         We select a host to connect to. If we have no connections to one of the hosts
         yet then we select this host, else we pick the one with the smallest number
-        of jobs.
+        of requests.
 
         :return: An ExportSession connected to the chosen host.
         """
@@ -474,19 +731,18 @@ class ExportProcess(mp.Process):
                 cql_version=self.cql_version,
                 protocol_version=self.protocol_version,
                 auth_provider=self.auth_provider,
-                ssl_options=sslhandling.ssl_settings(host, self.config_file) if self.ssl else None,
+                ssl_options=ssl_settings(host, self.config_file) if self.ssl else None,
                 load_balancing_policy=TokenAwarePolicy(WhiteListRoundRobinPolicy(hosts)),
                 default_retry_policy=ExpBackoffRetryPolicy(self),
-                compression=None,
-                executor_threads=max(2, self.csv_options['jobs'] / 2))
+                compression=None)
 
             session = ExportSession(new_cluster, self)
             self.hosts_to_sessions[host] = session
             return session
         else:
-            host = min(hosts, key=lambda h: self.hosts_to_sessions[h].jobs)
+            host = min(hosts, key=lambda h: self.hosts_to_sessions[h].requests)
             session = self.hosts_to_sessions[host]
-            session.add_job()
+            session.add_request()
             return session
 
     def attach_callbacks(self, token_range, future, session):
@@ -497,16 +753,16 @@ class ExportProcess(mp.Process):
             else:
                 self.write_rows_to_csv(token_range, rows)
                 self.outmsg.put((None, None))
-                session.complete_job()
+                session.complete_request()
 
         def err_callback(err):
             self.report_error(err, token_range)
-            session.complete_job()
+            session.complete_request()
 
         future.add_callbacks(callback=result_callback, errback=err_callback)
 
     def write_rows_to_csv(self, token_range, rows):
-        if len(rows) == 0:
+        if not rows:
             return  # no rows in this range
 
         try:
@@ -537,12 +793,9 @@ class ExportProcess(mp.Process):
                          float_precision=self.float_precision, nullval=self.nullval, quote=False)
 
     def close(self):
-        self.printmsg("Export process terminating...")
-        self.inmsg.close()
-        self.outmsg.close()
+        ChildProcess.close(self)
         for session in self.hosts_to_sessions.values():
             session.shutdown()
-        self.printmsg("Export process terminated")
 
     def prepare_query(self, partition_key, token_range, attempts):
         """
@@ -598,26 +851,439 @@ class ExportProcess(mp.Process):
         return query
 
 
+class ImportConversion(object):
+    """
+    A class for converting strings to values when importing from csv, used by ImportProcess,
+    the parent.
+    """
+    def __init__(self, parent, table_meta, statement):
+        self.ks = parent.ks
+        self.cf = parent.cf
+        self.columns = parent.columns
+        self.nullval = parent.nullval
+        self.printmsg = parent.printmsg
+        self.table_meta = table_meta
+        self.primary_key_indexes = [self.columns.index(col.name) for col in self.table_meta.primary_key]
+        self.partition_key_indexes = [self.columns.index(col.name) for col in self.table_meta.partition_key]
+
+        self.proto_version = statement.protocol_version
+        self.cqltypes = dict([(c.name, c.type) for c in statement.column_metadata])
+        self.converters = dict([(c.name, self._get_converter(c.type)) for c in statement.column_metadata])
+
+    def _get_converter(self, cql_type):
+        """
+        Return a function that converts a string into a value the can be passed
+        into BoundStatement.bind() for the given cql type. See cassandra.cqltypes
+        for more details.
+        """
+        def unprotect(v):
+            if v is not None:
+                return CqlRuleSet.dequote_value(v)
+
+        def convert(t, v):
+            return converters.get(t.typename, convert_unknown)(unprotect(v), ct=t)
+
+        def split(val, sep=','):
+            """
+            Split into a list of values whenever we encounter a separator but
+            ignore separators inside parentheses or single quotes, except for the two
+            outermost parentheses, which will be ignored. We expect val to be at least
+            2 characters long (the two outer parentheses).
+            """
+            ret = []
+            last = 1
+            level = 0
+            quote = False
+            for i, c in enumerate(val):
+                if c == '{' or c == '[' or c == '(':
+                    level += 1
+                elif c == '}' or c == ']' or c == ')':
+                    level -= 1
+                elif c == '\'':
+                    quote = not quote
+                elif c == sep and level == 1 and not quote:
+                    ret.append(val[last:i])
+                    last = i + 1
+            else:
+                if last < len(val) - 1:
+                    ret.append(val[last:-1])
+
+            return ret
+
+        # this should match all possible CQL datetime formats
+        p = re.compile("(\d{4})\-(\d{2})\-(\d{2})\s?(?:'T')?" +  # YYYY-MM-DD[( |'T')]
+                       "(?:(\d{2}):(\d{2})(?::(\d{2}))?)?" +  # [HH:MM[:SS]]
+                       "(?:([+\-])(\d{2}):?(\d{2}))?")  # [(+|-)HH[:]MM]]
+
+        def convert_date(val, **_):
+            m = p.match(val)
+            if not m:
+                raise ValueError("can't interpret %r as a date" % (val,))
+
+            # https://docs.python.org/2/library/time.html#time.struct_time
+            tval = time.struct_time((int(m.group(1)), int(m.group(2)), int(m.group(3)),  # year, month, day
+                                     int(m.group(4)) if m.group(4) else 0,  # hour
+                                     int(m.group(5)) if m.group(5) else 0,  # minute
+                                     int(m.group(6)) if m.group(6) else 0,  # second
+                                     0, 1, -1))  # day of week, day of year, dst-flag
+
+            if m.group(7):
+                offset = (int(m.group(8)) * 3600 + int(m.group(9)) * 60) * int(m.group(7) + '1')
+            else:
+                offset = -time.timezone
+
+            # scale seconds to millis for the raw value
+            return (timegm(tval) + offset) * 1e3
+
+        def convert_tuple(val, ct=cql_type):
+            return tuple(convert(t, v) for t, v in zip(ct.subtypes, split(val)))
+
+        def convert_list(val, ct=cql_type):
+            return list(convert(ct.subtypes[0], v) for v in split(val))
+
+        def convert_set(val, ct=cql_type):
+            return frozenset(convert(ct.subtypes[0], v) for v in split(val))
+
+        def convert_map(val, ct=cql_type):
+            """
+            We need to pass to BoundStatement.bind() a dict() because it calls iteritems(),
+            except we can't create a dict with another dict as the key, hence we use a class
+            that adds iteritems to a frozen set of tuples (which is how dict are normally made
+            immutable in python).
+            """
+            class ImmutableDict(frozenset):
+                iteritems = frozenset.__iter__
+
+            return ImmutableDict(frozenset((convert(ct.subtypes[0], v[0]), convert(ct.subtypes[1], v[1]))
+                                 for v in [split('{%s}' % vv, sep=':') for vv in split(val)]))
+
+        def convert_user_type(val, ct=cql_type):
+            """
+            A user type is a dictionary except that we must convert each key into
+            an attribute, so we are using named tuples. It must also be hashable,
+            so we cannot use dictionaries. Maybe there is a way to instantiate ct
+            directly but I could not work it out.
+            """
+            vals = [v for v in [split('{%s}' % vv, sep=':') for vv in split(val)]]
+            ret_type = namedtuple(ct.typename, [unprotect(v[0]) for v in vals])
+            return ret_type(*tuple(convert(t, v[1]) for t, v in zip(ct.subtypes, vals)))
+
+        def convert_single_subtype(val, ct=cql_type):
+            return converters.get(ct.subtypes[0].typename, convert_unknown)(val, ct=ct.subtypes[0])
+
+        def convert_unknown(val, ct=cql_type):
+            if issubclass(ct, UserType):
+                return convert_user_type(val, ct=ct)
+            elif issubclass(ct, ReversedType):
+                return convert_single_subtype(val, ct=ct)
+
+            self.printmsg("Unknown type %s (%s) for val %s" % (ct, ct.typename, val))
+            return val
+
+        converters = {
+            'blob': (lambda v, ct=cql_type: bytearray.fromhex(v[2:])),
+            'decimal': (lambda v, ct=cql_type: Decimal(v)),
+            'uuid': (lambda v, ct=cql_type: UUID(v)),
+            'boolean': (lambda v, ct=cql_type: bool(v)),
+            'tinyint': (lambda v, ct=cql_type: int(v)),
+            'ascii': (lambda v, ct=cql_type: v),
+            'float': (lambda v, ct=cql_type: float(v)),
+            'double': (lambda v, ct=cql_type: float(v)),
+            'bigint': (lambda v, ct=cql_type: long(v)),
+            'int': (lambda v, ct=cql_type: int(v)),
+            'varint': (lambda v, ct=cql_type: int(v)),
+            'inet': (lambda v, ct=cql_type: v),
+            'counter': (lambda v, ct=cql_type: long(v)),
+            'timestamp': convert_date,
+            'timeuuid': (lambda v, ct=cql_type: UUID(v)),
+            'date': (lambda v, ct=cql_type: Date(v)),
+            'smallint': (lambda v, ct=cql_type: int(v)),
+            'time': (lambda v, ct=cql_type: Time(v)),
+            'text': (lambda v, ct=cql_type: v),
+            'varchar': (lambda v, ct=cql_type: v),
+            'list': convert_list,
+            'set': convert_set,
+            'map': convert_map,
+            'tuple': convert_tuple,
+            'frozen': convert_single_subtype,
+        }
+
+        return converters.get(cql_type.typename, convert_unknown)
+
+    def get_row_values(self, row):
+        """
+        Parse the row into a list of row values to be returned
+        """
+        ret = [None] * len(row)
+        for i, val in enumerate(row):
+            if val != self.nullval:
+                ret[i] = self.converters[self.columns[i]](val)
+            else:
+                if i in self.primary_key_indexes:
+                    message = "Cannot insert null value for primary key column '%s'." % (self.columns[i],)
+                    if self.nullval == '':
+                        message += " If you want to insert empty strings, consider using" \
+                                   " the WITH NULL=<marker> option for COPY."
+                    raise Exception(message=message)
+
+                ret[i] = None
+
+        return ret
+
+    def get_row_partition_key_values(self, row):
+        """
+        Return a string composed of the partition key values, serialized and binary packed -
+        as expected by metadata.get_replicas(), see also BoundStatement.routing_key.
+        """
+        def serialize(n):
+            c, v = self.columns[n], row[n]
+            return self.cqltypes[c].serialize(self.converters[c](v), self.proto_version)
+
+        partition_key_indexes = self.partition_key_indexes
+        if len(partition_key_indexes) == 1:
+            return serialize(partition_key_indexes[0])
+        else:
+            pk_values = []
+            for i in partition_key_indexes:
+                val = serialize(i)
+                l = len(val)
+                pk_values.append(struct.pack(">H%dsB" % l, l, val, 0))
+            return b"".join(pk_values)
+
+
+class ImportProcess(ChildProcess):
+
+    def __init__(self, params):
+        ChildProcess.__init__(self, params=params, target=self.run)
+
+        csv_options = params['csv_options']
+        self.nullval = csv_options['nullval']
+        self.max_attempts = csv_options['maxattempts']
+        self.min_batch_size = csv_options['minbatchsize']
+        self.max_batch_size = csv_options['maxbatchsize']
+        self._session = None
+
+    @property
+    def session(self):
+        if not self._session:
+            cluster = Cluster(
+                contact_points=(self.hostname,),
+                port=self.port,
+                cql_version=self.cql_version,
+                protocol_version=self.protocol_version,
+                auth_provider=self.auth_provider,
+                load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()),
+                ssl_options=ssl_settings(self.hostname, self.config_file) if self.ssl else None,
+                default_retry_policy=ExpBackoffRetryPolicy(self),
+                compression=None,
+                connect_timeout=self.connect_timeout)
+
+            self._session = cluster.connect(self.ks)
+            self._session.default_timeout = None
+        return self._session
+
+    def run(self):
+        try:
+            table_meta = self.session.cluster.metadata.keyspaces[self.ks].tables[self.cf]
+            is_counter = ("counter" in [table_meta.columns[name].typestring for name in self.columns])
+
+            if is_counter:
+                self.run_counter(table_meta)
+            else:
+                self.run_normal(table_meta)
+
+        except Exception, exc:
+            if self.debug:
+                traceback.print_exc(exc)
+
+        finally:
+            self.close()
+
+    def close(self):
+        if self._session:
+            self._session.cluster.shutdown()
+        ChildProcess.close(self)
+
+    def run_counter(self, table_meta):
+        """
+        Main run method for tables that contain counter columns.
+        """
+        query = 'UPDATE %s.%s SET %%s WHERE %%s' % (protect_name(self.ks), protect_name(self.cf))
+
+        # We prepare a query statement to find out the types of the partition key columns so we can
+        # route the update query to the correct replicas. As far as I understood this is the easiest
+        # way to find out the types of the partition columns, we will never use this prepared statement
+        where_clause = ' AND '.join(['%s = ?' % (protect_name(c.name)) for c in table_meta.partition_key])
+        select_query = 'SELECT * FROM %s.%s WHERE %s' % (protect_name(self.ks), protect_name(self.cf), where_clause)
+        conv = ImportConversion(self, table_meta, self.session.prepare(select_query))
+
+        while True:
+            try:
+                batch = self.inmsg.get()
+
+                for batches in self.split_batches(batch, conv):
+                    for b in batches:
+                        self.send_counter_batch(query, conv, b)
+
+            except Exception, exc:
+                self.outmsg.put((batch, '%s - %s' % (exc.__class__.__name__, exc.message)))
+                if self.debug:
+                    traceback.print_exc(exc)
+
+    def run_normal(self, table_meta):
+        """
+        Main run method for normal tables, i.e. tables that do not contain counter columns.
+        """
+        query = 'INSERT INTO %s.%s (%s) VALUES (%s)' % (protect_name(self.ks),
+                                                        protect_name(self.cf),
+                                                        ', '.join(protect_names(self.columns),),
+                                                        ', '.join(['?' for _ in self.columns]))
+        query_statement = self.session.prepare(query)
+        conv = ImportConversion(self, table_meta, query_statement)
+
+        while True:
+            try:
+                batch = self.inmsg.get()
+
+                for batches in self.split_batches(batch, conv):
+                    for b in batches:
+                        self.send_normal_batch(conv, query_statement, b)
+
+            except Exception, exc:
+                self.outmsg.put((batch, '%s - %s' % (exc.__class__.__name__, exc.message)))
+                if self.debug:
+                    traceback.print_exc(exc)
+
+    def send_counter_batch(self, query_text, conv, batch):
+        if self.test_failures and self.maybe_inject_failures(batch):
+            return
+
+        columns = self.columns
+        batch_statement = BatchStatement(batch_type=BatchType.COUNTER, consistency_level=self.consistency_level)
+        for row in batch['rows']:
+            where_clause = []
+            set_clause = []
+            for i, value in enumerate(row):
+                if i in conv.primary_key_indexes:
+                    where_clause.append("%s=%s" % (columns[i], value))
+                else:
+                    set_clause.append("%s=%s+%s" % (columns[i], columns[i], value))
+
+            full_query_text = query_text % (','.join(set_clause), ' AND '.join(where_clause))
+            batch_statement.add(full_query_text)
+
+        self.execute_statement(batch_statement, batch)
+
+    def send_normal_batch(self, conv, query_statement, batch):
+        try:
+            if self.test_failures and self.maybe_inject_failures(batch):
+                return
+
+            batch_statement = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=self.consistency_level)
+            for row in batch['rows']:
+                batch_statement.add(query_statement, conv.get_row_values(row))
+
+            self.execute_statement(batch_statement, batch)
+
+        except Exception, exc:
+            self.err_callback(exc, batch)
+
+    def maybe_inject_failures(self, batch):
+        """
+        Examine self.test_failures and see if token_range is either a token range
+        supposed to cause a failure (failing_range) or to terminate the worker process
+        (exit_range). If not then call prepare_export_query(), which implements the
+        normal behavior.
+        """
+        if 'failing_batch' in self.test_failures:
+            failing_batch = self.test_failures['failing_batch']
+            if failing_batch['id'] == batch['id']:
+                if batch['attempts'] < failing_batch['failures']:
+                    statement = SimpleStatement("INSERT INTO badtable (a, b) VALUES (1, 2)",
+                                                consistency_level=self.consistency_level)
+                    self.execute_statement(statement, batch)
+                    return True
+
+        if 'exit_batch' in self.test_failures:
+            exit_batch = self.test_failures['exit_batch']
+            if exit_batch['id'] == batch['id']:
+                sys.exit(1)
+
+        return False  # carry on as normal
+
+    def execute_statement(self, statement, batch):
+        future = self.session.execute_async(statement)
+        future.add_callbacks(callback=self.result_callback, callback_args=(batch, ),
+                             errback=self.err_callback, errback_args=(batch, ))
+
+    def split_batches(self, batch, conv):
+        """
+        Split a batch into sub-batches with the same
+        partition key, if possible. If there are at least
+        batch_size rows with the same partition key value then
+        create a sub-batch with that partition key value, else
+        aggregate all remaining rows in a single 'left-overs' batch
+        """
+        rows_by_pk = defaultdict(list)
+
+        for row in batch['rows']:
+            pk = conv.get_row_partition_key_values(row)
+            rows_by_pk[pk].append(row)
+
+        ret = dict()
+        remaining_rows = []
+
+        for pk, rows in rows_by_pk.items():
+            if len(rows) >= self.min_batch_size:
+                ret[pk] = self.batches(rows, batch)
+            else:
+                remaining_rows.extend(rows)
+
+        if remaining_rows:
+            ret[self.hostname] = self.batches(remaining_rows, batch)
+
+        return ret.itervalues()
+
+    def batches(self, rows, batch):
+        for i in xrange(0, len(rows), self.max_batch_size):
+            yield ImportTask.make_batch(batch['id'], rows[i:i + self.max_batch_size], batch['attempts'])
+
+    def result_callback(self, result, batch):
+        batch['imported'] = len(batch['rows'])
+        batch['rows'] = []  # no need to resend these
+        self.outmsg.put((batch, None))
+
+    def err_callback(self, response, batch):
+        batch['imported'] = len(batch['rows'])
+        self.outmsg.put((batch, '%s - %s' % (response.__class__.__name__, response.message)))
+        if self.debug:
+            traceback.print_exc(response)
+
+
 class RateMeter(object):
 
-    def __init__(self, log_threshold):
-        self.log_threshold = log_threshold  # number of records after which we log
-        self.last_checkpoint_time = time.time()  # last time we logged
+    def __init__(self, update_interval=0.25, log=True):
+        self.log = log  # true if we should log
+        self.update_interval = update_interval  # how often we update in seconds
+        self.start_time = time.time()  # the start time
+        self.last_checkpoint_time = self.start_time  # last time we logged
         self.current_rate = 0.0  # rows per second
-        self.current_record = 0  # number of records since we last logged
+        self.current_record = 0  # number of records since we last updated
         self.total_records = 0   # total number of records
 
     def increment(self, n=1):
         self.current_record += n
+        self.maybe_update()
 
-        if self.current_record >= self.log_threshold:
-            self.update()
-            self.log()
-
-    def update(self):
+    def maybe_update(self):
         new_checkpoint_time = time.time()
+        if new_checkpoint_time - self.last_checkpoint_time >= self.update_interval:
+            self.update(new_checkpoint_time)
+            self.log_message()
+
+    def update(self, new_checkpoint_time):
         time_difference = new_checkpoint_time - self.last_checkpoint_time
-        if time_difference != 0.0:
+        if time_difference >= 1e-09:
             self.current_rate = self.get_new_rate(self.current_record / time_difference)
 
         self.last_checkpoint_time = new_checkpoint_time
@@ -626,19 +1292,29 @@ class RateMeter(object):
 
     def get_new_rate(self, new_rate):
         """
-         return the previous rate averaged with the new rate to smooth a bit
+         return the rate of the last period: this is the new rate but
+         averaged with the last rate to smooth a bit
         """
         if self.current_rate == 0.0:
             return new_rate
         else:
             return (self.current_rate + new_rate) / 2.0
 
-    def log(self):
-        output = 'Processed %d rows; Written: %f rows/s\r' % (self.total_records, self.current_rate,)
-        sys.stdout.write(output)
-        sys.stdout.flush()
+    def get_avg_rate(self):
+        """
+         return the average rate since we started measuring
+        """
+        time_difference = time.time() - self.start_time
+        return self.total_records / time_difference if time_difference >= 1e-09 else 0
+
+    def log_message(self):
+        if self.log:
+            output = 'Processed: %d rows; Rate: %7.0f rows/s; Avg. rage: %7.0f rows/s\r' % \
+                     (self.total_records, self.current_rate, self.get_avg_rate())
+            sys.stdout.write(output)
+            sys.stdout.flush()
 
     def get_total_records(self):
-        self.update()
-        self.log()
+        self.update(time.time())
+        self.log_message()
         return self.total_records

http://git-wip-us.apache.org/repos/asf/cassandra/blob/124f1bd2/pylib/cqlshlib/util.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/util.py b/pylib/cqlshlib/util.py
index 4d6cf8a..281aad6 100644
--- a/pylib/cqlshlib/util.py
+++ b/pylib/cqlshlib/util.py
@@ -14,9 +14,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+
+import cProfile
 import codecs
+import pstats
+
 from itertools import izip
 from datetime import timedelta, tzinfo
+from StringIO import StringIO
 
 ZERO = timedelta(0)
 
@@ -122,3 +127,17 @@ def get_file_encoding_bomsize(filename):
         file_encoding, size = "utf-8", 0
 
     return (file_encoding, size)
+
+
+def profile_on():
+    pr = cProfile.Profile()
+    pr.enable()
+    return pr
+
+
+def profile_off(pr):
+    pr.disable()
+    s = StringIO()
+    ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
+    ps.print_stats()
+    print s.getvalue()


Mime
View raw message