Skip to content

Commit

Permalink
some small refactors.
Browse files Browse the repository at this point in the history
  • Loading branch information
youngsofun committed Jul 22, 2022
1 parent 542c3a0 commit e3fb42f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 37 deletions.
3 changes: 1 addition & 2 deletions tests/logictest/complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

target_dir = "./"

http_client = HttpConnector()
http_client.connect(**http_config)
http_client = HttpConnector(**http_config)


def run(source_file, target_path="."):
Expand Down
3 changes: 1 addition & 2 deletions tests/logictest/gen_suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def get_all_tests_under_dir_recursive(suite_dir):

def parse_cases(sql_file):
# New session every case file
http_client = HttpConnector()
http_client.connect(**http_config)
http_client = HttpConnector(**http_config)
cnx = mysql.connector.connect(**mysql_config)
mysql_client = cnx.cursor()

Expand Down
57 changes: 27 additions & 30 deletions tests/logictest/http_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ def format_result(results):
return ""

for line in results:
lineTmp = ""
buf = ""
for item in line:
if isinstance(item, bool):
item = str.lower(str(item))
if lineTmp == "":
lineTmp = str(item)
if buf == "":
buf = str(item)
else:
lineTmp = lineTmp + " " + str(
buf = buf + " " + str(
item) # every item seperate by space
if len(lineTmp) == 0:
if len(buf) == 0:
# empty line in results will replace with tab
lineTmp = "\t"
res = res + lineTmp + "\n"
buf = "\t"
res = res + buf + "\n"
return res


Expand All @@ -71,16 +71,16 @@ def get_data_type(field):

def get_query_options(response):
ret = ""
if get_error(response) != None:
if get_error(response) is not None:
return ret
for field in response['schema']['fields']:
type = str.lower(get_data_type(field))
log.debug(f"type:{type}")
if "int" in type:
typ = str.lower(get_data_type(field))
log.debug(f"type:{typ}")
if "int" in typ:
ret = ret + "I"
elif "float" in type or "double" in type:
elif "float" in typ or "double" in typ:
ret = ret + "F"
elif "bool" in type:
elif "bool" in typ:
ret = ret + "B"
else:
ret = ret + "T"
Expand All @@ -103,12 +103,12 @@ def get_error(response):
return None

# Wrap errno into msg, for result check
wrapMsg = f"errno:{response['error']['code']},msg:{response['error']['message']}"
return Error(msg=wrapMsg, errno=response['error']['code'])
wrap_msg = f"errno:{response['error']['code']},msg:{response['error']['message']}"
return Error(msg=wrap_msg, errno=response['error']['code'])


class HttpConnector():
# Databend http hander doc: https://databend.rs/doc/reference/api/rest
class HttpConnector(object):
# Databend http handler doc: https://databend.rs/doc/reference/api/rest

# Call connect(**driver)
# driver is a dict contains:
Expand All @@ -118,30 +118,28 @@ class HttpConnector():
# 'port': 3307,
# 'database': 'default'
# }
def __init__(self):
self._session = None

def connect(self, host, port, user="root", database=default_database):
def __init__(self, host, port, user="root", database=default_database):
self._host = host
self._port = port
self._user = user
self._database = database
self._session_max_idle_time = 30
self._session = ClientSession()
self._additonal_headers = dict()
self._additional_headers = dict()
self._query_option = None
e = environs.Env()
if os.getenv("ADDITIONAL_HEADERS") is not None:
self._additonal_headers = e.dict("ADDITIONAL_HEADERS")
self._additional_headers = e.dict("ADDITIONAL_HEADERS")

def make_headers(self):
if "Authorization" not in self._additonal_headers:
if "Authorization" not in self._additional_headers:
return {
**headers, "Authorization":
"Basic " + base64.b64encode("{}:{}".format(
self._user, "").encode(encoding="utf-8")).decode()
}
else:
return {**headers, **self._additonal_headers}
return {**headers, **self._additional_headers}

def query(self, statement, session):
url = f"http://{self._host}:{self._port}/v1/query/"
Expand All @@ -152,10 +150,10 @@ def parseSQL(sql):
# SELECT parse_json('"false"')::boolean; => SELECT parse_json('\"false\"')::boolean;
if '"' in sql:
if '\'' in sql:
return str.replace(sql, '"', '\\\"') # " -> \"
return str.replace(sql, "\"", "'") # " -> '
return str.replace(sql, '"', '\\\"') # " -> \"
return str.replace(sql, "\"", "'") # " -> '
else:
return sql # do nothing
return sql # do nothing

log.debug(f"http sql: {parseSQL(statement)}")
query_sql = {'sql': parseSQL(statement), "string_fields": True}
Expand Down Expand Up @@ -191,7 +189,7 @@ def query_with_session(self, statement):
try:
resp = requests.get(url="http://{}:{}{}".format(
self._host, self._port, response['next_uri']),
headers=self.make_headers())
headers=self.make_headers())
response = json.loads(resp.content)
log.debug(
f"Sql in progress, fetch next_uri content: {response}")
Expand Down Expand Up @@ -227,7 +225,6 @@ def fetch_all(self, statement):
def get_query_option(self):
return self._query_option


# if __name__ == '__main__':
# from config import http_config
# connector = HttpConnector()
Expand Down
4 changes: 1 addition & 3 deletions tests/logictest/http_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logictest
import http_connector
from log import log


class TestHttp(logictest.SuiteRunner, ABC):
Expand All @@ -13,8 +12,7 @@ def __init__(self, kind, pattern):

def get_connection(self):
if self._http is None:
self._http = http_connector.HttpConnector()
self._http.connect(**self.driver)
self._http = http_connector.HttpConnector(**self.driver)
return self._http

def reset_connection(self):
Expand Down

0 comments on commit e3fb42f

Please sign in to comment.