This is the mail archive of the gcc-patches@gcc.gnu.org mailing list for the GCC project.


Index Nav: [Date Index] [Subject Index] [Author Index] [Thread Index]
Message Nav: [Date Prev] [Date Next] [Thread Prev] [Thread Next]
Other format: [Raw text]

[gfortran,patch] MATMUL using BLAS gemm routines


Hi all,

I've come up with another idea to implement the optional use of BLAS
gemm routines for gfortran MATMULtiplication. I still have some
testing and fine-tuning to do (using lots of different vendor BLAS
libraries, like IBM's ESSL, Intel MKL, etc), but here's the current
version of the patch anyhow, along with some explanations of how it
works. I think this approach is simple, easy to use and understand,
and opens the way for modification of other array routines.

Here's how it works: the library function libgfortran_matmul_r8 (I
included this generated file in the diff, because it's easier to read
than m4 for some of us :) gets 3 new arguments: int try_blas, which
tells it if a BLAS call should be made if possible; int blas_limit,
the size above which BLAS will be used; and gemm, a pointer to the
BLAS function to be called.

The front-end function gfc_conv_function_call() is added an extra
argument, append_args, which is a list of tree args to be appended at
the end of the function call. All calls to gfc_conv_function_call() in
the front-end set append_args = NULL_TREE. The one exception is
gfc_conv_intrinsic_funcall() which, if it's translating a MATMUL call
(recognized by its expr->value.function.isym->generic_id), appends the
three arguments we talked about earlier (try_blas, blas_limit and
function pointer).

One note: I would have liked to special case the MATMUL calls right
inside gfc_conv_function_call(), but I didn't find a way to recognize
them there. The good thing about the new append_args mechanism is that
we could use it for other things, like passing flags about runtime
bounds-checking.

Another note: it is assumed that sgemm is for default real kind, dgemm
for double precision real kind, and likewise for complex routines. If
you feel I should hardcode sgemm to real(kind=4) and dgemm to
real(kind=8), or provide a flag for the user to specify the mapping,
please say so.

I'm open to all suggestions before I finalize the patch and submit it
formally next saturday. After submission, I unfortunately won't have
time to change it significantly any more.

Thanks for reading this long mail,
FX
Index: libgfortran/m4/matmul.m4
===================================================================
--- libgfortran/m4/matmul.m4	(revision 117190)
+++ libgfortran/m4/matmul.m4	(working copy)
@@ -57,18 +57,28 @@
        DO I=1,M
          S = 0
          DO K=1,COUNT
-           S = S+A(I,K)+B(K,J)
+           S = S+A(I,K)*B(K,J)
          C(I,J) = S
    ENDIF
 */
 
