This is an automated email from the ASF dual-hosted git repository.

xyz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pulsar-client-python.git


The following commit(s) were added to refs/heads/main by this push:
     new df6871e  Add AutoClusterFailover and ServiceInfoProvider (#295)
df6871e is described below

commit df6871eb0dd22e8e3959e0e7924a78067f16bfc9
Author: Yunze Xu <[email protected]>
AuthorDate: Fri Apr 10 16:15:44 2026 +0800

    Add AutoClusterFailover and ServiceInfoProvider (#295)
---
 pulsar/__init__.py                  | 204 +++++++++++++++++++++++++++-
 src/client.cc                       |  98 +++++++++++++-
 src/config.cc                       |  43 ++++--
 tests/auto_cluster_failover_test.py | 260 ++++++++++++++++++++++++++++++++++++
 tests/pulsar_test.py                |  74 ++++++++++
 tests/run-unit-tests.sh             |   3 +
 6 files changed, 661 insertions(+), 21 deletions(-)

diff --git a/pulsar/__init__.py b/pulsar/__init__.py
index 0f60552..afcb634 100644
--- a/pulsar/__init__.py
+++ b/pulsar/__init__.py
@@ -569,6 +569,159 @@ class AuthenticationBasic(Authentication):
             _check_type(str, method, 'method')
             self.auth = _pulsar.AuthenticationBasic.create(username, password, 
method)
 
+
+class ServiceInfoProvider:
+    """
+    Base class for Python-defined service discovery and failover providers.
+
+    Subclasses must return the initial :class:`ServiceInfo` and may keep the
+    provided update callback to push later service changes into the client.
+    """
+
+    def initial_service_info(self) -> "ServiceInfo":
+        raise NotImplementedError
+
+    def initialize(self, on_service_info_update: Callable[["ServiceInfo"], 
None]) -> None:
+        raise NotImplementedError
+
+    def close(self) -> None:
+        """
+        Stop background work and release resources.
+
+        This is invoked when the underlying C++ client destroys the provider,
+        typically during :meth:`Client.close`.
+        """
+        return None
+
+
+class ServiceInfo:
+    """
+    Connection information for one Pulsar cluster endpoint.
+
+    This is primarily used with :class:`AutoClusterFailover`.
+    """
+
+    def __init__(self,
+                 service_url: str,
+                 authentication: Optional[Authentication] = None,
+                 tls_trust_certs_file_path: Optional[str] = None):
+        """
+        Create a service info entry.
+
+        Parameters
+        ----------
+        service_url: str
+            The Pulsar service URL for this cluster.
+        authentication: Authentication, optional
+            Authentication to use when connecting to this cluster.
+        tls_trust_certs_file_path: str, optional
+            Trust store path for TLS connections to this cluster.
+        """
+        _check_type(str, service_url, 'service_url')
+        _check_type_or_none(Authentication, authentication, 'authentication')
+        _check_type_or_none(str, tls_trust_certs_file_path, 
'tls_trust_certs_file_path')
+
+        self._authentication = authentication
+        self._service_info = _pulsar.ServiceInfo(
+            service_url,
+            authentication.auth if authentication else None,
+            tls_trust_certs_file_path,
+        )
+
+    @property
+    def service_url(self) -> str:
+        return self._service_info.service_url
+
+    @property
+    def use_tls(self) -> bool:
+        return self._service_info.use_tls
+
+    @property
+    def tls_trust_certs_file_path(self) -> Optional[str]:
+        return self._service_info.tls_trust_certs_file_path
+
+    def __repr__(self) -> str:
+        return (
+            "ServiceInfo("
+            f"service_url={self.service_url!r}, "
+            f"use_tls={self.use_tls!r}, "
+            f"tls_trust_certs_file_path={self.tls_trust_certs_file_path!r})"
+        )
+
+    @classmethod
+    def wrap(cls, service_info: _pulsar.ServiceInfo):
+        self = cls.__new__(cls)
+        self._authentication = None
+        self._service_info = service_info
+        return self
+
+
+class AutoClusterFailover:
+    """
+    Cluster-level automatic failover configuration for :class:`Client`.
+    """
+
+    def __init__(self,
+                 primary: ServiceInfo,
+                 secondary: List[ServiceInfo],
+                 check_interval_ms: int = 5000,
+                 failover_threshold: int = 1,
+                 switch_back_threshold: int = 1):
+        """
+        Create an automatic failover configuration.
+
+        Parameters
+        ----------
+        primary: ServiceInfo
+            The preferred cluster to use.
+        secondary: list[ServiceInfo]
+            Ordered fallback clusters to probe when the primary becomes 
unavailable.
+        check_interval_ms: int, default=5000
+            Probe interval in milliseconds.
+        failover_threshold: int, default=1
+            Number of consecutive probe failures required before failover.
+        switch_back_threshold: int, default=1
+            Number of consecutive successful primary probes required before 
switching back.
+        """
+        _check_type(ServiceInfo, primary, 'primary')
+        _check_type(list, secondary, 'secondary')
+        _check_type(int, check_interval_ms, 'check_interval_ms')
+        _check_type(int, failover_threshold, 'failover_threshold')
+        _check_type(int, switch_back_threshold, 'switch_back_threshold')
+
+        if not secondary:
+            raise ValueError("Argument secondary is expected to contain at 
least one ServiceInfo")
+
+        for index, service_info in enumerate(secondary):
+            if not isinstance(service_info, ServiceInfo):
+                raise ValueError(
+                    "Argument secondary[%d] is expected to be of type 
'ServiceInfo' and not '%s'"
+                    % (index, type(service_info).__name__)
+                )
+
+        if check_interval_ms <= 0:
+            raise ValueError("Argument check_interval_ms is expected to be 
greater than 0")
+        if failover_threshold <= 0:
+            raise ValueError("Argument failover_threshold is expected to be 
greater than 0")
+        if switch_back_threshold <= 0:
+            raise ValueError("Argument switch_back_threshold is expected to be 
greater than 0")
+
+        self.primary = primary
+        self.secondary = list(secondary)
+        self.check_interval_ms = check_interval_ms
+        self.failover_threshold = failover_threshold
+        self.switch_back_threshold = switch_back_threshold
+
+    def __repr__(self) -> str:
+        return (
+            "AutoClusterFailover("
+            f"primary={self.primary!r}, "
+            f"secondary={self.secondary!r}, "
+            f"check_interval_ms={self.check_interval_ms!r}, "
+            f"failover_threshold={self.failover_threshold!r}, "
+            f"switch_back_threshold={self.switch_back_threshold!r})"
+        )
+
 class ConsumerDeadLetterPolicy:
     """
     Configuration for the "dead letter queue" feature in consumer.
@@ -681,8 +834,9 @@ class Client:
         Parameters
         ----------
 
-        service_url: str
-            The Pulsar service url eg: pulsar://my-broker.com:6650/
+        service_url: str or AutoClusterFailover or ServiceInfoProvider
+            The Pulsar service URL, for example 
``pulsar://my-broker.com:6650/``, or an
+            :class:`AutoClusterFailover` or :class:`ServiceInfoProvider` 
configuration.
         authentication: Authentication, optional
             Set the authentication provider to be used with the broker. 
Supported methods:
 
@@ -743,7 +897,26 @@ class Client:
         tls_certificate_file_path: str, optional
             The path to the TLS certificate file.
         """
-        _check_type(str, service_url, 'service_url')
+        if not isinstance(service_url, (str, AutoClusterFailover, 
ServiceInfoProvider)):
+            raise ValueError(
+                "Argument service_url is expected to be of type 'str', 
'AutoClusterFailover' or "
+                "'ServiceInfoProvider'"
+            )
+
+        if isinstance(service_url, (AutoClusterFailover, ServiceInfoProvider)) 
and authentication is not None:
+            raise ValueError(
+                "Argument authentication is not supported when service_url is 
an AutoClusterFailover or "
+                "ServiceInfoProvider; set authentication on each ServiceInfo 
instead"
+            )
+
+        if isinstance(service_url, (AutoClusterFailover, ServiceInfoProvider)) 
and \
+                tls_trust_certs_file_path is not None:
+            raise ValueError(
+                "Argument tls_trust_certs_file_path is not supported when 
service_url is an "
+                "AutoClusterFailover or ServiceInfoProvider; set 
tls_trust_certs_file_path on each "
+                "ServiceInfo instead"
+            )
+
         _check_type_or_none(Authentication, authentication, 'authentication')
         _check_type(int, operation_timeout_seconds, 
'operation_timeout_seconds')
         _check_type(int, connection_timeout_ms, 'connection_timeout_ms')
@@ -792,7 +965,24 @@ class Client:
             conf.tls_private_key_file_path(tls_private_key_file_path)
         if tls_certificate_file_path is not None:
             conf.tls_certificate_file_path(tls_certificate_file_path)
-        self._client = _pulsar.Client(service_url, conf)
+        if isinstance(service_url, AutoClusterFailover):
+            self._client = _pulsar.Client.create_auto_cluster_failover(
+                service_url.primary._service_info,
+                [service_info._service_info for service_info in 
service_url.secondary],
+                service_url.check_interval_ms,
+                service_url.failover_threshold,
+                service_url.switch_back_threshold,
+                conf,
+            )
+        elif isinstance(service_url, ServiceInfoProvider):
+            try:
+                self._client = 
_pulsar.Client.create_service_info_provider(service_url, conf)
+            except RuntimeError as e:
+                if str(e) == "Expected a pulsar.ServiceInfo or 
_pulsar.ServiceInfo instance":
+                    raise ValueError(str(e))
+                raise
+        else:
+            self._client = _pulsar.Client(service_url, conf)
         self._consumers = []
 
     @staticmethod
@@ -1417,6 +1607,12 @@ class Client:
         _check_type(str, topic, 'topic')
         return self._client.get_topic_partitions(topic)
 
+    def get_service_info(self) -> ServiceInfo:
+        """
+        Get the current service info used by this client.
+        """
+        return ServiceInfo.wrap(self._client.get_service_info())
+
     def shutdown(self):
         """
         Perform immediate shutdown of Pulsar client.
diff --git a/src/client.cc b/src/client.cc
index d77938f..64a8e7b 100644
--- a/src/client.cc
+++ b/src/client.cc
@@ -18,12 +18,70 @@
  */
 #include "utils.h"
 
+#include <pulsar/AutoClusterFailover.h>
+#include <pulsar/ServiceInfoProvider.h>
+#include <chrono>
+#include <memory>
 #include <pybind11/functional.h>
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
 
 namespace py = pybind11;
 
+static ServiceInfo unwrapPythonServiceInfo(const py::handle& object) {
+    auto serviceInfoObject = py::reinterpret_borrow<py::object>(object);
+
+    try {
+        return serviceInfoObject.cast<ServiceInfo>();
+    } catch (const py::cast_error&) {
+    }
+
+    if (py::hasattr(serviceInfoObject, "_service_info")) {
+        try {
+            return serviceInfoObject.attr("_service_info").cast<ServiceInfo>();
+        } catch (const py::cast_error&) {
+        }
+    }
+
+    throw py::value_error("Expected a pulsar.ServiceInfo or 
_pulsar.ServiceInfo instance");
+}
+
+class PythonServiceInfoProvider : public ServiceInfoProvider {
+   public:
+    explicit PythonServiceInfoProvider(py::object provider) : 
provider_(std::move(provider)) {}
+
+    ~PythonServiceInfoProvider() override {
+        if (!Py_IsInitialized()) {
+            return;
+        }
+
+        py::gil_scoped_acquire acquire;
+        try {
+            if (py::hasattr(provider_, "close")) {
+                provider_.attr("close")();
+            }
+        } catch (const py::error_already_set&) {
+            PyErr_Print();
+        }
+    }
+
+    ServiceInfo initialServiceInfo() override {
+        py::gil_scoped_acquire acquire;
+        return 
unwrapPythonServiceInfo(provider_.attr("initial_service_info")());
+    }
+
+    void initialize(std::function<void(ServiceInfo)> onServiceInfoUpdate) 
override {
+        py::gil_scoped_acquire acquire;
+        provider_.attr("initialize")(py::cpp_function(
+            [onServiceInfoUpdate = std::move(onServiceInfoUpdate)](py::object 
serviceInfo) mutable {
+                onServiceInfoUpdate(unwrapPythonServiceInfo(serviceInfo));
+            }));
+    }
+
+   private:
+    py::object provider_;
+};
+
 Producer Client_createProducer(Client& client, const std::string& topic, const 
ProducerConfiguration& conf) {
     return waitForAsyncValue<Producer>(
         [&](CreateProducerCallback callback) { 
client.createProducerAsync(topic, conf, callback); });
@@ -65,7 +123,8 @@ std::vector<std::string> Client_getTopicPartitions(Client& 
client, const std::st
         [&](GetPartitionsCallback callback) { 
client.getPartitionsForTopicAsync(topic, callback); });
 }
 
-void Client_getTopicPartitionsAsync(Client &client, const std::string& topic, 
GetPartitionsCallback callback) {
+void Client_getTopicPartitionsAsync(Client& client, const std::string& topic,
+                                    GetPartitionsCallback callback) {
     py::gil_scoped_release release;
     client.getPartitionsForTopicAsync(topic, callback);
 }
@@ -76,6 +135,25 @@ SchemaInfo Client_getSchemaInfo(Client& client, const 
std::string& topic, int64_
     });
 }
 
+std::shared_ptr<Client> Client_createAutoClusterFailover(ServiceInfo primary,
+                                                         
std::vector<ServiceInfo> secondary,
+                                                         int64_t 
checkIntervalMs, uint32_t failoverThreshold,
+                                                         uint32_t 
switchBackThreshold,
+                                                         const 
ClientConfiguration& conf) {
+    AutoClusterFailover::Config autoClusterFailoverConfig(std::move(primary), 
std::move(secondary));
+    autoClusterFailoverConfig.checkInterval = 
std::chrono::milliseconds(checkIntervalMs);
+    autoClusterFailoverConfig.failoverThreshold = failoverThreshold;
+    autoClusterFailoverConfig.switchBackThreshold = switchBackThreshold;
+    return std::make_shared<Client>(
+        
Client::create(std::make_unique<AutoClusterFailover>(std::move(autoClusterFailoverConfig)),
 conf));
+}
+
+std::shared_ptr<Client> Client_createServiceInfoProvider(py::object provider,
+                                                         const 
ClientConfiguration& conf) {
+    return std::make_shared<Client>(
+        
Client::create(std::make_unique<PythonServiceInfoProvider>(std::move(provider)),
 conf));
+}
+
 void Client_close(Client& client) {
     waitForAsyncResult([&](ResultCallback callback) { 
client.closeAsync(callback); });
 }
@@ -108,19 +186,25 @@ void Client_subscribeAsync_pattern(Client& client, const 
std::string& topic_patt
 void export_client(py::module_& m) {
     py::class_<Client, std::shared_ptr<Client>>(m, "Client")
         .def(py::init<const std::string&, const ClientConfiguration&>())
+        .def_static("create_auto_cluster_failover", 
&Client_createAutoClusterFailover, py::arg("primary"),
+                    py::arg("secondary"), py::arg("check_interval_ms"), 
py::arg("failover_threshold"),
+                    py::arg("switch_back_threshold"), 
py::arg("client_configuration"))
+        .def_static("create_service_info_provider", 
&Client_createServiceInfoProvider, py::arg("provider"),
+                    py::arg("client_configuration"))
         .def("create_producer", &Client_createProducer)
         .def("create_producer_async", &Client_createProducerAsync)
         .def("subscribe", &Client_subscribe)
         .def("subscribe_topics", &Client_subscribe_topics)
         .def("subscribe_pattern", &Client_subscribe_pattern)
         .def("create_reader", &Client_createReader)
-        .def("create_table_view", [](Client& client, const std::string& topic,
-                                     const TableViewConfiguration& config) {
-            return waitForAsyncValue<TableView>([&](TableViewCallback 
callback) {
-                client.createTableViewAsync(topic, config, callback);
-            });
-        })
+        .def("create_table_view",
+             [](Client& client, const std::string& topic, const 
TableViewConfiguration& config) {
+                 return waitForAsyncValue<TableView>([&](TableViewCallback 
callback) {
+                     client.createTableViewAsync(topic, config, callback);
+                 });
+             })
         .def("get_topic_partitions", &Client_getTopicPartitions)
+        .def("get_service_info", &Client::getServiceInfo)
         .def("get_schema_info", &Client_getSchemaInfo)
         .def("close", &Client_close)
         .def("close_async", &Client_closeAsync)
diff --git a/src/config.cc b/src/config.cc
index 4c5557d..ec16b52 100644
--- a/src/config.cc
+++ b/src/config.cc
@@ -23,6 +23,7 @@
 #include <pulsar/ProducerConfiguration.h>
 #include <pulsar/KeySharedPolicy.h>
 #include <pulsar/DeadLetterPolicyBuilder.h>
+#include <pulsar/ServiceInfo.h>
 #include <pybind11/functional.h>
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
@@ -135,16 +136,36 @@ static ClientConfiguration& 
ClientConfiguration_setFileLogger(ClientConfiguratio
     return conf;
 }
 
+static ServiceInfo ServiceInfo_init(const std::string& serviceUrl, 
AuthenticationPtr authentication,
+                                    std::optional<std::string> 
tlsTrustCertsFilePath) {
+    return ServiceInfo(serviceUrl, authentication ? std::move(authentication) 
: AuthFactory::Disabled(),
+                       std::move(tlsTrustCertsFilePath));
+}
+
 void export_config(py::module_& m) {
     using namespace py;
 
+    class_<ServiceInfo>(m, "ServiceInfo")
+        .def(init(&ServiceInfo_init), arg("service_url"), 
arg("authentication") = nullptr,
+             arg("tls_trust_certs_file_path") = py::none())
+        .def_property_readonly("service_url",
+                               [](const ServiceInfo& serviceInfo) { return 
serviceInfo.serviceUrl(); })
+        .def_property_readonly("use_tls", &ServiceInfo::useTls)
+        .def_property_readonly("tls_trust_certs_file_path", [](const 
ServiceInfo& serviceInfo) {
+            return serviceInfo.tlsTrustCertsFilePath();
+        });
+
     class_<KeySharedPolicy, std::shared_ptr<KeySharedPolicy>>(m, 
"KeySharedPolicy")
         .def(init<>())
         .def("set_key_shared_mode", &KeySharedPolicy::setKeySharedMode, 
return_value_policy::reference)
         .def("get_key_shared_mode", &KeySharedPolicy::getKeySharedMode)
-        .def("set_allow_out_of_order_delivery", 
&KeySharedPolicy::setAllowOutOfOrderDelivery, return_value_policy::reference)
+        .def("set_allow_out_of_order_delivery", 
&KeySharedPolicy::setAllowOutOfOrderDelivery,
+             return_value_policy::reference)
         .def("is_allow_out_of_order_delivery", 
&KeySharedPolicy::isAllowOutOfOrderDelivery)
-        .def("set_sticky_ranges", static_cast<KeySharedPolicy& 
(KeySharedPolicy::*)(const StickyRanges&)>(&KeySharedPolicy::setStickyRanges), 
return_value_policy::reference)
+        .def("set_sticky_ranges",
+             static_cast<KeySharedPolicy& (KeySharedPolicy::*)(const 
StickyRanges&)>(
+                 &KeySharedPolicy::setStickyRanges),
+             return_value_policy::reference)
         .def("get_sticky_ranges", &KeySharedPolicy::getStickyRanges);
 
     class_<CryptoKeyReader, std::shared_ptr<CryptoKeyReader>>(m, 
"AbstractCryptoKeyReader")
@@ -266,7 +287,8 @@ void export_config(py::module_& m) {
         .def(init<>())
         .def("deadLetterTopic", &DeadLetterPolicyBuilder::deadLetterTopic, 
return_value_policy::reference)
         .def("maxRedeliverCount", &DeadLetterPolicyBuilder::maxRedeliverCount, 
return_value_policy::reference)
-        .def("initialSubscriptionName", 
&DeadLetterPolicyBuilder::initialSubscriptionName, 
return_value_policy::reference)
+        .def("initialSubscriptionName", 
&DeadLetterPolicyBuilder::initialSubscriptionName,
+             return_value_policy::reference)
         .def("build", &DeadLetterPolicyBuilder::build, 
return_value_policy::reference)
         .def("build", &DeadLetterPolicyBuilder::build, 
return_value_policy::reference);
 
@@ -305,7 +327,8 @@ void export_config(py::module_& m) {
         .def("subscription_initial_position", 
&ConsumerConfiguration::getSubscriptionInitialPosition)
         .def("subscription_initial_position", 
&ConsumerConfiguration::setSubscriptionInitialPosition)
         .def("regex_subscription_mode", 
&ConsumerConfiguration::setRegexSubscriptionMode)
-        .def("regex_subscription_mode", 
&ConsumerConfiguration::getRegexSubscriptionMode, 
return_value_policy::reference)
+        .def("regex_subscription_mode", 
&ConsumerConfiguration::getRegexSubscriptionMode,
+             return_value_policy::reference)
         .def("crypto_key_reader", &ConsumerConfiguration::setCryptoKeyReader, 
return_value_policy::reference)
         .def("replicate_subscription_state_enabled",
              &ConsumerConfiguration::setReplicateSubscriptionStateEnabled)
@@ -328,9 +351,9 @@ void export_config(py::module_& m) {
         .def("dead_letter_policy", &ConsumerConfiguration::setDeadLetterPolicy)
         .def("dead_letter_policy", 
&ConsumerConfiguration::getDeadLetterPolicy, return_value_policy::copy)
         .def("crypto_failure_action", 
&ConsumerConfiguration::getCryptoFailureAction,
-            return_value_policy::copy)
+             return_value_policy::copy)
         .def("crypto_failure_action", 
&ConsumerConfiguration::setCryptoFailureAction,
-            return_value_policy::reference);
+             return_value_policy::reference);
 
     class_<ReaderConfiguration, std::shared_ptr<ReaderConfiguration>>(m, 
"ReaderConfiguration")
         .def(init<>())
@@ -348,9 +371,9 @@ void export_config(py::module_& m) {
         .def("read_compacted", &ReaderConfiguration::setReadCompacted)
         .def("crypto_key_reader", &ReaderConfiguration::setCryptoKeyReader, 
return_value_policy::reference)
         .def("start_message_id_inclusive", 
&ReaderConfiguration::isStartMessageIdInclusive)
-        .def("start_message_id_inclusive", 
&ReaderConfiguration::setStartMessageIdInclusive, 
return_value_policy::reference)
-        .def("crypto_failure_action", 
&ReaderConfiguration::getCryptoFailureAction,
-            return_value_policy::copy)
+        .def("start_message_id_inclusive", 
&ReaderConfiguration::setStartMessageIdInclusive,
+             return_value_policy::reference)
+        .def("crypto_failure_action", 
&ReaderConfiguration::getCryptoFailureAction, return_value_policy::copy)
         .def("crypto_failure_action", 
&ReaderConfiguration::setCryptoFailureAction,
-            return_value_policy::reference);
+             return_value_policy::reference);
 }
diff --git a/tests/auto_cluster_failover_test.py 
b/tests/auto_cluster_failover_test.py
new file mode 100644
index 0000000..168ae0d
--- /dev/null
+++ b/tests/auto_cluster_failover_test.py
@@ -0,0 +1,260 @@
+#!/usr/bin/env python3
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import shutil
+import time
+from unittest import SkipTest, TestCase, main
+from urllib.error import URLError
+from urllib.request import Request, urlopen
+
+import pulsar
+from pulsar import AutoClusterFailover, Client, MessageId, ServiceInfo
+
+try:
+    from testcontainers.core.container import DockerContainer
+except ImportError:
+    DockerContainer = None
+
+
+class AutoClusterFailoverDockerTest(TestCase):
+
+    PRIMARY_URL = "pulsar://localhost:16650"
+    PRIMARY_ADMIN_URL = "http://localhost:18080";
+    SECONDARY_URL = "pulsar://localhost:26650"
+    SECONDARY_ADMIN_URL = "http://localhost:28080";
+    RECEIVE_TIMEOUT_MS = 10000
+    FAILOVER_WAIT_SECONDS = 30
+
+    @classmethod
+    def setUpClass(cls):
+        if shutil.which("docker") is None:
+            raise SkipTest("docker is required for auto_cluster_failover_test")
+        if DockerContainer is None:
+            raise SkipTest("testcontainers is required for 
auto_cluster_failover_test")
+
+        try:
+            cls.primary_container = cls._create_container(
+                service_port=16650,
+                admin_port=18080,
+                cluster_name="standalone-0",
+            )
+            cls.secondary_container = cls._create_container(
+                service_port=26650,
+                admin_port=28080,
+                cluster_name="standalone-1",
+            )
+            cls.primary_container.start()
+            cls.secondary_container.start()
+            cls._wait_for_http(cls.PRIMARY_ADMIN_URL + "/metrics")
+            cls._wait_for_http(cls.SECONDARY_ADMIN_URL + "/metrics")
+            cls._configure_cluster(
+                cls.PRIMARY_ADMIN_URL,
+                cls.PRIMARY_URL,
+                "standalone-0",
+            )
+            cls._configure_cluster(
+                cls.SECONDARY_ADMIN_URL,
+                cls.SECONDARY_URL,
+                "standalone-1",
+            )
+        except Exception:
+            cls._print_container_logs()
+            raise
+
+    @classmethod
+    def tearDownClass(cls):
+        for container in (
+            getattr(cls, "primary_container", None),
+            getattr(cls, "secondary_container", None),
+        ):
+            if container is not None:
+                container.stop()
+
+    @classmethod
+    def _create_container(cls, service_port, admin_port, cluster_name):
+        return (
+            DockerContainer("apachepulsar/pulsar:latest")
+            .with_env("clusterName", cluster_name)
+            .with_env("advertisedAddress", "localhost")
+            .with_env("advertisedListeners", 
f"external:pulsar://localhost:{service_port}")
+            .with_env("PULSAR_MEM", "-Xms512m -Xmx512m 
-XX:MaxDirectMemorySize=256m")
+            .with_bind_ports(6650, service_port)
+            .with_bind_ports(8080, admin_port)
+            .with_command(
+                'bash -c "bin/apply-config-from-env.py conf/standalone.conf && 
'
+                'exec bin/pulsar standalone -nss -nfw"'
+            )
+        )
+
+    @classmethod
+    def _print_container_logs(cls):
+        for container in (
+            getattr(cls, "primary_container", None),
+            getattr(cls, "secondary_container", None),
+        ):
+            if container is None:
+                continue
+            wrapped = container.get_wrapped_container()
+            try:
+                print(wrapped.logs().decode("utf-8", errors="replace"))
+            except Exception:
+                pass
+
+    @classmethod
+    def _wait_for_http(cls, url, timeout_seconds=180):
+        deadline = time.time() + timeout_seconds
+        last_error = None
+        while time.time() < deadline:
+            try:
+                with urlopen(url, timeout=5) as response:
+                    if response.status == 200:
+                        return
+            except (URLError, OSError) as e:
+                last_error = e
+            time.sleep(1)
+        raise AssertionError(f"Timed out waiting for {url}: {last_error}")
+
+    @staticmethod
+    def _http_put(url, data):
+        request = Request(url, data.encode("utf-8"))
+        request.add_header("Content-Type", "application/json")
+        request.get_method = lambda: "PUT"
+        try:
+            with urlopen(request, timeout=10):
+                return
+        except URLError as e:
+            if "409" in str(e):
+                return
+            raise
+
+    @classmethod
+    def _configure_cluster(cls, admin_url, service_url, cluster_name):
+        cls._http_put(
+            f"{admin_url}/admin/v2/clusters/{cluster_name}",
+            """
+            {
+              "serviceUrl": "%s/",
+              "brokerServiceUrl": "%s/"
+            }
+            """ % (admin_url, service_url),
+        )
+        cls._http_put(
+            f"{admin_url}/admin/v2/tenants/public",
+            """
+            {
+              "adminRoles": ["anonymous"],
+              "allowedClusters": ["%s"]
+            }
+            """ % cluster_name,
+        )
+        cls._http_put(
+            f"{admin_url}/admin/v2/namespaces/public/default",
+            """
+            {
+              "replication_clusters": ["%s"]
+            }
+            """ % cluster_name,
+        )
+
+    @staticmethod
+    def _wait_until(predicate, timeout_seconds, description):
+        deadline = time.time() + timeout_seconds
+        while time.time() < deadline:
+            if predicate():
+                return
+            time.sleep(0.2)
+        raise AssertionError(f"Timed out waiting for {description}")
+
+    @staticmethod
+    def _ensure_topic_exists(service_url, topic):
+        client = Client(service_url)
+        producer = client.create_producer(topic)
+        producer.close()
+        client.close()
+
+    def test_producer_failover_between_two_standalones(self):
+        topic = "test-auto-cluster-failover-%d" % int(time.time() * 1000)
+        message_before_failover = b"before-failover"
+        message_after_failover = b"after-failover"
+
+        self._ensure_topic_exists(self.PRIMARY_URL, topic)
+        self._ensure_topic_exists(self.SECONDARY_URL, topic)
+
+        primary_client = Client(self.PRIMARY_URL)
+        primary_reader = primary_client.create_reader(topic, 
MessageId.earliest)
+
+        secondary_client = Client(self.SECONDARY_URL)
+        secondary_reader = secondary_client.create_reader(topic, 
MessageId.earliest)
+
+        failover_client = Client(
+            AutoClusterFailover(
+                ServiceInfo(self.PRIMARY_URL),
+                [ServiceInfo(self.SECONDARY_URL)],
+                check_interval_ms=200,
+                failover_threshold=1,
+                switch_back_threshold=1,
+            ),
+            operation_timeout_seconds=10,
+        )
+        producer = failover_client.create_producer(
+            topic,
+            send_timeout_millis=3000,
+            batching_enabled=False,
+        )
+
+        self.assertEqual(failover_client.get_service_info().service_url, 
self.PRIMARY_URL)
+
+        producer.send(message_before_failover)
+        
self.assertEqual(primary_reader.read_next(self.RECEIVE_TIMEOUT_MS).data(), 
message_before_failover)
+
+        primary_reader.close()
+        primary_client.close()
+
+        self.primary_container.get_wrapped_container().kill(signal="SIGTERM")
+
+        self._wait_until(
+            lambda: failover_client.get_service_info().service_url == 
self.SECONDARY_URL,
+            self.FAILOVER_WAIT_SECONDS,
+            "client service info to switch to the secondary broker",
+        )
+
+        last_error = None
+        deadline = time.time() + self.FAILOVER_WAIT_SECONDS
+        while time.time() < deadline:
+            try:
+                producer.send(message_after_failover)
+                break
+            except pulsar.PulsarException as e:
+                last_error = e
+                time.sleep(0.5)
+        else:
+            raise AssertionError(f"Producer did not recover after failover: 
{last_error}")
+
+        
self.assertEqual(secondary_reader.read_next(self.RECEIVE_TIMEOUT_MS).data(), 
message_after_failover)
+        self.assertEqual(failover_client.get_service_info().service_url, 
self.SECONDARY_URL)
+
+        producer.close()
+        failover_client.close()
+        secondary_reader.close()
+        secondary_client.close()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tests/pulsar_test.py b/tests/pulsar_test.py
index b7f38ed..be817c5 100755
--- a/tests/pulsar_test.py
+++ b/tests/pulsar_test.py
@@ -45,6 +45,9 @@ from pulsar import (
     ConsumerBatchReceivePolicy,
     ProducerAccessMode,
     ConsumerDeadLetterPolicy,
+    ServiceInfoProvider,
+    ServiceInfo,
+    AutoClusterFailover,
 )
 from pulsar.schema import JsonSchema, Record, Integer
 
@@ -95,6 +98,30 @@ class PulsarTest(TestCase):
 
     serviceUrlTls = "pulsar+ssl://localhost:6651"
 
+    class StaticServiceInfoProvider(ServiceInfoProvider):
+
+        def __init__(self, initial_service_info):
+            self.initial = initial_service_info
+            self.callback = None
+            self.closed = False
+
+        def initial_service_info(self):
+            return self.initial
+
+        def initialize(self, on_service_info_update):
+            self.callback = on_service_info_update
+
+        def close(self):
+            self.closed = True
+
+    class InvalidServiceInfoProvider(ServiceInfoProvider):
+
+        def initial_service_info(self):
+            return "invalid"
+
+        def initialize(self, on_service_info_update):
+            pass
+
     def test_producer_config(self):
         conf = ProducerConfiguration()
         conf.send_timeout_millis(12)
@@ -934,6 +961,53 @@ class PulsarTest(TestCase):
         self._check_value_error(lambda: Client(self.serviceUrl, 
tls_trust_certs_file_path=5))
         self._check_value_error(lambda: Client(self.serviceUrl, 
tls_allow_insecure_connection="test"))
 
+    def test_service_info_argument_errors(self):
+        self._check_value_error(lambda: ServiceInfo(None))
+        self._check_value_error(lambda: ServiceInfo(self.serviceUrl, 
authentication="test"))
+        self._check_value_error(lambda: ServiceInfo(self.serviceUrl, 
tls_trust_certs_file_path=5))
+
+    def test_auto_cluster_failover_argument_errors(self):
+        primary = ServiceInfo(self.serviceUrl)
+        secondary = [ServiceInfo("pulsar://192.0.2.1:6650")]
+
+        self._check_value_error(lambda: AutoClusterFailover("test", secondary))
+        self._check_value_error(lambda: AutoClusterFailover(primary, "test"))
+        self._check_value_error(lambda: AutoClusterFailover(primary, []))
+        self._check_value_error(lambda: AutoClusterFailover(primary, ["test"]))
+        self._check_value_error(lambda: AutoClusterFailover(primary, 
secondary, check_interval_ms=0))
+        self._check_value_error(lambda: AutoClusterFailover(primary, 
secondary, failover_threshold=0))
+        self._check_value_error(lambda: AutoClusterFailover(primary, 
secondary, switch_back_threshold=0))
+        self._check_value_error(lambda: Client(AutoClusterFailover(primary, 
secondary),
+                                               
authentication=AuthenticationToken("token")))
+        self._check_value_error(lambda: Client(AutoClusterFailover(primary, 
secondary),
+                                               
tls_trust_certs_file_path=CERTS_DIR + "cacert.pem"))
+        self._check_value_error(lambda: 
Client(self.StaticServiceInfoProvider(primary),
+                                               
authentication=AuthenticationToken("token")))
+        self._check_value_error(lambda: 
Client(self.StaticServiceInfoProvider(primary),
+                                               
tls_trust_certs_file_path=CERTS_DIR + "cacert.pem"))
+        self._check_value_error(lambda: 
Client(self.InvalidServiceInfoProvider()))
+
+    def test_auto_cluster_failover_client(self):
+        primary = ServiceInfo(self.serviceUrl)
+        secondary = [ServiceInfo("pulsar://192.0.2.1:6650")]
+        client = Client(AutoClusterFailover(primary, secondary, 
check_interval_ms=100))
+        self.assertEqual(client.get_service_info().service_url, 
self.serviceUrl)
+        client.close()
+
+    def test_service_info_provider_client(self):
+        primary = ServiceInfo(self.serviceUrl)
+        secondary = ServiceInfo("pulsar://192.0.2.1:6650")
+        provider = self.StaticServiceInfoProvider(primary)
+
+        client = Client(provider)
+        self.assertEqual(client.get_service_info().service_url, 
self.serviceUrl)
+
+        provider.callback(secondary)
+        self.assertEqual(client.get_service_info().service_url, 
secondary.service_url)
+
+        client.close()
+        self.assertTrue(provider.closed)
+
     def test_producer_argument_errors(self):
         client = Client(self.serviceUrl)
 
diff --git a/tests/run-unit-tests.sh b/tests/run-unit-tests.sh
index 8d7600d..587d6c6 100755
--- a/tests/run-unit-tests.sh
+++ b/tests/run-unit-tests.sh
@@ -23,9 +23,12 @@ set -e -x
 ROOT_DIR=$(git rev-parse --show-toplevel)
 cd $ROOT_DIR/tests
 
+python3 -m pip install testcontainers
+
 python3 custom_logger_test.py
 python3 debug_logger_test.py
 python3 interrupted_test.py
+python3 auto_cluster_failover_test.py
 python3 pulsar_test.py
 python3 schema_test.py
 python3 table_view_test.py


Reply via email to