This is the mail archive of the
gcc-patches@gcc.gnu.org
mailing list for the GCC project.
Re: [gfortran,patch] MATMUL using BLAS gemm routines
- From: "François-Xavier Coudert" <fxcoudert at gmail dot com>
- To: gfortran <fortran at gcc dot gnu dot org>, gcc-patches <gcc-patches at gcc dot gnu dot org>
- Date: Mon, 25 Sep 2006 12:41:48 +0200
- Subject: Re: [gfortran,patch] MATMUL using BLAS gemm routines
- References: <19c433eb0609250156u5cacf9aw853ef407d89d0489@mail.gmail.com>
And here is the real patch, and not an old version of it.
Sorry,
FX
Index: libgfortran/m4/matmul.m4
===================================================================
--- libgfortran/m4/matmul.m4 (revision 117190)
+++ libgfortran/m4/matmul.m4 (working copy)
@@ -57,18 +57,28 @@
DO I=1,M
S = 0
DO K=1,COUNT
- S = S+A(I,K)+B(K,J)
+ S = S+A(I,K)*B(K,J)
C(I,J) = S
ENDIF
*/
+typedef void (*blas_call)(char *, char *, int *, int *, int *,
+ rtype_name *, rtype_name *, int *, rtype_name *,
+ int *, rtype_name *, rtype_name *, int *, int, int);
+
+/* If try_blas is set to a nonzero value, then the matmul function will
+ see if there is a way to perform the matrix multiplication by a BLAS
+ call to the gemm function. */
+
extern void matmul_`'rtype_code (rtype * const restrict retarray,
- rtype * const restrict a, rtype * const restrict b);
+ rtype * const restrict a, rtype * const restrict b, int try_blas,
+ int blas_limit, blas_call gemm);
export_proto(matmul_`'rtype_code);
void
matmul_`'rtype_code (rtype * const restrict retarray,
- rtype * const restrict a, rtype * const restrict b)
+ rtype * const restrict a, rtype * const restrict b, int try_blas,
+ int blas_limit, blas_call gemm)
{
const rtype_name * restrict abase;
const rtype_name * restrict bbase;
@@ -179,37 +189,56 @@
bbase = b->data;
dest = retarray->data;
- if (rxstride == 1 && axstride == 1 && bxstride == 1)
+#define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
+
+ if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
+ && (bxstride == 1 || bystride == 1)
+ && (xcount > POW3(blas_limit) || xcount > POW3(blas_limit)
+ || count > POW3(blas_limit)
+ || ((float) xcount) * ((float) ycount) * ((float) count)
+ > POW3(blas_limit)))
{
- const rtype_name * restrict bbase_y;
- rtype_name * restrict dest_y;
- const rtype_name * restrict abase_n;
- rtype_name bbase_yn;
+ const int m = xcount, n = ycount, k = count, ldc = rystride;
+ const rtype_name one = 1, zero = 0;
+ const int lda = (axstride == 1) ? aystride : axstride,
+ ldb = (bxstride == 1) ? bystride : bxstride;
- if (rystride == xcount)
- memset (dest, 0, (sizeof (rtype_name) * xcount * ycount));
- else
+ assert (gemm != NULL);
+ gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
+ &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
+ }
+ else
+ {
+
+ if (rxstride == 1 && axstride == 1 && bxstride == 1)
{
- for (y = 0; y < ycount; y++)
- for (x = 0; x < xcount; x++)
- dest[x + y*rystride] = (rtype_name)0;
- }
+ const rtype_name * restrict bbase_y;
+ rtype_name * restrict dest_y;
+ const rtype_name * restrict abase_n;
+ rtype_name bbase_yn;
- for (y = 0; y < ycount; y++)
- {
- bbase_y = bbase + y*bystride;
- dest_y = dest + y*rystride;
- for (n = 0; n < count; n++)
+ if (rystride == xcount)
+ memset (dest, 0, (sizeof (rtype_name) * size0((array_t *) retarray)));
+ else
{
- abase_n = abase + n*aystride;
- bbase_yn = bbase_y[n];
- for (x = 0; x < xcount; x++)
+ for (y = 0; y < ycount; y++)
+ for (x = 0; x < xcount; x++)
+ dest[x + y*rystride] = (rtype_name)0;
+ }
+
+ for (y = 0; y < ycount; y++)
+ {
+ bbase_y = bbase + y*bystride;
+ dest_y = dest + y*rystride;
+ for (n = 0; n < count; n++)
{
- dest_y[x] += abase_n[x] * bbase_yn;
+ abase_n = abase + n*aystride;
+ bbase_yn = bbase_y[n];
+ for (x = 0; x < xcount; x++)
+ dest_y[x] += abase_n[x] * bbase_yn;
}
}
}
- }
else if (rxstride == 1 && aystride == 1 && bxstride == 1)
{
if (GFC_DESCRIPTOR_RANK (a) != 1)
@@ -295,6 +324,7 @@
}
}
}
+ }
}
#endif
Index: libgfortran/generated/matmul_r8.c
===================================================================
--- libgfortran/generated/matmul_r8.c (revision 117190)
+++ libgfortran/generated/matmul_r8.c (working copy)
@@ -56,18 +56,28 @@
DO I=1,M
S = 0
DO K=1,COUNT
- S = S+A(I,K)+B(K,J)
+ S = S+A(I,K)*B(K,J)
C(I,J) = S
ENDIF
*/
+typedef void (*blas_call)(char *, char *, int *, int *, int *,
+ GFC_REAL_8 *, GFC_REAL_8 *, int *, GFC_REAL_8 *,
+ int *, GFC_REAL_8 *, GFC_REAL_8 *, int *, int, int);
+
+/* If try_blas is set to a nonzero value, then the matmul function will
+ see if there is a way to perform the matrix multiplication by a BLAS
+ call to the gemm function. */
+
extern void matmul_r8 (gfc_array_r8 * const restrict retarray,
- gfc_array_r8 * const restrict a, gfc_array_r8 * const restrict b);
+ gfc_array_r8 * const restrict a, gfc_array_r8 * const restrict b, int try_blas,
+ int blas_limit, blas_call gemm);
export_proto(matmul_r8);
void
matmul_r8 (gfc_array_r8 * const restrict retarray,
- gfc_array_r8 * const restrict a, gfc_array_r8 * const restrict b)
+ gfc_array_r8 * const restrict a, gfc_array_r8 * const restrict b, int try_blas,
+ int blas_limit, blas_call gemm)
{
const GFC_REAL_8 * restrict abase;
const GFC_REAL_8 * restrict bbase;
@@ -177,37 +187,56 @@
bbase = b->data;
dest = retarray->data;
- if (rxstride == 1 && axstride == 1 && bxstride == 1)
+#define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
+
+ if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
+ && (bxstride == 1 || bystride == 1)
+ && (xcount > POW3(blas_limit) || xcount > POW3(blas_limit)
+ || count > POW3(blas_limit)
+ || ((float) xcount) * ((float) ycount) * ((float) count)
+ > POW3(blas_limit)))
{
- const GFC_REAL_8 * restrict bbase_y;
- GFC_REAL_8 * restrict dest_y;
- const GFC_REAL_8 * restrict abase_n;
- GFC_REAL_8 bbase_yn;
+ const int m = xcount, n = ycount, k = count, ldc = rystride;
+ const GFC_REAL_8 one = 1, zero = 0;
+ const int lda = (axstride == 1) ? aystride : axstride,
+ ldb = (bxstride == 1) ? bystride : bxstride;
- if (rystride == xcount)
- memset (dest, 0, (sizeof (GFC_REAL_8) * xcount * ycount));
- else
+ assert (gemm != NULL);
+ gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
+ &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
+ }
+ else
+ {
+
+ if (rxstride == 1 && axstride == 1 && bxstride == 1)
{
- for (y = 0; y < ycount; y++)
- for (x = 0; x < xcount; x++)
- dest[x + y*rystride] = (GFC_REAL_8)0;
- }
+ const GFC_REAL_8 * restrict bbase_y;
+ GFC_REAL_8 * restrict dest_y;
+ const GFC_REAL_8 * restrict abase_n;
+ GFC_REAL_8 bbase_yn;
- for (y = 0; y < ycount; y++)
- {
- bbase_y = bbase + y*bystride;
- dest_y = dest + y*rystride;
- for (n = 0; n < count; n++)
+ if (rystride == xcount)
+ memset (dest, 0, (sizeof (GFC_REAL_8) * size0((array_t *) retarray)));
+ else
{
- abase_n = abase + n*aystride;
- bbase_yn = bbase_y[n];
- for (x = 0; x < xcount; x++)
+ for (y = 0; y < ycount; y++)
+ for (x = 0; x < xcount; x++)
+ dest[x + y*rystride] = (GFC_REAL_8)0;
+ }
+
+ for (y = 0; y < ycount; y++)
+ {
+ bbase_y = bbase + y*bystride;
+ dest_y = dest + y*rystride;
+ for (n = 0; n < count; n++)
{
- dest_y[x] += abase_n[x] * bbase_yn;
+ abase_n = abase + n*aystride;
+ bbase_yn = bbase_y[n];
+ for (x = 0; x < xcount; x++)
+ dest_y[x] += abase_n[x] * bbase_yn;
}
}
}
- }
else if (rxstride == 1 && aystride == 1 && bxstride == 1)
{
if (GFC_DESCRIPTOR_RANK (a) != 1)
@@ -293,6 +322,7 @@
}
}
}
+ }
}
#endif
Index: gcc/fortran/trans-expr.c
===================================================================
--- gcc/fortran/trans-expr.c (revision 117190)
+++ gcc/fortran/trans-expr.c (working copy)
@@ -1853,7 +1853,7 @@
int
gfc_conv_function_call (gfc_se * se, gfc_symbol * sym,
- gfc_actual_arglist * arg)
+ gfc_actual_arglist * arg, tree append_args)
{
gfc_interface_mapping mapping;
tree arglist;
@@ -2166,6 +2166,11 @@
/* Add the hidden string length parameters to the arguments. */
arglist = chainon (arglist, stringargs);
+ /* We may want to append extra arguments here. This is used e.g. for
+ calls to libgfortran_matmul_??, which need extra information. */
+ if (append_args != NULL_TREE)
+ arglist = chainon (arglist, append_args);
+
/* Generate the actual call. */
gfc_conv_function_val (se, sym);
/* If there are alternate return labels, function type should be
@@ -2485,7 +2490,7 @@
sym = expr->value.function.esym;
if (!sym)
sym = expr->symtree->n.sym;
- gfc_conv_function_call (se, sym, expr->value.function.actual);
+ gfc_conv_function_call (se, sym, expr->value.function.actual, NULL_TREE);
}
Index: gcc/fortran/gfortran.h
===================================================================
--- gcc/fortran/gfortran.h (revision 117190)
+++ gcc/fortran/gfortran.h (working copy)
@@ -1644,6 +1644,8 @@
int flag_f2c;
int flag_automatic;
int flag_backslash;
+ int flag_external_blas;
+ int blas_matmul_limit;
int flag_cray_pointer;
int flag_d_lines;
int flag_openmp;
Index: gcc/fortran/lang.opt
===================================================================
--- gcc/fortran/lang.opt (revision 117190)
+++ gcc/fortran/lang.opt (working copy)
@@ -89,6 +89,14 @@
Fortran
Specify that backslash in string introduces an escape character
+fexternal-blas
+Fortran
+Specify that an external BLAS library should be used for matmul calls on large-size arrays
+
+fblas-matmul-limit=
+Fortran RejectNegative Joined UInteger
+-fblas-matmul-limit=<n> Size of the smallest matrix for which matmul will use BLAS
+
fdefault-double-8
Fortran
Set the default double precision kind to an 8 byte wide type
Index: gcc/fortran/trans-stmt.c
===================================================================
--- gcc/fortran/trans-stmt.c (revision 117190)
+++ gcc/fortran/trans-stmt.c (working copy)
@@ -334,7 +334,8 @@
/* Translate the call. */
has_alternate_specifier
- = gfc_conv_function_call (&se, code->resolved_sym, code->ext.actual);
+ = gfc_conv_function_call (&se, code->resolved_sym, code->ext.actual,
+ NULL_TREE);
/* A subroutine without side-effect, by definition, does nothing! */
TREE_SIDE_EFFECTS (se.expr) = 1;
@@ -399,7 +400,8 @@
gfc_init_block (&block);
/* Add the subroutine call to the block. */
- gfc_conv_function_call (&loopse, code->resolved_sym, code->ext.actual);
+ gfc_conv_function_call (&loopse, code->resolved_sym, code->ext.actual,
+ NULL_TREE);
gfc_add_expr_to_block (&loopse.pre, loopse.expr);
gfc_add_block_to_block (&block, &loopse.pre);
Index: gcc/fortran/trans.h
===================================================================
--- gcc/fortran/trans.h (revision 117190)
+++ gcc/fortran/trans.h (working copy)
@@ -303,7 +303,8 @@
int gfc_is_intrinsic_libcall (gfc_expr *);
/* Also used to CALL subroutines. */
-int gfc_conv_function_call (gfc_se *, gfc_symbol *, gfc_actual_arglist *);
+int gfc_conv_function_call (gfc_se *, gfc_symbol *, gfc_actual_arglist *,
+ tree);
/* gfc_trans_* shouldn't call push/poplevel, use gfc_push/pop_scope */
/* Generate code for a scalar assignment. */
@@ -507,6 +508,12 @@
extern GTY(()) tree gfor_fndecl_math_exponent10;
extern GTY(()) tree gfor_fndecl_math_exponent16;
+/* BLAS functions. */
+extern GTY(()) tree gfor_fndecl_sgemm;
+extern GTY(()) tree gfor_fndecl_dgemm;
+extern GTY(()) tree gfor_fndecl_cgemm;
+extern GTY(()) tree gfor_fndecl_zgemm;
+
/* String functions. */
extern GTY(()) tree gfor_fndecl_compare_string;
extern GTY(()) tree gfor_fndecl_concat_string;
Index: gcc/fortran/trans-decl.c
===================================================================
--- gcc/fortran/trans-decl.c (revision 117190)
+++ gcc/fortran/trans-decl.c (working copy)
@@ -143,7 +143,13 @@
tree gfor_fndecl_si_kind;
tree gfor_fndecl_sr_kind;
+/* BLAS gemm functions. */
+tree gfor_fndecl_sgemm;
+tree gfor_fndecl_dgemm;
+tree gfor_fndecl_cgemm;
+tree gfor_fndecl_zgemm;
+
static void
gfc_add_decl_to_parent_function (tree decl)
{
@@ -2170,6 +2176,49 @@
gfc_int4_type_node, 1,
gfc_real16_type_node);
+ /* BLAS functions. */
+ {
+ tree pint = build_pointer_type (gfc_c_int_type_node);
+ tree ps = build_pointer_type (gfc_get_real_type (gfc_default_real_kind));
+ tree pd = build_pointer_type (gfc_get_real_type (gfc_default_double_kind));
+ tree pc = build_pointer_type (gfc_get_complex_type (gfc_default_real_kind));
+ tree pz = build_pointer_type
+ (gfc_get_complex_type (gfc_default_double_kind));
+
+ gfor_fndecl_sgemm = gfc_build_library_function_decl
+ (get_identifier
+ (gfc_option.flag_underscoring ? "sgemm_"
+ : "sgemm"),
+ void_type_node, 15, pchar_type_node,
+ pchar_type_node, pint, pint, pint, ps, ps, pint,
+ ps, pint, ps, ps, pint, gfc_c_int_type_node,
+ gfc_c_int_type_node);
+ gfor_fndecl_dgemm = gfc_build_library_function_decl
+ (get_identifier
+ (gfc_option.flag_underscoring ? "dgemm_"
+ : "dgemm"),
+ void_type_node, 15, pchar_type_node,
+ pchar_type_node, pint, pint, pint, pd, pd, pint,
+ pd, pint, pd, pd, pint, gfc_c_int_type_node,
+ gfc_c_int_type_node);
+ gfor_fndecl_cgemm = gfc_build_library_function_decl
+ (get_identifier
+ (gfc_option.flag_underscoring ? "cgemm_"
+ : "cgemm"),
+ void_type_node, 15, pchar_type_node,
+ pchar_type_node, pint, pint, pint, pc, pc, pint,
+ pc, pint, pc, pc, pint, gfc_c_int_type_node,
+ gfc_c_int_type_node);
+ gfor_fndecl_zgemm = gfc_build_library_function_decl
+ (get_identifier
+ (gfc_option.flag_underscoring ? "zgemm_"
+ : "zgemm"),
+ void_type_node, 15, pchar_type_node,
+ pchar_type_node, pint, pint, pint, pz, pz, pint,
+ pz, pint, pz, pz, pint, gfc_c_int_type_node,
+ gfc_c_int_type_node);
+ }
+
/* Other functions. */
gfor_fndecl_size0 =
gfc_build_library_function_decl (get_identifier (PREFIX("size0")),
Index: gcc/fortran/trans-intrinsic.c
===================================================================
--- gcc/fortran/trans-intrinsic.c (revision 117190)
+++ gcc/fortran/trans-intrinsic.c (working copy)
@@ -1265,6 +1265,7 @@
gfc_conv_intrinsic_funcall (gfc_se * se, gfc_expr * expr)
{
gfc_symbol *sym;
+ tree append_args;
gcc_assert (!se->ss || se->ss->expr == expr);
@@ -1274,7 +1275,54 @@
gcc_assert (expr->rank == 0);
sym = gfc_get_symbol_for_expr (expr);
- gfc_conv_function_call (se, sym, expr->value.function.actual);
+
+ /* Calls to libgfortran_matmul need to be appended special arguments,
+ to be able to call the BLAS ?gemm functions if required and possible. */
+ append_args = NULL_TREE;
+ if (expr->value.function.isym->generic_id == GFC_ISYM_MATMUL
+ && sym->ts.type != BT_LOGICAL)
+ {
+ tree cint = gfc_get_int_type (gfc_c_int_kind);
+
+ if (gfc_option.flag_external_blas
+ && (sym->ts.type == BT_REAL || sym->ts.type == BT_COMPLEX)
+ && (sym->ts.kind == gfc_default_real_kind
+ || sym->ts.kind == gfc_default_double_kind))
+ {
+ tree gemm_fndecl;
+
+ if (sym->ts.type == BT_REAL)
+ {
+ if (sym->ts.kind == gfc_default_real_kind)
+ gemm_fndecl = gfor_fndecl_sgemm;
+ else
+ gemm_fndecl = gfor_fndecl_dgemm;
+ }
+ else
+ {
+ if (sym->ts.kind == gfc_default_real_kind)
+ gemm_fndecl = gfor_fndecl_cgemm;
+ else
+ gemm_fndecl = gfor_fndecl_zgemm;
+ }
+
+ append_args = gfc_chainon_list (NULL_TREE, build_int_cst (cint, 1));
+ append_args = gfc_chainon_list
+ (append_args, build_int_cst
+ (cint, gfc_option.blas_matmul_limit));
+ append_args = gfc_chainon_list (append_args,
+ gfc_build_addr_expr (NULL_TREE,
+ gemm_fndecl));
+ }
+ else
+ {
+ append_args = gfc_chainon_list (NULL_TREE, build_int_cst (cint, 0));
+ append_args = gfc_chainon_list (append_args, build_int_cst (cint, 0));
+ append_args = gfc_chainon_list (append_args, null_pointer_node);
+ }
+ }
+
+ gfc_conv_function_call (se, sym, expr->value.function.actual, append_args);
gfc_free (sym);
}
Index: gcc/fortran/options.c
===================================================================
--- gcc/fortran/options.c (revision 117190)
+++ gcc/fortran/options.c (working copy)
@@ -79,6 +79,8 @@
gfc_option.flag_preprocessed = 0;
gfc_option.flag_automatic = 1;
gfc_option.flag_backslash = 1;
+ gfc_option.flag_external_blas = 0;
+ gfc_option.blas_matmul_limit = 32;
gfc_option.flag_cray_pointer = 0;
gfc_option.flag_d_lines = -1;
gfc_option.flag_openmp = 0;
@@ -454,6 +456,14 @@
gfc_option.flag_dollar_ok = value;
break;
+ case OPT_fexternal_blas:
+ gfc_option.flag_external_blas = value;
+ break;
+
+ case OPT_fblas_matmul_limit_:
+ gfc_option.blas_matmul_limit = value;
+ break;
+
case OPT_fd_lines_as_code:
gfc_option.flag_d_lines = 1;
break;