Skip to content
This repository has been archived by the owner on Jan 22, 2025. It is now read-only.

Commit

Permalink
Merge pull request #83 from youngsofun/fix
Browse files Browse the repository at this point in the history
feat: support cookie and temp table
  • Loading branch information
hantmac authored Dec 11, 2024
2 parents dec102e + 0fd3aa0 commit d29c4a3
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
25 changes: 22 additions & 3 deletions databend_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import time
import uuid

from http.cookiejar import Cookie
from requests.auth import HTTPBasicAuth
from requests.cookies import RequestsCookieJar

import environs
import requests
Expand Down Expand Up @@ -75,6 +77,17 @@ def get_error(response):
return ServerException(response["error"]["message"], response["error"]["code"])


class GlobalCookieJar(RequestsCookieJar):

def __init__(self):
super().__init__()

def set_cookie(self, cookie: Cookie, *args, **kwargs):
cookie.domain = ""
cookie.path = "/"
super().set_cookie(cookie, *args, **kwargs)


class Connection(object):
# Databend http handler doc: https://databend.rs/doc/reference/api/rest

Expand Down Expand Up @@ -120,6 +133,10 @@ def __init__(
self.context = Context()
self.requests_session = requests.Session()
self.schema = "http"
cookie_jar = GlobalCookieJar()
cookie_jar.set("cookie_enabled", "true")
self.requests_session.cookies = cookie_jar
self.schema = 'http'
if self.secure:
self.schema = "https"
e = environs.Env()
Expand Down Expand Up @@ -223,7 +240,9 @@ def query(self, statement):
log.logger.debug(f"http headers {self.make_headers()}")
try:
resp_dict = self.do_query(url, query_sql)
self.client_session = resp_dict.get("session", self.default_session())
new_session_state = resp_dict.get("session", self.default_session())
if new_session_state:
self.client_session = new_session_state
if self.additional_headers:
self.additional_headers.update(
{XDatabendQueryIDHeader: resp_dict.get(QueryID)}
Expand Down Expand Up @@ -286,15 +305,15 @@ def query_with_session(self, statement):
response_list.append(response)
start_time = time.time()
time_limit = 12
session = response.get("session", self.default_session())
session = response.get("session")
if session:
self.client_session = session
while response["next_uri"] is not None:
resp = self.next_page(response["next_uri"])
response = json.loads(resp.content)
log.logger.debug(f"Sql in progress, fetch next_uri content: {response}")
self.check_error(response)
session = response.get("session", self.default_session())
session = response.get("session")
if session:
self.client_session = session
response_list.append(response)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,14 @@ def test_cast_bool(self):
_, data = client.execute("select 'False'::boolean union select 'True'::boolean")
self.assertEqual(data, [(True,), (False,)])

def test_temp_table(self):
client = Client.from_url(self.databend_url)
client.execute("create temp table t1(a int)")
client.execute("insert into t1 values (1)")
_, data = client.execute("select * from t1")
self.assertEqual(data, [(1,)])
client.execute("drop table t1")


if __name__ == "__main__":
unittest.main()
6 changes: 3 additions & 3 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ def __setattr__(self, key, value):


class TestDict(unittest.TestCase):
databend_url = None # 使用类属性来存储 databend_url
databend_url = None

@classmethod
def setUpClass(cls):
cls.databend_url = "test_url" # 在类级别设置 databend_url
cls.databend_url = "test_url"

def test_init(self):
d = Dict(a=1, b="test")
self.assertEqual(self.databend_url, "test_url") # 使用类属性
self.assertEqual(self.databend_url, "test_url")
self.assertEqual(d.a, 1)
self.assertEqual(d.b, "test")
self.assertTrue(isinstance(d, dict))
Expand Down

0 comments on commit d29c4a3

Please sign in to comment.