diff --git a/.gitignore b/.gitignore index 2ba823c2..7d1026e7 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ cover/ .cache/ *.iml /scripts/.thrift_gen +build/ \ No newline at end of file diff --git a/TCLIService/TCLIService-remote b/TCLIService/TCLIService-remote index 8d875fa7..482d25bb 100755 --- a/TCLIService/TCLIService-remote +++ b/TCLIService/TCLIService-remote @@ -21,7 +21,7 @@ from TCLIService.ttypes import * if len(sys.argv) <= 1 or sys.argv[1] == '--help': print('') - print('Usage: ' + sys.argv[0] + ' [-h host[:port]] [-u url] [-f[ramed]] [-s[sl]] [-novalidate] [-ca_certs certs] [-keyfile keyfile] [-certfile certfile] function [arg1 [arg2...]]') + print('Usage: ' + sys.argv[0] + ' [-h host[:port]] [-u url] [-f[ramed]] [-s[sl]] [-novalidate] [-keepalive] [-ca_certs certs] [-keyfile keyfile] [-certfile certfile] function [arg1 [arg2...]]') print('') print('Functions:') print(' TOpenSessionResp OpenSession(TOpenSessionReq req)') @@ -56,6 +56,7 @@ uri = '' framed = False ssl = False validate = True +keepalive = False ca_certs = None keyfile = None certfile = None @@ -95,6 +96,10 @@ if sys.argv[argi] == '-novalidate': validate = False argi += 1 +if sys.argv[argi] == '-keepalive': + keepalive = True + argi += 1 + if sys.argv[argi] == '-ca_certs': ca_certs = sys.argv[argi+1] argi += 2 @@ -114,9 +119,9 @@ if http: transport = THttpClient.THttpClient(host, port, uri) else: if ssl: - socket = TSSLSocket.TSSLSocket(host, port, validate=validate, ca_certs=ca_certs, keyfile=keyfile, certfile=certfile) + socket = TSSLSocket.TSSLSocket(host, port, validate=validate, ca_certs=ca_certs, keyfile=keyfile, certfile=certfile, socket_keepalive=keepalive) else: - socket = TSocket.TSocket(host, port) + socket = TSocket.TSocket(host, port, socket_keepalive=keepalive) if framed: transport = TTransport.TFramedTransport(socket) else: diff --git a/dev_requirements.txt b/dev_requirements.txt index 40bb605a..d61f9336 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -14,6 +14,6 @@ requests_kerberos>=0.12.0 sasl>=0.2.1 pure-sasl>=0.6.2 kerberos>=1.3.0 -thrift>=0.10.0 +thrift>=0.13.0 #thrift_sasl>=0.1.0 git+https://github.com/cloudera/thrift_sasl # Using master branch in order to get Python 3 SASL patches diff --git a/pyhive/__init__.py b/pyhive/__init__.py index 0a6bb1f6..aa1b9a92 100644 --- a/pyhive/__init__.py +++ b/pyhive/__init__.py @@ -1,3 +1,3 @@ from __future__ import absolute_import from __future__ import unicode_literals -__version__ = '0.7.0' +__version__ = '0.7.1' diff --git a/pyhive/hive.py b/pyhive/hive.py index c1287488..e3cbd739 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -32,6 +32,7 @@ import thrift.protocol.TBinaryProtocol import thrift.transport.TSocket import thrift.transport.TTransport +import thrift.transport.TSSLSocket # PEP 249 module globals apilevel = '2.0' @@ -159,7 +160,10 @@ def __init__( password=None, check_hostname=None, ssl_cert=None, - thrift_transport=None + thrift_transport=None, + timeout=None, + query_timeout=None, + thrift_keepalive=False ): """Connect to HiveServer2 @@ -191,6 +195,7 @@ def __init__( ), ssl_context=ssl_context, ) + thrift_transport.setTimeout(timeout) if auth in ("BASIC", "NOSASL", "NONE", None): # Always needs the Authorization header @@ -233,7 +238,20 @@ def __init__( port = 10000 if auth is None: auth = 'NONE' - socket = thrift.transport.TSocket.TSocket(host, port) + if configuration.get('use_ssl', False): + _logger.info("Using SSL for Hive connection") + hive_ssl_context = create_default_context() + hive_ssl_context.load_verify_locations(capath=configuration.get('ca_certs_dir', '/etc/ssl/certs/')) + hive_ssl_context.check_hostname = check_hostname == configuration.get('ssl_check_hostname', "true") + socket = thrift.transport.TSSLSocket.TSSLSocket(host, port, ssl_context=hive_ssl_context, + socket_keepalive=thrift_keepalive) + configuration.pop("use_ssl", None) + configuration.pop("ca_certs_dir", None) + configuration.pop("ssl_check_hostname", None) + else: + _logger.info("Using Non-SSL for Hive connection") + socket = thrift.transport.TSocket.TSocket(host, port, socket_keepalive=thrift_keepalive) + socket.setTimeout(timeout) if auth == 'NOSASL': # NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml self._transport = thrift.transport.TTransport.TBufferedTransport(socket) @@ -278,6 +296,7 @@ def __init__( self._sessionHandle = response.sessionHandle assert response.serverProtocolVersion == protocol_version, \ "Unable to handle protocol version {}".format(response.serverProtocolVersion) + socket.setTimeout(query_timeout) with contextlib.closing(self.cursor()) as cursor: cursor.execute('USE `{}`'.format(database)) except: diff --git a/pyhive/tests/test_hive.py b/pyhive/tests/test_hive.py index b49fc190..dd638252 100644 --- a/pyhive/tests/test_hive.py +++ b/pyhive/tests/test_hive.py @@ -199,6 +199,8 @@ def test_invalid_transport(self): lambda: hive.connect(_HOST, thrift_transport=transport) ) + # TODO test keepalive + def test_custom_transport(self): socket = thrift.transport.TSocket.TSocket('localhost', 10000) sasl_auth = 'PLAIN' diff --git a/setup.py b/setup.py index d141ea1b..f5cdd5d1 100755 --- a/setup.py +++ b/setup.py @@ -45,8 +45,8 @@ def run_tests(self): extras_require={ 'presto': ['requests>=1.0.0'], 'trino': ['requests>=1.0.0'], - 'hive': ['sasl>=0.2.1', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], - 'hive_pure_sasl': ['pure-sasl>=0.6.2', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], + 'hive': ['sasl>=0.2.1', 'thrift>=0.13.0', 'thrift_sasl>=0.1.0'], + 'hive_pure_sasl': ['pure-sasl>=0.6.2', 'thrift>=0.13.0', 'thrift_sasl>=0.1.0'], 'sqlalchemy': ['sqlalchemy>=1.3.0'], 'kerberos': ['requests_kerberos>=0.12.0'], }, @@ -60,7 +60,7 @@ def run_tests(self): 'pure-sasl>=0.6.2', 'kerberos>=1.3.0', 'sqlalchemy>=1.3.0', - 'thrift>=0.10.0', + 'thrift>=0.13.0', ], cmdclass={'test': PyTest}, package_data={