https://gcc.gnu.org/g:c4253d6a170f40725ce3a11ce7a3e236b6e4842f

commit r14-10737-gc4253d6a170f40725ce3a11ce7a3e236b6e4842f
Author: Jonathan Wakely <jwak...@redhat.com>
Date:   Tue Jun 11 16:45:43 2024 +0100

    libstdc++: Fix std::codecvt<wchar_t, char, mbstate_t> for empty dest 
[PR37475]
    
    For the GNU locale model, codecvt::do_out and codecvt::do_in incorrectly
    return 'ok' when the destination range is empty. That happens because
    detecting incomplete output is done in the loop body, and the loop is
    never even entered if to == to_end.
    
    By restructuring the loop condition so that we check the output range
    separately, we can ensure that for a non-empty source range, we always
    enter the loop at least once, and detect if the destination range is too
    small.
    
    The loops also seem easier to reason about if we return immediately on
    any error, instead of checking the result twice on every iteration. We
    can use an RAII type to restore the locale before returning, which also
    simplifies all the other member functions.
    
    libstdc++-v3/ChangeLog:
    
            PR libstdc++/37475
            * config/locale/gnu/codecvt_members.cc (Guard): New RAII type.
            (do_out, do_in): Return partial if the destination is empty but
            the source is not. Use Guard to restore locale on scope exit.
            Return immediately on any conversion error.
            (do_encoding, do_max_length, do_length): Use Guard.
            * testsuite/22_locale/codecvt/in/char/37475.cc: New test.
            * testsuite/22_locale/codecvt/in/wchar_t/37475.cc: New test.
            * testsuite/22_locale/codecvt/out/char/37475.cc: New test.
            * testsuite/22_locale/codecvt/out/wchar_t/37475.cc: New test.
    
    (cherry picked from commit 73ad57c244c283bf6da0c16630212f11b945eda5)

Diff:
---
 libstdc++-v3/config/locale/gnu/codecvt_members.cc  | 117 +++++++++------------
 .../testsuite/22_locale/codecvt/in/char/37475.cc   |  23 ++++
 .../22_locale/codecvt/in/wchar_t/37475.cc          |  23 ++++
 .../testsuite/22_locale/codecvt/out/char/37475.cc  |  23 ++++
 .../22_locale/codecvt/out/wchar_t/37475.cc         |  23 ++++
 5 files changed, 142 insertions(+), 67 deletions(-)

