summaryrefslogtreecommitdiff
path: root/gnu/packages/patches/onnxruntime-1.22.0-splittosequence-bool.patch
blob: 373471f48445aaabc0cbb876f32747bfb3d5954b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
From 3b18e8ac863f817cddf0a2fc1ff89aa05fe60557 Mon Sep 17 00:00:00 2001
From: Mauricio Cortazar <mcortazar@truora.com>
Date: Mon, 2 Jun 2025 12:51:41 -0500
Subject: [PATCH] Add support for bool type in SplitToSequence.

Add support for `bool` type to address the below issue.
This PR fixes <https://github.com/microsoft/onnxruntime/issues/12286>
---
 docs/OperatorKernels.md                          |  2 +-
 .../core/providers/cpu/sequence/sequence_ops.cc  |  2 +-
 .../providers/cpu/sequence/sequence_ops_test.cc  | 16 +++++++++++++++-
 3 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 8c1ab002bce67..fc7b1ed21575c 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -437,7 +437,7 @@ Do not modify directly.*
 |||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
 |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
 |||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|SplitToSequence|*in* input:**T**<br> *in* split:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)|
+|SplitToSequence|*in* input:**T**<br> *in* split:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)|
 |Sqrt|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
 |||[6, 12]|**T** = tensor(double), tensor(float)|
 |Squeeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* squeezed:**T**<br><br>or<br><br>*in* data:**T**<br> *out* squeezed:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc
index 7a27b04ece7cf..d2541cf3d35ce 100644
--- a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc
+++ b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc
@@ -339,7 +339,7 @@ ONNX_CPU_OPERATOR_KERNEL(
     11,
     KernelDefBuilder()
         .TypeConstraint("T",
-                        BuildKernelDefConstraints<float, MLFloat16, double, int32_t, int64_t, std::string>())
+                        BuildKernelDefConstraints<float, MLFloat16, double, int32_t, int64_t, bool, std::string>())
         .TypeConstraint("S", DataTypeImpl::AllSequenceTensorTypes())
         .TypeConstraint("I", BuildKernelDefConstraints<int32_t, int64_t>()),
     SplitToSequence);
diff --git a/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc b/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc
index c2d64b8e5ee4a..d819b0973adc2 100644
--- a/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc
+++ b/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc
@@ -536,5 +536,19 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisDontKeepDims) {
   test.AddSeqOutput("S2", output);
   test.Run();
 }
+
+TEST(SequenceOpsTest, SplitToSequence_BoolSplit) {
+  OpTester test("SplitToSequence", 11);
+  test.AddInput<bool>("input", {4, 2}, std::initializer_list<bool>({1, 1, 1, 1, 0, 0, 0, 0}));
+  int64_t axis = 0;
+  test.AddAttribute("axis", axis);
+  SeqTensors<bool> output;
+  output.AddTensor({1, 2}, {1, 1});
+  output.AddTensor({1, 2}, {1, 1});
+  output.AddTensor({1, 2}, {0, 0});
+  output.AddTensor({1, 2}, {0, 0});
+  test.AddSeqOutput("S2", output);
+  test.Run();
+}
 }  // namespace test
-}  // namespace onnxruntime
\ No newline at end of file
+}  // namespace onnxruntime