This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new e3d7f7c8d8 [feature](Nereids) add test framework for cost model 
(#17071)
e3d7f7c8d8 is described below

commit e3d7f7c8d8bf7e665b1479cfc64691141f0ccc0c
Author: 谢健 <jianx...@gmail.com>
AuthorDate: Tue Feb 28 20:59:07 2023 +0800

    [feature](Nereids) add test framework for cost model (#17071)
    
    add test-frame-work for cost model according paper Testing the Accuracy of 
Query Optimizers
---
 tools/cost_model_evaluate/README.MD           | 23 +++++++
 tools/cost_model_evaluate/config.py           | 40 ++++++++++++
 tools/cost_model_evaluate/evaluator.py        | 91 +++++++++++++++++++++++++++
 tools/cost_model_evaluate/index_calculator.py | 69 ++++++++++++++++++++
 tools/cost_model_evaluate/main.py             | 61 ++++++++++++++++++
 tools/cost_model_evaluate/requirements.txt    | 17 +++++
 tools/cost_model_evaluate/sql_executor.py     | 69 ++++++++++++++++++++
 7 files changed, 370 insertions(+)

diff --git a/tools/cost_model_evaluate/README.MD 
b/tools/cost_model_evaluate/README.MD
new file mode 100644
index 0000000000..2cc9daf45d
--- /dev/null
+++ b/tools/cost_model_evaluate/README.MD
@@ -0,0 +1,23 @@
+<!-- 
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+This code is used to evaluate the cost model in doris.
+You can config the query in config of main.py 
+
+Before running, you should install the libraries in requirements.txt
\ No newline at end of file
diff --git a/tools/cost_model_evaluate/config.py 
b/tools/cost_model_evaluate/config.py
new file mode 100644
index 0000000000..bfe378a8d2
--- /dev/null
+++ b/tools/cost_model_evaluate/config.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from dataclasses import dataclass
+
+@dataclass
+class Config:
+    # user for mysql client
+    user: str
+    # password for mysql client
+    password: str
+    # host of mysql client
+    host: str
+    # post of mysql client
+    port: int
+    # database of query that used to evaluated
+    database: str
+    # execute times for one plan of the query. Note a query can generate 
multiple plans
+    execute_times: int
+    # the number of generate plans for one query. Note if the number > the 
possible plans, 
+    # we will only use the valid plans.   
+    plan_number: int
+    # Does plot the relation of cost and time
+    plot: bool
+    # run the query before really evaluate, just for avoiding cold running
+    cold_run: int
\ No newline at end of file
diff --git a/tools/cost_model_evaluate/evaluator.py 
b/tools/cost_model_evaluate/evaluator.py
new file mode 100644
index 0000000000..e963a183ed
--- /dev/null
+++ b/tools/cost_model_evaluate/evaluator.py
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.command.config import config
+from config import Config
+from index_calculator import IndexCalculator
+from sql_executor import SQLExecutor
+import matplotlib.pyplot as plt
+
+
+class Evaluator:
+    def __init__(self, config: Config, query: str) -> None:
+        self.config = config
+        self.query = query.lower()
+        self.setup_queries = [
+            "set enable_nereids_planner=true;",
+            "set enable_fallback_to_original_planner=false;",
+            "set enable_profile=true;"
+        ]
+        self.sql_executor = SQLExecutor(
+            config.user,
+            config.password,
+            config.host,
+            config.port,
+            config.database)
+
+    def cold_run(self):
+        for _ in range(self.config.cold_run):
+            self.sql_executor.execute_query(self.query, None)
+
+    def evaluate(self):
+        self.setup()
+        self.cold_run()
+        plans = self.extract_all_plans()
+        res: list[tuple[float, float]] = []
+        for n, (plan, cost) in plans.items():
+            time = self.sql_executor.get_execute_time(plan)
+            res.append((cost, time))
+        if self.config.plot:
+            self.plot(res)
+        print(res)
+        index_calculator = IndexCalculator(res)
+        return index_calculator.calculate()
+
+    def plot(self, data):
+        x_values = [t[0] for t in data]
+        y_values = [t[1] for t in data]
+        fig, ax = plt.subplots()
+        ax.scatter(x_values, y_values)
+        ax.set_xlabel('Cost')
+        ax.set_ylabel('Time')
+        plt.show()
+
+    def setup(self):
+        for q in self.setup_queries:
+            self.sql_executor.execute_query(q, None)
+
+    def extract_all_plans(self):
+        plan_set = set()
+        plan_map: dict[int, tuple[str, float]] = {}
+        for n in range(1, self.config.plan_number):
+            query = self.inject_nth_optimized_hint(n)
+            plan, cost = self.sql_executor.get_plan_with_cost(query)
+            if plan in plan_set:
+                break
+            plan_set.add(plan)
+            plan_map[n] = (query, cost)
+        return plan_map
+
+    def inject_nth_optimized_hint(self, n: int):
+        if ("set_var(" in self.query):
+            query = self.query.replace(
+                "/*+set_var(", f"/*+set_var(nth_optimized_plan={n}, ")
+        else:
+            query = self.query.replace(
+                "select", f"select /*+set_var(nth_optimized_plan={n})*/")
+        return query
diff --git a/tools/cost_model_evaluate/index_calculator.py 
b/tools/cost_model_evaluate/index_calculator.py
new file mode 100644
index 0000000000..8422146c9b
--- /dev/null
+++ b/tools/cost_model_evaluate/index_calculator.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import math
+from typing import List, Tuple
+import unittest
+
+# The index is motivated by Testing the Accuracy of Query Optimizers
+
+
+class IndexCalculator:
+    def __init__(self, cost_time_list: List[Tuple[float, float]]) -> None:
+        self.cost_time_list = cost_time_list
+        sorted(self.cost_time_list, key=lambda t: t[0])
+        self.max_c = max(self.cost_time_list, key=lambda ct: ct[0])[0]
+        self.min_c = min(self.cost_time_list, key=lambda ct: ct[0])[0]
+        self.max_t = max(self.cost_time_list, key=lambda ct: ct[1])[1]
+        self.min_t = min(self.cost_time_list, key=lambda ct: ct[1])[1]
+
+    def calculate(self) -> float:
+
+        l = len(self.cost_time_list)
+        score = 0.0
+        for j in range(0, l):
+            for i in range(0, j):
+                score += self.weight(i)*self.weight(j) * \
+                    self.distance(i, j)*self.sgn(i, j)
+        return score
+
+    def weight(self, i: int) -> float:
+        return self.cost_time_list[0][0]/self.cost_time_list[i][0]
+
+    def distance(self, i: int, j: int) -> float:
+        d0 = (self.cost_time_list[i][0] - self.cost_time_list[j]
+              [0])/(self.max_c - self.min_c + 0.00001)
+        d1 = (self.cost_time_list[i][1] - self.cost_time_list[j]
+              [1])/(self.max_t - self.min_t + 0.00001)
+
+        return math.sqrt(d0*d0 + d1*d1)
+
+    def sgn(self, i: int, j: int) -> float:
+        if self.cost_time_list[j][1] - self.cost_time_list[i][1] >= 0:
+            return 1
+        else:
+            return -1
+
+
+class Test(unittest.TestCase):
+    def test(self):
+        idx_cal = IndexCalculator([(1, 2), (2, 3)])
+        self.assertEqual(round(idx_cal.calculate(), 2), 0.71)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tools/cost_model_evaluate/main.py 
b/tools/cost_model_evaluate/main.py
new file mode 100644
index 0000000000..3103fb2316
--- /dev/null
+++ b/tools/cost_model_evaluate/main.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from config import Config
+from evaluator import Evaluator
+
+
+config = Config(
+    "root",
+    "",
+    "127.0.0.1",
+    9030,
+    "regression_test_nereids_tpch_p0",
+    2,
+    50,
+    True,
+    3
+)
+
+sql = """
+select
+    n_name,
+    sum(l_extendedprice * (1 - l_discount)) as revenue
+from
+    customer,
+    orders,
+    lineitem,
+    supplier,
+    nation,
+    region
+where
+    c_custkey = o_custkey
+    and l_orderkey = o_orderkey
+    and l_suppkey = s_suppkey
+    and c_nationkey = s_nationkey
+    and s_nationkey = n_nationkey
+    and n_regionkey = r_regionkey
+    and r_name = 'ASIA'
+    and o_orderdate >= date '1994-01-01'
+    and o_orderdate < date '1994-01-01' + interval '1' year
+group by
+    n_name
+order by
+    revenue desc;
+"""
+
+print(Evaluator(config, sql).evaluate())
diff --git a/tools/cost_model_evaluate/requirements.txt 
b/tools/cost_model_evaluate/requirements.txt
new file mode 100644
index 0000000000..ffbc650c26
--- /dev/null
+++ b/tools/cost_model_evaluate/requirements.txt
@@ -0,0 +1,17 @@
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+matplotlib==3.7.0
+mysql_connector_repackaged==0.3.1
diff --git a/tools/cost_model_evaluate/sql_executor.py 
b/tools/cost_model_evaluate/sql_executor.py
new file mode 100644
index 0000000000..511e12c8ad
--- /dev/null
+++ b/tools/cost_model_evaluate/sql_executor.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest import result
+import mysql.connector
+from typing import List, Tuple
+
+
+class SQLExecutor:
+    def __init__(self, user: str, password: str, host: str, port: int, 
database: str) -> None:
+        self.connection = mysql.connector.connect(
+            user=user,
+            password=password,
+            host=host,
+            port=port,
+            database=database
+        )
+        self.cursor = self.connection.cursor()
+        self.wait_fetch_time_index = 16
+
+    def execute_query(self, query: str, parameters: Tuple | None) -> 
List[Tuple]:
+        if parameters:
+            self.cursor.execute(query, parameters)
+        else:
+            self.cursor.execute(query)
+        results = self.cursor.fetchall()
+        return results
+
+    def get_execute_time(self, query: str) -> float:
+        self.execute_query(query, None)
+        profile = self.execute_query("show query profile\"\"", None)
+        return float(profile[0][self.wait_fetch_time_index].replace("ms", ""))
+
+    def execute_many_queries(self, queries: List[Tuple[str, Tuple]]) -> 
List[List[Tuple]]:
+        results = []
+        for query, parameters in queries:
+            result = self.execute_query(query, parameters)
+            results.append(result)
+        return results
+
+    def get_plan_with_cost(self, query: str):
+        result = self.execute_query(f"explain optimized plan {query}", None)
+        cost = float(result[0][0].replace("cost = ", ""))
+        plan = "".join([s[0] for s in result[1:]])
+        return plan, cost
+
+    def commit(self) -> None:
+        self.connection.commit()
+
+    def rollback(self) -> None:
+        self.connection.rollback()
+
+    def close(self) -> None:
+        self.cursor.close()
+        self.connection.close()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to