This is the mail archive of the gcc-patches@gcc.gnu.org mailing list for the GCC project.


Index Nav: [Date Index] [Subject Index] [Author Index] [Thread Index]
Message Nav: [Date Prev] [Date Next] [Thread Prev] [Thread Next]
Other format: [Raw text]

Re: [gfortran,patch] MATMUL using BLAS gemm routines


And here is the real patch, and not an old version of it.

Sorry,
FX
Index: libgfortran/m4/matmul.m4
===================================================================
--- libgfortran/m4/matmul.m4	(revision 117190)
+++ libgfortran/m4/matmul.m4	(working copy)
@@ -57,18 +57,28 @@
        DO I=1,M
          S = 0
          DO K=1,COUNT
-           S = S+A(I,K)+B(K,J)
+           S = S+A(I,K)*B(K,J)
          C(I,J) = S
    ENDIF
 */
 
+typedef void (*blas_call)(char *, char *, int *, int *, int *,
+			  rtype_name *, rtype_name *, int *, rtype_name *,
+			  int *, rtype_name *, rtype_name *, int *, int, int);
+
+/* If try_blas is set to a nonzero value, then the matmul function will
+   see if there is a way to perform the matrix multiplication by a BLAS
+   call to the gemm function.  */
+
 extern void matmul_`'rtype_code (rtype * const restrict retarray, 
-	rtype * const restrict a, rtype * const restrict b);
+	rtype * const restrict a, rtype * const restrict b, int try_blas,
+        int blas_limit, blas_call gemm);
 export_proto(matmul_`'rtype_code);
 
 void
 matmul_`'rtype_code (rtype * const restrict retarray, 
-	rtype * const restrict a, rtype * const restrict b)
+	rtype * const restrict a, rtype * const restrict b, int try_blas,
+	int blas_limit, blas_call gemm)
 {
   const rtype_name * restrict abase;
   const rtype_name * restrict bbase;
@@ -179,37 +189,56 @@
   bbase = b->data;
   dest = retarray->data;
 
-  if (rxstride == 1 && axstride == 1 && bxstride == 1)
+#define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
+
+  if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
+      && (bxstride == 1 || bystride == 1)
+      && (xcount > POW3(blas_limit) || xcount > POW3(blas_limit)
+	  || count > POW3(blas_limit)
+	  || ((float) xcount) * ((float) ycount) * ((float) count)
+	     > POW3(blas_limit)))
     {
-      const rtype_name * restrict bbase_y;
-      rtype_name * restrict dest_y;
-      const rtype_name * restrict abase_n;
-      rtype_name bbase_yn;
+      const int m = xcount, n = ycount, k = count, ldc = rystride;
+      const rtype_name one = 1, zero = 0;
+      const int lda = (axstride == 1) ? aystride : axstride,
+		ldb = (bxstride == 1) ? bystride : bxstride;
 
-      if (rystride == xcount)
-	memset (dest, 0, (sizeof (rtype_name) * xcount * ycount));
-      else
+      assert (gemm != NULL);
+      gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
+	    &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
+    }
+  else
+    {
+
+      if (rxstride == 1 && axstride == 1 && bxstride == 1)
 	{
-	  for (y = 0; y < ycount; y++)
-	    for (x = 0; x < xcount; x++)
-	      dest[x + y*rystride] = (rtype_name)0;
-	}
+	  const rtype_name * restrict bbase_y;
+	  rtype_name * restrict dest_y;
+	  const rtype_name * restrict abase_n;
+	  rtype_name bbase_yn;
 
-      for (y = 0; y < ycount; y++)
-	{
-	  bbase_y = bbase + y*bystride;
-	  dest_y = dest + y*rystride;
-	  for (n = 0; n < count; n++)
+	  if (rystride == xcount)
+	    memset (dest, 0, (sizeof (rtype_name) * size0((array_t *) retarray)));
+	  else
 	    {
-	      abase_n = abase + n*aystride;
-	      bbase_yn = bbase_y[n];
-	      for (x = 0; x < xcount; x++)
+	      for (y = 0; y < ycount; y++)
+		for (x = 0; x < xcount; x++)
+		  dest[x + y*rystride] = (rtype_name)0;
+	    }
+
+	  for (y = 0; y < ycount; y++)
+	    {
+	      bbase_y = bbase + y*bystride;
+	      dest_y = dest + y*rystride;
+	      for (n = 0; n < count; n++)
 		{
-		  dest_y[x] += abase_n[x] * bbase_yn;
+		  abase_n = abase + n*aystride;
+		  bbase_yn = bbase_y[n];
+		  for (x = 0; x < xcount; x++)
+		    dest_y[x] += abase_n[x] * bbase_yn;
 		}
 	    }
 	}
-    }
   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
     {
       if (GFC_DESCRIPTOR_RANK (a) != 1)
@@ -295,6 +324,7 @@
 	    }
 	}
     }
+  }
 }
 
 #endif
Index: libgfortran/generated/matmul_r8.c
===================================================================
--- libgfortran/generated/matmul_r8.c	(revision 117190)
+++ libgfortran/generated/matmul_r8.c	(working copy)
@@ -56,18 +56,28 @@
        DO I=1,M
          S = 0
          DO K=1,COUNT
-           S = S+A(I,K)+B(K,J)
+           S = S+A(I,K)*B(K,J)
          C(I,J) = S
    ENDIF
 */
 
+typedef void (*blas_call)(char *, char *, int *, int *, int *,
+			  GFC_REAL_8 *, GFC_REAL_8 *, int *, GFC_REAL_8 *,
+			  int *, GFC_REAL_8 *, GFC_REAL_8 *, int *, int, int);
+
+/* If try_blas is set to a nonzero value, then the matmul function will
+   see if there is a way to perform the matrix multiplication by a BLAS
+   call to the gemm function.  */
+
 extern void matmul_r8 (gfc_array_r8 * const restrict retarray, 
-	gfc_array_r8 * const restrict a, gfc_array_r8 * const restrict b);
+	gfc_array_r8 * const restrict a, gfc_array_r8 * const restrict b, int try_blas,
+        int blas_limit, blas_call gemm);
 export_proto(matmul_r8);
 
 void
 matmul_r8 (gfc_array_r8 * const restrict retarray, 
-	gfc_array_r8 * const restrict a, gfc_array_r8 * const restrict b)
+	gfc_array_r8 * const restrict a, gfc_array_r8 * const restrict b, int try_blas,
+	int blas_limit, blas_call gemm)
 {
   const GFC_REAL_8 * restrict abase;
   const GFC_REAL_8 * restrict bbase;
@@ -177,37 +187,56 @@
   bbase = b->data;
   dest = retarray->data;
 
-  if (rxstride == 1 && axstride == 1 && bxstride == 1)
+#define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
+
+  if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
+      && (bxstride == 1 || bystride == 1)
+      && (xcount > POW3(blas_limit) || xcount > POW3(blas_limit)
+	  || count > POW3(blas_limit)
+	  || ((float) xcount) * ((float) ycount) * ((float) count)
+	     > POW3(blas_limit)))
     {
-      const GFC_REAL_8 * restrict bbase_y;
-      GFC_REAL_8 * restrict dest_y;
-      const GFC_REAL_8 * restrict abase_n;
-      GFC_REAL_8 bbase_yn;
+      const int m = xcount, n = ycount, k = count, ldc = rystride;
+      const GFC_REAL_8 one = 1, zero = 0;
+      const int lda = (axstride == 1) ? aystride : axstride,
+		ldb = (bxstride == 1) ? bystride : bxstride;
 
-      if (rystride == xcount)
-	memset (dest, 0, (sizeof (GFC_REAL_8) * xcount * ycount));
-      else
+      assert (gemm != NULL);
+      gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
+	    &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
+    }
+  else
+    {
+
+      if (rxstride == 1 && axstride == 1 && bxstride == 1)
 	{
-	  for (y = 0; y < ycount; y++)
-	    for (x = 0; x < xcount; x++)
-	      dest[x + y*rystride] = (GFC_REAL_8)0;
-	}
+	  const GFC_REAL_8 * restrict bbase_y;
+	  GFC_REAL_8 * restrict dest_y;
+	  const GFC_REAL_8 * restrict abase_n;
+	  GFC_REAL_8 bbase_yn;
 
-      for (y = 0; y < ycount; y++)
-	{
-	  bbase_y = bbase + y*bystride;
-	  dest_y = dest + y*rystride;
-	  for (n = 0; n < count; n++)
+	  if (rystride == xcount)
+	    memset (dest, 0, (sizeof (GFC_REAL_8) * size0((array_t *) retarray)));
+	  else
 	    {
-	      abase_n = abase + n*aystride;
-	      bbase_yn = bbase_y[n];
-	      for (x = 0; x < xcount; x++)
+	      for (y = 0; y < ycount; y++)
+		for (x = 0; x < xcount; x++)
+		  dest[x + y*rystride] = (GFC_REAL_8)0;
+	    }
+
+	  for (y = 0; y < ycount; y++)
+	    {
+	      bbase_y = bbase + y*bystride;
+	      dest_y = dest + y*rystride;
+	      for (n = 0; n < count; n++)
 		{
-		  dest_y[x] += abase_n[x] * bbase_yn;
+		  abase_n = abase + n*aystride;
+		  bbase_yn = bbase_y[n];
+		  for (x = 0; x < xcount; x++)
+		    dest_y[x] += abase_n[x] * bbase_yn;
 		}
 	    }
 	}
-    }
   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
     {
       if (GFC_DESCRIPTOR_RANK (a) != 1)
@@ -293,6 +322,7 @@
 	    }
 	}
     }
+  }
 }
 
 #endif
Index: gcc/fortran/trans-expr.c
===================================================================
--- gcc/fortran/trans-expr.c	(revision 117190)
+++ gcc/fortran/trans-expr.c	(working copy)
@@ -1853,7 +1853,7 @@
 
 int
 gfc_conv_function_call (gfc_se * se, gfc_symbol * sym,
-			gfc_actual_arglist * arg)
+			gfc_actual_arglist * arg, tree append_args)
 {
   gfc_interface_mapping mapping;
   tree arglist;
@@ -2166,6 +2166,11 @@
   /* Add the hidden string length parameters to the arguments.  */
   arglist = chainon (arglist, stringargs);
 
+  /* We may want to append extra arguments here.  This is used e.g. for
+     calls to libgfortran_matmul_??, which need extra information.  */
+  if (append_args != NULL_TREE)
+    arglist = chainon (arglist, append_args);
+
   /* Generate the actual call.  */
   gfc_conv_function_val (se, sym);
   /* If there are alternate return labels, function type should be
@@ -2485,7 +2490,7 @@
   sym = expr->value.function.esym;
   if (!sym)
     sym = expr->symtree->n.sym;
-  gfc_conv_function_call (se, sym, expr->value.function.actual);
+  gfc_conv_function_call (se, sym, expr->value.function.actual, NULL_TREE);
 }
 
 
Index: gcc/fortran/gfortran.h
===================================================================
--- gcc/fortran/gfortran.h	(revision 117190)
+++ gcc/fortran/gfortran.h	(working copy)
@@ -1644,6 +1644,8 @@
   int flag_f2c;
   int flag_automatic;
   int flag_backslash;
+  int flag_external_blas;
+  int blas_matmul_limit;
   int flag_cray_pointer;
   int flag_d_lines;
   int flag_openmp;
Index: gcc/fortran/lang.opt
===================================================================
--- gcc/fortran/lang.opt	(revision 117190)
+++ gcc/fortran/lang.opt	(working copy)
@@ -89,6 +89,14 @@
 Fortran
 Specify that backslash in string introduces an escape character
 
+fexternal-blas
+Fortran
+Specify that an external BLAS library should be used for matmul calls on large-size arrays
+
+fblas-matmul-limit=
+Fortran RejectNegative Joined UInteger
+-fblas-matmul-limit=<n>        Size of the smallest matrix for which matmul will use BLAS
+
 fdefault-double-8
 Fortran
 Set the default double precision kind to an 8 byte wide type
Index: gcc/fortran/trans-stmt.c
===================================================================
--- gcc/fortran/trans-stmt.c	(revision 117190)
+++ gcc/fortran/trans-stmt.c	(working copy)
@@ -334,7 +334,8 @@
 
       /* Translate the call.  */
       has_alternate_specifier
-	= gfc_conv_function_call (&se, code->resolved_sym, code->ext.actual);
+	= gfc_conv_function_call (&se, code->resolved_sym, code->ext.actual,
+				  NULL_TREE);
 
       /* A subroutine without side-effect, by definition, does nothing!  */
       TREE_SIDE_EFFECTS (se.expr) = 1;
@@ -399,7 +400,8 @@
       gfc_init_block (&block);
 
       /* Add the subroutine call to the block.  */
-      gfc_conv_function_call (&loopse, code->resolved_sym, code->ext.actual);
+      gfc_conv_function_call (&loopse, code->resolved_sym, code->ext.actual,
+			      NULL_TREE);
       gfc_add_expr_to_block (&loopse.pre, loopse.expr);
 
       gfc_add_block_to_block (&block, &loopse.pre);
Index: gcc/fortran/trans.h
===================================================================
--- gcc/fortran/trans.h	(revision 117190)
+++ gcc/fortran/trans.h	(working copy)
@@ -303,7 +303,8 @@
 int gfc_is_intrinsic_libcall (gfc_expr *);
 
 /* Also used to CALL subroutines.  */
-int gfc_conv_function_call (gfc_se *, gfc_symbol *, gfc_actual_arglist *);
+int gfc_conv_function_call (gfc_se *, gfc_symbol *, gfc_actual_arglist *,
+			    tree);
 /* gfc_trans_* shouldn't call push/poplevel, use gfc_push/pop_scope */
 
 /* Generate code for a scalar assignment.  */
@@ -507,6 +508,12 @@
 extern GTY(()) tree gfor_fndecl_math_exponent10;
 extern GTY(()) tree gfor_fndecl_math_exponent16;
 
+/* BLAS functions.  */
+extern GTY(()) tree gfor_fndecl_sgemm;
+extern GTY(()) tree gfor_fndecl_dgemm;
+extern GTY(()) tree gfor_fndecl_cgemm;
+extern GTY(()) tree gfor_fndecl_zgemm;
+
 /* String functions.  */
 extern GTY(()) tree gfor_fndecl_compare_string;
 extern GTY(()) tree gfor_fndecl_concat_string;
Index: gcc/fortran/trans-decl.c
===================================================================
--- gcc/fortran/trans-decl.c	(revision 117190)
+++ gcc/fortran/trans-decl.c	(working copy)
@@ -143,7 +143,13 @@
 tree gfor_fndecl_si_kind;
 tree gfor_fndecl_sr_kind;
 
+/* BLAS gemm functions.  */
+tree gfor_fndecl_sgemm;
+tree gfor_fndecl_dgemm;
+tree gfor_fndecl_cgemm;
+tree gfor_fndecl_zgemm;
 
+
 static void
 gfc_add_decl_to_parent_function (tree decl)
 {
@@ -2170,6 +2176,49 @@
 				       gfc_int4_type_node, 1,
 				       gfc_real16_type_node);
 
+  /* BLAS functions.  */
+  {
+    tree pint = build_pointer_type (gfc_c_int_type_node);
+    tree ps = build_pointer_type (gfc_get_real_type (gfc_default_real_kind));
+    tree pd = build_pointer_type (gfc_get_real_type (gfc_default_double_kind));
+    tree pc = build_pointer_type (gfc_get_complex_type (gfc_default_real_kind));
+    tree pz = build_pointer_type
+		(gfc_get_complex_type (gfc_default_double_kind));
+
+    gfor_fndecl_sgemm = gfc_build_library_function_decl
+			  (get_identifier
+			     (gfc_option.flag_underscoring ? "sgemm_"
+							   : "sgemm"),
+			   void_type_node, 15, pchar_type_node,
+			   pchar_type_node, pint, pint, pint, ps, ps, pint,
+			   ps, pint, ps, ps, pint, gfc_c_int_type_node,
+			   gfc_c_int_type_node);
+    gfor_fndecl_dgemm = gfc_build_library_function_decl
+			  (get_identifier
+			     (gfc_option.flag_underscoring ? "dgemm_"
+							   : "dgemm"),
+			   void_type_node, 15, pchar_type_node,
+			   pchar_type_node, pint, pint, pint, pd, pd, pint,
+			   pd, pint, pd, pd, pint, gfc_c_int_type_node,
+			   gfc_c_int_type_node);
+    gfor_fndecl_cgemm = gfc_build_library_function_decl
+			  (get_identifier
+			     (gfc_option.flag_underscoring ? "cgemm_"
+							   : "cgemm"),
+			   void_type_node, 15, pchar_type_node,
+			   pchar_type_node, pint, pint, pint, pc, pc, pint,
+			   pc, pint, pc, pc, pint, gfc_c_int_type_node,
+			   gfc_c_int_type_node);
+    gfor_fndecl_zgemm = gfc_build_library_function_decl
+			  (get_identifier
+			     (gfc_option.flag_underscoring ? "zgemm_"
+							   : "zgemm"),
+			   void_type_node, 15, pchar_type_node,
+			   pchar_type_node, pint, pint, pint, pz, pz, pint,
+			   pz, pint, pz, pz, pint, gfc_c_int_type_node,
+			   gfc_c_int_type_node);
+  }
+
   /* Other functions.  */
   gfor_fndecl_size0 =
     gfc_build_library_function_decl (get_identifier (PREFIX("size0")),
Index: gcc/fortran/trans-intrinsic.c
===================================================================
--- gcc/fortran/trans-intrinsic.c	(revision 117190)
+++ gcc/fortran/trans-intrinsic.c	(working copy)
@@ -1265,6 +1265,7 @@
 gfc_conv_intrinsic_funcall (gfc_se * se, gfc_expr * expr)
 {
   gfc_symbol *sym;
+  tree append_args;
 
   gcc_assert (!se->ss || se->ss->expr == expr);
 
@@ -1274,7 +1275,54 @@
     gcc_assert (expr->rank == 0);
 
   sym = gfc_get_symbol_for_expr (expr);
-  gfc_conv_function_call (se, sym, expr->value.function.actual);
+
+  /* Calls to libgfortran_matmul need to be appended special arguments,
+     to be able to call the BLAS ?gemm functions if required and possible.  */
+  append_args = NULL_TREE;
+  if (expr->value.function.isym->generic_id == GFC_ISYM_MATMUL
+      && sym->ts.type != BT_LOGICAL)
+    {
+      tree cint = gfc_get_int_type (gfc_c_int_kind);
+
+      if (gfc_option.flag_external_blas
+	  && (sym->ts.type == BT_REAL || sym->ts.type == BT_COMPLEX)
+	  && (sym->ts.kind == gfc_default_real_kind
+	      || sym->ts.kind == gfc_default_double_kind))
+	{
+	  tree gemm_fndecl;
+
+	  if (sym->ts.type == BT_REAL)
+	    {
+	      if (sym->ts.kind == gfc_default_real_kind)
+		gemm_fndecl = gfor_fndecl_sgemm;
+	      else
+		gemm_fndecl = gfor_fndecl_dgemm;
+	    }
+	  else
+	    {
+	      if (sym->ts.kind == gfc_default_real_kind)
+		gemm_fndecl = gfor_fndecl_cgemm;
+	      else
+		gemm_fndecl = gfor_fndecl_zgemm;
+	    }
+
+	  append_args = gfc_chainon_list (NULL_TREE, build_int_cst (cint, 1));
+	  append_args = gfc_chainon_list
+			  (append_args, build_int_cst
+					  (cint, gfc_option.blas_matmul_limit));
+	  append_args = gfc_chainon_list (append_args,
+					  gfc_build_addr_expr (NULL_TREE,
+							       gemm_fndecl));
+	}
+      else
+	{
+	  append_args = gfc_chainon_list (NULL_TREE, build_int_cst (cint, 0));
+	  append_args = gfc_chainon_list (append_args, build_int_cst (cint, 0));
+	  append_args = gfc_chainon_list (append_args, null_pointer_node);
+	}
+    }
+
+  gfc_conv_function_call (se, sym, expr->value.function.actual, append_args);
   gfc_free (sym);
 }
 
Index: gcc/fortran/options.c
===================================================================
--- gcc/fortran/options.c	(revision 117190)
+++ gcc/fortran/options.c	(working copy)
@@ -79,6 +79,8 @@
   gfc_option.flag_preprocessed = 0;
   gfc_option.flag_automatic = 1;
   gfc_option.flag_backslash = 1;
+  gfc_option.flag_external_blas = 0;
+  gfc_option.blas_matmul_limit = 32;
   gfc_option.flag_cray_pointer = 0;
   gfc_option.flag_d_lines = -1;
   gfc_option.flag_openmp = 0;
@@ -454,6 +456,14 @@
       gfc_option.flag_dollar_ok = value;
       break;
 
+    case OPT_fexternal_blas:
+      gfc_option.flag_external_blas = value;
+      break;
+
+    case OPT_fblas_matmul_limit_:
+      gfc_option.blas_matmul_limit = value;
+      break;
+
     case OPT_fd_lines_as_code:
       gfc_option.flag_d_lines = 1;
       break;

Index Nav: [Date Index] [Subject Index] [Author Index] [Thread Index]
Message Nav: [Date Prev] [Date Next] [Thread Prev] [Thread Next]