diff --git a/libstdc++-v3/config/locale/gnu/codecvt_members.cc 
b/libstdc++-v3/config/locale/gnu/codecvt_members.cc
index 034713d236ef..794f25a5f356 100644
--- a/libstdc++-v3/config/locale/gnu/codecvt_members.cc
+++ b/libstdc++-v3/config/locale/gnu/codecvt_members.cc
@@ -37,8 +37,23 @@ namespace std _GLIBCXX_VISIBILITY(default)
 {
 _GLIBCXX_BEGIN_NAMESPACE_VERSION
 
-  // Specializations.
 #ifdef _GLIBCXX_USE_WCHAR_T
+namespace
+{
+  // RAII type for changing and restoring the current thread's locale.
+  struct Guard
+  {
+#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
+    explicit Guard(__c_locale loc) : old(__uselocale(loc)) { }
+    ~Guard() { __uselocale(old); }
+#else
+    explicit Guard(__c_locale) { }
+#endif
+    __c_locale old;
+  };
+}
+
+  // Specializations.
   codecvt_base::result
   codecvt<wchar_t, char, mbstate_t>::
   do_out(state_type& __state, const intern_type* __from,
@@ -46,22 +61,21 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
         extern_type* __to, extern_type* __to_end,
         extern_type*& __to_next) const
   {
-    result __ret = ok;
     state_type __tmp_state(__state);
-
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __c_locale __old = __uselocale(_M_c_locale_codecvt);
-#endif
+    Guard g(_M_c_locale_codecvt);
 
     // wcsnrtombs is *very* fast but stops if encounters NUL characters:
     // in case we fall back to wcrtomb and then continue, in a loop.
     // NB: wcsnrtombs is a GNU extension
-    for (__from_next = __from, __to_next = __to;
-        __from_next < __from_end && __to_next < __to_end
-        && __ret == ok;)
+    __from_next = __from;
+    __to_next = __to;
+    while (__from_next < __from_end)
       {
-       const intern_type* __from_chunk_end = wmemchr(__from_next, L'\0',
-                                                     __from_end - __from_next);
+       if (__to_next >= __to_end)
+         return partial;
+
+       const intern_type* __from_chunk_end
+         = wmemchr(__from_next, L'\0', __from_end - __from_next);
        if (!__from_chunk_end)
          __from_chunk_end = __from_end;
 
@@ -77,12 +91,12 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
            for (; __from < __from_next; ++__from)
              __to_next += wcrtomb(__to_next, *__from, &__tmp_state);
            __state = __tmp_state;
-           __ret = error;
+           return error;
          }
        else if (__from_next && __from_next < __from_chunk_end)
          {
            __to_next += __conv;
-           __ret = partial;
+           return partial;
          }
        else
          {
@@ -90,13 +104,13 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
            __to_next += __conv;
          }
 
-       if (__from_next < __from_end && __ret == ok)
+       if (__from_next < __from_end)
          {
            extern_type __buf[MB_LEN_MAX];
            __tmp_state = __state;
            const size_t __conv2 = wcrtomb(__buf, *__from_next, &__tmp_state);
            if (__conv2 > static_cast<size_t>(__to_end - __to_next))
-             __ret = partial;
+             return partial;
            else
              {
                memcpy(__to_next, __buf, __conv2);
@@ -107,11 +121,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
          }
       }
 
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __uselocale(__old);
-#endif
-
-    return __ret;
+    return ok;
   }
 
   codecvt_base::result
@@ -121,24 +131,22 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
        intern_type* __to, intern_type* __to_end,
        intern_type*& __to_next) const
   {
-    result __ret = ok;
     state_type __tmp_state(__state);
-
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __c_locale __old = __uselocale(_M_c_locale_codecvt);
-#endif
+    Guard g(_M_c_locale_codecvt);
 
     // mbsnrtowcs is *very* fast but stops if encounters NUL characters:
     // in case we store a L'\0' and then continue, in a loop.
     // NB: mbsnrtowcs is a GNU extension
-    for (__from_next = __from, __to_next = __to;
-        __from_next < __from_end && __to_next < __to_end
-        && __ret == ok;)
+    __from_next = __from;
+    __to_next = __to;
+    while (__from_next < __from_end)
       {
-       const extern_type* __from_chunk_end;
-       __from_chunk_end = static_cast<const extern_type*>(memchr(__from_next, 
'\0',
-                                                                 __from_end
-                                                                 - 
__from_next));
+       if (__to_next >= __to_end)
+         return partial;
+
+       const extern_type* __from_chunk_end
+         = static_cast<const extern_type*>(memchr(__from_next, '\0',
+                                                  __from_end - __from_next));
        if (!__from_chunk_end)
          __from_chunk_end = __from_end;
 
@@ -161,13 +169,13 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
              }
            __from_next = __from;
            __state = __tmp_state;
-           __ret = error;
+           return error;
          }
        else if (__from_next && __from_next < __from_chunk_end)
          {
            // It is unclear what to return in this case (see DR 382).
            __to_next += __conv;
-           __ret = partial;
+           return partial;
          }
        else
          {
@@ -175,7 +183,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
            __to_next += __conv;
          }
 
-       if (__from_next < __from_end && __ret == ok)
+       if (__from_next < __from_end)
          {
            if (__to_next < __to_end)
              {
@@ -185,48 +193,30 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
                *__to_next++ = L'\0';
              }
            else
-             __ret = partial;
+             return partial;
          }
       }
 
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __uselocale(__old);
-#endif
-
-    return __ret;
+    return ok;
   }
 
   int
   codecvt<wchar_t, char, mbstate_t>::
   do_encoding() const throw()
   {
+    Guard g(_M_c_locale_codecvt);
     // XXX This implementation assumes that the encoding is
     // stateless and is either single-byte or variable-width.
-    int __ret = 0;
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __c_locale __old = __uselocale(_M_c_locale_codecvt);
-#endif
-    if (MB_CUR_MAX == 1)
-      __ret = 1;
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __uselocale(__old);
-#endif
-    return __ret;
+    return MB_CUR_MAX == 1;
   }
 
   int
   codecvt<wchar_t, char, mbstate_t>::
   do_max_length() const throw()
   {
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __c_locale __old = __uselocale(_M_c_locale_codecvt);
-#endif
+    Guard g(_M_c_locale_codecvt);
     // XXX Probably wrong for stateful encodings.
-    int __ret = MB_CUR_MAX;
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __uselocale(__old);
-#endif
-    return __ret;
+    return MB_CUR_MAX;
   }
 
   int
@@ -236,10 +226,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
   {
     int __ret = 0;
     state_type __tmp_state(__state);
-
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __c_locale __old = __uselocale(_M_c_locale_codecvt);
-#endif
+    Guard g(_M_c_locale_codecvt);
 
     // mbsnrtowcs is *very* fast but stops if encounters NUL characters:
     // in case we advance past it and then continue, in a loop.
@@ -295,10 +282,6 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
          }
       }
 
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __uselocale(__old);
-#endif
-
     return __ret;
   }
 #endif
