This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7084429f42 Fixing template_fields for WeaviateIngestOperator (#36359)
7084429f42 is described below
commit 7084429f42d0a006e777612c07b3471100f953c9
Author: vatsrahul1001 <[email protected]>
AuthorDate: Fri Dec 22 03:26:02 2023 +0530
Fixing template_fields for WeaviateIngestOperator (#36359)
* fixing teamplate fields for WeaviateIngestOperator
* removing fxitures which are not required
* marking test as db as its accessing db
---
airflow/providers/weaviate/operators/weaviate.py | 4 ++--
tests/providers/weaviate/operators/test_weaviate.py | 17 +++++++++++++++++
2 files changed, 19 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/weaviate/operators/weaviate.py
b/airflow/providers/weaviate/operators/weaviate.py
index 4e07a59edb..d12a2c2e6c 100644
--- a/airflow/providers/weaviate/operators/weaviate.py
+++ b/airflow/providers/weaviate/operators/weaviate.py
@@ -51,7 +51,7 @@ class WeaviateIngestOperator(BaseOperator):
:param vector_col: key/column name in which the vectors are stored.
"""
- template_fields: Sequence[str] = ("input_json",)
+ template_fields: Sequence[str] = ("input_json", "input_data")
def __init__(
self,
@@ -69,7 +69,7 @@ class WeaviateIngestOperator(BaseOperator):
self.class_name = class_name
self.conn_id = conn_id
self.vector_col = vector_col
-
+ self.input_json = input_json
if input_data is not None:
self.input_data = input_data
elif input_json is not None:
diff --git a/tests/providers/weaviate/operators/test_weaviate.py
b/tests/providers/weaviate/operators/test_weaviate.py
index 7490b64dc6..7a2c362494 100644
--- a/tests/providers/weaviate/operators/test_weaviate.py
+++ b/tests/providers/weaviate/operators/test_weaviate.py
@@ -50,3 +50,20 @@ class TestWeaviateIngestOperator:
"my_class", {"data": "sample_data"}, vector_col="Vector", **{}
)
mock_log.debug.assert_called_once_with("Input data: %s", {"data":
"sample_data"})
+
+ @pytest.mark.db_test
+ def test_templates(self, create_task_instance_of_operator):
+ dag_id = "TestWeaviateIngestOperator"
+ ti = create_task_instance_of_operator(
+ WeaviateIngestOperator,
+ dag_id=dag_id,
+ task_id="task-id",
+ conn_id="weaviate_conn",
+ class_name="my_class",
+ input_json="{{ dag.dag_id }}",
+ input_data="{{ dag.dag_id }}",
+ )
+ ti.render_templates()
+
+ assert dag_id == ti.task.input_json
+ assert dag_id == ti.task.input_data