From 12cccd31cf63b026e27d55f37fa29e842f6d5ea1 Mon Sep 17 00:00:00 2001 From: "wuqingfu.528" Date: Tue, 10 Mar 2026 09:21:53 +0800 Subject: [PATCH 1/2] fix: support clickhouse in vannatoolset --- .../tools/vanna_tools/clickhouse/__init__.py | 17 ++++ .../vanna_tools/clickhouse/sql_runner.py | 96 +++++++++++++++++++ veadk/tools/vanna_tools/vanna_toolset.py | 59 +++++++++++- 3 files changed, 171 insertions(+), 1 deletion(-) create mode 100644 veadk/tools/vanna_tools/clickhouse/__init__.py create mode 100644 veadk/tools/vanna_tools/clickhouse/sql_runner.py diff --git a/veadk/tools/vanna_tools/clickhouse/__init__.py b/veadk/tools/vanna_tools/clickhouse/__init__.py new file mode 100644 index 00000000..b3d327ed --- /dev/null +++ b/veadk/tools/vanna_tools/clickhouse/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .sql_runner import ClickHouseRunner + +__all__ = ["ClickHouseRunner"] diff --git a/veadk/tools/vanna_tools/clickhouse/sql_runner.py b/veadk/tools/vanna_tools/clickhouse/sql_runner.py new file mode 100644 index 00000000..755dc512 --- /dev/null +++ b/veadk/tools/vanna_tools/clickhouse/sql_runner.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd + +from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs +from vanna.core.tool import ToolContext + + +class ClickHouseRunner(SqlRunner): + """ClickHouse implementation of the SqlRunner interface.""" + + def __init__( + self, + host: str, + database: str, + user: str, + password: str, + port: int = 9000, + **kwargs, + ): + """Initialize with ClickHouse connection parameters. + + Args: + host: Database host address + database: Database name + user: Database user + password: Database password + port: Database port (default: 9000 for native protocol) + **kwargs: Additional clickhouse_driver connection parameters + """ + try: + from clickhouse_driver import Client + + self.Client = Client + except ImportError as e: + raise ImportError( + "clickhouse-driver package is required. " + "Install with: pip install clickhouse-driver" + ) from e + + self.host = host + self.port = port + self.user = user + self.password = password + self.database = database + self.kwargs = kwargs + + async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame: + """Execute SQL query against ClickHouse database and return results as DataFrame. + + Args: + args: SQL query arguments + context: Tool execution context + + Returns: + DataFrame with query results + + Raises: + Exception: If query execution fails + """ + # Connect to the database + client = self.Client( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + database=self.database, + **self.kwargs, + ) + + try: + # Execute the query + result = client.execute(args.sql, with_column_types=True) + + # result is a tuple: (data, [(column_name, column_type), ...]) + data = result[0] + columns = [col[0] for col in result[1]] + + # Create a pandas dataframe from the results + df = pd.DataFrame(data, columns=columns) + return df + + finally: + client.disconnect() diff --git a/veadk/tools/vanna_tools/vanna_toolset.py b/veadk/tools/vanna_tools/vanna_toolset.py index ab19171f..111b6327 100644 --- a/veadk/tools/vanna_tools/vanna_toolset.py +++ b/veadk/tools/vanna_tools/vanna_toolset.py @@ -63,6 +63,7 @@ def _post_init(self): from vanna.integrations.sqlite import SqliteRunner from vanna.integrations.postgres import PostgresRunner from vanna.integrations.mysql import MySQLRunner + from .clickhouse.sql_runner import ClickHouseRunner from vanna.tools import LocalFileSystem from vanna.integrations.local.agent_memory import DemoAgentMemory @@ -128,9 +129,65 @@ def _post_init(self): ) except (IndexError, ValueError) as e: raise ValueError(f"Invalid MySQL connection string format: {e}") from e + elif self.connection_string.startswith("clickhouse://"): + try: + from urllib.parse import urlparse, parse_qs + + # 解析 URI + parsed = urlparse(self.connection_string) + + # 提取基本信息 + user = parsed.username + password = parsed.password + host = parsed.hostname + port = parsed.port or 8123 # 默认端口 + database = parsed.path.lstrip("/") + + # 解析查询参数 + query_params = parse_qs(parsed.query) + kwargs = {} + + # 处理所有查询参数 + for key, values in query_params.items(): + if not values: + continue + + value = values[0] # 取第一个值 + + # 处理布尔值参数 + if value.lower() in ("true", "false", "1", "0", "yes", "no"): + kwargs[key] = value.lower() in ("true", "1", "yes") + # 处理数字参数 + elif value.isdigit(): + kwargs[key] = int(value) + # 处理浮点数参数 + elif value.replace(".", "", 1).isdigit(): + kwargs[key] = float(value) + # 其他参数保持字符串 + else: + kwargs[key] = value + + # 验证必需参数 + if not all([user, password, host, database]): + raise ValueError( + "Missing required connection parameters (user, password, host, database)" + ) + + self.runner = ClickHouseRunner( + host=host, + database=database, + user=user, + password=password, + port=port, + **kwargs, + ) + except (IndexError, ValueError, AttributeError) as e: + raise ValueError( + f"Invalid ClickHouse connection string format: {e}" + ) from e else: raise ValueError( - "Unsupported connection string format. Please use sqlite://, postgresql://, or mysql://" + "Unsupported connection string format. Please use sqlite://, postgresql://, mysql://, or clickhouse://" ) if not os.path.exists(self.file_storage): From 166cf9008d9cfb86dfdce51b98a97792da95db3c Mon Sep 17 00:00:00 2001 From: "wuqingfu.528" Date: Tue, 10 Mar 2026 09:24:49 +0800 Subject: [PATCH 2/2] fix: chinese --- veadk/tools/vanna_tools/vanna_toolset.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/veadk/tools/vanna_tools/vanna_toolset.py b/veadk/tools/vanna_tools/vanna_toolset.py index 111b6327..dad4a53e 100644 --- a/veadk/tools/vanna_tools/vanna_toolset.py +++ b/veadk/tools/vanna_tools/vanna_toolset.py @@ -67,11 +67,9 @@ def _post_init(self): from vanna.tools import LocalFileSystem from vanna.integrations.local.agent_memory import DemoAgentMemory - # 验证连接字符串格式 if not self.connection_string: raise ValueError("Connection string cannot be empty") - # 检查连接字符串格式 if self.connection_string.startswith("sqlite://"): if len(self.connection_string) <= len("sqlite://"): raise ValueError( @@ -133,41 +131,32 @@ def _post_init(self): try: from urllib.parse import urlparse, parse_qs - # 解析 URI parsed = urlparse(self.connection_string) - # 提取基本信息 user = parsed.username password = parsed.password host = parsed.hostname - port = parsed.port or 8123 # 默认端口 + port = parsed.port or 8123 database = parsed.path.lstrip("/") - # 解析查询参数 query_params = parse_qs(parsed.query) kwargs = {} - # 处理所有查询参数 for key, values in query_params.items(): if not values: continue - value = values[0] # 取第一个值 + value = values[0] - # 处理布尔值参数 if value.lower() in ("true", "false", "1", "0", "yes", "no"): kwargs[key] = value.lower() in ("true", "1", "yes") - # 处理数字参数 elif value.isdigit(): kwargs[key] = int(value) - # 处理浮点数参数 elif value.replace(".", "", 1).isdigit(): kwargs[key] = float(value) - # 其他参数保持字符串 else: kwargs[key] = value - # 验证必需参数 if not all([user, password, host, database]): raise ValueError( "Missing required connection parameters (user, password, host, database)"