diff --git a/libstdc++-v3/testsuite/22_locale/codecvt/in/char/37475.cc 
b/libstdc++-v3/testsuite/22_locale/codecvt/in/char/37475.cc
new file mode 100644
index 000000000000..6184c3280cbe
--- /dev/null
+++ b/libstdc++-v3/testsuite/22_locale/codecvt/in/char/37475.cc
@@ -0,0 +1,23 @@
+#include <locale>
+#include <testsuite_hooks.h>
+
+void
+test_pr37475()
+{
+  typedef std::codecvt<char, char, std::mbstate_t> test_type;
+  const test_type& cvt = std::use_facet<test_type>(std::locale::classic());
+  const char from = 'a';
+  const char* from_next;
+  char to = 0;
+  char* to_next;
+  std::mbstate_t st = std::mbstate_t();
+  std::codecvt_base::result res
+    = cvt.in(st, &from, &from+1, from_next, &to, &to, to_next);
+
+  VERIFY( res == std::codecvt_base::noconv );
+}
+
+int main()
+{
+  test_pr37475();
+}
diff --git a/libstdc++-v3/testsuite/22_locale/codecvt/in/wchar_t/37475.cc 
b/libstdc++-v3/testsuite/22_locale/codecvt/in/wchar_t/37475.cc
new file mode 100644
index 000000000000..a0e64847ea90
--- /dev/null
+++ b/libstdc++-v3/testsuite/22_locale/codecvt/in/wchar_t/37475.cc
@@ -0,0 +1,23 @@
+#include <locale>
+#include <testsuite_hooks.h>
+
+void
+test_pr37475()
+{
+  typedef std::codecvt<wchar_t, char, std::mbstate_t> test_type;
+  const test_type& cvt = std::use_facet<test_type>(std::locale::classic());
+  const char from = 'a';
+  const char* from_next;
+  wchar_t to = 0;
+  wchar_t* to_next;
+  std::mbstate_t st = std::mbstate_t();
+  std::codecvt_base::result res
+    = cvt.in(st, &from, &from+1, from_next, &to, &to, to_next);
+
+  VERIFY( res == std::codecvt_base::partial );
+}
+
+int main()
+{
+  test_pr37475();
+}
diff --git a/libstdc++-v3/testsuite/22_locale/codecvt/out/char/37475.cc 
b/libstdc++-v3/testsuite/22_locale/codecvt/out/char/37475.cc
new file mode 100644
index 000000000000..8736e4b7f3f6
--- /dev/null
+++ b/libstdc++-v3/testsuite/22_locale/codecvt/out/char/37475.cc
@@ -0,0 +1,23 @@
+#include <locale>
+#include <assert.h>
+
+void
+test_pr37475()
+{
+  typedef std::codecvt<char, char, std::mbstate_t> test_type;
+  const test_type& cvt = std::use_facet<test_type>(std::locale::classic());
+  const char from = 'a';
+  const char* from_next;
+  char to;
+  char* to_next;
+  std::mbstate_t st = std::mbstate_t();
+  std::codecvt_base::result res
+    = cvt.out(st, &from, &from+1, from_next, &to, &to, to_next);
+
+  assert( res == std::codecvt_base::noconv );
+}
+
+int main()
+{
+  test_pr37475();
+}
diff --git a/libstdc++-v3/testsuite/22_locale/codecvt/out/wchar_t/37475.cc 
b/libstdc++-v3/testsuite/22_locale/codecvt/out/wchar_t/37475.cc
new file mode 100644
index 000000000000..2cd2edb74040
--- /dev/null
+++ b/libstdc++-v3/testsuite/22_locale/codecvt/out/wchar_t/37475.cc
@@ -0,0 +1,23 @@
+#include <locale>
+#include <assert.h>
+
+void
+test_pr37475()
+{
+  typedef std::codecvt<wchar_t, char, std::mbstate_t> test_type;
+  const test_type& cvt = std::use_facet<test_type>(std::locale::classic());
+  const wchar_t from = L'a';
+  const wchar_t* from_next;
+  char to;
+  char* to_next;
+  std::mbstate_t st = std::mbstate_t();
+  std::codecvt_base::result res
+    = cvt.out(st, &from, &from+1, from_next, &to, &to, to_next);
+
+  assert( res == std::codecvt_base::partial );
+}
+
+int main()
+{
+  test_pr37475();
+}

Reply via email to