Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support with query #2895

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
goinception去重语法解析后的表,去掉with后的虚拟名
  • Loading branch information
woshiyanghai committed Jan 17, 2025
commit fd4b73803c8154edaa049dd0ae406da4e12224b8
36 changes: 25 additions & 11 deletions sql/engines/goinception.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,14 @@ def get_table_ref(query_tree, db_name=None):
小子树才是效率较高的算法,但是就这样吧,反正它能运行 :)
"""
table_ref = []
temporary_tables = set() # 用于存储临时表名

# 首先识别所有的临时表名
if "With" in query_tree:
logger.warning(query_tree["With"])
with_definitions = query_tree["With"].get("CTEs", []) # 假设临时表定义在CTEs键下
for definition in with_definitions:
temporary_tables.add(definition["Name"]["O"]) # 获取临时表的名称

find_queue = [query_tree]
for tree in find_queue:
Expand All @@ -326,20 +334,16 @@ def get_table_ref(query_tree, db_name=None):
else:
snodes = tree.find_max_tree("Source")
if snodes:
table_ref.extend(
[
{
"schema": snode["Source"].get("Schema", {}).get("O")
or db_name,
"name": snode["Source"].get("Name", {}).get("O", ""),
}
for snode in snodes
]
)
for snode in snodes:
schema_name = snode["Source"].get("Schema", {}).get("O") or db_name
table_name = snode["Source"].get("Name", {}).get("O", "")
# 检查表名是否为临时表,如果不是则添加到结果中
if table_name not in temporary_tables:
table_ref.append({"schema": schema_name, "name": table_name})
# assert: source node must exists if table_refs node exists.
# else:
# raise Exception("GoInception Error: not found source node")
return table_ref
return remove_duplicates(table_ref)

def close(self):
if self.conn:
Expand Down Expand Up @@ -380,3 +384,13 @@ def get_session_variables(instance):
for k, v in variables.items():
set_session_sql += f"inception set session {k} = '{v}';\n"
return variables, set_session_sql

def remove_duplicates(table_list):
unique_tables = []
seen = set()
for table in table_list:
identifier = (table['schema'], table['name'])
if identifier not in seen:
seen.add(identifier)
unique_tables.append(table)
return unique_tables