https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/72176
>From e2e0e60c5cf01b5e99cb2494e2444b91d1f6605d Mon Sep 17 00:00:00 2001 From: Peter Klausler <pklaus...@nvidia.com> Date: Fri, 3 Nov 2023 13:04:01 -0700 Subject: [PATCH] [flang] Fold MATMUL() Implements constant folding for matrix multiplication for all four accepted type categories. --- flang/lib/Evaluate/fold-complex.cpp | 4 +- flang/lib/Evaluate/fold-integer.cpp | 4 +- flang/lib/Evaluate/fold-logical.cpp | 4 +- flang/lib/Evaluate/fold-matmul.h | 103 ++++++++++++++++++++++++++++ flang/lib/Evaluate/fold-real.cpp | 4 +- flang/lib/Evaluate/fold-reduction.h | 4 +- flang/test/Evaluate/fold-matmul.f90 | 41 +++++++++++ 7 files changed, 158 insertions(+), 6 deletions(-) create mode 100644 flang/lib/Evaluate/fold-matmul.h create mode 100644 flang/test/Evaluate/fold-matmul.f90 diff --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp index e40e3a37df14948..3260f82ffe8d734 100644 --- a/flang/lib/Evaluate/fold-complex.cpp +++ b/flang/lib/Evaluate/fold-complex.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "fold-implementation.h" +#include "fold-matmul.h" #include "fold-reduction.h" namespace Fortran::evaluate { @@ -64,13 +65,14 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction( } } else if (name == "dot_product") { return FoldDotProduct<T>(context, std::move(funcRef)); + } else if (name == "matmul") { + return FoldMatmul(context, std::move(funcRef)); } else if (name == "product") { auto one{Scalar<Part>::FromInteger(value::Integer<8>{1}).value}; return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{one}); } else if (name == "sum") { return FoldSum<T>(context, std::move(funcRef)); } - // TODO: matmul return Expr<T>{std::move(funcRef)}; } diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp index dedfc20a491cd88..2882369105f6626 100644 --- a/flang/lib/Evaluate/fold-integer.cpp +++ b/flang/lib/Evaluate/fold-integer.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "fold-implementation.h" +#include "fold-matmul.h" #include "fold-reduction.h" #include "flang/Evaluate/check-expression.h" @@ -1042,6 +1043,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( ScalarFunc<T, Int4>([&fptr](const Scalar<Int4> &places) -> Scalar<T> { return fptr(static_cast<int>(places.ToInt64())); })); + } else if (name == "matmul") { + return FoldMatmul(context, std::move(funcRef)); } else if (name == "max") { return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater); } else if (name == "max0" || name == "max1") { @@ -1279,7 +1282,6 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( } else if (name == "ubound") { return UBOUND(context, std::move(funcRef)); } - // TODO: dot_product, matmul, sign return Expr<T>{std::move(funcRef)}; } diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp index bfedc32a33a8bad..82a5cb20db9e409 100644 --- a/flang/lib/Evaluate/fold-logical.cpp +++ b/flang/lib/Evaluate/fold-logical.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "fold-implementation.h" +#include "fold-matmul.h" #include "fold-reduction.h" #include "flang/Evaluate/check-expression.h" #include "flang/Runtime/magic-numbers.h" @@ -231,6 +232,8 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction( if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) { return Fold(context, ConvertToType<T>(std::move(*expr))); } + } else if (name == "matmul") { + return FoldMatmul(context, std::move(funcRef)); } else if (name == "out_of_range") { if (Expr<SomeType> * cx{UnwrapExpr<Expr<SomeType>>(args[0])}) { auto restorer{context.messages().DiscardMessages()}; @@ -367,7 +370,6 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction( name == "__builtin_ieee_support_underflow_control") { return Expr<T>{true}; } - // TODO: logical, matmul, parity return Expr<T>{std::move(funcRef)}; } diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h new file mode 100644 index 000000000000000..27b6db1fd8bf025 --- /dev/null +++ b/flang/lib/Evaluate/fold-matmul.h @@ -0,0 +1,103 @@ +//===-- lib/Evaluate/fold-matmul.h ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_ +#define FORTRAN_EVALUATE_FOLD_MATMUL_H_ + +#include "fold-implementation.h" + +namespace Fortran::evaluate { + +template <typename T> +static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) { + using Element = typename Constant<T>::Element; + auto args{funcRef.arguments()}; + CHECK(args.size() == 2); + Folder<T> folder{context}; + Constant<T> *ma{folder.Folding(args[0])}; + Constant<T> *mb{folder.Folding(args[1])}; + if (!ma || !mb) { + return Expr<T>{std::move(funcRef)}; + } + CHECK(ma->Rank() >= 1 && ma->Rank() <= 2 && mb->Rank() >= 1 && + mb->Rank() <= 2 && (ma->Rank() == 2 || mb->Rank() == 2)); + ConstantSubscript commonExtent{ma->shape().back()}; + if (mb->shape().front() != commonExtent) { + context.messages().Say( + "Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US, + commonExtent, mb->shape().front()); + return MakeInvalidIntrinsic(std::move(funcRef)); + } + ConstantSubscript rows{ma->Rank() == 1 ? 1 : ma->shape()[0]}; + ConstantSubscript columns{mb->Rank() == 1 ? 1 : mb->shape()[1]}; + std::vector<Element> elements; + elements.reserve(rows * columns); + bool overflow{false}; + [[maybe_unused]] const auto &rounding{ + context.targetCharacteristics().roundingMode()}; + // result(j,k) = SUM(A(j,:) * B(:,k)) + for (ConstantSubscript ci{0}; ci < columns; ++ci) { + for (ConstantSubscript ri{0}; ri < rows; ++ri) { + ConstantSubscripts aAt{ma->lbounds()}; + if (ma->Rank() == 2) { + aAt[0] += ri; + } + ConstantSubscripts bAt{mb->lbounds()}; + if (mb->Rank() == 2) { + bAt[1] += ci; + } + Element sum{}; + [[maybe_unused]] Element correction{}; + for (ConstantSubscript j{0}; j < commonExtent; ++j) { + Element aElt{ma->At(aAt)}; + Element bElt{mb->At(bAt)}; + if constexpr (T::category == TypeCategory::Real || + T::category == TypeCategory::Complex) { + // Kahan summation + auto product{aElt.Multiply(bElt, rounding)}; + overflow |= product.flags.test(RealFlag::Overflow); + auto next{correction.Add(product.value, rounding)}; + overflow |= next.flags.test(RealFlag::Overflow); + auto added{sum.Add(next.value, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + correction = added.value.Subtract(sum, rounding) + .value.Subtract(next.value, rounding) + .value; + sum = std::move(added.value); + } else if constexpr (T::category == TypeCategory::Integer) { + auto product{aElt.MultiplySigned(bElt)}; + overflow |= product.SignedMultiplicationOverflowed(); + auto added{sum.AddSigned(product.lower)}; + overflow |= added.overflow; + sum = std::move(added.value); + } else { + static_assert(T::category == TypeCategory::Logical); + sum = sum.OR(aElt.AND(bElt)); + } + ++aAt.back(); + ++bAt.front(); + } + elements.push_back(sum); + } + } + if (overflow) { + context.messages().Say( + "MATMUL of %s data overflowed during computation"_warn_en_US, + T::AsFortran()); + } + ConstantSubscripts shape; + if (ma->Rank() == 2) { + shape.push_back(rows); + } + if (mb->Rank() == 2) { + shape.push_back(columns); + } + return Expr<T>{Constant<T>{std::move(elements), std::move(shape)}}; +} +} // namespace Fortran::evaluate +#endif // FORTRAN_EVALUATE_FOLD_MATMUL_H_ diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp index 6bcc3ec73982157..6ae069df5d7a425 100644 --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "fold-implementation.h" +#include "fold-matmul.h" #include "fold-reduction.h" namespace Fortran::evaluate { @@ -269,6 +270,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( } return result.value; })); + } else if (name == "matmul") { + return FoldMatmul(context, std::move(funcRef)); } else if (name == "max") { return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater); } else if (name == "maxval") { @@ -446,7 +449,6 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( return result.value; })); } - // TODO: matmul return Expr<T>{std::move(funcRef)}; } diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h index 0dd55124e6a512e..60c757dc3f4fa8e 100644 --- a/flang/lib/Evaluate/fold-reduction.h +++ b/flang/lib/Evaluate/fold-reduction.h @@ -43,7 +43,7 @@ static Expr<T> FoldDotProduct( Expr<T> products{Fold( context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})}; Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; - Element correction; // Use Kahan summation for greater precision. + Element correction{}; // Use Kahan summation for greater precision. const auto &rounding{context.targetCharacteristics().roundingMode()}; for (const Element &x : cProducts.values()) { auto next{correction.Add(x, rounding)}; @@ -80,7 +80,7 @@ static Expr<T> FoldDotProduct( Expr<T> products{ Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})}; Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))}; - Element correction; // Use Kahan summation for greater precision. + Element correction{}; // Use Kahan summation for greater precision. const auto &rounding{context.targetCharacteristics().roundingMode()}; for (const Element &x : cProducts.values()) { auto next{correction.Add(x, rounding)}; diff --git a/flang/test/Evaluate/fold-matmul.f90 b/flang/test/Evaluate/fold-matmul.f90 new file mode 100644 index 000000000000000..dce90197e1f1fdd --- /dev/null +++ b/flang/test/Evaluate/fold-matmul.f90 @@ -0,0 +1,41 @@ +! RUN: %python %S/test_folding.py %s %flang_fc1 +! Tests folding of MATMUL() +module m + integer, parameter :: ia(2,3) = reshape([1, 2, 2, 3, 3, 4], shape(ia)) + integer, parameter :: ib(3,2) = reshape([1, 2, 3, 2, 3, 4], shape(ib)) + integer, parameter :: ix(*) = [1, 2] + integer, parameter :: iy(*) = [1, 2, 3] + integer, parameter :: iab(*,*) = matmul(ia, ib) + integer, parameter :: ixa(*) = matmul(ix, ia) + integer, parameter :: iay(*) = matmul(ia, iy) + logical, parameter :: test_iab = all([iab] == [14, 20, 20, 29]) + logical, parameter :: test_ixa = all(ixa == [5, 8, 11]) + logical, parameter :: test_iay = all(iay == [14, 20]) + + real, parameter :: ra(*,*) = ia + real, parameter :: rb(*,*) = ib + real, parameter :: rx(*) = ix + real, parameter :: ry(*) = iy + real, parameter :: rab(*,*) = matmul(ra, rb) + real, parameter :: rxa(*) = matmul(rx, ra) + real, parameter :: ray(*) = matmul(ra, ry) + logical, parameter :: test_rab = all(rab == iab) + logical, parameter :: test_rxa = all(rxa == ixa) + logical, parameter :: test_ray = all(ray == iay) + + complex, parameter :: za(*,*) = cmplx(ra, -1.) + complex, parameter :: zb(*,*) = cmplx(rb, -1.) + complex, parameter :: zx(*) = cmplx(rx, -1.) + complex, parameter :: zy(*) = cmplx(ry, -1.) + complex, parameter :: zab(*,*) = matmul(za, zb) + complex, parameter :: zxa(*) = matmul(zx, za) + complex, parameter :: zay(*) = matmul(za, zy) + logical, parameter :: test_zab = all([zab] == [(11,-12),(17,-15),(17,-15),(26,-18)]) + logical, parameter :: test_zxa = all(zxa == [(3,-6),(6,-8),(9,-10)]) + logical, parameter :: test_zay = all(zay == [(11,-12),(17,-15)]) + + logical, parameter :: la(16, 4) = reshape([((iand(shiftr(j,k),1)/=0, j=0,15), k=0,3)], shape(la)) + logical, parameter :: lb(4, 16) = transpose(la) + logical, parameter :: lab(16, 16) = matmul(la, lb) + logical, parameter :: test_lab = all([lab] .eqv. [((iand(k,j)/=0, k=0,15), j=0,15)]) +end _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits