[PATCH v1] RISC-V: Implement RESOLVE_OVERLOADED_BUILTIN for RVV intrinsic

pan2.li@intel.com pan2.li@intel.com
Mon Sep 11 07:57:27 GMT 2023


From: Pan Li <pan2.li@intel.com>

This patch would like add the framework to support the RVV overloaded
intrinsic API in riscv-xxx-xxx-gcc, like riscv-xxx-xxx-g++ did.

However, it almost leverage the hook TARGET_RESOLVE_OVERLOADED_BUILTIN
with below steps.

* Register overloaded functions.
* Add function_resolver for overloaded function resolving.
* Add resolve API for function shape with default implementation.
* Implement HOOK for navigating the overloaded API to non-overloaded API.

We validated this framework by the vmv_v intrinsic API(s), and we will
add more intrins API support in the underlying patches.

gcc/ChangeLog:

	* config/riscv/riscv-c.cc
	(riscv_resolve_overloaded_builtin): New function for the hook.
	(riscv_register_pragmas): Register the hook
	* config/riscv/riscv-protos.h (resolve_overloaded_builtin): New decl.
	* config/riscv/riscv-vector-builtins-shapes.cc (build_one):
	Register overloaded function.
	(struct overloaded_base): New struct for overloaded shape.
	(struct non_overloaded_base): New struct for non overloaded shape.
	(struct move_def): Inherit overloaded shape.
	* config/riscv/riscv-vector-builtins.cc
	(function_builder::add_function): Add overloaded arg.
	(function_builder::add_overloaded_function): New function impl.
	(function_resolver::function_resolver): New constructor.
	(function_resolver::get_sub_code): New API impl.
	(function_resolver::resolve): New API impl.
	(function_resolver::lookup): New API impl.
	(resolve_overloaded_builtin): New func impl.
	* config/riscv/riscv-vector-builtins.h
	(class function_resolver): New class.

gcc/testsuite/ChangeLog:

	* gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c: New test.
	* gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c: New test.
	* gcc.target/riscv/rvv/base/overloaded_vmv_v.h: New test.

Signed-off-by: Pan Li <pan2.li@intel.com>
---
 gcc/config/riscv/riscv-c.cc                   |  36 +++++
 gcc/config/riscv/riscv-protos.h               |   1 +
 .../riscv/riscv-vector-builtins-shapes.cc     |  22 ++-
 gcc/config/riscv/riscv-vector-builtins.cc     | 138 +++++++++++++++++-
 gcc/config/riscv/riscv-vector-builtins.h      |  30 +++-
 .../riscv/rvv/base/overloaded_rv32_vmv_v.c    |   4 +
 .../riscv/rvv/base/overloaded_rv64_vmv_v.c    |   4 +
 .../riscv/rvv/base/overloaded_vmv_v.h         |  17 +++
 8 files changed, 248 insertions(+), 4 deletions(-)
 create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
 create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
 create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h

diff --git a/gcc/config/riscv/riscv-c.cc b/gcc/config/riscv/riscv-c.cc
index 283052ae313..060edd3129d 100644
--- a/gcc/config/riscv/riscv-c.cc
+++ b/gcc/config/riscv/riscv-c.cc
@@ -220,11 +220,47 @@ riscv_check_builtin_call (location_t loc, vec<location_t> arg_loc, tree fndecl,
   gcc_unreachable ();
 }
 
+/* Implement TARGET_RESOLVE_OVERLOADED_BUILTIN.  */
+static tree
+riscv_resolve_overloaded_builtin (unsigned int uncast_location, tree fndecl,
+				  void *uncast_arglist)
+{
+  vec<tree, va_gc> empty = {};
+  location_t loc = (location_t) uncast_location;
+  vec<tree, va_gc> *arglist = (vec<tree, va_gc> *) uncast_arglist;
+  unsigned int code = DECL_MD_FUNCTION_CODE (fndecl);
+  unsigned int subcode = code >> RISCV_BUILTIN_SHIFT;
+  tree new_fndecl = NULL_TREE;
+
+  if (!arglist)
+    arglist = ∅
+
+  switch (code & RISCV_BUILTIN_CLASS)
+    {
+    case RISCV_BUILTIN_GENERAL:
+      break;
+    case RISCV_BUILTIN_VECTOR:
+      new_fndecl = riscv_vector::resolve_overloaded_builtin (loc, subcode,
+							     arglist);
+      break;
+    default:
+      gcc_unreachable ();
+    }
+
+  if (new_fndecl == NULL_TREE)
+    return new_fndecl;
+
+  return build_function_call_vec (loc, vNULL, new_fndecl, arglist, NULL,
+				  fndecl);
+}
+
 /* Implement REGISTER_TARGET_PRAGMAS.  */
 
 void
 riscv_register_pragmas (void)
 {
+  targetm.resolve_overloaded_builtin = riscv_resolve_overloaded_builtin;
   targetm.check_builtin_call = riscv_check_builtin_call;
+
   c_register_pragma ("riscv", "intrinsic", riscv_pragma_intrinsic);
 }
diff --git a/gcc/config/riscv/riscv-protos.h b/gcc/config/riscv/riscv-protos.h
index 6dbf6b9f943..5d2492dd031 100644
--- a/gcc/config/riscv/riscv-protos.h
+++ b/gcc/config/riscv/riscv-protos.h
@@ -381,6 +381,7 @@ gimple *gimple_fold_builtin (unsigned int, gimple_stmt_iterator *, gcall *);
 rtx expand_builtin (unsigned int, tree, rtx);
 bool check_builtin_call (location_t, vec<location_t>, unsigned int,
 			   tree, unsigned int, tree *);
+tree resolve_overloaded_builtin (location_t, unsigned int, vec<tree, va_gc> *);
 bool const_vec_all_same_in_range_p (rtx, HOST_WIDE_INT, HOST_WIDE_INT);
 bool legitimize_move (rtx, rtx);
 void emit_vlmax_vsetvl (machine_mode, rtx);
diff --git a/gcc/config/riscv/riscv-vector-builtins-shapes.cc b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
index f8fdec863e6..6091016fa42 100644
--- a/gcc/config/riscv/riscv-vector-builtins-shapes.cc
+++ b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
@@ -49,6 +49,8 @@ build_one (function_builder &b, const function_group_info &group,
     group.ops_infos.types[vec_type_idx].index);
   b.allocate_argument_types (function_instance, argument_types);
   b.apply_predication (function_instance, return_type, argument_types);
+
+  b.add_overloaded_function (function_instance, *group.shape);
   b.add_unique_function (function_instance, (*group.shape), return_type,
 			 argument_types);
 }
@@ -87,6 +89,22 @@ struct build_base : public function_shape
   }
 };
 
+struct overloaded_base : public build_base
+{
+  tree resolve (function_resolver &r) const override
+  {
+    return r.lookup ();
+  }
+};
+
+struct non_overloaded_base : public build_base
+{
+  tree resolve (function_resolver &r) const override
+  {
+    gcc_unreachable ();
+  }
+};
+
 /* vsetvl_def class.  */
 struct vsetvl_def : public build_base
 {
@@ -525,7 +543,7 @@ struct narrow_alu_def : public build_base
 };
 
 /* move_def class. Handle vmv.v.v/vmv.v.x.  */
-struct move_def : public build_base
+struct move_def : public overloaded_base
 {
   char *get_name (function_builder &b, const function_instance &instance,
 		  bool overloaded_p) const override
@@ -545,7 +563,7 @@ struct move_def : public build_base
 
     /* According to rvv-intrinsic-doc, it does not add "_m" suffix
        for vop_m C++ overloaded API.  */
-    if (overloaded_p && instance.pred == PRED_TYPE_m)
+    if (overloaded_p)
       return b.finish_name ();
     b.append_name (predication_suffixes[instance.pred]);
     return b.finish_name ();
diff --git a/gcc/config/riscv/riscv-vector-builtins.cc b/gcc/config/riscv/riscv-vector-builtins.cc
index 6d99f970ead..cc61e524dc7 100644
--- a/gcc/config/riscv/riscv-vector-builtins.cc
+++ b/gcc/config/riscv/riscv-vector-builtins.cc
@@ -80,6 +80,10 @@ public:
 
   /* The decl itself.  */
   tree GTY ((skip)) decl;
+
+  /* True if the decl represents an overloaded function that needs to be
+     resolved by function_resolver.  */
+  bool overloaded_p;
 };
 
 /* Hash traits for registered_function.  */