+typedef void (*blas_call)(char *, char *, int *, int *, int *,
+			  rtype_name *, rtype_name *, int *, rtype_name *,
+			  int *, rtype_name *, rtype_name *, int *);
+
+/* If try_blas is set to a nonzero value, then the matmul function will
+   see if there is a way to perform the matrix multiplication by a BLAS
+   call to the gemm function.  */
+
 extern void matmul_`'rtype_code (rtype * const restrict retarray, 
-	rtype * const restrict a, rtype * const restrict b);
+	rtype * const restrict a, rtype * const restrict b, int try_blas,
+        int blas_limit, blas_call gemm);
 export_proto(matmul_`'rtype_code);
 
 void
 matmul_`'rtype_code (rtype * const restrict retarray, 
-	rtype * const restrict a, rtype * const restrict b)
+	rtype * const restrict a, rtype * const restrict b, int try_blas,
+	int blas_limit, blas_call gemm)
 {
   const rtype_name * restrict abase;
   const rtype_name * restrict bbase;
@@ -179,37 +189,56 @@
   bbase = b->data;
   dest = retarray->data;
 
-  if (rxstride == 1 && axstride == 1 && bxstride == 1)
+#define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
+
+  if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
+      && (bxstride == 1 || bystride == 1)
+      && (xcount > POW3(blas_limit) || xcount > POW3(blas_limit)
+	  || count > POW3(blas_limit)
+	  || ((float) xcount) * ((float) ycount) * ((float) count)
+	     > POW3(blas_limit)))
     {
-      const rtype_name * restrict bbase_y;
-      rtype_name * restrict dest_y;
-      const rtype_name * restrict abase_n;
-      rtype_name bbase_yn;
+      const int m = xcount, n = ycount, k = count, ldc = rystride;
+      const double one = 1, zero = 0;
+      const int lda = (axstride == 1) ? aystride : axstride,
+		ldb = (bxstride == 1) ? bystride : bxstride;
 
-      if (rystride == xcount)
-	memset (dest, 0, (sizeof (rtype_name) * xcount * ycount));
-      else
+      assert (gemm != NULL);
+      gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
+	    &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc);
+    }
+  else
+    {
+
+      if (rxstride == 1 && axstride == 1 && bxstride == 1)
 	{
-	  for (y = 0; y < ycount; y++)
-	    for (x = 0; x < xcount; x++)
-	      dest[x + y*rystride] = (rtype_name)0;
-	}
+	  const rtype_name * restrict bbase_y;
+	  rtype_name * restrict dest_y;
+	  const rtype_name * restrict abase_n;
+	  rtype_name bbase_yn;
 
-      for (y = 0; y < ycount; y++)
-	{
-	  bbase_y = bbase + y*bystride;
-	  dest_y = dest + y*rystride;
-	  for (n = 0; n < count; n++)
+	  if (rystride == xcount)
+	    memset (dest, 0, (sizeof (rtype_name) * size0((array_t *) retarray)));
+	  else
 	    {
-	      abase_n = abase + n*aystride;
-	      bbase_yn = bbase_y[n];
-	      for (x = 0; x < xcount; x++)
+	      for (y = 0; y < ycount; y++)
+		for (x = 0; x < xcount; x++)
+		  dest[x + y*rystride] = (rtype_name)0;
+	    }
+
+	  for (y = 0; y < ycount; y++)
+	    {
+	      bbase_y = bbase + y*bystride;
+	      dest_y = dest + y*rystride;
+	      for (n = 0; n < count; n++)
 		{
-		  dest_y[x] += abase_n[x] * bbase_yn;
+		  abase_n = abase + n*aystride;
+		  bbase_yn = bbase_y[n];
+		  for (x = 0; x < xcount; x++)
+		    dest_y[x] += abase_n[x] * bbase_yn;
 		}
 	    }
 	}
-    }
   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
     {
       if (GFC_DESCRIPTOR_RANK (a) != 1)
@@ -295,6 +324,7 @@
 	    }
 	}
     }
+  }
 }
 
 #endif
Index: libgfortran/generated/matmul_r8.c
===================================================================
--- libgfortran/generated/matmul_r8.c	(revision 117190)
+++ libgfortran/generated/matmul_r8.c	(working copy)
@@ -56,18 +56,28 @@
        DO I=1,M
          S = 0
          DO K=1,COUNT
-           S = S+A(I,K)+B(K,J)
+           S = S+A(I,K)*B(K,J)
          C(I,J) = S
    ENDIF
 */
 
+typedef void (*blas_call)(char *, char *, int *, int *, int *,
+			  GFC_REAL_8 *, GFC_REAL_8 *, int *, GFC_REAL_8 *,
+			  int *, GFC_REAL_8 *, GFC_REAL_8 *, int *);
+
+/* If try_blas is set to a nonzero value, then the matmul function will
+   see if there is a way to perform the matrix multiplication by a BLAS
+   call to the gemm function.  */
+
 extern void matmul_r8 (gfc_array_r8 * const restrict retarray, 
-	gfc_array_r8 * const restrict a, gfc_array_r8 * const restrict b);
+	gfc_array_r8 * const restrict a, gfc_array_r8 * const restrict b, int try_blas,
+        int blas_limit, blas_call gemm);
 export_proto(matmul_r8);
 
 void
 matmul_r8 (gfc_array_r8 * const restrict retarray, 
-	gfc_array_r8 * const restrict a, gfc_array_r8 * const restrict b)
+	gfc_array_r8 * const restrict a, gfc_array_r8 * const restrict b, int try_blas,
+	int blas_limit, blas_call gemm)
 {
   const GFC_REAL_8 * restrict abase;
   const GFC_REAL_8 * restrict bbase;
@@ -177,37 +187,56 @@
   bbase = b->data;
   dest = retarray->data;
 
-  if (rxstride == 1 && axstride == 1 && bxstride == 1)
+#define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
+
+  if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
+      && (bxstride == 1 || bystride == 1)
+      && (xcount > POW3(blas_limit) || xcount > POW3(blas_limit)
+	  || count > POW3(blas_limit)
+	  || ((float) xcount) * ((float) ycount) * ((float) count)
+	     > POW3(blas_limit)))
     {
-      const GFC_REAL_8 * restrict bbase_y;
-      GFC_REAL_8 * restrict dest_y;
-      const GFC_REAL_8 * restrict abase_n;
-      GFC_REAL_8 bbase_yn;
+      const int m = xcount, n = ycount, k = count, ldc = rystride;
+      const double one = 1, zero = 0;
+      const int lda = (axstride == 1) ? aystride : axstride,
+		ldb = (bxstride == 1) ? bystride : bxstride;
 
-      if (rystride == xcount)
-	memset (dest, 0, (sizeof (GFC_REAL_8) * xcount * ycount));
-      else
+      assert (gemm != NULL);
+      gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
+	    &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc);
+    }
+  else
+    {
+
+      if (rxstride == 1 && axstride == 1 && bxstride == 1)
 	{
-	  for (y = 0; y < ycount; y++)
-	    for (x = 0; x < xcount; x++)
-	      dest[x + y*rystride] = (GFC_REAL_8)0;
-	}
+	  const GFC_REAL_8 * restrict bbase_y;
+	  GFC_REAL_8 * restrict dest_y;
+	  const GFC_REAL_8 * restrict abase_n;
+	  GFC_REAL_8 bbase_yn;
 
-      for (y = 0; y < ycount; y++)
-	{
-	  bbase_y = bbase + y*bystride;
-	  dest_y = dest + y*rystride;
-	  for (n = 0; n < count; n++)
+	  if (rystride == xcount)
+	    memset (dest, 0, (sizeof (GFC_REAL_8) * size0((array_t *) retarray)));
+	  else
 	    {
-	      abase_n = abase + n*aystride;
-	      bbase_yn = bbase_y[n];
-	      for (x = 0; x < xcount; x++)
+	      for (y = 0; y < ycount; y++)
+		for (x = 0; x < xcount; x++)
+		  dest[x + y*rystride] = (GFC_REAL_8)0;
+	    }
+
+	  for (y = 0; y < ycount; y++)
+	    {
+	      bbase_y = bbase + y*bystride;
+	      dest_y = dest + y*rystride;
+	      for (n = 0; n < count; n++)
 		{
-		  dest_y[x] += abase_n[x] * bbase_yn;
+		  abase_n = abase + n*aystride;
+		  bbase_yn = bbase_y[n];
+		  for (x = 0; x < xcount; x++)
+		    dest_y[x] += abase_n[x] * bbase_yn;
 		}
 	    }
 	}
-    }
   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
     {
       if (GFC_DESCRIPTOR_RANK (a) != 1)
@@ -293,6 +322,7 @@
 	    }
 	}
     }
+  }
 }
 
 #endif
Index: gcc/fortran/trans-expr.c
===================================================================
--- gcc/fortran/trans-expr.c	(revision 117190)
+++ gcc/fortran/trans-expr.c	(working copy)
@@ -1853,7 +1853,7 @@
 
 int
 gfc_conv_function_call (gfc_se * se, gfc_symbol * sym,
-			gfc_actual_arglist * arg)
+			gfc_actual_arglist * arg, tree append_args)
 {
   gfc_interface_mapping mapping;
   tree arglist;
@@ -2166,6 +2166,11 @@
   /* Add the hidden string length parameters to the arguments.  */
   arglist = chainon (arglist, stringargs);
 
+  /* We may want to append extra arguments here.  This is used e.g. for
+     calls to libgfortran_matmul_??, which need extra information.  */
+  if (append_args != NULL_TREE)
+    arglist = chainon (arglist, append_args);
+
   /* Generate the actual call.  */
   gfc_conv_function_val (se, sym);
   /* If there are alternate return labels, function type should be
@@ -2485,7 +2490,7 @@
   sym = expr->value.function.esym;
   if (!sym)
     sym = expr->symtree->n.sym;
-  gfc_conv_function_call (se, sym, expr->value.function.actual);
+  gfc_conv_function_call (se, sym, expr->value.function.actual, NULL_TREE);
 }
 
 
Index: gcc/fortran/gfortran.h
===================================================================
--- gcc/fortran/gfortran.h	(revision 117190)
+++ gcc/fortran/gfortran.h	(working copy)
@@ -1644,6 +1644,8 @@
   int flag_f2c;
   int flag_automatic;
   int flag_backslash;
+  int flag_external_blas;
+  int blas_matmul_limit;
   int flag_cray_pointer;
   int flag_d_lines;
   int flag_openmp;
Index: gcc/fortran/lang.opt
===================================================================
--- gcc/fortran/lang.opt	(revision 117190)
+++ gcc/fortran/lang.opt	(working copy)
@@ -89,6 +89,14 @@
 Fortran
 Specify that backslash in string introduces an escape character
 
+fexternal-blas
+Fortran
+Specify that an external BLAS library should be used for matmul calls on large-size arrays
+
+fblas-matmul-limit=
+Fortran RejectNegative Joined UInteger
+-fblas-matmul-limit=<n>        Size of the smallest matrix for which matmul will use BLAS
+
 fdefault-double-8
 Fortran
 Set the default double precision kind to an 8 byte wide type
Index: gcc/fortran/trans-stmt.c
===================================================================
--- gcc/fortran/trans-stmt.c	(revision 117190)
+++ gcc/fortran/trans-stmt.c	(working copy)
@@ -334,7 +334,8 @@
 
       /* Translate the call.  */
       has_alternate_specifier
-	= gfc_conv_function_call (&se, code->resolved_sym, code->ext.actual);
+	= gfc_conv_function_call (&se, code->resolved_sym, code->ext.actual,
+				  NULL_TREE);
 
       /* A subroutine without side-effect, by definition, does nothing!  */
       TREE_SIDE_EFFECTS (se.expr) = 1;
@@ -399,7 +400,8 @@
       gfc_init_block (&block);
 
       /* Add the subroutine call to the block.  */
-      gfc_conv_function_call (&loopse, code->resolved_sym, code->ext.actual);
+      gfc_conv_function_call (&loopse, code->resolved_sym, code->ext.actual,
+			      NULL_TREE);
       gfc_add_expr_to_block (&loopse.pre, loopse.expr);
 
       gfc_add_block_to_block (&block, &loopse.pre);
Index: gcc/fortran/trans.h
===================================================================
--- gcc/fortran/trans.h	(revision 117190)
+++ gcc/fortran/trans.h	(working copy)
@@ -303,7 +303,8 @@
 int gfc_is_intrinsic_libcall (gfc_expr *);
 
 /* Also used to CALL subroutines.  */
-int gfc_conv_function_call (gfc_se *, gfc_symbol *, gfc_actual_arglist *);
+int gfc_conv_function_call (gfc_se *, gfc_symbol *, gfc_actual_arglist *,
+			    tree);
 /* gfc_trans_* shouldn't call push/poplevel, use gfc_push/pop_scope */
 
 /* Generate code for a scalar assignment.  */
@@ -507,6 +508,12 @@
 extern GTY(()) tree gfor_fndecl_math_exponent10;
 extern GTY(()) tree gfor_fndecl_math_exponent16;
 
+/* BLAS functions.  */
+extern GTY(()) tree gfor_fndecl_sgemm;
+extern GTY(()) tree gfor_fndecl_dgemm;
+extern GTY(()) tree gfor_fndecl_cgemm;
+extern GTY(()) tree gfor_fndecl_zgemm;
+
 /* String functions.  */
 extern GTY(()) tree gfor_fndecl_compare_string;
 extern GTY(()) tree gfor_fndecl_concat_string;
Index: gcc/fortran/trans-decl.c
===================================================================
--- gcc/fortran/trans-decl.c	(revision 117190)
+++ gcc/fortran/trans-decl.c	(working copy)
@@ -143,7 +143,13 @@
 tree gfor_fndecl_si_kind;
 tree gfor_fndecl_sr_kind;
 
+/* BLAS gemm functions.  */
+tree gfor_fndecl_sgemm;
+tree gfor_fndecl_dgemm;
+tree gfor_fndecl_cgemm;
+tree gfor_fndecl_zgemm;
 
+
 static void
 gfc_add_decl_to_parent_function (tree decl)
 {
@@ -2170,6 +2176,45 @@
 				       gfc_int4_type_node, 1,
 				       gfc_real16_type_node);
 
+  /* BLAS functions.  */
+  {
+    tree pint = build_pointer_type (gfc_c_int_type_node);
+    tree ps = build_pointer_type (gfc_get_real_type (gfc_default_real_kind));
+    tree pd = build_pointer_type (gfc_get_real_type (gfc_default_double_kind));
+    tree pc = build_pointer_type (gfc_get_complex_type (gfc_default_real_kind));
+    tree pz = build_pointer_type
+		(gfc_get_complex_type (gfc_default_double_kind));
+
+    gfor_fndecl_sgemm = gfc_build_library_function_decl
+			  (get_identifier
+			     (gfc_option.flag_underscoring ? "sgemm_"
+							   : "sgemm"),
+			   void_type_node, 13,
+			   pchar_type_node, pchar_type_node, pint, pint,
+			   pint, ps, ps, pint, ps, pint, ps, ps, pint);
+    gfor_fndecl_dgemm = gfc_build_library_function_decl
+			  (get_identifier
+			     (gfc_option.flag_underscoring ? "dgemm_"
+							   : "dgemm"),
+			   void_type_node, 13,
+			   pchar_type_node, pchar_type_node, pint, pint,
+			   pint, pd, pd, pint, pd, pint, pd, pd, pint);
+    gfor_fndecl_cgemm = gfc_build_library_function_decl
+			  (get_identifier
+			     (gfc_option.flag_underscoring ? "cgemm_"
+							   : "cgemm"),
+			   void_type_node, 13,
+			   pchar_type_node, pchar_type_node, pint, pint,
+			   pint, pc, pc, pint, pc, pint, pc, pc, pint);
+    gfor_fndecl_zgemm = gfc_build_library_function_decl
+			  (get_identifier
+			     (gfc_option.flag_underscoring ? "zgemm_"
+							   : "zgemm"),
+			   void_type_node, 13,
+			   pchar_type_node, pchar_type_node, pint, pint,
+			   pint, pz, pz, pint, pz, pint, pz, pz, pint);
+  }
+
   /* Other functions.  */
   gfor_fndecl_size0 =
     gfc_build_library_function_decl (get_identifier (PREFIX("size0")),
Index: gcc/fortran/trans-intrinsic.c
===================================================================
--- gcc/fortran/trans-intrinsic.c	(revision 117190)
+++ gcc/fortran/trans-intrinsic.c	(working copy)
@@ -1265,6 +1265,7 @@
 gfc_conv_intrinsic_funcall (gfc_se * se, gfc_expr * expr)
 {
   gfc_symbol *sym;
+  tree append_args;
 
   gcc_assert (!se->ss || se->ss->expr == expr);
 
@@ -1274,7 +1275,54 @@
     gcc_assert (expr->rank == 0);
 
   sym = gfc_get_symbol_for_expr (expr);
-  gfc_conv_function_call (se, sym, expr->value.function.actual);
+
+  /* Calls to libgfortran_matmul need to be appended special arguments,
+     to be able to call the BLAS ?gemm functions if required and possible.  */
+  append_args = NULL_TREE;
+  if (expr->value.function.isym->generic_id == GFC_ISYM_MATMUL
+      && sym->ts.type != BT_LOGICAL)
+    {
+      tree cint = gfc_get_int_type (gfc_c_int_kind);
+
+      if (gfc_option.flag_external_blas
+	  && (sym->ts.type == BT_REAL || sym->ts.type == BT_COMPLEX)
+	  && (sym->ts.kind == gfc_default_real_kind
+	      || sym->ts.kind == gfc_default_double_kind))
+	{
+	  tree gemm_fndecl;
+
+	  if (sym->ts.type == BT_REAL)
+	    {
+	      if (sym->ts.kind == gfc_default_real_kind)
+		gemm_fndecl = gfor_fndecl_sgemm;
+	      else
+		gemm_fndecl = gfor_fndecl_dgemm;
+	    }
+	  else
+	    {
+	      if (sym->ts.kind == gfc_default_real_kind)
+		gemm_fndecl = gfor_fndecl_cgemm;
+	      else
+		gemm_fndecl = gfor_fndecl_zgemm;
+	    }
+
+	  append_args = gfc_chainon_list (NULL_TREE, build_int_cst (cint, 1));
+	  append_args = gfc_chainon_list
+			  (append_args, build_int_cst
+					  (cint, gfc_option.blas_matmul_limit));
+	  append_args = gfc_chainon_list (append_args,
+					  gfc_build_addr_expr (NULL_TREE,
+							       gemm_fndecl));
+	}
+      else
+	{
+	  append_args = gfc_chainon_list (NULL_TREE, build_int_cst (cint, 0));
+	  append_args = gfc_chainon_list (append_args, build_int_cst (cint, 0));
+	  append_args = gfc_chainon_list (append_args, null_pointer_node);
+	}
+    }
+
+  gfc_conv_function_call (se, sym, expr->value.function.actual, append_args);
   gfc_free (sym);
 }
 
Index: gcc/fortran/options.c
===================================================================
--- gcc/fortran/options.c	(revision 117190)
+++ gcc/fortran/options.c	(working copy)
@@ -79,6 +79,8 @@
   gfc_option.flag_preprocessed = 0;
   gfc_option.flag_automatic = 1;
   gfc_option.flag_backslash = 1;
+  gfc_option.flag_external_blas = 0;
+  gfc_option.blas_matmul_limit = 32;
   gfc_option.flag_cray_pointer = 0;
   gfc_option.flag_d_lines = -1;
   gfc_option.flag_openmp = 0;
@@ -454,6 +456,14 @@
       gfc_option.flag_dollar_ok = value;
       break;
 
+    case OPT_fexternal_blas:
+      gfc_option.flag_external_blas = value;
+      break;
+
+    case OPT_fblas_matmul_limit_:
+      gfc_option.blas_matmul_limit = value;
+      break;
+
     case OPT_fd_lines_as_code:
       gfc_option.flag_d_lines = 1;
       break;

Index Nav: [Date Index] [Subject Index] [Author Index] [Thread Index]
Message Nav: [Date Prev] [Date Next] [Thread Prev] [Thread Next]