@@ -3357,7 +3361,8 @@ function_builder::get_attributes (const function_instance &instance)
 registered_function &
 function_builder::add_function (const function_instance &instance,
 				const char *name, tree fntype, tree attrs,
-				bool placeholder_p)
+				bool placeholder_p,
+				bool overloaded_p = false)
 {
   unsigned int code = vec_safe_length (registered_functions);
   code = (code << RISCV_BUILTIN_SHIFT) + RISCV_BUILTIN_VECTOR;
@@ -3383,6 +3388,7 @@ function_builder::add_function (const function_instance &instance,
   registered_function &rfn = *ggc_alloc<registered_function> ();
   rfn.instance = instance;
   rfn.decl = decl;
+  rfn.overloaded_p = overloaded_p;
   vec_safe_push (registered_functions, &rfn);
 
   return rfn;
@@ -3432,6 +3438,26 @@ function_builder::add_unique_function (const function_instance &instance,
   obstack_free (&m_string_obstack, name);
 }
 
+void
+function_builder::add_overloaded_function (const function_instance &instance,
+					   const function_shape *shape)
+{
+  if (!check_required_extensions (instance))
+    return;
+
+  char *name = shape->get_name (*this, instance, true);
+
+  if (name)
+    {
+      /* To avoid API conflicting, we use void return type and void argument
+	 for the overloaded function register, like aarch64-sve.  */
+      tree fntype = build_function_type (void_type_node, void_list_node);
+      add_function (instance, name, fntype, NULL_TREE, m_direct_overloads,
+		    true);
+      obstack_free (&m_string_obstack, name);
+    }
+}
+
 function_call_info::function_call_info (location_t location_in,
 					const function_instance &instance_in,
 					tree fndecl_in)
@@ -3852,6 +3878,13 @@ function_checker::function_checker (location_t location,
     m_nargs (nargs), m_args (args)
 {}
 
+function_resolver::function_resolver (location_t location,
+				      const function_instance &instance,
+				      tree fndecl,
+				      vec<tree, va_gc> &arglist)
+  : function_call_info (location, instance, fndecl), m_arglist (arglist)
+{}
+
 /* Report that LOCATION has a call to FNDECL in which argument ARGNO
    was not an integer constant expression.  ARGNO counts from zero.  */
 void
@@ -3967,6 +4000,93 @@ function_checker::check ()
   return shape->check (*this);
 }
 
+unsigned int
+function_resolver::get_sub_code ()
+{
+  unsigned int fun_code = DECL_MD_FUNCTION_CODE (fndecl);
+
+  return fun_code >> RISCV_BUILTIN_SHIFT;
+}
+
+tree
+function_resolver::resolve ()
+{
+  return shape->resolve (*this);
+}
+
+/* Perform the lookup from the registered functions.
+   After we register the overloaded the functions, the registered functions
+   table may look like:
+
+   +--------+---------------------------+-------------------+
+   | index  | name                      | kind              |
+   +--------+---------------------------+-------------------+
+   | 124733 | __riscv_vmv_v             | Overloaded        | <- Hook fun code
+   +--------+---------------------------+-------------------+
+   | 124735 | __riscv_vmv_v_v_i8mf8     | Non-overloaded    |
+   +--------+---------------------------+-------------------+
+   | 124737 | __riscv_vmv_v             | Placeholder       |
+   +--------+---------------------------+-------------------+
+   | 124739 | __riscv_vmv_v             | Overloaded        |
+   +--------+---------------------------+-------------------+
+   | 124741 | __riscv_vmv_v_v_i8mf4     | Non-overloaded    |
+   +--------+---------------------------+-------------------+
+   | 124743 | __riscv_vmv_v             | Placeholder       |
+   +--------+---------------------------+-------------------+
+   | 124745 | __riscv_vmv_v             | Overloaded        |
+   +--------+---------------------------+-------------------+
+   | 124747 | __riscv_vmv_v_v_i8mf2     | Non-overloaded    |
+   +--------+---------------------------+-------------------+
+   | 124749 | __riscv_vmv_v             | Placeholder       |
+   +--------+---------------------------+-------------------+
+   | 124751 | __riscv_vmv_v             | Overloaded        |
+   +--------+---------------------------+-------------------+
+   | 124753 | __riscv_vmv_v_v_i8m1      | Non-overloaded    |
+   +--------+---------------------------+-------------------+
+   | 124755 | __riscv_vmv_v             | Placeholder       |
+   +--------+---------------------------+-------------------+
+
+   When we resolve the overloaded API from the hook, we always get the first
+   function code of one API group (aka vmv_v as above table). We will search
+   start from that index to find the only one non-overloaded API with exactly
+   the same arglist. Or NULL_TREE will be returned.
+ */
+tree
+function_resolver::lookup ()
+{
+  unsigned int code_limit = vec_safe_length (registered_functions);
+
+  for (unsigned code = get_sub_code () + 1; code < code_limit; code++)
+    {
+      registered_function *rfun = (*registered_functions)[code];
+      function_instance instance = rfun->instance;
+
+      if (strcmp (base_name, instance.base_name) != 0)
+	break;
+
+      if (rfun->overloaded_p)
+	continue;
+
+      unsigned k;
+      const rvv_arg_type_info *args = instance.op_info->args;
+
+      for (k = 0; args[k].base_type != NUM_BASE_TYPES; k++)
+	{
+	  if (k >= m_arglist.length ())
+	    break;
+
+	  if (TYPE_MODE (instance.get_arg_type (k))
+	    != TYPE_MODE (TREE_TYPE (m_arglist[k])))
+	    break;
+	}
+
+	if (args[k].base_type == NUM_BASE_TYPES)
+	  return rfun->decl;
+    }
+
+  return NULL_TREE;
+}
+
 inline hashval_t
 registered_function_hasher::hash (value_type value)
 {
@@ -4196,6 +4316,22 @@ check_builtin_call (location_t location, vec<location_t>, unsigned int code,
 			   TREE_TYPE (rfn.decl), nargs, args).check ();
 }
 
+tree
+resolve_overloaded_builtin (location_t loc, unsigned int code,
+			    vec<tree, va_gc> *arglist)
+{
+  if (code >= vec_safe_length (registered_functions))
+    return NULL_TREE;
+
+  const registered_function *rfun = (*registered_functions)[code];
+
+  if (!rfun || !rfun->overloaded_p)
+    return NULL_TREE;
+
+  return function_resolver (loc, rfun->instance, rfun->decl, *arglist)
+    .resolve ();
+}
+
 function_instance
 get_read_vl_instance (void)
 {
diff --git a/gcc/config/riscv/riscv-vector-builtins.h b/gcc/config/riscv/riscv-vector-builtins.h
index e358a8e4d91..3a466a99770 100644
--- a/gcc/config/riscv/riscv-vector-builtins.h
+++ b/gcc/config/riscv/riscv-vector-builtins.h
@@ -277,6 +277,8 @@ public:
   void apply_predication (const function_instance &, tree, vec<tree> &) const;
   void add_unique_function (const function_instance &, const function_shape *,
 			    tree, vec<tree> &);
+  void add_overloaded_function (const function_instance &,
+				const function_shape *);
   void register_function_group (const function_group_info &);
   void append_name (const char *);
   void append_base_name (const char *);
@@ -288,7 +290,7 @@ private:
   tree get_attributes (const function_instance &);
 
   registered_function &add_function (const function_instance &, const char *,
-				     tree, tree, bool);
+				     tree, tree, bool, bool);
 
   /* True if we should create a separate decl for each instance of an
      overloaded function, instead of using function_builder.  */
@@ -462,6 +464,28 @@ private:
   tree *m_args;
 };
 
+/* A class for resolving an overloaded function call.  */
+class function_resolver : public function_call_info
+{
+public:
+  function_resolver (location_t, const function_instance &, tree,
+		     vec<tree, va_gc> &);
+
+  /* Lookup the non overloaded registered function decl
+     from the registered_functions table.  */
+  tree lookup ();
+
+  /* Resolve the overloaded function.  */
+  tree resolve ();
+
+private:
+  /* Return the sub code of the fndecl.  */
+  unsigned int get_sub_code ();
+
+  /* The arguments to the overloaded function.  */
+  vec<tree, va_gc> &m_arglist;
+};
+
 /* Classifies functions into "shapes" base on:
 
    - Base name of the intrinsic function.
@@ -486,6 +510,10 @@ public:
   /* Check whether the given call is semantically valid.  Return true
    if it is, otherwise report an error and return false.  */
   virtual bool check (function_checker &) const { return true; }
+
+  /* Try to resolve the overloaded call.  Return the non-overloaded
+     function decl on success and NULL_TREE on failure.  */
+  virtual tree resolve (function_resolver &) const { return NULL_TREE; };
 };
 
 extern const char *const operand_suffixes[NUM_OP_TYPES];
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
new file mode 100644
index 00000000000..913fe678b51
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv32_vmv_v.c
@@ -0,0 +1,4 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv32gcv_zvfh -mabi=ilp32 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
new file mode 100644
index 00000000000..52f65e9f8a8
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_rv64_vmv_v.c
@@ -0,0 +1,4 @@
+/* { dg-do compile } */
+/* { dg-options "-march=rv64gcv_zvfh -mabi=lp64 -O3 -Wno-psabi" } */
+
+#include "overloaded_vmv_v.h"
diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
new file mode 100644
index 00000000000..dd818320f63
--- /dev/null
+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/overloaded_vmv_v.h
@@ -0,0 +1,17 @@
+#include "riscv_vector.h"
+
+vint32m1_t test_vmv_overloaded_0 (vint32m1_t src, size_t vl) {
+  return __riscv_vmv_v (src, vl);
+}
+
+vfloat16m1_t test_vmv_overloaded_1 (vfloat16m1_t src, size_t vl) {
+  return __riscv_vmv_v (src, vl);
+}
+
+vint32m1_t test_vmv_non_overloaded_0 (vint32m1_t src, size_t vl) {
+  return __riscv_vmv_v_v_i32m1 (src, vl);
+}
+
+vfloat16m1_t test_vmv_non_overloaded_1 (vfloat16m1_t src, size_t vl) {
+  return __riscv_vmv_v_v_f16m1 (src, vl);
+}
-- 
2.34.1



More information about the Gcc-patches mailing list